Source code for litestar_queues.service

import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any

from typing_extensions import Self

from litestar_queues.config import execution_backend_name
from litestar_queues.events.context import TaskExecutionContext, _bind_task_context, _reset_task_context
from litestar_queues.events.models import QueueEvent
from litestar_queues.exceptions import NonRetryableError
from litestar_queues.task import ScheduleConfig, Task, TaskResult, get_scheduled_tasks, get_task_registry

if TYPE_CHECKING:
    from collections.abc import Mapping
    from types import TracebackType
    from uuid import UUID

    from litestar_queues.backends import BaseQueueBackend
    from litestar_queues.config import QueueConfig
    from litestar_queues.events import QueueEventPublisher
    from litestar_queues.execution import BaseExecutionBackend
    from litestar_queues.models import QueuedTaskRecord, StaleTaskRecoveryResult

__all__ = ("QueueService",)

logger = logging.getLogger(__name__)

_LOG_LEVELS = {
    "critical": logging.CRITICAL,
    "error": logging.ERROR,
    "warning": logging.WARNING,
    "warn": logging.WARNING,
    "info": logging.INFO,
    "debug": logging.DEBUG,
}


[docs] class QueueService: """High-level facade for queue and execution backends.""" __slots__ = ("_config", "_event_publisher", "_execution_backend", "_queue_backend", "_sync_executor")
[docs] def __init__( self, config: "QueueConfig", *, queue_backend: "BaseQueueBackend | None" = None, execution_backend: "BaseExecutionBackend | None" = None, event_publisher: "QueueEventPublisher | None" = None, ) -> "None": """Initialize the queue service.""" self._config = config self._queue_backend = queue_backend self._execution_backend = execution_backend self._event_publisher = event_publisher self._sync_executor: "ThreadPoolExecutor | None" = None
@property def config(self) -> "QueueConfig": """Queue configuration.""" return self._config
[docs] def get_queue_backend(self) -> "BaseQueueBackend": """Return the configured queue backend.""" if self._queue_backend is None: self._queue_backend = self._config.get_queue_backend() return self._queue_backend
[docs] def get_execution_backend(self) -> "BaseExecutionBackend": """Return the configured execution backend.""" if self._execution_backend is None: self._execution_backend = self._config.get_execution_backend() return self._execution_backend
[docs] def get_event_publisher(self) -> "QueueEventPublisher": """Return the configured event publisher.""" if self._event_publisher is None: self._event_publisher = self._config.get_event_publisher() return self._event_publisher
[docs] async def open(self) -> "Self": """Open queue and execution backends. Returns: The opened service. """ await self.get_queue_backend().open() await self.get_execution_backend().open() if self._config.sync_executor_max_workers is not None and self._sync_executor is None: self._sync_executor = ThreadPoolExecutor( max_workers=self._config.sync_executor_max_workers, thread_name_prefix=self._config.sync_executor_thread_name_prefix, ) return self
[docs] async def close(self) -> "None": """Close queue and execution backends.""" if self._execution_backend is not None: await self._execution_backend.close() if self._queue_backend is not None: await self._queue_backend.close() if self._sync_executor is not None: self._sync_executor.shutdown(wait=True, cancel_futures=True) self._sync_executor = None
async def __aenter__(self) -> "Self": await self.open() return self async def __aexit__( self, exc_type: "type[BaseException] | None", # noqa: PYI036 exc_val: "BaseException | None", # noqa: PYI036 exc_tb: "TracebackType | None", # noqa: PYI036 ) -> "None": await self.close()
[docs] async def enqueue( self, task: "str | Task[Any, Any]", *args: "Any", scheduled_at: "datetime | None" = None, run_after: "float | timedelta | None" = None, key: "str | None" = None, queue: "str | None" = None, priority: "int | None" = None, retries: "int | None" = None, timeout: "float | None" = None, execution_backend: "str | None" = None, execution_profile: "str | None" = None, description: "str | None" = None, log_level: "str | None" = None, quiet_success: "bool | None" = None, requeue_on_stale: "bool | None" = None, metadata: "dict[str, Any] | None" = None, **kwargs: "Any", ) -> "TaskResult": """Enqueue a registered task. Returns: A result handle for the queued record. """ task_obj = self.resolve_task(task) effective_key = key if key is not None else task_obj.key coerced_run_after = _coerce_timedelta(run_after) effective_run_after = coerced_run_after if run_after is not None else task_obj.run_after effective_scheduled_at = scheduled_at if effective_scheduled_at is None and effective_run_after is not None: effective_scheduled_at = datetime.now(timezone.utc) + effective_run_after effective_execution_backend = ( execution_backend or task_obj.execution_backend or execution_backend_name(self._config.execution_backend) ) effective_execution_profile = execution_profile if execution_profile is not None else task_obj.execution_profile effective_metadata = task_obj.metadata(metadata) if description is not None: effective_metadata["description"] = description if log_level is not None: effective_metadata["log_level"] = log_level if quiet_success is not None: effective_metadata["quiet_success"] = quiet_success if requeue_on_stale is not None: effective_metadata["requeue_on_stale"] = requeue_on_stale if timeout is not None: effective_metadata["timeout"] = timeout record = await self.get_queue_backend().enqueue( task_obj.name, args=args, kwargs=kwargs, queue=queue if queue is not None else task_obj.queue, priority=priority if priority is not None else task_obj.priority, max_retries=retries if retries is not None else task_obj.retries, scheduled_at=effective_scheduled_at, key=effective_key, execution_backend=effective_execution_backend, execution_profile=effective_execution_profile, metadata=effective_metadata, ) result = TaskResult(record.id, task_obj.name, service=self, record=record) if record.execution_backend == "immediate" and record.status == "pending": claimed = await self.get_queue_backend().claim_task(record.id) if claimed is not None: await self.get_execution_backend().execute(self, claimed) return result
[docs] def resolve_task(self, task: "str | Task[Any, Any]") -> "Task[Any, Any]": """Resolve a task name or wrapper to a registered task. Returns: The registered task wrapper. Raises: KeyError: If a task name is not registered. """ if isinstance(task, Task): return task registry = get_task_registry() try: return registry[task] except KeyError as exc: msg = f"Unknown queue task: {task!r}" raise KeyError(msg) from exc
[docs] async def get_task(self, task_id: "UUID") -> "QueuedTaskRecord | None": """Return a queued task record by ID.""" return await self.get_queue_backend().get_task(task_id)
[docs] async def claim_next( self, *, queue: "str | None" = None, execution_backend: "str | None" = None ) -> "QueuedTaskRecord | None": """Claim the next due queued task. Returns: The claimed task record, if one was available. """ return await self.get_queue_backend().claim_next(queue=queue, execution_backend=execution_backend)
[docs] async def execute_record(self, record: "QueuedTaskRecord", *, worker_id: "str | None" = None) -> "QueuedTaskRecord": """Execute a claimed queue record and persist the lifecycle result. Args: record: The claimed queue record to execute. worker_id: Identity of the worker driving execution, if any. The value is forwarded to ``TaskExecutionContext.worker_id`` so published events carry stable worker provenance. Service-driven executions (no worker) leave this as ``None``. Returns: The updated queue record. Raises: asyncio.CancelledError: If task execution is cancelled. """ task_obj = self.resolve_task(record.task_name) timeout = record.metadata.get("timeout", task_obj.timeout) task_context = TaskExecutionContext( task_id=str(record.id), task_name=record.task_name, queue=record.queue, worker_id=worker_id, execution_backend=record.execution_backend, execution_profile=record.execution_profile, attempt=record.retry_count + 1, event_publisher=self.get_event_publisher(), ) context_token = _bind_task_context(task_context) try: await task_context.lifecycle("task.started") extra_kwargs = await self._resolve_task_dependencies(task_obj, record, task_context) coroutine = task_obj.execute_record( record, task_context=task_context, extra_kwargs=extra_kwargs, sync_executor=self._sync_executor ) result = await asyncio.wait_for(coroutine, timeout=timeout if isinstance(timeout, int | float) else None) except asyncio.CancelledError: await task_context.lifecycle("task.cancelled") self._log_task_event("Queue task cancelled", record, level=logging.WARNING) raise except NonRetryableError as exc: updated = await self.get_queue_backend().fail_task( record.id, str(exc), retry=False, expected_retry_count=record.retry_count ) if updated is None: return await self.publish_claim_lost(record, phase="fail", task_context=task_context) failed = updated payload = {"status": failed.status, "retry_count": failed.retry_count, "will_retry": False} await task_context.lifecycle("task.failed", message=str(exc), payload=payload) self._log_task_event("Queue task failed", failed, level=logging.ERROR, payload=payload) return updated except Exception as exc: updated = await self.get_queue_backend().fail_task( record.id, str(exc), expected_retry_count=record.retry_count ) if updated is None: return await self.publish_claim_lost(record, phase="fail", task_context=task_context) failed = updated payload = { "status": failed.status, "retry_count": failed.retry_count, "will_retry": failed.status == "pending", } await task_context.lifecycle("task.failed", message=str(exc), payload=payload) self._log_task_event( "Queue task failed", failed, level=logging.WARNING if failed.status == "pending" else logging.ERROR, payload=payload, ) if failed.status == "failed": await self._reschedule_if_needed(failed) return failed finally: _reset_task_context(context_token) updated = await self.get_queue_backend().complete_task( record.id, result=result, expected_retry_count=record.retry_count ) if updated is None: return await self.publish_claim_lost(record, phase="complete", task_context=task_context) completed = updated await task_context.lifecycle( "task.completed", payload={"status": completed.status, "retry_count": completed.retry_count} ) self._log_task_completed(completed) await self._reschedule_if_needed(completed) return completed
[docs] async def recover_stale_tasks( self, *, stale_after: "timedelta", worker_id: "str | None" = None ) -> "StaleTaskRecoveryResult": """Recover stale running tasks and publish a worker summary event. Returns: Summary of recovered, failed, skipped, and handler-needed tasks. """ result = await self.get_queue_backend().requeue_stale_running(stale_after=stale_after) if result.requeued or result.failed or result.skipped or result.handler_needed: await self._publish_stale_failed_events(result, worker_id=worker_id) await self.get_event_publisher().publish( QueueEvent( type="worker.stale_recovery", scope="worker", worker_id=worker_id, message="Recovered stale running tasks", payload=result.to_payload(), ) ) return result
[docs] async def initialize_schedules(self) -> "list[QueuedTaskRecord]": """Create queue records for registered recurring schedules. Returns: The created or reused schedule records. """ records: 'list["QueuedTaskRecord"]' = [] queue_backend = self.get_queue_backend() for task_name, schedule in get_scheduled_tasks().items(): task_obj = self.resolve_task(task_name) schedule_metadata = schedule.as_metadata() schedule_key = f"scheduled:{task_name}" existing = await queue_backend.get_task_by_key(schedule_key) if existing is not None and not existing.is_terminal: if existing.metadata.get("schedule") == schedule_metadata: records.append(existing) continue await queue_backend.cancel_task(existing.id) scheduled_at = schedule.get_next_run(use_initial_delay=True) records.append( await queue_backend.enqueue( task_name, key=schedule_key, max_retries=0, scheduled_at=scheduled_at, execution_backend=task_obj.execution_backend or execution_backend_name(self._config.execution_backend), execution_profile=task_obj.execution_profile, metadata=task_obj.metadata({"schedule": schedule_metadata}), ) ) return records
async def _resolve_task_dependencies( self, task: "Task[..., object]", record: "QueuedTaskRecord", task_context: "TaskExecutionContext" ) -> "Mapping[str, object] | None": """Invoke the configured task dependency resolver, if any. Returns: The resolver's kwargs mapping, or ``None`` when no resolver is configured. """ resolver = self._config.task_dependency_resolver if resolver is None: return None return await resolver(task, record, task_context) async def _reschedule_if_needed(self, record: "QueuedTaskRecord") -> "None": schedule_data = record.metadata.get("schedule") if not isinstance(schedule_data, dict) or record.completed_at is None: return schedule = ScheduleConfig( task_name=str(schedule_data["task_name"]), cron=schedule_data.get("cron"), initial_delay=schedule_data.get("initial_delay", 0), interval=schedule_data.get("interval"), jitter=schedule_data.get("jitter", 0), max_instances=int(schedule_data.get("max_instances", 1)), timeout=schedule_data.get("timeout"), timezone=str(schedule_data.get("timezone", "UTC")), ) await self.get_queue_backend().enqueue( record.task_name, key=record.key, queue=record.queue, max_retries=record.max_retries, scheduled_at=schedule.get_next_run(record.completed_at), execution_backend=record.execution_backend, execution_profile=record.execution_profile, metadata={**record.metadata, "schedule": schedule.as_metadata()}, ) async def _current_or_claimed(self, record: "QueuedTaskRecord") -> "QueuedTaskRecord": return await self.get_queue_backend().get_task(record.id) or record
[docs] async def publish_claim_lost( self, record: "QueuedTaskRecord", *, phase: "str", task_context: "TaskExecutionContext | None" = None, worker_id: "str | None" = None, expected_retry_count: "int | None" = None, ) -> "QueuedTaskRecord": """Publish an ownership-loss event and return the current record state. Returns: Current queue task record state. """ current = await self._current_or_claimed(record) expected = record.retry_count if expected_retry_count is None else expected_retry_count payload = { "phase": phase, "expected_retry_count": expected, "current_status": current.status, "current_retry_count": current.retry_count, } message = "Queue task ownership lost" if task_context is not None: await task_context.lifecycle("task.claim_lost", message=message, payload=payload) else: await self.get_event_publisher().publish( QueueEvent( type="task.claim_lost", scope="task", task_id=str(record.id), task_name=record.task_name, queue=record.queue, worker_id=worker_id, execution_backend=record.execution_backend, execution_profile=record.execution_profile, attempt=expected + 1, message=message, payload=payload, ) ) self._log_task_event(message, current, level=logging.WARNING, payload=payload) return current
async def _publish_stale_failed_events( self, result: "StaleTaskRecoveryResult", *, worker_id: "str | None" ) -> "None": handler_needed_ids = set(result.handler_needed_task_ids) for task_id in result.failed_task_ids: record = await self.get_queue_backend().get_task(task_id) if record is None: continue requeue_on_stale = record.metadata.get("requeue_on_stale", True) is not False payload = { "status": record.status, "retry_count": record.retry_count, "max_retries": record.max_retries, "requeue_on_stale": requeue_on_stale, "handler_needed": record.id in handler_needed_ids, } await self.get_event_publisher().publish( QueueEvent( type="task.stale_failed", scope="task", task_id=str(record.id), task_name=record.task_name, queue=record.queue, worker_id=worker_id, execution_backend=record.execution_backend, execution_profile=record.execution_profile, attempt=record.retry_count + 1, message=record.error or "Task heartbeat stale", payload=payload, ) ) self._log_task_event( "Queue task failed after stale heartbeat", record, level=logging.ERROR, payload=payload ) def _log_task_completed(self, record: "QueuedTaskRecord") -> "None": if record.metadata.get("quiet_success") is True: return self._log_task_event("Queue task completed", record, level=_coerce_log_level(record.metadata.get("log_level"))) def _log_task_event( self, message: "str", record: "QueuedTaskRecord", *, level: "int", payload: "Mapping[str, object] | None" = None ) -> "None": logger.log( level, message, extra={ "queue_task_id": str(record.id), "queue_task_name": record.task_name, "queue_task_queue": record.queue, "queue_task_status": record.status, "queue_task_retry_count": record.retry_count, "queue_task_max_retries": record.max_retries, "queue_task_execution_backend": record.execution_backend, "queue_task_execution_profile": record.execution_profile, "queue_task_description": record.metadata.get("description"), "queue_task_event_payload": dict(payload or {}), }, )
def _coerce_timedelta(value: "float | timedelta | None") -> "timedelta | None": if value is None: return None if isinstance(value, timedelta): return value return timedelta(seconds=value) def _coerce_log_level(value: "object", default: "int" = logging.INFO) -> "int": if not isinstance(value, str): return default return _LOG_LEVELS.get(value.lower(), default)