Source code for litestar_queues.backends.memory.backend

import asyncio
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any

from litestar_queues.backends.base import BaseQueueBackend
from litestar_queues.models import QueueBackendCapabilities, QueuedTaskRecord, QueueStatistics, StaleTaskRecoveryResult

if TYPE_CHECKING:
    from uuid import UUID

    from litestar_queues.config import QueueConfig

__all__ = ("InMemoryQueueBackend",)


[docs] class InMemoryQueueBackend(BaseQueueBackend): """In-process queue backend for tests, local development, and examples.""" __slots__ = ("_keys", "_lock", "_notification_event", "_records")
[docs] def __init__(self, config: "QueueConfig | None" = None) -> "None": super().__init__(config=config) self._records: "dict[UUID, QueuedTaskRecord]" = {} self._keys: "dict[str, UUID]" = {} self._lock = asyncio.Lock() self._notification_event = asyncio.Event()
@property def capabilities(self) -> "QueueBackendCapabilities": """Backend behavior capabilities.""" return QueueBackendCapabilities( supports_notifications=True, notification_backend="asyncio-event", notifications_durable=False )
[docs] async def enqueue( self, task_name: "str", *, args: "tuple[Any, ...]" = (), kwargs: "dict[str, Any] | None" = None, queue: "str" = "default", priority: "int" = 0, max_retries: "int" = 0, scheduled_at: "datetime | None" = None, key: "str | None" = None, execution_backend: "str" = "local", execution_profile: "str | None" = None, metadata: "dict[str, Any] | None" = None, ) -> "QueuedTaskRecord": async with self._lock: if key is not None: existing_id = self._keys.get(key) if existing_id is not None: existing = self._records.get(existing_id) if existing is not None and not existing.is_terminal: return existing record = QueuedTaskRecord( task_name=task_name, args=args, kwargs=dict(kwargs or {}), queue=queue, execution_backend=execution_backend, execution_profile=execution_profile, status="scheduled" if scheduled_at is not None and scheduled_at > _utc_now() else "pending", priority=priority, max_retries=max_retries, scheduled_at=scheduled_at, key=key, metadata=dict(metadata or {}), ) self._records[record.id] = record if key is not None: self._keys[key] = record.id await self.notify_new_task(record) return record
[docs] async def get_task(self, task_id: "UUID") -> "QueuedTaskRecord | None": return self._records.get(task_id)
[docs] async def get_task_by_key(self, key: "str") -> "QueuedTaskRecord | None": task_id = self._keys.get(key) if task_id is None: return None return self._records.get(task_id)
[docs] async def list_pending( self, *, limit: "int" = 1, queue: "str | None" = None, execution_backend: "str | None" = None ) -> "list[QueuedTaskRecord]": due_records = [ record for record in self._records.values() if record.status in {"pending", "scheduled"} and record.is_due and (queue is None or record.queue == queue) and (execution_backend is None or record.execution_backend == execution_backend) ] due_records.sort(key=lambda record: (-record.priority, record.created_at)) return due_records[:limit]
[docs] async def claim_task(self, task_id: "UUID") -> "QueuedTaskRecord | None": async with self._lock: record = self._records.get(task_id) if record is None or record.status not in {"pending", "scheduled"} or not record.is_due: return None now = _utc_now() record.status = "running" record.started_at = now record.heartbeat_at = now return record
[docs] async def complete_task( self, task_id: "UUID", *, result: "Any" = None, expected_retry_count: "int | None" = None ) -> "QueuedTaskRecord | None": async with self._lock: record = self._records.get(task_id) if record is None: return None if expected_retry_count is not None and ( record.status != "running" or record.retry_count != expected_retry_count ): return None now = _utc_now() record.status = "completed" record.completed_at = now record.heartbeat_at = now record.result = result record.error = None return record
[docs] async def fail_task( self, task_id: "UUID", error: "str", *, retry: "bool" = True, expected_retry_count: "int | None" = None ) -> "QueuedTaskRecord | None": async with self._lock: record = self._records.get(task_id) if record is None: return None if expected_retry_count is not None and ( record.status != "running" or record.retry_count != expected_retry_count ): return None record.error = error if retry and record.retry_count < record.max_retries: record.retry_count += 1 record.status = "pending" record.started_at = None record.heartbeat_at = None return record now = _utc_now() record.status = "failed" record.completed_at = now record.heartbeat_at = now return record
[docs] async def cancel_task(self, task_id: "UUID") -> "bool": async with self._lock: record = self._records.get(task_id) if record is None or record.status not in {"pending", "scheduled"}: return False record.status = "cancelled" record.completed_at = _utc_now() return True
[docs] async def touch_heartbeat(self, task_id: "UUID", *, expected_retry_count: "int | None" = None) -> "bool": record = self._records.get(task_id) if record is None or record.status != "running": return False if expected_retry_count is not None and record.retry_count != expected_retry_count: return False record.heartbeat_at = _utc_now() return True
[docs] async def null_heartbeats(self, task_ids: "list[UUID]", *, expected_retry_count: "int | None" = None) -> "None": task_id_set = set(task_ids) async with self._lock: for task_id, record in self._records.items(): if task_id in task_id_set: if expected_retry_count is not None and record.retry_count != expected_retry_count: continue record.heartbeat_at = None
[docs] async def requeue_stale_running(self, *, stale_after: "timedelta") -> "StaleTaskRecoveryResult": cutoff = _utc_now() - stale_after result = StaleTaskRecoveryResult() async with self._lock: for record in self._records.values(): if record.status != "running": continue if record.heartbeat_at is not None and record.heartbeat_at >= cutoff: result.skipped += 1 continue requeue_on_stale = record.metadata.get("requeue_on_stale", True) is not False if requeue_on_stale and record.retry_count < record.max_retries: record.status = "pending" record.started_at = None record.heartbeat_at = None record.retry_count += 1 result.requeued += 1 continue record.status = "failed" record.completed_at = _utc_now() record.heartbeat_at = None record.error = "Task heartbeat stale" result.failed += 1 result.failed_task_ids.append(record.id) if not requeue_on_stale: result.handler_needed += 1 result.handler_needed_task_ids.append(record.id) return result
[docs] async def set_execution_ref( self, task_id: "UUID", execution_backend: "str", execution_ref: "str", *, execution_profile: "str | None" = None ) -> "QueuedTaskRecord | None": async with self._lock: record = self._records.get(task_id) if record is None: return None record.execution_backend = execution_backend record.execution_profile = execution_profile record.execution_ref = execution_ref return record
[docs] async def set_execution_backend( self, task_id: "UUID", execution_backend: "str", *, execution_profile: "str | None" = None ) -> "QueuedTaskRecord | None": async with self._lock: record = self._records.get(task_id) if record is None: return None record.execution_backend = execution_backend record.execution_profile = execution_profile record.execution_ref = None await self.notify_new_task(record) return record
[docs] async def list_running_external(self, *, limit: "int | None" = None) -> "list[QueuedTaskRecord]": records = [ record for record in self._records.values() if not record.is_terminal and record.execution_ref is not None ] records.sort(key=lambda record: record.started_at or record.created_at) return records[:limit] if limit is not None else records
[docs] async def get_statistics(self) -> "QueueStatistics": statistics = QueueStatistics() for record in self._records.values(): setattr(statistics, record.status, getattr(statistics, record.status) + 1) return statistics
[docs] async def list_completed_by_task( self, task_name: "str", *, since: "datetime | None" = None, limit: "int" = 10 ) -> "list[QueuedTaskRecord]": records = [ record for record in self._records.values() if record.task_name == task_name and record.status == "completed" and record.completed_at is not None and (since is None or record.completed_at >= since) ] records.sort(key=lambda record: record.completed_at or record.created_at, reverse=True) return records[:limit]
[docs] async def cleanup_terminal(self, before: "datetime") -> "int": removed = 0 async with self._lock: for task_id, record in list(self._records.items()): if not record.is_terminal or record.completed_at is None or record.completed_at >= before: continue removed += 1 del self._records[task_id] if record.key is not None and self._keys.get(record.key) == task_id: del self._keys[record.key] return removed
[docs] async def notify_new_task(self, record: "QueuedTaskRecord") -> "None": if record.status in {"pending", "scheduled"}: self._notification_event.set()
[docs] async def wait_for_notifications(self, timeout: "float | None" = None) -> "bool": if self._notification_event.is_set(): self._notification_event.clear() return True try: await asyncio.wait_for(self._notification_event.wait(), timeout=timeout) except asyncio.TimeoutError: return False self._notification_event.clear() return True
[docs] async def clear(self) -> "None": """Clear all in-memory records.""" async with self._lock: self._records.clear() self._keys.clear() self._notification_event.clear()
def _utc_now() -> "datetime": return datetime.now(timezone.utc)