Source code for litestar_queues.backends.advanced_alchemy.service

"""Advanced Alchemy queue persistence service."""

from datetime import datetime, timedelta, timezone
from typing import Any, cast
from uuid import UUID

from advanced_alchemy.operations import OnConflictUpsert
from advanced_alchemy.service import SQLAlchemyAsyncRepositoryService
from advanced_alchemy.utils.serialization import decode_json as _decode_json
from advanced_alchemy.utils.serialization import encode_json as _encode_json
from sqlalchemy import and_, delete, desc, func, or_, select, update

from litestar_queues.backends.advanced_alchemy.repository import QueueTaskRepository
from litestar_queues.models import QueuedTaskRecord, QueueStatistics, StaleTaskRecoveryResult, TaskStatus

__all__ = ("QueueTaskService",)

_DUE_STATUSES = ("pending", "scheduled")
_TERMINAL_STATUSES = ("completed", "failed", "cancelled")
_SKIP_LOCKED_CLAIM_DIALECTS = frozenset({"mariadb", "mysql", "oracle", "postgresql"})
_NATIVE_KEYED_ENQUEUE_DIALECTS = frozenset({"mariadb", "mysql", "oracle", "postgresql"})


[docs] class QueueTaskService(SQLAlchemyAsyncRepositoryService[Any]): """Persistence operations for Advanced Alchemy queue records."""
[docs] @classmethod def for_model(cls, model_class: "type[Any]") -> 'type["QueueTaskService"]': """Return a service subclass bound to ``model_class``.""" repository_type = QueueTaskRepository.for_model(model_class) return cast( "type[QueueTaskService]", type(f"QueueTaskServiceFor{model_class.__name__}", (cls,), {"repository_type": repository_type}), )
[docs] async def enqueue( self, task_name: "str", *, args: "tuple[Any, ...]", kwargs: "dict[str, Any]", queue: "str", priority: "int", max_retries: "int", scheduled_at: "datetime | None", key: "str | None", execution_backend: "str", execution_profile: "str | None", metadata: "dict[str, Any]", ) -> "QueuedTaskRecord": if key is not None: existing = await self._select_task_by_key(key) if existing is not None: existing_record = self.record_from_model(existing) if not existing_record.is_terminal: return existing_record existing.task_key = None await self.repository.session.flush() now = _utc_now() record = QueuedTaskRecord( task_name=task_name, args=args, kwargs=dict(kwargs), queue=queue, execution_backend=execution_backend, execution_profile=execution_profile, status="scheduled" if scheduled_at is not None and scheduled_at > now else "pending", priority=priority, max_retries=max_retries, scheduled_at=scheduled_at, key=key, metadata=dict(metadata), ) return await self._insert_task_record(record, key=key)
[docs] async def get_task(self, task_id: "UUID") -> "QueuedTaskRecord | None": model = await self._select_task(task_id) return self.record_from_model(model) if model is not None else None
[docs] async def get_task_by_key(self, key: "str") -> "QueuedTaskRecord | None": model = await self._select_task_by_key(key) return self.record_from_model(model) if model is not None else None
[docs] async def list_pending( self, *, limit: "int", queue: "str | None", execution_backend: "str | None" ) -> "list[QueuedTaskRecord]": statement = self._pending_statement(queue=queue, execution_backend=execution_backend).limit(limit) models = await self.get_many(statement=statement) return [self.record_from_model(model) for model in models]
[docs] async def claim_task(self, task_id: "UUID") -> "QueuedTaskRecord | None": now = _utc_now() model_type = self.model_type result = await self.repository.session.execute( update(model_type) .where( model_type.id == task_id, model_type.status.in_(_DUE_STATUSES), or_(model_type.scheduled_at.is_(None), model_type.scheduled_at <= now), ) .values(status="running", started_at=now, heartbeat_at=now) ) if result.rowcount != 1: return None model = await self._select_task(task_id) return self.record_from_model(model) if model is not None else None
[docs] async def claim_next(self, *, queue: "str | None", execution_backend: "str | None") -> "QueuedTaskRecord | None": if _supports_skip_locked_claim(self._dialect_name()): return await self._claim_next_skip_locked(queue=queue, execution_backend=execution_backend) pending = await self.list_pending(limit=10, queue=queue, execution_backend=execution_backend) for record in pending: claimed = await self.claim_task(record.id) if claimed is not None: return claimed return None
async def _claim_next_skip_locked( self, *, queue: "str | None", execution_backend: "str | None" ) -> "QueuedTaskRecord | None": now = _utc_now() dialect_name = self._dialect_name() statement = _build_claim_candidate_statement( self.model_type, queue=queue, execution_backend=execution_backend, now=now, limit=None if dialect_name == "oracle" else 1, skip_locked=True, ) row = (await self.repository.session.execute(statement)).scalars().first() if row is None: return None return await self.claim_task(UUID(str(row.id)))
[docs] async def complete_task( self, task_id: "UUID", *, result: "Any" = None, expected_retry_count: "int | None" = None ) -> "QueuedTaskRecord | None": now = _utc_now() model_type = self.model_type criteria = [model_type.id == task_id] if expected_retry_count is not None: criteria.extend((model_type.status == "running", model_type.retry_count == expected_retry_count)) update_result = await self.repository.session.execute( update(model_type) .where(*criteria) .values( status="completed", completed_at=now, heartbeat_at=now, result_json=_serialize_json(result), error=None ) ) if update_result.rowcount != 1: return None model = await self._select_task(task_id) return self.record_from_model(model) if model is not None else None
[docs] async def fail_task( self, task_id: "UUID", error: "str", *, retry: "bool", expected_retry_count: "int | None" = None ) -> "QueuedTaskRecord | None": model = await self._select_task(task_id) if model is None: return None if expected_retry_count is not None and ( str(model.status) != "running" or int(model.retry_count) != expected_retry_count ): return None model_type = self.model_type if retry and int(model.retry_count) < int(model.max_retries): await self.repository.session.execute( update(model_type) .where(model_type.id == task_id) .values( status="pending", started_at=None, heartbeat_at=None, retry_count=int(model.retry_count) + 1, error=error, ) ) else: now = _utc_now() await self.repository.session.execute( update(model_type) .where(model_type.id == task_id) .values(status="failed", completed_at=now, heartbeat_at=now, error=error) ) updated = await self._select_task(task_id) return self.record_from_model(updated) if updated is not None else None
[docs] async def cancel_task(self, task_id: "UUID") -> "bool": model_type = self.model_type result = await self.repository.session.execute( update(model_type) .where(model_type.id == task_id, model_type.status.in_(_DUE_STATUSES)) .values(status="cancelled", completed_at=_utc_now()) ) return int(result.rowcount or 0) == 1
[docs] async def touch_heartbeat(self, task_id: "UUID", *, expected_retry_count: "int | None" = None) -> "bool": model_type = self.model_type criteria = [model_type.id == task_id, model_type.status == "running"] if expected_retry_count is not None: criteria.append(model_type.retry_count == expected_retry_count) result = await self.repository.session.execute( update(model_type).where(*criteria).values(heartbeat_at=_utc_now()) ) return int(result.rowcount or 0) == 1
[docs] async def null_heartbeats(self, task_ids: "list[UUID]", *, expected_retry_count: "int | None" = None) -> "None": if not task_ids: return model_type = self.model_type criteria = [model_type.id.in_(task_ids)] if expected_retry_count is not None: criteria.append(model_type.retry_count == expected_retry_count) await self.repository.session.execute(update(model_type).where(*criteria).values(heartbeat_at=None))
[docs] async def requeue_stale_running(self, *, stale_after: "timedelta") -> "StaleTaskRecoveryResult": cutoff = _utc_now() - stale_after model_type = self.model_type criteria = [model_type.status == "running"] if stale_after > timedelta(0): criteria.append(or_(model_type.heartbeat_at.is_(None), model_type.heartbeat_at <= cutoff)) models = (await self.repository.session.execute(select(model_type).where(*criteria))).scalars().all() result = StaleTaskRecoveryResult() for model in models: metadata = _deserialize_json(model.metadata_json) requeue_on_stale = metadata.get("requeue_on_stale", True) is not False if requeue_on_stale and int(model.retry_count) < int(model.max_retries): await self.repository.session.execute( update(model_type) .where(model_type.id == model.id) .values( status="pending", started_at=None, heartbeat_at=None, retry_count=int(model.retry_count) + 1, error="Task heartbeat stale", ) ) result.requeued += 1 else: now = _utc_now() await self.repository.session.execute( update(model_type) .where(model_type.id == model.id) .values(status="failed", completed_at=now, heartbeat_at=now, error="Task heartbeat stale") ) result.failed += 1 task_id = UUID(str(model.id)) result.failed_task_ids.append(task_id) if not requeue_on_stale: result.handler_needed += 1 result.handler_needed_task_ids.append(task_id) return result
[docs] async def set_execution_ref( self, task_id: "UUID", execution_backend: "str", execution_ref: "str", *, execution_profile: "str | None" ) -> "QueuedTaskRecord | None": model_type = self.model_type result = await self.repository.session.execute( update(model_type) .where(model_type.id == task_id) .values( execution_backend=execution_backend, execution_profile=execution_profile, execution_ref=execution_ref ) ) if result.rowcount != 1: return None model = await self._select_task(task_id) return self.record_from_model(model) if model is not None else None
[docs] async def set_execution_backend( self, task_id: "UUID", execution_backend: "str", *, execution_profile: "str | None" ) -> "QueuedTaskRecord | None": model_type = self.model_type result = await self.repository.session.execute( update(model_type) .where(model_type.id == task_id) .values(execution_backend=execution_backend, execution_profile=execution_profile, execution_ref=None) ) if result.rowcount != 1: return None model = await self._select_task(task_id) return self.record_from_model(model) if model is not None else None
[docs] async def list_running_external(self, *, limit: "int | None" = None) -> "list[QueuedTaskRecord]": model_type = self.model_type statement = ( select(model_type) .where(model_type.status.in_(("pending", "scheduled", "running")), model_type.execution_ref.is_not(None)) .order_by(model_type.started_at, model_type.created_at) ) if limit is not None: statement = statement.limit(limit) models = await self.get_many(statement=statement) return [self.record_from_model(model) for model in models]
[docs] async def get_statistics(self) -> "QueueStatistics": model_type = self.model_type result = await self.repository.session.execute( select(model_type.status, func.count()).group_by(model_type.status) ) statistics = QueueStatistics() for status, count in result.all(): coerced = _coerce_status(status) setattr(statistics, coerced, int(count)) return statistics
[docs] async def list_completed_by_task( self, task_name: "str", *, since: "datetime | None", limit: "int" ) -> "list[QueuedTaskRecord]": model_type = self.model_type criteria = [model_type.task_name == task_name, model_type.status == "completed"] if since is not None: criteria.append(model_type.completed_at >= since) statement = select(model_type).where(and_(*criteria)).order_by(desc(model_type.completed_at)).limit(limit) models = await self.get_many(statement=statement) return [self.record_from_model(model) for model in models]
[docs] async def cleanup_terminal(self, before: "datetime") -> "int": model_type = self.model_type result = await self.repository.session.execute( delete(model_type).where( model_type.status.in_(_TERMINAL_STATUSES), model_type.completed_at.is_not(None), model_type.completed_at < before, ) ) return int(result.rowcount or 0)
[docs] def model_from_record(self, record: "QueuedTaskRecord") -> "Any": """Convert a backend-neutral record into an Advanced Alchemy model. Returns: The Advanced Alchemy queue task model. """ return self.model_type( id=record.id, task_name=record.task_name, args_json=_serialize_json(list(record.args)), kwargs_json=_serialize_json(record.kwargs), queue=record.queue, execution_backend=record.execution_backend, execution_profile=record.execution_profile, execution_ref=record.execution_ref, status=record.status, priority=record.priority, max_retries=record.max_retries, retry_count=record.retry_count, scheduled_at=record.scheduled_at, created_at=record.created_at, started_at=record.started_at, completed_at=record.completed_at, heartbeat_at=record.heartbeat_at, result_json=_serialize_json(record.result), error=record.error, task_key=record.key, metadata_json=_serialize_json(record.metadata), )
[docs] @staticmethod def record_from_model(model: "Any") -> "QueuedTaskRecord": """Convert an Advanced Alchemy model into a backend-neutral record. Returns: The backend-neutral queued task record. """ args = _deserialize_json(model.args_json) kwargs = _deserialize_json(model.kwargs_json) metadata = _deserialize_json(model.metadata_json) return QueuedTaskRecord( id=UUID(str(model.id)), task_name=model.task_name, args=tuple(args), kwargs=kwargs, queue=model.queue, execution_backend=model.execution_backend, execution_profile=model.execution_profile, execution_ref=model.execution_ref, status=_coerce_status(model.status), priority=int(model.priority), max_retries=int(model.max_retries), retry_count=int(model.retry_count), scheduled_at=_coerce_datetime(model.scheduled_at), created_at=cast("datetime", _coerce_datetime(model.created_at)), started_at=_coerce_datetime(model.started_at), completed_at=_coerce_datetime(model.completed_at), heartbeat_at=_coerce_datetime(model.heartbeat_at), result=_deserialize_json(model.result_json), error=model.error, key=model.task_key, metadata=metadata, )
async def _select_task(self, task_id: "UUID") -> "Any | None": return await self.repository.get_one_or_none(id=task_id) async def _select_task_by_key(self, key: "str") -> "Any | None": return await self.repository.get_one_or_none(task_key=key) async def _insert_task_record(self, record: "QueuedTaskRecord", *, key: "str | None") -> "QueuedTaskRecord": model = self.model_from_record(record) dialect_name = self._dialect_name() if key is not None and _supports_native_keyed_enqueue(dialect_name): values = _model_insert_values(model, self.model_type) statement, params = _build_keyed_enqueue_upsert( cast("Any", self.model_type.__table__), values, dialect_name=cast("str", dialect_name) ) await self.repository.session.execute(statement, {**values, **params}) await self.repository.session.flush() inserted = await self._select_task(record.id) if inserted is not None: return self.record_from_model(inserted) existing = await self._select_task_by_key(key) if existing is not None: return self.record_from_model(existing) return record await self.repository.add(model, auto_commit=False, auto_refresh=False) return record def _dialect_name(self) -> "str | None": bind = self.repository.session.get_bind() dialect = getattr(bind, "dialect", None) return cast("str | None", getattr(dialect, "name", None)) def _pending_statement(self, *, queue: "str | None", execution_backend: "str | None") -> "Any": return _build_claim_candidate_statement( self.model_type, queue=queue, execution_backend=execution_backend, now=_utc_now(), limit=None, skip_locked=False, )
def _supports_skip_locked_claim(dialect_name: "str | None") -> "bool": return dialect_name in _SKIP_LOCKED_CLAIM_DIALECTS def _supports_native_keyed_enqueue(dialect_name: "str | None") -> "bool": return dialect_name in _NATIVE_KEYED_ENQUEUE_DIALECTS def _build_claim_candidate_statement( model_type: "type[Any]", *, queue: "str | None", execution_backend: "str | None", now: "datetime", limit: "int | None", skip_locked: "bool", ) -> "Any": criteria = [ model_type.status.in_(_DUE_STATUSES), or_(model_type.scheduled_at.is_(None), model_type.scheduled_at <= now), ] if queue is not None: criteria.append(model_type.queue == queue) if execution_backend is not None: criteria.append(model_type.execution_backend == execution_backend) statement = select(model_type).where(and_(*criteria)).order_by(desc(model_type.priority), model_type.created_at) if limit is not None: statement = statement.limit(limit) if skip_locked: statement = statement.with_for_update(skip_locked=True) return statement def _build_keyed_enqueue_upsert( table: "Any", values: "dict[str, Any]", *, dialect_name: "str" ) -> "tuple[Any, dict[str, Any]]": if dialect_name == "oracle": return OnConflictUpsert.create_merge_upsert( table=table, values=values, conflict_columns=["task_key"], update_columns=[], dialect_name=dialect_name, validate_identifiers=True, ) update_columns = ["task_key"] return ( OnConflictUpsert.create_upsert( table=table, values=values, conflict_columns=["task_key"], update_columns=update_columns, dialect_name=dialect_name, validate_identifiers=True, ), {}, ) def _model_insert_values(model: "Any", model_type: "type[Any]") -> "dict[str, Any]": values: "dict[str, Any]" = {} for column in model_type.__table__.columns: if not hasattr(model, column.name): continue value = getattr(model, column.name) if value is None and column.name == "updated_at": value = _utc_now() if value is None and (column.default is not None or column.server_default is not None): continue values[column.name] = value return values def _utc_now() -> "datetime": return datetime.now(timezone.utc) def _serialize_json(value: "Any") -> "Any": return _decode_json(str(_encode_json(value))) def _deserialize_json(value: "Any") -> "Any": if value is None: return None if isinstance(value, bytes | bytearray | memoryview): return _decode_json(bytes(value)) if isinstance(value, str): try: return _decode_json(value) except ValueError: return value return value def _coerce_datetime(value: "Any") -> "datetime | None": if value is None: return None if isinstance(value, datetime): if value.tzinfo is None: return value.replace(tzinfo=timezone.utc) return value.astimezone(timezone.utc) parsed = datetime.fromisoformat(str(value).replace("Z", "+00:00")) if parsed.tzinfo is None: return parsed.replace(tzinfo=timezone.utc) return parsed.astimezone(timezone.utc) def _coerce_status(value: "Any") -> "TaskStatus": status = str(value) if status not in {"cancelled", "completed", "failed", "pending", "running", "scheduled"}: msg = f"Unknown queued task status from Advanced Alchemy queue backend: {status!r}" raise ValueError(msg) return cast("TaskStatus", status)