Source code for litestar_queues.backends.redis.backend

"""Redis queue backend.

Stores queued task records in a Redis-protocol key-value server. The
implementation lives directly on ``RedisQueueBackend``; the Valkey
backend inherits from this class and only swaps the client factory and
``_backend_name`` ClassVar.
"""

import asyncio
import inspect
import json
from contextlib import asynccontextmanager, suppress
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, ClassVar, cast
from uuid import UUID, uuid4

from redis import asyncio as redis_asyncio

from litestar_queues.backends.base import BaseQueueBackend
from litestar_queues.backends.redis.config import RedisBackendConfig as _RedisBackendConfig
from litestar_queues.exceptions import QueueError
from litestar_queues.models import (
    QueueBackendCapabilities,
    QueuedTaskRecord,
    QueueStatistics,
    StaleTaskRecoveryResult,
    TaskStatus,
)

if TYPE_CHECKING:
    from collections.abc import AsyncIterator

    from litestar_queues.config import QueueConfig

__all__ = ("RedisQueueBackend",)

_DUE_STATUSES = {"pending", "scheduled"}
_STATUS_VALUES = {"cancelled", "completed", "failed", "pending", "running", "scheduled"}
_TERMINAL_STATUSES = {"cancelled", "completed", "failed"}
_RELEASE_LOCK_SCRIPT = """
if redis.call('GET', KEYS[1]) == ARGV[1] then
    return redis.call('DEL', KEYS[1])
end
return 0
"""


