Source code for litestar_queues.execution.cloudrun.entrypoint

import asyncio
import contextlib
import os
from enum import IntEnum
from importlib import import_module
from typing import TYPE_CHECKING, Any, cast
from uuid import UUID

from litestar_queues.config import QueueConfig
from litestar_queues.service import QueueService
from litestar_queues.task import load_task_modules

if TYPE_CHECKING:
    from collections.abc import AsyncIterator, Callable, Mapping

__all__ = ("CloudRunExitCode", "execute_cloudrun_task", "main")


[docs] class CloudRunExitCode(IntEnum): """Deterministic Cloud Run task process exit codes.""" SUCCESS = 0 FAILURE = 1 MISSING_TASK_ID = 2 INVALID_TASK_ID = 3 MISSING_RECORD = 4 UNKNOWN_TASK = 5 CLAIM_LOST = 6 CANCELLED = 7
[docs] async def execute_cloudrun_task( *, config: "QueueConfig | None" = None, service: "QueueService | None" = None, service_factory: "Callable[[], Any] | None" = None, env: "Mapping[str, str] | None" = None, ) -> "CloudRunExitCode": """Execute one persisted queue record in a Cloud Run task process. Returns: A deterministic process exit code. """ environ = env or os.environ task_id_raw = environ.get(_env_name(config, "TASK_ID")) if not task_id_raw: return CloudRunExitCode.MISSING_TASK_ID try: task_id = UUID(task_id_raw) except ValueError: return CloudRunExitCode.INVALID_TASK_ID async with _provide_service(config=config, service=service, service_factory=service_factory, env=environ) as queue: _load_configured_task_modules(queue.config, environ) record = await queue.get_task(task_id) if record is None: return CloudRunExitCode.MISSING_RECORD try: queue.resolve_task(record.task_name) except KeyError: await queue.get_queue_backend().fail_task( record.id, f"Unknown queue task: {record.task_name!r}", retry=False ) return CloudRunExitCode.UNKNOWN_TASK claimed = await queue.get_queue_backend().claim_task(record.id) if claimed is None: await queue.publish_claim_lost(record, phase="claim") return CloudRunExitCode.CLAIM_LOST expected_retry_count = claimed.retry_count heartbeat_task = asyncio.create_task( _heartbeat_loop(queue, claimed.id, expected_retry_count=expected_retry_count) ) execution_task = asyncio.create_task(queue.execute_record(claimed)) try: done, _pending = await asyncio.wait({heartbeat_task, execution_task}, return_when=asyncio.FIRST_COMPLETED) if heartbeat_task in done and not heartbeat_task.result(): execution_task.cancel() with contextlib.suppress(asyncio.CancelledError): await execution_task await queue.publish_claim_lost(claimed, phase="heartbeat", expected_retry_count=expected_retry_count) return CloudRunExitCode.CLAIM_LOST updated = await execution_task except asyncio.CancelledError: return CloudRunExitCode.CANCELLED finally: heartbeat_task.cancel() with contextlib.suppress(asyncio.CancelledError): await heartbeat_task await queue.get_queue_backend().null_heartbeats([claimed.id], expected_retry_count=expected_retry_count) if updated.status == "completed": return CloudRunExitCode.SUCCESS if updated.status == "cancelled": return CloudRunExitCode.CANCELLED return CloudRunExitCode.FAILURE
[docs] def main() -> "None": """Console entry point for Cloud Run task execution. Raises: SystemExit: Always raised with the execution exit code. """ raise SystemExit(int(asyncio.run(execute_cloudrun_task())))
async def _heartbeat_loop(queue: "QueueService", task_id: "UUID", *, expected_retry_count: "int") -> "bool": interval = queue.config.worker_heartbeat_interval while True: await asyncio.sleep(interval) if not await queue.get_queue_backend().touch_heartbeat(task_id, expected_retry_count=expected_retry_count): return False @contextlib.asynccontextmanager async def _provide_service( *, config: "QueueConfig | None", service: "QueueService | None", service_factory: "Callable[[], Any] | None", env: "Mapping[str, str]", ) -> "AsyncIterator[QueueService]": if service is not None: yield service return factory = service_factory or _load_config_factory(config, env) if factory is not None: provided = factory() if isinstance(provided, QueueConfig): async with QueueService(provided) as queue: yield queue return if isinstance(provided, QueueService): async with provided as queue: yield queue return async with provided as queue: yield queue return async with QueueService(config or QueueConfig()) as queue: yield queue def _load_config_factory(config: "QueueConfig | None", env: "Mapping[str, str]") -> "Callable[[], Any] | None": env_var = _env_name(config, "CONFIG_FACTORY") import_path = env.get(env_var) if not import_path: return None module_path, separator, attribute = import_path.partition(":") if not separator: module_path, attribute = import_path.rsplit(".", 1) module = import_module(module_path) factory = getattr(module, attribute) if not callable(factory): msg = f"Cloud Run config factory {import_path!r} is not callable." raise TypeError(msg) return cast("Callable[[], Any]", factory) def _load_configured_task_modules(config: "QueueConfig", env: "Mapping[str, str]") -> "None": modules = list(config.task_modules) env_modules = env.get(_env_name(config, "TASK_MODULES")) if env_modules: modules.extend(module.strip() for module in env_modules.split(",") if module.strip()) if modules: load_task_modules(tuple(modules), force_reload=True) def _env_name(config: "QueueConfig | None", suffix: "str") -> "str": raw_config = config.execution_backend if config is not None else None env_name = getattr(raw_config, "env_name", None) if callable(env_name): return str(env_name(suffix)) return f"LITESTAR_QUEUES_{suffix}" if __name__ == "__main__": main()