"""SQLSpec queue backend."""
import asyncio
from contextlib import asynccontextmanager, contextmanager, suppress
from datetime import datetime, timedelta, timezone
from inspect import isawaitable
from logging import getLogger
from typing import TYPE_CHECKING, Any, cast
from uuid import UUID
from sqlspec import SQLSpec, StatementStack
from sqlspec.extensions.events import normalize_event_channel_name, resolve_adapter_name
from sqlspec.utils.sync_tools import async_
from litestar_queues.backends.base import BaseQueueBackend
from litestar_queues.backends.sqlspec.config import DEFAULT_NOTIFICATION_CHANNEL, SQLSpecBackendConfig
from litestar_queues.backends.sqlspec.extension import QUEUE_EXTENSION_NAME, configure_queue_migration_extension
from litestar_queues.backends.sqlspec.schema import (
DEFAULT_TABLE_NAME,
validate_column_map,
validate_native_json_columns,
validate_table_name,
)
from litestar_queues.backends.sqlspec.stores.factory import create_queue_store
from litestar_queues.exceptions import QueueConfigurationError
from litestar_queues.models import (
EnqueueSpec,
QueueBackendCapabilities,
QueuedTaskRecord,
QueueStatistics,
StaleTaskRecoveryResult,
TaskStatus,
)
if TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterator, Sequence
from litestar_queues.config import QueueConfig
__all__ = ("SQLSpecQueueBackend",)
_DUE_STATUSES = ("pending", "scheduled")
_DURABLE_NOTIFICATION_BACKENDS = frozenset({"aq", "listen_notify_durable", "table_queue", "txeventq"})
_EVENT_EXTENSION_NAME = "events"
_QUEUE_SETTING_EVENT_SETTINGS = ("event_settings", "events")
_NOTIFY_TRANSPORT_POLLING = "polling"
# Adapter families that can push worker wakeups. Postgres-over-asyncpg ships the
# durable LISTEN/NOTIFY hybrid; psycopg/psqlpy fall back to the durable table
# queue until their LISTEN/NOTIFY path lands upstream. Everything else polls.
_NOTIFY_DURABLE_ADAPTERS = frozenset({"asyncpg"})
_NOTIFY_TABLE_QUEUE_ADAPTERS = frozenset({"psycopg", "psqlpy"})
def _adapter_notify_transport(adapter_name: "str | None") -> "str":
"""Return the default wakeup transport for a SQLSpec adapter.
The wakeup transport is gated purely by adapter knowledge so backends only
advertise push wakeups where the driver can deliver them.
Returns:
``"listen_notify_durable"`` for asyncpg, ``"table_queue"`` for
psycopg/psqlpy, otherwise ``"polling"``.
"""
if adapter_name in _NOTIFY_DURABLE_ADAPTERS:
return "listen_notify_durable"
if adapter_name in _NOTIFY_TABLE_QUEUE_ADAPTERS:
return "table_queue"
return _NOTIFY_TRANSPORT_POLLING
[docs]
class SQLSpecQueueBackend(BaseQueueBackend):
"""SQLSpec-backed queue backend."""
__slots__ = (
"_column_map",
"_create_schema",
"_event_backend",
"_event_channel",
"_event_poll_interval",
"_event_queue_table",
"_event_settings",
"_heartbeat_pool_config",
"_heartbeat_pool_enabled",
"_heartbeat_pool_registered",
"_manage_schema",
"_native_json_columns",
"_notification_backend",
"_notification_channel",
"_notifications_enabled",
"_notifications_requested",
"_notify_transport",
"_opened",
"_owns_event_channel",
"_owns_sqlspec",
"_queue_observability",
"_run_migrations",
"_sqlspec",
"_sqlspec_config",
"_store",
"_table_name",
)
[docs]
def __init__(
self, config: "QueueConfig | None" = None, *, backend_config: "SQLSpecBackendConfig | None" = None
) -> "None":
super().__init__(config=config)
backend_config = backend_config or SQLSpecBackendConfig()
self._column_map = validate_column_map(backend_config.column_map)
self._native_json_columns = validate_native_json_columns(frozenset(backend_config.native_json_columns))
self._manage_schema = backend_config.manage_schema
self._sqlspec = backend_config.sqlspec
self._sqlspec_config = backend_config.config
self._heartbeat_pool_config = backend_config.heartbeat_pool_config
self._heartbeat_pool_enabled = self._heartbeat_pool_config is not None
self._heartbeat_pool_registered = False
self._owns_sqlspec = self._sqlspec is None
self._table_name = (
validate_table_name(backend_config.table_name) if backend_config.table_name is not None else None
)
self._create_schema = backend_config.create_schema
self._run_migrations = backend_config.run_migrations
self._event_channel = backend_config.event_channel
self._owns_event_channel = self._event_channel is None
self._notifications_requested = backend_config.notifications
self._notification_channel = backend_config.notification_channel
self._notify_transport = backend_config.notify_transport
self._event_backend = backend_config.event_backend
self._event_queue_table = backend_config.event_queue_table
self._event_poll_interval = backend_config.event_poll_interval
self._event_settings = dict(backend_config.event_settings)
self._queue_observability = backend_config.queue_observability
self._notification_backend: "str | None" = getattr(self._event_channel, "_backend_name", None)
self._notifications_enabled = self._event_channel is not None
self._store: "Any | None" = None
self._opened = False
[docs]
async def open(self) -> "bool":
"""Open SQLSpec resources.
Returns:
True when SQLSpec resources are ready.
"""
if self._opened:
return True
self._get_or_create_sqlspec()
self._resolve_table_name()
self._configure_notifications()
self._register_heartbeat_pool()
self._opened = True
if self._resolve_run_migrations():
await self.run_migrations()
if self._resolve_create_schema():
await self.create_schema()
return True
[docs]
async def close(self) -> "None":
"""Close SQLSpec resources."""
await self._close_heartbeat_pool()
if self._owns_event_channel and self._event_channel is not None:
await self._event_channel.shutdown()
self._event_channel = None
if self._owns_sqlspec and self._sqlspec is not None:
await self._sqlspec.close_all_pools()
self._sqlspec = None
self._opened = False
@property
def capabilities(self) -> "QueueBackendCapabilities":
"""Backend behavior capabilities."""
notification_backend = self._notification_backend
return QueueBackendCapabilities(
supports_notifications=self._notifications_enabled,
notification_backend=notification_backend,
notifications_durable=notification_backend in _DURABLE_NOTIFICATION_BACKENDS,
)
[docs]
async def create_schema(self) -> "None":
"""Create the SQLSpec queue table and indexes."""
if self._manage_schema:
async with self._session() as driver:
for statement in await _create_schema_statements(self._get_store(), driver):
await driver.execute_script(statement)
await driver.commit()
[docs]
async def run_migrations(self) -> "None":
"""Apply packaged SQLSpec migrations."""
if self._manage_schema:
sqlspec_config = self._get_sqlspec_config()
configure_queue_migration_extension(sqlspec_config, table_name=self._resolve_table_name())
await sqlspec_config.migrate_up(echo=False)
[docs]
async def enqueue(
self,
task_name: "str",
*,
args: "tuple[Any, ...]" = (),
kwargs: "dict[str, Any] | None" = None,
queue: "str" = "default",
priority: "int" = 0,
max_retries: "int" = 0,
scheduled_at: "datetime | None" = None,
key: "str | None" = None,
execution_backend: "str" = "local",
execution_profile: "str | None" = None,
metadata: "dict[str, Any] | None" = None,
) -> "QueuedTaskRecord":
with self._observe_queue_operation("enqueue", queue=queue, task_name=task_name):
async with self._session() as driver:
await driver.begin()
try:
if key is not None:
existing_row = await self._select_task_by_key(driver, key)
if existing_row is not None:
existing = self._record_from_row(existing_row)
if not existing.is_terminal:
await driver.rollback()
return existing
await self._clear_key(driver, existing.id)
now = _utc_now()
record = QueuedTaskRecord(
task_name=task_name,
args=args,
kwargs=dict(kwargs or {}),
queue=queue,
execution_backend=execution_backend,
execution_profile=execution_profile,
status="scheduled" if scheduled_at is not None and scheduled_at > now else "pending",
priority=priority,
max_retries=max_retries,
scheduled_at=scheduled_at,
key=key,
metadata=dict(metadata or {}),
)
await driver.execute(self._get_store().insert_task(self._params_from_record(record)))
await driver.commit()
except Exception:
with suppress(Exception):
await driver.rollback()
raise
self._increment_queue_metric("enqueue")
await self.notify_new_task(record)
return record
[docs]
async def enqueue_many(self, specs: "Sequence[EnqueueSpec]") -> "list[QueuedTaskRecord]":
"""Persist many tasks via the adapter's fastest bulk path.
Resolves existing deduplication keys in one round trip, then inserts the
remaining rows through the native Arrow ingest path
(:meth:`load_from_records`) when the adapter supports it, otherwise via a
batched ``execute_many``. Returns records in input order, with existing
non-terminal keyed tasks returned as-is (no duplicate insert) to match
the semantics of :meth:`enqueue`.
Returns:
Queue task records in the same order as ``specs``.
"""
if not specs:
return []
store = self._get_store()
now = _utc_now()
keyed = [spec.key for spec in specs if spec.key is not None]
with self._observe_queue_operation("enqueue", task_count=len(specs)):
async with self._session() as driver:
await driver.begin()
try:
existing_by_key = await self._existing_records_by_key(driver, store, keyed)
results, to_insert, terminal_keys = self._plan_bulk_enqueue(specs, existing_by_key, now)
for task_id in terminal_keys:
await driver.execute(store.clear_key(task_id=str(task_id)))
if to_insert:
await self._bulk_insert(driver, store, to_insert)
await driver.commit()
except Exception:
with suppress(Exception):
await driver.rollback()
raise
self._increment_queue_metric("enqueue", float(len(to_insert)))
for record in to_insert:
await self.notify_new_task(record)
return results
[docs]
async def get_task(self, task_id: "UUID") -> "QueuedTaskRecord | None":
async with self._session() as driver:
row = await self._select_task(driver, task_id)
return self._record_from_row(row) if row is not None else None
[docs]
async def get_task_by_key(self, key: "str") -> "QueuedTaskRecord | None":
async with self._session() as driver:
row = await self._select_task_by_key(driver, key)
return self._record_from_row(row) if row is not None else None
[docs]
async def list_pending(
self, *, limit: "int" = 1, queue: "str | None" = None, execution_backend: "str | None" = None
) -> "list[QueuedTaskRecord]":
rows = await self._select_pending_rows(limit=limit, queue=queue, execution_backend=execution_backend)
return [self._record_from_row(row) for row in rows]
[docs]
async def claim_task(self, task_id: "UUID") -> "QueuedTaskRecord | None":
with self._observe_queue_operation("claim", task_id=str(task_id)):
async with self._session() as driver:
await driver.begin()
try:
row = await self._select_task(driver, task_id)
if row is None:
await driver.rollback()
return None
record = self._record_from_row(row)
if record.status not in _DUE_STATUSES or not record.is_due:
await driver.rollback()
return None
now = _utc_now()
result = await driver.execute(
self._get_store().claim_task(
task_id=str(task_id),
due_at=self._serialize_datetime(now),
heartbeat_at=self._serialize_datetime(now),
started_at=self._serialize_datetime(now),
)
)
if result.rows_affected == 0:
await driver.rollback()
return None
updated_row = await self._select_task(driver, task_id)
if updated_row is None or self._record_from_row(updated_row).status != "running":
await driver.rollback()
return None
await driver.commit()
except Exception:
with suppress(Exception):
await driver.rollback()
raise
claimed = self._record_from_row(updated_row) if updated_row is not None else None
if claimed is not None:
self._increment_queue_metric("claim")
return claimed
[docs]
async def claim_next(
self, *, queue: "str | None" = None, execution_backend: "str | None" = None
) -> "QueuedTaskRecord | None":
store = self._get_store()
if store.supports_skip_locked:
return await self._claim_next_skip_locked(store, queue=queue, execution_backend=execution_backend)
rows = await self._select_pending_rows(limit=10, queue=queue, execution_backend=execution_backend)
for row in rows:
task_id = UUID(str(row["id"]))
claimed = await self.claim_task(task_id)
if claimed is not None:
return claimed
return None
async def _claim_next_skip_locked(
self, store: "Any", *, queue: "str | None", execution_backend: "str | None"
) -> "QueuedTaskRecord | None":
"""Claim the next due task under ``SELECT ... FOR UPDATE SKIP LOCKED``.
Locks a single due row and claims it inside one transaction so
competing workers skip the locked row instead of colliding on the
optimistic CAS claim. The v1 fenced-claim contract is preserved: a row
that cannot be transitioned to ``running`` yields ``None``.
Returns:
The claimed task record, if a claim was available.
"""
with self._observe_queue_operation("claim", queue=queue, execution_backend=execution_backend):
async with self._session() as driver:
await driver.begin()
try:
now = _utc_now()
statement = store.select_claimable(
now=self._serialize_datetime(now), limit=1, queue=queue, execution_backend=execution_backend
)
stream_chunk_size = cast("int | None", getattr(store, "claim_select_stream_chunk_size", None))
if stream_chunk_size is None:
rows = await driver.select(statement)
row = cast("dict[str, Any] | None", rows[0] if rows else None)
else:
row = None
async for claimable_row in _select_stream(driver, statement, chunk_size=stream_chunk_size):
row = cast("dict[str, Any]", claimable_row)
break
if row is None:
await driver.rollback()
return None
record = self._record_from_row(row)
result = await driver.execute(
store.claim_task(
task_id=str(record.id),
due_at=self._serialize_datetime(now),
heartbeat_at=self._serialize_datetime(now),
started_at=self._serialize_datetime(now),
)
)
if result.rows_affected == 0:
await driver.rollback()
return None
updated_row = await self._select_task(driver, record.id)
if updated_row is None or self._record_from_row(updated_row).status != "running":
await driver.rollback()
return None
await driver.commit()
except Exception:
with suppress(Exception):
await driver.rollback()
raise
claimed = self._record_from_row(updated_row)
self._increment_queue_metric("claim")
return claimed
[docs]
async def complete_task(
self, task_id: "UUID", *, result: "Any" = None, expected_retry_count: "int | None" = None
) -> "QueuedTaskRecord | None":
now = _utc_now()
store = self._get_store()
with self._observe_queue_operation("complete", task_id=str(task_id)):
async with self._session() as driver:
await driver.begin()
try:
if expected_retry_count is not None:
existing_row = await self._select_task(driver, task_id)
if existing_row is None:
await driver.rollback()
return None
existing = self._record_from_row(existing_row)
if existing.status != "running" or existing.retry_count != expected_retry_count:
await driver.rollback()
self._increment_queue_metric("claim_lost")
return None
updated = await driver.execute(
store.complete_task(
task_id=str(task_id),
completed_at=self._serialize_datetime(now),
heartbeat_at=self._serialize_datetime(now),
result_json=store.serialize_json("result_json", result),
)
)
row = await self._select_task(driver, task_id) if updated.rows_affected else None
await driver.commit()
except Exception:
with suppress(Exception):
await driver.rollback()
raise
completed = self._record_from_row(row) if row is not None else None
if completed is not None:
self._increment_queue_metric("complete")
return completed
[docs]
async def fail_task(
self, task_id: "UUID", error: "str", *, retry: "bool" = True, expected_retry_count: "int | None" = None
) -> "QueuedTaskRecord | None":
with self._observe_queue_operation("fail", task_id=str(task_id), retry=retry):
async with self._session() as driver:
await driver.begin()
try:
row = await self._select_task(driver, task_id)
if row is None:
await driver.rollback()
return None
record = self._record_from_row(row)
if expected_retry_count is not None and (
record.status != "running" or record.retry_count != expected_retry_count
):
await driver.rollback()
self._increment_queue_metric("claim_lost")
return None
metric = "fail"
if retry and record.retry_count < record.max_retries:
await driver.execute(
self._get_store().retry_task(
task_id=str(task_id), error=error, retry_count=record.retry_count + 1
)
)
metric = "retry"
else:
now = _utc_now()
await driver.execute(
self._get_store().fail_task(
task_id=str(task_id),
completed_at=self._serialize_datetime(now),
heartbeat_at=self._serialize_datetime(now),
error=error,
)
)
updated_row = await self._select_task(driver, task_id)
await driver.commit()
except Exception:
with suppress(Exception):
await driver.rollback()
raise
updated_record = self._record_from_row(updated_row) if updated_row is not None else None
if updated_record is not None:
self._increment_queue_metric(metric)
return updated_record
[docs]
async def cancel_task(self, task_id: "UUID") -> "bool":
async with self._session() as driver:
await driver.begin()
try:
result = await driver.execute(
self._get_store().cancel_task(
task_id=str(task_id), completed_at=self._serialize_datetime(_utc_now())
)
)
await driver.commit()
except Exception:
with suppress(Exception):
await driver.rollback()
raise
return int(result.rows_affected) == 1
[docs]
async def touch_heartbeat(self, task_id: "UUID", *, expected_retry_count: "int | None" = None) -> "bool":
async with self._heartbeat_session() as driver:
await driver.begin()
try:
if expected_retry_count is not None:
row = await self._select_task(driver, task_id)
if row is None:
await driver.rollback()
return False
record = self._record_from_row(row)
if record.status != "running" or record.retry_count != expected_retry_count:
await driver.rollback()
return False
result = await driver.execute(
self._get_store().touch_heartbeat(
task_id=str(task_id), heartbeat_at=self._serialize_datetime(_utc_now())
)
)
rows_affected = int(result.rows_affected)
if rows_affected == 1:
touched = True
elif rows_affected == 0:
touched = False
else:
touched_row = await self._select_task(driver, task_id)
touched_record = self._record_from_row(touched_row) if touched_row is not None else None
touched = (
touched_record is not None
and touched_record.status == "running"
and (expected_retry_count is None or touched_record.retry_count == expected_retry_count)
)
await driver.commit()
except Exception:
with suppress(Exception):
await driver.rollback()
raise
return touched
[docs]
async def null_heartbeats(self, task_ids: "list[UUID]", *, expected_retry_count: "int | None" = None) -> "None":
if not task_ids:
return
async with self._heartbeat_session() as driver:
await driver.begin()
try:
filtered_task_ids = task_ids
if expected_retry_count is not None:
filtered_task_ids = []
for task_id in task_ids:
row = await self._select_task(driver, task_id)
if row is None:
continue
record = self._record_from_row(row)
if record.retry_count == expected_retry_count:
filtered_task_ids.append(task_id)
if filtered_task_ids:
await driver.execute(
self._get_store().null_heartbeats(task_ids=[str(task_id) for task_id in filtered_task_ids])
)
await driver.commit()
except Exception:
with suppress(Exception):
await driver.rollback()
raise
[docs]
async def requeue_stale_running(self, *, stale_after: "timedelta") -> "StaleTaskRecoveryResult":
cutoff = _utc_now() - stale_after
store = self._get_store()
result = StaleTaskRecoveryResult()
with self._observe_queue_operation("stale_recovered"):
async with self._session() as driver:
rows = await driver.select(store.list_stale_running(cutoff=self._serialize_datetime(cutoff)))
stack = StatementStack()
for row in cast("list[dict[str, Any]]", rows):
record = self._record_from_row(row)
requeue_on_stale = record.metadata.get("requeue_on_stale", True) is not False
if requeue_on_stale and record.retry_count < record.max_retries:
stack = stack.push_execute(
store.retry_task(
task_id=str(record.id), error="Task heartbeat stale", retry_count=record.retry_count + 1
)
)
result.requeued += 1
else:
now = _utc_now()
stack = stack.push_execute(
store.fail_task(
task_id=str(record.id),
completed_at=self._serialize_datetime(now),
heartbeat_at=self._serialize_datetime(now),
error="Task heartbeat stale",
)
)
result.failed += 1
result.failed_task_ids.append(record.id)
if not requeue_on_stale:
result.handler_needed += 1
result.handler_needed_task_ids.append(record.id)
if stack:
# Reset any implicit read transaction the SELECT may have opened so
# SQLSpec's stack runner owns the write transaction.
with suppress(Exception):
await driver.rollback()
try:
await driver.execute_stack(stack)
except Exception:
with suppress(Exception):
await driver.rollback()
raise
recovered = result.requeued + result.failed
if recovered:
self._increment_queue_metric("stale_recovered", float(recovered))
if result.requeued:
self._increment_queue_metric("retry", float(result.requeued))
if result.failed:
self._increment_queue_metric("stale_failed", float(result.failed))
return result
[docs]
async def set_execution_ref(
self, task_id: "UUID", execution_backend: "str", execution_ref: "str", *, execution_profile: "str | None" = None
) -> "QueuedTaskRecord | None":
async with self._session() as driver:
await driver.begin()
try:
result = await driver.execute(
self._get_store().set_execution_ref(
task_id=str(task_id),
execution_backend=execution_backend,
execution_profile=execution_profile,
execution_ref=execution_ref,
)
)
row = await self._select_task(driver, task_id) if result.rows_affected else None
await driver.commit()
except Exception:
with suppress(Exception):
await driver.rollback()
raise
return self._record_from_row(row) if row is not None else None
[docs]
async def set_execution_backend(
self, task_id: "UUID", execution_backend: "str", *, execution_profile: "str | None" = None
) -> "QueuedTaskRecord | None":
async with self._session() as driver:
await driver.begin()
try:
result = await driver.execute(
self._get_store().set_execution_backend(
task_id=str(task_id), execution_backend=execution_backend, execution_profile=execution_profile
)
)
row = await self._select_task(driver, task_id) if result.rows_affected else None
await driver.commit()
except Exception:
with suppress(Exception):
await driver.rollback()
raise
record = self._record_from_row(row) if row is not None else None
if record is not None:
await self.notify_new_task(record)
return record
[docs]
async def list_running_external(self, *, limit: "int | None" = None) -> "list[QueuedTaskRecord]":
async with self._session() as driver:
rows = await driver.select(self._get_store().list_running_external(limit=limit))
return [self._record_from_row(row) for row in cast("list[dict[str, Any]]", rows)]
[docs]
async def get_statistics(self) -> "QueueStatistics":
statistics = QueueStatistics()
async with self._session() as driver:
async for row in _select_stream(driver, self._get_store().list_all()):
status = _coerce_status(cast("dict[str, Any]", row)["status"])
setattr(statistics, status, getattr(statistics, status) + 1)
return statistics
[docs]
async def iter_all(self, *, chunk_size: "int" = 1000) -> "AsyncIterator[QueuedTaskRecord]":
"""Stream every queue record without materializing the full table.
Uses SQLSpec ``select_stream`` so large administrative scans and exports
consume rows in chunks of ``chunk_size`` rather than loading the entire
result set into memory. The backend session stays open for the duration
of iteration, so callers should consume the iterator promptly.
Yields:
Queue task records from the backing SQLSpec table.
"""
session = self._session()
driver = await session.__aenter__()
try:
async for row in _select_stream(driver, self._get_store().list_all(), chunk_size=chunk_size):
yield self._record_from_row(cast("dict[str, Any]", row))
except BaseException as exc:
if not await session.__aexit__(type(exc), exc, exc.__traceback__):
raise
else:
await session.__aexit__(None, None, None)
[docs]
async def list_completed_by_task(
self, task_name: "str", *, since: "datetime | None" = None, limit: "int" = 10
) -> "list[QueuedTaskRecord]":
async with self._session() as driver:
rows = await driver.select(
self._get_store().list_completed_by_task(
task_name=task_name, since=self._serialize_datetime(since), limit=limit
)
)
return [self._record_from_row(row) for row in cast("list[dict[str, Any]]", rows)]
[docs]
async def cleanup_terminal(self, before: "datetime") -> "int":
store = self._get_store()
before_str = self._serialize_datetime(before)
async with self._session() as driver:
await driver.begin()
try:
# Some drivers (e.g. psqlpy) cannot reliably report
# ``rows_affected`` for DELETE: they return ``-1`` as a
# sentinel for "unknown". Count first inside the same
# transaction so the cleanup count is always exact.
count_row = await driver.select_one_or_none(store.count_terminal(before=before_str))
deleted = int(count_row["terminal_count"]) if count_row is not None else 0
if deleted > 0:
await driver.execute(store.cleanup_terminal(before=before_str))
await driver.commit()
except Exception:
with suppress(Exception):
await driver.rollback()
raise
return deleted
[docs]
async def notify_new_task(self, record: "QueuedTaskRecord") -> "None":
"""Publish a SQLSpec event when configured queue work becomes available."""
if self._notifications_enabled and self._event_channel is not None and record.status in _DUE_STATUSES:
with self._observe_queue_operation("notify", task_id=str(record.id), queue=record.queue):
await self._event_channel.publish(
self._resolve_notification_channel(),
{
"task_id": str(record.id),
"task_name": record.task_name,
"queue": record.queue,
"execution_backend": record.execution_backend,
},
{"event_type": "litestar_queues.task_available"},
)
self._increment_queue_metric("notify")
[docs]
async def wait_for_notifications(self, timeout: "float | None" = None) -> "bool":
"""Wait for a SQLSpec event when queue notifications are configured.
Returns:
True when a notification was received.
"""
if not self._notifications_enabled or self._event_channel is None:
return await super().wait_for_notifications(timeout=timeout)
stream = self._event_channel.iter_events(
self._resolve_notification_channel(),
poll_interval=self._event_poll_interval if self._event_poll_interval is not None else timeout,
)
try:
if timeout is None:
event = await anext(stream)
else:
event = await asyncio.wait_for(anext(stream), timeout=timeout)
except asyncio.TimeoutError:
return False
finally:
await cast("Any", stream).aclose()
await self._event_channel.ack(event.event_id)
return True
@staticmethod
def _default_sqlspec_config() -> "Any":
from sqlspec.adapters.aiosqlite import AiosqliteConfig
return AiosqliteConfig()
def _resolve_table_name(self) -> "str":
if self._table_name is None:
queue_settings = _queue_extension_settings(self._sqlspec_config)
configured_table_name = _setting(queue_settings, "table_name") or DEFAULT_TABLE_NAME
self._table_name = validate_table_name(str(configured_table_name))
return self._table_name
def _resolve_create_schema(self) -> "bool":
if not self._manage_schema:
return False
return _resolve_bool(
self._create_schema, _queue_extension_settings(self._sqlspec_config), "create_schema", True
)
def _resolve_run_migrations(self) -> "bool":
if not self._manage_schema:
return False
return _resolve_bool(
self._run_migrations, _queue_extension_settings(self._sqlspec_config), "run_migrations", False
)
def _resolve_notification_channel(self) -> "str":
if self._notification_channel is not None:
self._notification_channel = _normalize_notification_channel(str(self._notification_channel))
else:
queue_settings = _queue_extension_settings(self._sqlspec_config)
configured_channel = _setting(queue_settings, "notification_channel") or DEFAULT_NOTIFICATION_CHANNEL
self._notification_channel = _normalize_notification_channel(str(configured_channel))
return self._notification_channel
def _configure_notifications(self) -> "None":
sqlspec_config = self._get_sqlspec_config()
queue_settings = _queue_extension_settings(sqlspec_config)
events_settings = _events_extension_settings(sqlspec_config)
notifications_requested = self._resolve_notifications_requested(queue_settings)
transport = self._select_notify_transport(sqlspec_config, queue_settings, events_settings)
if not self._notifications_should_enable(
notifications_requested, transport, sqlspec_config, queue_settings, events_settings
):
self._notifications_enabled = False
self._notification_backend = None
return
self._notifications_enabled = True
self._resolve_notification_channel()
if self._event_channel is None:
self._apply_event_settings(sqlspec_config, queue_settings, events_settings, transport)
self._event_channel = self._get_or_create_sqlspec().event_channel(sqlspec_config)
self._owns_event_channel = True
else:
# An injected channel already owns its backend; still resolve the
# configured poll interval so wait_for_notifications honors it.
self._resolve_event_poll_interval(queue_settings, events_settings)
self._notification_backend = cast("str | None", getattr(self._event_channel, "_backend_name", None))
def _resolve_notifications_requested(self, queue_settings: "dict[str, Any]") -> "bool | None":
notifications_requested = self._notifications_requested
if notifications_requested is None and "notifications" in queue_settings:
notifications_requested = bool(queue_settings["notifications"])
return notifications_requested
def _select_notify_transport(
self, sqlspec_config: "Any", queue_settings: "dict[str, Any]", events_settings: "dict[str, Any]"
) -> "str":
"""Resolve the effective wakeup transport.
Explicit ``queue_backend_config`` selections win over
``extension_config`` defaults, which in turn win over the per-adapter
capability gate.
Returns:
A wakeup transport name (``listen_notify``, ``listen_notify_durable``,
``table_queue``, or ``polling``), or an events backend name carried
over from existing ``event_backend`` configuration.
"""
explicit = self._notify_transport or _setting(queue_settings, "notify_transport")
if explicit is None:
explicit = (
self._event_backend or _setting(queue_settings, "event_backend") or events_settings.get("backend")
)
if explicit is not None:
return str(explicit)
return _adapter_notify_transport(resolve_adapter_name(sqlspec_config))
def _notifications_should_enable(
self,
notifications_requested: "bool | None",
transport: "str",
sqlspec_config: "Any",
queue_settings: "dict[str, Any]",
events_settings: "dict[str, Any]",
) -> "bool":
"""Decide whether push wakeups are active for the resolved transport.
Notifications stay opt-in: a bare backend config never auto-enables an
events channel (which keeps the frozen claim/lease contract and the
zero-config polling default intact). When a signal is present but the
adapter is gated to ``polling``, wakeups degrade to interval polling.
Returns:
True when an events channel should back worker wakeups.
"""
if notifications_requested is False:
return False
if self._event_channel is not None:
return True
events_present = _EVENT_EXTENSION_NAME in cast(
"dict[str, Any]", getattr(sqlspec_config, "extension_config", {}) or {}
)
explicit_signal = (
self._notify_transport is not None
or "notify_transport" in queue_settings
or self._event_backend is not None
or "event_backend" in queue_settings
or bool(events_settings)
or events_present
or notifications_requested is True
)
return explicit_signal and transport != _NOTIFY_TRANSPORT_POLLING
def _resolve_event_poll_interval(
self, queue_settings: "dict[str, Any]", events_settings: "dict[str, Any]"
) -> "None":
configured_poll_interval = self._event_poll_interval
if configured_poll_interval is None:
configured_poll_interval = _setting(queue_settings, "event_poll_interval")
if configured_poll_interval is None and "poll_interval" in events_settings:
configured_poll_interval = events_settings["poll_interval"]
if configured_poll_interval is not None:
self._event_poll_interval = float(configured_poll_interval)
def _apply_event_settings(
self,
sqlspec_config: "Any",
queue_settings: "dict[str, Any]",
events_settings: "dict[str, Any]",
transport: "str",
) -> "None":
merged_event_settings = dict(events_settings)
for name in _QUEUE_SETTING_EVENT_SETTINGS:
configured_events = queue_settings.get(name)
if isinstance(configured_events, dict):
merged_event_settings.update(configured_events)
merged_event_settings.update(self._event_settings)
merged_event_settings["backend"] = transport
configured_queue_table = self._event_queue_table or _setting(queue_settings, "event_queue_table")
if configured_queue_table is not None:
merged_event_settings["queue_table"] = str(configured_queue_table)
self._resolve_event_poll_interval(queue_settings, merged_event_settings)
if self._event_poll_interval is not None:
merged_event_settings["poll_interval"] = self._event_poll_interval
extension_config = dict(cast("dict[str, Any]", getattr(sqlspec_config, "extension_config", {}) or {}))
extension_config[_EVENT_EXTENSION_NAME] = merged_event_settings
sqlspec_config.extension_config = extension_config
migration_config = dict(cast("dict[str, Any]", getattr(sqlspec_config, "migration_config", {}) or {}))
sqlspec_config.set_migration_config(migration_config)
def _get_or_create_sqlspec(self) -> "SQLSpec":
if self._sqlspec is None:
self._sqlspec = SQLSpec()
return self._sqlspec
def _get_sqlspec_config(self) -> "Any":
if self._sqlspec_config is None:
registered_configs = tuple(cast("dict[int, Any]", self._get_or_create_sqlspec().configs).values())
if len(registered_configs) == 1:
self._sqlspec_config = registered_configs[0]
elif len(registered_configs) > 1:
msg = (
"SQLSpecQueueBackend received a SQLSpec manager with multiple configs; "
"pass config to select the queue database."
)
raise QueueConfigurationError(msg)
else:
self._sqlspec_config = self._default_sqlspec_config()
return self._sqlspec_config
def _get_store(self) -> "Any":
if self._store is None:
self._store = create_queue_store(
self._get_sqlspec_config(),
table_name=self._resolve_table_name(),
column_map=self._column_map,
native_json_columns=self._native_json_columns,
manage_schema=self._manage_schema,
)
return self._store
@asynccontextmanager
async def _session(self) -> "AsyncIterator[Any]":
if not self._opened or self._sqlspec is None:
msg = "SQLSpecQueueBackend.open() must be called before using the backend."
raise RuntimeError(msg)
sqlspec_config = self._get_sqlspec_config()
async with _bridge_session(self._get_or_create_sqlspec(), sqlspec_config) as driver:
yield driver
@asynccontextmanager
async def _heartbeat_session(self) -> "AsyncIterator[Any]":
"""Yield a driver bound to the dedicated heartbeat pool when configured.
Falls back to the main pool when ``heartbeat_pool_config`` is not set,
or when the dedicated pool failed to register at ``open()`` time.
Yields:
A SQLSpec driver bound to the heartbeat or main pool.
Raises:
RuntimeError: When ``open()`` has not been called on the backend.
"""
if not self._opened or self._sqlspec is None:
msg = "SQLSpecQueueBackend.open() must be called before using the backend."
raise RuntimeError(msg)
if self._heartbeat_pool_enabled and self._heartbeat_pool_registered and self._heartbeat_pool_config is not None:
async with _bridge_session(self._sqlspec, self._heartbeat_pool_config) as driver:
yield driver
else:
async with self._session() as driver:
yield driver
def _register_heartbeat_pool(self) -> "None":
"""Register the dedicated heartbeat pool with the SQLSpec manager.
Best effort. On failure the backend logs a warning and continues with
the main pool for heartbeats.
"""
if (
self._heartbeat_pool_enabled
and self._heartbeat_pool_config is not None
and not self._heartbeat_pool_registered
):
try:
self._get_or_create_sqlspec().add_config(self._heartbeat_pool_config)
except Exception:
getLogger("litestar_queues").warning(
"SQLSpecQueueBackend heartbeat pool registration failed; "
"falling back to main pool for heartbeat writes.",
exc_info=True,
)
self._heartbeat_pool_enabled = False
self._heartbeat_pool_registered = False
else:
self._heartbeat_pool_registered = True
async def _close_heartbeat_pool(self) -> "None":
"""Close the dedicated heartbeat pool if the backend opened one."""
if self._heartbeat_pool_registered and self._heartbeat_pool_config is not None:
try:
close_result = self._heartbeat_pool_config.close_pool()
if isawaitable(close_result):
await close_result
except Exception:
getLogger("litestar_queues").debug("SQLSpecQueueBackend heartbeat pool close failed.", exc_info=True)
self._heartbeat_pool_registered = False
async def _select_pending_rows(
self, *, limit: "int", queue: "str | None", execution_backend: "str | None"
) -> "list[dict[str, Any]]":
async with self._session() as driver:
rows = await driver.select(
self._get_store().list_pending(
now=self._serialize_datetime(_utc_now()),
limit=limit,
queue=queue,
execution_backend=execution_backend,
)
)
return cast("list[dict[str, Any]]", rows)
async def _select_task(self, driver: "Any", task_id: "UUID") -> "dict[str, Any] | None":
row = await driver.select_one_or_none(self._get_store().select_task(str(task_id)))
return cast("dict[str, Any] | None", row)
async def _select_task_by_key(self, driver: "Any", key: "str") -> "dict[str, Any] | None":
row = await driver.select_one_or_none(self._get_store().select_task_by_key(key))
return cast("dict[str, Any] | None", row)
async def _clear_key(self, driver: "Any", task_id: "UUID") -> "None":
await driver.execute(self._get_store().clear_key(task_id=str(task_id)))
def _get_observability_runtime(self) -> "Any | None":
if not self._queue_observability:
return None
return self._get_sqlspec_config().get_observability_runtime()
@contextmanager
def _observe_queue_operation(self, operation: "str", **attributes: "Any") -> "Iterator[None]":
runtime = self._get_observability_runtime()
if runtime is None:
yield
return
span_attributes = {
"sqlspec.queue.operation": operation,
**{f"litestar_queues.{key}": value for key, value in attributes.items() if value is not None},
}
span = runtime.start_span(f"sqlspec.queue.{operation}", attributes=span_attributes)
error: "Exception | None" = None
try:
yield
except Exception as exc:
error = exc
raise
finally:
if span is not None:
runtime.end_span(span, error=error)
def _increment_queue_metric(self, name: "str", amount: "float" = 1.0) -> "None":
runtime = self._get_observability_runtime()
if runtime is not None and amount:
runtime.increment_metric(f"queue.{name}", amount)
async def _existing_records_by_key(
self, driver: "Any", store: "Any", keys: "list[str]"
) -> "dict[str, QueuedTaskRecord]":
"""Return a map of deduplication key to existing record for the given keys."""
existing: "dict[str, QueuedTaskRecord]" = {}
if not keys:
return existing
rows = await driver.select(store.select_tasks_by_keys(keys))
for row in cast("list[dict[str, Any]]", rows):
record = self._record_from_row(row)
if record.key is not None:
existing[record.key] = record
return existing
def _plan_bulk_enqueue(
self, specs: "Sequence[EnqueueSpec]", existing_by_key: "dict[str, QueuedTaskRecord]", now: "datetime"
) -> "tuple[list[QueuedTaskRecord], list[QueuedTaskRecord], list[UUID]]":
"""Resolve deduplication keys and build records, preserving input order.
Returns the ordered result records, the subset that must be inserted, and
the ids of terminal-key rows whose key must be cleared before insert.
Active (non-terminal) keys, whether already persisted or earlier in the
batch, reuse the existing record instead of inserting a duplicate.
Returns:
Ordered result records, records to insert, and terminal-key ids to clear.
"""
results: "list[QueuedTaskRecord]" = []
to_insert: "list[QueuedTaskRecord]" = []
terminal_keys_to_clear: "list[UUID]" = []
batch_new_by_key: "dict[str, QueuedTaskRecord]" = {}
for spec in specs:
key = spec.key
if key is not None:
reused = self._reuse_for_key(key, existing_by_key, batch_new_by_key, terminal_keys_to_clear)
if reused is not None:
results.append(reused)
continue
record = self._record_from_spec(spec, now)
results.append(record)
to_insert.append(record)
if key is not None:
batch_new_by_key[key] = record
return results, to_insert, terminal_keys_to_clear
@staticmethod
def _reuse_for_key(
key: "str",
existing_by_key: "dict[str, QueuedTaskRecord]",
batch_new_by_key: "dict[str, QueuedTaskRecord]",
terminal_keys_to_clear: "list[UUID]",
) -> "QueuedTaskRecord | None":
"""Return the record to reuse for ``key``, or ``None`` if a new row is needed.
Records a terminal key for clearing so its row can be replaced.
"""
active = existing_by_key.get(key)
if active is not None and not active.is_terminal:
return active
earlier = batch_new_by_key.get(key)
if earlier is not None:
return earlier
if active is not None:
terminal_keys_to_clear.append(active.id)
del existing_by_key[key]
return None
@staticmethod
def _record_from_spec(spec: "EnqueueSpec", now: "datetime") -> "QueuedTaskRecord":
return QueuedTaskRecord(
task_name=spec.task_name,
args=spec.args,
kwargs=dict(spec.kwargs or {}),
queue=spec.queue,
execution_backend=spec.execution_backend,
execution_profile=spec.execution_profile,
status="scheduled" if spec.scheduled_at is not None and spec.scheduled_at > now else "pending",
priority=spec.priority,
max_retries=spec.max_retries,
scheduled_at=spec.scheduled_at,
key=spec.key,
metadata=dict(spec.metadata or {}),
)
async def _bulk_insert(self, driver: "Any", store: "Any", records: "list[QueuedTaskRecord]") -> "None":
"""Insert records using the adapter's fastest available bulk tier."""
values = store.bulk_values([self._params_from_record(record) for record in records])
if store.supports_native_bulk_ingest:
await driver.load_from_records(store.table_name, values)
else:
await driver.execute_many(store.insert_tasks_template(), values)
def _serialize_datetime(self, value: "datetime | None") -> "datetime | str | None":
serialized = _serialize_datetime(value)
if serialized is not None and self._get_store().bind_datetime_as_text:
return serialized.isoformat()
return serialized
def _params_from_record(self, record: "QueuedTaskRecord") -> "dict[str, Any]":
store = self._get_store()
return {
"args_json": store.serialize_json("args_json", list(record.args)),
"completed_at": self._serialize_datetime(record.completed_at),
"created_at": self._serialize_datetime(record.created_at),
"error": record.error,
"execution_backend": record.execution_backend,
"execution_profile": record.execution_profile,
"execution_ref": record.execution_ref,
"heartbeat_at": self._serialize_datetime(record.heartbeat_at),
"id": str(record.id),
"kwargs_json": store.serialize_json("kwargs_json", record.kwargs),
"max_retries": record.max_retries,
"metadata_json": store.serialize_json("metadata_json", record.metadata),
"priority": record.priority,
"queue": record.queue,
"result_json": store.serialize_json("result_json", record.result),
"retry_count": record.retry_count,
"scheduled_at": self._serialize_datetime(record.scheduled_at),
"started_at": self._serialize_datetime(record.started_at),
"status": record.status,
"task_key": record.key,
"task_name": record.task_name,
}
def _record_from_row(self, row: "dict[str, Any]") -> "QueuedTaskRecord":
store = self._get_store()
args = store.deserialize_json("args_json", row["args_json"])
kwargs = store.deserialize_json("kwargs_json", row["kwargs_json"])
metadata = store.deserialize_json("metadata_json", row["metadata_json"])
return QueuedTaskRecord(
id=UUID(str(row["id"])),
task_name=str(row["task_name"]),
args=tuple(args),
kwargs=kwargs,
queue=str(row["queue"]),
execution_backend=str(row["execution_backend"]),
execution_profile=cast("str | None", row["execution_profile"]),
execution_ref=cast("str | None", row["execution_ref"]),
status=_coerce_status(row["status"]),
priority=int(row["priority"]),
max_retries=int(row["max_retries"]),
retry_count=int(row["retry_count"]),
scheduled_at=_deserialize_datetime(row["scheduled_at"]),
created_at=cast("datetime", _deserialize_datetime(row["created_at"])),
started_at=_deserialize_datetime(row["started_at"]),
completed_at=_deserialize_datetime(row["completed_at"]),
heartbeat_at=_deserialize_datetime(row["heartbeat_at"]),
result=store.deserialize_json("result_json", row["result_json"]),
error=cast("str | None", row["error"]),
key=cast("str | None", row["task_key"]),
metadata=metadata,
)
class _ManagedAsyncDriver:
"""Expose sync SQLSpec driver methods through SQLSpec's managed async bridge."""
__slots__ = ("_driver",)
def __init__(self, driver: "Any") -> "None":
self._driver = driver
def __getattr__(self, name: "str") -> "Any":
attr = getattr(self._driver, name)
if callable(attr):
return async_(attr)
return attr
@asynccontextmanager
async def _bridge_session(sqlspec_manager: "Any", sqlspec_config: "Any") -> "AsyncIterator[Any]":
"""Yield a SQLSpec driver regardless of sync/async config.
Sync SQLSpec configs (``SqliteConfig``, ``DuckDBConfig``, ``MysqlConnectorSyncConfig``, etc.)
return sync context managers and sync drivers. They are bridged with
``sqlspec.utils.sync_tools.async_`` so blocking operations use SQLSpec's
managed executor and honor ``SQLSPEC_ASYNC_THREAD_LIMIT``.
Yields:
A SQLSpec driver whose methods can be awaited regardless of whether the
underlying config is sync or async.
"""
session_cm = sqlspec_manager.provide_session(sqlspec_config)
if sqlspec_config.is_async:
async with session_cm as driver:
yield driver
else:
driver = await async_(session_cm.__enter__)()
try:
yield _ManagedAsyncDriver(driver)
except BaseException as exc:
if not await async_(session_cm.__exit__)(type(exc), exc, exc.__traceback__):
raise
else:
await async_(session_cm.__exit__)(None, None, None)
async def _select_stream(driver: "Any", statement: "Any", *, chunk_size: "int | None" = None) -> "AsyncIterator[Any]":
"""Yield rows from SQLSpec async and sync stream implementations.
Yields:
Rows returned by the SQLSpec statement.
"""
if isinstance(driver, _ManagedAsyncDriver):
rows = await driver.select(statement)
for row in rows:
yield row
else:
if chunk_size is None:
stream = driver.select_stream(statement)
else:
stream = driver.select_stream(statement, chunk_size=chunk_size)
if isawaitable(stream):
stream = await stream
async for row in stream:
yield row
def _utc_now() -> "datetime":
return datetime.now(timezone.utc)
def _extract_count(result: "Any") -> "int":
"""Pull a non-negative ``COUNT(*)`` value off a SQLSpec result.
Returns:
The count value from the first result row, or ``0`` when unavailable.
"""
rows = getattr(result, "data", None) or []
if rows:
row = rows[0]
if isinstance(row, dict):
return int(next(iter(row.values())))
if isinstance(row, (list, tuple)):
return int(row[0])
return 0
def _serialize_datetime(value: "datetime | None") -> "datetime | None":
if value is None:
return None
if value.tzinfo is None:
value = value.replace(tzinfo=timezone.utc)
return value.astimezone(timezone.utc)
def _deserialize_datetime(value: "Any") -> "datetime | None":
if value is None:
return None
value_text = str(value)
try:
parsed = datetime.fromisoformat(value_text)
except ValueError:
parsed = datetime.strptime(value_text.upper(), "%d-%b-%y").replace(tzinfo=timezone.utc)
if parsed.tzinfo is None:
return parsed.replace(tzinfo=timezone.utc)
return parsed.astimezone(timezone.utc)
def _coerce_status(value: "Any") -> "TaskStatus":
status = str(value)
if status not in {"cancelled", "completed", "failed", "pending", "running", "scheduled"}:
msg = f"Unknown queued task status from SQLSpec queue backend: {status!r}"
raise ValueError(msg)
return cast("TaskStatus", status)
def _queue_extension_settings(sqlspec_config: "Any | None") -> "dict[str, Any]":
if sqlspec_config is None:
return {}
extension_config = cast("dict[str, Any]", getattr(sqlspec_config, "extension_config", {}) or {})
return dict(cast("dict[str, Any]", extension_config.get(QUEUE_EXTENSION_NAME, {}) or {}))
async def _create_schema_statements(store: "Any", driver: "Any") -> "list[str]":
create_for_driver = getattr(store, "create_statements_for_driver", None)
if callable(create_for_driver):
result = create_for_driver(driver)
if isawaitable(result):
return cast("list[str]", await result)
return cast("list[str]", result)
return cast("list[str]", store.create_statements())
def _events_extension_settings(sqlspec_config: "Any | None") -> "dict[str, Any]":
if sqlspec_config is None:
return {}
extension_config = cast("dict[str, Any]", getattr(sqlspec_config, "extension_config", {}) or {})
return dict(cast("dict[str, Any]", extension_config.get(_EVENT_EXTENSION_NAME, {}) or {}))
def _setting(queue_settings: "dict[str, Any]", *names: "str") -> "Any":
for name in names:
if name in queue_settings:
return queue_settings[name]
return None
def _resolve_bool(value: "bool | None", queue_settings: "dict[str, Any]", key: "str", default: "bool") -> "bool":
if value is not None:
return value
if key in queue_settings:
return bool(queue_settings[key])
return default
def _normalize_notification_channel(channel: "str") -> "str":
try:
return str(normalize_event_channel_name(channel))
except Exception as exc:
msg = f"Invalid SQLSpec queue notification channel: {channel!r}"
raise QueueConfigurationError(msg) from exc