[docs] class RedisQueueBackend(BaseQueueBackend): """Queue backend that stores records in a Redis-protocol key-value server.""" _backend_name: "ClassVar[str]" = "redis" __slots__ = ( "_client", "_key_prefix", "_lock_timeout", "_notification_channel", "_notifications", "_owns_client", "_poll_interval", "_url", )
[docs] def __init__( self, config: "QueueConfig | None" = None, *, backend_config: "_RedisBackendConfig | None" = None ) -> "None": super().__init__(config=config) backend_config = backend_config or _RedisBackendConfig() self._client = backend_config.client self._owns_client = self._client is None self._url = backend_config.url self._key_prefix = backend_config.key_prefix.rstrip(":") self._notifications = backend_config.notifications self._notification_channel = backend_config.notification_channel self._lock_timeout = backend_config.lock_timeout self._poll_interval = backend_config.poll_interval
@property def capabilities(self) -> "QueueBackendCapabilities": """Backend behavior capabilities.""" return QueueBackendCapabilities( supports_notifications=self._notifications, notification_backend=f"{self._backend_name}-pubsub" if self._notifications else None, notifications_durable=False, )
[docs] async def open(self) -> "bool": """Open Redis-protocol client resources. Returns: True when the client is ready. """ if self._client is None: self._client = self._create_client(self._url) self._owns_client = True return True
[docs] async def close(self) -> "None": """Close owned Redis-protocol client resources.""" if self._owns_client and self._client is not None: close = getattr(self._client, "aclose", None) or getattr(self._client, "close", None) if close is not None: result = close() if inspect.isawaitable(result): await result self._client = None
[docs] async def enqueue( self, task_name: "str", *, args: "tuple[Any, ...]" = (), kwargs: "dict[str, Any] | None" = None, queue: "str" = "default", priority: "int" = 0, max_retries: "int" = 0, scheduled_at: "datetime | None" = None, key: "str | None" = None, execution_backend: "str" = "local", execution_profile: "str | None" = None, metadata: "dict[str, Any] | None" = None, ) -> "QueuedTaskRecord": """Persist a queued task. Returns: The created or deduplicated queued task record. """ if key is not None: async with self._lock(f"key:{key}", wait=True): existing = await self.get_task_by_key(key) if existing is not None and not existing.is_terminal: return existing if existing is not None: await self._clear_key(existing) record = self._create_record( task_name, args=args, kwargs=kwargs, queue=queue, priority=priority, max_retries=max_retries, scheduled_at=scheduled_at, key=key, execution_backend=execution_backend, execution_profile=execution_profile, metadata=metadata, ) await self._save_record(record) await self._client_hset(self._keys_key, key, str(record.id)) else: record = self._create_record( task_name, args=args, kwargs=kwargs, queue=queue, priority=priority, max_retries=max_retries, scheduled_at=scheduled_at, key=None, execution_backend=execution_backend, execution_profile=execution_profile, metadata=metadata, ) await self._save_record(record) await self.notify_new_task(record) return record
[docs] async def get_task(self, task_id: "UUID") -> "QueuedTaskRecord | None": """Return a queued task by ID.""" mapping = await self._client_hgetall(self._task_key(task_id)) if not mapping: return None return self._record_from_mapping(mapping)
[docs] async def get_task_by_key(self, key: "str") -> "QueuedTaskRecord | None": """Return a queued task by deduplication key.""" task_id = await self._client_hget(self._keys_key, key) if task_id is None: return None return await self.get_task(UUID(str(_decode(task_id))))
[docs] async def list_pending( self, *, limit: "int" = 1, queue: "str | None" = None, execution_backend: "str | None" = None ) -> "list[QueuedTaskRecord]": """Return due pending or scheduled tasks ordered for execution.""" client = await self._get_client() candidate_ids = await client.zrangebyscore(self._pending_key, "-inf", _utc_now().timestamp()) due_records = [ record for record in await self._records_from_ids(candidate_ids) if record.status in _DUE_STATUSES and record.is_due and (queue is None or record.queue == queue) and (execution_backend is None or record.execution_backend == execution_backend) ] due_records.sort(key=lambda record: (-record.priority, record.created_at)) return due_records[:limit]
[docs] async def claim_task(self, task_id: "UUID") -> "QueuedTaskRecord | None": """Atomically claim a pending task. Returns: The claimed record, if it was still due and claimable. """ async with self._lock(f"task:{task_id}", wait=False) as acquired: if not acquired: return None record = await self.get_task(task_id) if record is None or record.status not in _DUE_STATUSES or not record.is_due: return None now = _utc_now() record.status = "running" record.started_at = now record.heartbeat_at = now await self._save_record(record) return record
[docs] async def complete_task( self, task_id: "UUID", *, result: "Any" = None, expected_retry_count: "int | None" = None ) -> "QueuedTaskRecord | None": """Mark a task as completed. Returns: The completed record, if it exists. """ async with self._lock(f"task:{task_id}", wait=True): record = await self.get_task(task_id) if record is None: return None if expected_retry_count is not None and ( record.status != "running" or record.retry_count != expected_retry_count ): return None now = _utc_now() record.status = "completed" record.completed_at = now record.heartbeat_at = now record.result = result record.error = None await self._save_record(record) return record
[docs] async def fail_task( self, task_id: "UUID", error: "str", *, retry: "bool" = True, expected_retry_count: "int | None" = None ) -> "QueuedTaskRecord | None": """Mark a task as failed or retry it. Returns: The updated record, if it exists. """ async with self._lock(f"task:{task_id}", wait=True): record = await self.get_task(task_id) if record is None: return None if expected_retry_count is not None and ( record.status != "running" or record.retry_count != expected_retry_count ): return None record.error = error if retry and record.retry_count < record.max_retries: record.status = "pending" record.retry_count += 1 record.started_at = None record.heartbeat_at = None await self._save_record(record) await self.notify_new_task(record) return record now = _utc_now() record.status = "failed" record.completed_at = now record.heartbeat_at = now await self._save_record(record) return record
[docs] async def cancel_task(self, task_id: "UUID") -> "bool": """Cancel a task if it has not started. Returns: True when the task was cancelled. """ async with self._lock(f"task:{task_id}", wait=True): record = await self.get_task(task_id) if record is None or record.status not in _DUE_STATUSES: return False record.status = "cancelled" record.completed_at = _utc_now() await self._save_record(record) return True
[docs] async def touch_heartbeat(self, task_id: "UUID", *, expected_retry_count: "int | None" = None) -> "bool": """Update the heartbeat timestamp for a running task. Returns: True when the heartbeat was updated. """ async with self._lock(f"task:{task_id}", wait=True): record = await self.get_task(task_id) if record is None or record.status != "running": return False if expected_retry_count is not None and record.retry_count != expected_retry_count: return False record.heartbeat_at = _utc_now() await self._save_record(record) return True
[docs] async def null_heartbeats(self, task_ids: "list[UUID]", *, expected_retry_count: "int | None" = None) -> "None": """Clear heartbeat timestamps for task IDs.""" for task_id in task_ids: async with self._lock(f"task:{task_id}", wait=True): record = await self.get_task(task_id) if record is None: continue if expected_retry_count is not None and record.retry_count != expected_retry_count: continue record.heartbeat_at = None await self._save_record(record)
[docs] async def requeue_stale_running(self, *, stale_after: "timedelta") -> "StaleTaskRecoveryResult": """Requeue running tasks with stale heartbeats. Returns: Summary of recovered records. """ cutoff = _utc_now() - stale_after result = StaleTaskRecoveryResult() for record in await self._list_records(): if record.status != "running": continue if record.heartbeat_at is not None and record.heartbeat_at >= cutoff: result.skipped += 1 continue async with self._lock(f"task:{record.id}", wait=False) as acquired: if not acquired: result.skipped += 1 continue latest = await self.get_task(record.id) if latest is None or latest.status != "running": result.skipped += 1 continue if latest.heartbeat_at is not None and latest.heartbeat_at >= cutoff: result.skipped += 1 continue requeue_on_stale = latest.metadata.get("requeue_on_stale", True) is not False if requeue_on_stale and latest.retry_count < latest.max_retries: latest.status = "pending" latest.started_at = None latest.heartbeat_at = None latest.retry_count += 1 result.requeued += 1 else: latest.status = "failed" latest.completed_at = _utc_now() latest.heartbeat_at = None latest.error = "Task heartbeat stale" result.failed += 1 result.failed_task_ids.append(latest.id) if not requeue_on_stale: result.handler_needed += 1 result.handler_needed_task_ids.append(latest.id) await self._save_record(latest) if latest.status in _DUE_STATUSES: await self.notify_new_task(latest) return result
[docs] async def set_execution_ref( self, task_id: "UUID", execution_backend: "str", execution_ref: "str", *, execution_profile: "str | None" = None ) -> "QueuedTaskRecord | None": """Persist an external execution reference for a running task. Returns: The updated record, if it exists. """ async with self._lock(f"task:{task_id}", wait=True): record = await self.get_task(task_id) if record is None: return None record.execution_backend = execution_backend record.execution_profile = execution_profile record.execution_ref = execution_ref await self._save_record(record) return record
[docs] async def set_execution_backend( self, task_id: "UUID", execution_backend: "str", *, execution_profile: "str | None" = None ) -> "QueuedTaskRecord | None": """Persist an execution backend/profile change for a queued task. Returns: The updated record, if it exists. """ async with self._lock(f"task:{task_id}", wait=True): record = await self.get_task(task_id) if record is None: return None record.execution_backend = execution_backend record.execution_profile = execution_profile record.execution_ref = None await self._save_record(record) await self.notify_new_task(record) return record
[docs] async def list_running_external(self, *, limit: "int | None" = None) -> "list[QueuedTaskRecord]": """Return externally dispatched tasks with references to reconcile.""" records = [ record for record in await self._list_records() if record.status in {"pending", "scheduled", "running"} and record.execution_ref is not None ] records.sort(key=lambda record: (record.started_at or record.created_at, record.created_at)) return records[:limit] if limit is not None else records
[docs] async def get_statistics(self) -> "QueueStatistics": """Return queue status counts.""" statistics = QueueStatistics() for record in await self._list_records(): setattr(statistics, record.status, getattr(statistics, record.status) + 1) return statistics
[docs] async def list_completed_by_task( self, task_name: "str", *, since: "datetime | None" = None, limit: "int" = 10 ) -> "list[QueuedTaskRecord]": """Return recent completed records for a task name.""" records = [ record for record in await self._list_records() if record.task_name == task_name and record.status == "completed" and record.completed_at is not None and (since is None or record.completed_at >= since) ] records.sort(key=lambda record: record.completed_at or record.created_at, reverse=True) return records[:limit]
[docs] async def cleanup_terminal(self, before: "datetime") -> "int": """Delete terminal records completed before a cutoff. Returns: Number of deleted records. """ count = 0 for record in await self._list_records(): if record.status not in _TERMINAL_STATUSES or record.completed_at is None or record.completed_at >= before: continue async with self._lock(f"task:{record.id}", wait=False) as acquired: if not acquired: continue latest = await self.get_task(record.id) if ( latest is None or latest.status not in _TERMINAL_STATUSES or latest.completed_at is None or latest.completed_at >= before ): continue await self._delete_record(latest) count += 1 return count
[docs] async def notify_new_task(self, record: "QueuedTaskRecord") -> "None": """Publish a Redis-protocol pub/sub message when work is available.""" if self._notifications and record.status in _DUE_STATUSES: payload = _json_dumps({ "task_id": str(record.id), "task_name": record.task_name, "queue": record.queue, "execution_backend": record.execution_backend, }) client = await self._get_client() await client.publish(self._notification_channel, payload)
[docs] async def wait_for_notifications(self, timeout: "float | None" = None) -> "bool": """Wait for a Redis-protocol pub/sub message when notifications are enabled. Returns: True when a notification was observed. """ if not self._notifications: return await super().wait_for_notifications(timeout=timeout) client = await self._get_client() pubsub = client.pubsub() await pubsub.subscribe(self._notification_channel) try: return await _wait_for_pubsub_message(pubsub, timeout=timeout) finally: await _close_pubsub(pubsub, self._notification_channel)
def _create_client(self, url: "str") -> "Any": return redis_asyncio.from_url(url, decode_responses=True) async def _get_client(self) -> "Any": if self._client is None: await self.open() return self._client @asynccontextmanager async def _lock(self, lock_name: "str", *, wait: "bool") -> "AsyncIterator[bool]": client = await self._get_client() lock_key = self._lock_key(lock_name) token = uuid4().hex timeout_ms = max(1, int(self._lock_timeout * 1000)) acquired = bool(await client.set(lock_key, token, nx=True, px=timeout_ms)) if not acquired and wait: deadline = asyncio.get_running_loop().time() + self._lock_timeout while not acquired and asyncio.get_running_loop().time() < deadline: await asyncio.sleep(min(self._poll_interval, self._lock_timeout)) acquired = bool(await client.set(lock_key, token, nx=True, px=timeout_ms)) if not acquired: msg = f"Timed out acquiring {self._backend_name} queue lock: {lock_name}" raise QueueError(msg) try: yield acquired finally: if acquired: await self._release_lock(client, lock_key, token) async def _release_lock(self, client: "Any", lock_key: "str", token: "str") -> "None": eval_method = getattr(client, "eval", None) if eval_method is not None: result = eval_method(_RELEASE_LOCK_SCRIPT, 1, lock_key, token) if inspect.isawaitable(result): await result return if _decode(await client.get(lock_key)) == token: await client.delete(lock_key) def _create_record( self, task_name: "str", *, args: "tuple[Any, ...]", kwargs: "dict[str, Any] | None", queue: "str", priority: "int", max_retries: "int", scheduled_at: "datetime | None", key: "str | None", execution_backend: "str", execution_profile: "str | None", metadata: "dict[str, Any] | None", ) -> "QueuedTaskRecord": return QueuedTaskRecord( task_name=task_name, args=args, kwargs=dict(kwargs or {}), queue=queue, execution_backend=execution_backend, execution_profile=execution_profile, status="scheduled" if scheduled_at is not None and scheduled_at > _utc_now() else "pending", priority=priority, max_retries=max_retries, scheduled_at=scheduled_at, key=key, metadata=dict(metadata or {}), ) async def _save_record(self, record: "QueuedTaskRecord") -> "None": client = await self._get_client() await client.hset(self._task_key(record.id), mapping=self._record_to_mapping(record)) await client.sadd(self._tasks_key, str(record.id)) if record.status in _DUE_STATUSES: await client.zadd(self._pending_key, {str(record.id): _score_datetime(record.scheduled_at)}) else: await client.zrem(self._pending_key, str(record.id)) async def _delete_record(self, record: "QueuedTaskRecord") -> "None": client = await self._get_client() await client.delete(self._task_key(record.id)) await client.srem(self._tasks_key, str(record.id)) await client.zrem(self._pending_key, str(record.id)) if record.key is not None and str(_decode(await client.hget(self._keys_key, record.key))) == str(record.id): await client.hdel(self._keys_key, record.key) async def _clear_key(self, record: "QueuedTaskRecord") -> "None": if record.key is not None: await self._client_hdel(self._keys_key, record.key) async def _list_records(self) -> "list[QueuedTaskRecord]": client = await self._get_client() task_ids = await client.smembers(self._tasks_key) return await self._records_from_ids(task_ids) async def _records_from_ids(self, task_ids: "set[Any] | list[Any] | tuple[Any, ...]") -> "list[QueuedTaskRecord]": records: "list[QueuedTaskRecord]" = [] for value in task_ids: record = await self.get_task(UUID(str(_decode(value)))) if record is not None: records.append(record) return records async def _client_hget(self, name: "str", key: "str") -> "Any": client = await self._get_client() return await client.hget(name, key) async def _client_hgetall(self, name: "str") -> "dict[str, Any]": client = await self._get_client() return _decode_mapping(await client.hgetall(name)) async def _client_hset(self, name: "str", key: "str", value: "Any") -> "None": client = await self._get_client() await client.hset(name, key, value) async def _client_hdel(self, name: "str", key: "str") -> "None": client = await self._get_client() await client.hdel(name, key) @property def _tasks_key(self) -> "str": return f"{self._key_prefix}:tasks" @property def _keys_key(self) -> "str": return f"{self._key_prefix}:keys" @property def _pending_key(self) -> "str": return f"{self._key_prefix}:pending" def _task_key(self, task_id: "UUID") -> "str": return f"{self._key_prefix}:task:{task_id}" def _lock_key(self, lock_name: "str") -> "str": return f"{self._key_prefix}:locks:{lock_name}" def _record_to_mapping(self, record: "QueuedTaskRecord") -> "dict[str, str]": return { "id": str(record.id), "task_name": record.task_name, "args": _json_dumps(list(record.args)), "kwargs": _json_dumps(record.kwargs), "queue": record.queue, "execution_backend": record.execution_backend, "execution_profile": record.execution_profile or "", "execution_ref": record.execution_ref or "", "status": record.status, "priority": str(record.priority), "max_retries": str(record.max_retries), "retry_count": str(record.retry_count), "scheduled_at": _serialize_datetime(record.scheduled_at), "created_at": _serialize_datetime(record.created_at), "started_at": _serialize_datetime(record.started_at), "completed_at": _serialize_datetime(record.completed_at), "heartbeat_at": _serialize_datetime(record.heartbeat_at), "result": _json_dumps(record.result), "error": record.error or "", "key": record.key or "", "metadata": _json_dumps(record.metadata), } def _record_from_mapping(self, mapping: "dict[str, Any]") -> "QueuedTaskRecord": return QueuedTaskRecord( id=UUID(str(mapping["id"])), task_name=str(mapping["task_name"]), args=tuple(_json_loads(mapping.get("args"), [])), kwargs=dict(_json_loads(mapping.get("kwargs"), {})), queue=str(mapping.get("queue") or "default"), execution_backend=str(mapping.get("execution_backend") or "local"), execution_profile=str(mapping["execution_profile"]) if mapping.get("execution_profile") else None, execution_ref=str(mapping["execution_ref"]) if mapping.get("execution_ref") else None, status=_coerce_status(mapping.get("status")), priority=int(str(mapping.get("priority") or 0)), max_retries=int(str(mapping.get("max_retries") or 0)), retry_count=int(str(mapping.get("retry_count") or 0)), scheduled_at=_deserialize_datetime(mapping.get("scheduled_at")), created_at=_deserialize_datetime(mapping.get("created_at")) or _utc_now(), started_at=_deserialize_datetime(mapping.get("started_at")), completed_at=_deserialize_datetime(mapping.get("completed_at")), heartbeat_at=_deserialize_datetime(mapping.get("heartbeat_at")), result=_json_loads(mapping.get("result"), None), error=str(mapping["error"]) if mapping.get("error") else None, key=str(mapping["key"]) if mapping.get("key") else None, metadata=dict(_json_loads(mapping.get("metadata"), {})), )
def _utc_now() -> "datetime": return datetime.now(timezone.utc) def _serialize_datetime(value: "datetime | None") -> "str": if value is None: return "" if value.tzinfo is None: value = value.replace(tzinfo=timezone.utc) return value.astimezone(timezone.utc).isoformat() def _deserialize_datetime(value: "Any") -> "datetime | None": value = _decode(value) if not value: return None parsed = datetime.fromisoformat(str(value).replace("Z", "+00:00")) if parsed.tzinfo is None: return parsed.replace(tzinfo=timezone.utc) return parsed.astimezone(timezone.utc) def _score_datetime(value: "datetime | None") -> "float": if value is None or value <= _utc_now(): return 0.0 if value.tzinfo is None: value = value.replace(tzinfo=timezone.utc) return value.astimezone(timezone.utc).timestamp() def _decode(value: "Any") -> "Any": if isinstance(value, bytes): return value.decode() return value def _decode_mapping(mapping: "dict[Any, Any]") -> "dict[str, Any]": return {str(_decode(key)): _decode(value) for key, value in mapping.items()} def _json_default(value: "Any") -> "Any": if isinstance(value, datetime): return _serialize_datetime(value) return str(value) def _json_dumps(value: "Any") -> "str": return json.dumps(value, default=_json_default, separators=(",", ":"), sort_keys=True) def _json_loads(value: "Any", default: "Any") -> "Any": value = _decode(value) if value in {None, ""}: return default return json.loads(str(value)) def _coerce_status(value: "Any") -> "TaskStatus": status = str(_decode(value)) if status not in _STATUS_VALUES: msg = f"Unknown queued task status from Redis-protocol queue backend: {status!r}" raise ValueError(msg) return cast("TaskStatus", status) async def _wait_for_pubsub_message(pubsub: "Any", *, timeout: "float | None") -> "bool": """Drain pubsub responses until a real ``message`` arrives or timeout. ``pubsub.get_message(ignore_subscribe_messages=True)`` returns ``None`` for both "no message in this read window" AND "subscribe-confirmation was filtered". Looping with a deadline distinguishes the two cases. Returns: True when a real published message was observed before the deadline. """ loop = asyncio.get_running_loop() deadline = loop.time() + timeout if timeout is not None else None while True: remaining = None if deadline is None else max(0.0, deadline - loop.time()) if remaining is not None and remaining <= 0.0: return False message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=remaining) if message is not None: return True if deadline is None: return False async def _close_pubsub(pubsub: "Any", channel: "str") -> "None": """Best-effort unsubscribe + close on a pubsub connection.""" unsubscribe = getattr(pubsub, "unsubscribe", None) if unsubscribe is not None: result = unsubscribe(channel) if inspect.isawaitable(result): with suppress(Exception): await result close = getattr(pubsub, "aclose", None) or getattr(pubsub, "close", None) if close is not None: result = close() if inspect.isawaitable(result): with suppress(Exception): await result