Source code for litestar_queues.backends.sqlspec.backend

"""SQLSpec queue backend."""

import asyncio
from contextlib import asynccontextmanager, contextmanager, suppress
from datetime import datetime, timedelta, timezone
from inspect import isawaitable
from logging import getLogger
from typing import TYPE_CHECKING, Any, cast
from uuid import UUID

from sqlspec import SQLSpec, StatementStack
from sqlspec.extensions.events import normalize_event_channel_name, resolve_adapter_name
from sqlspec.utils.sync_tools import async_

from litestar_queues.backends.base import BaseQueueBackend
from litestar_queues.backends.sqlspec.config import DEFAULT_NOTIFICATION_CHANNEL, SQLSpecBackendConfig
from litestar_queues.backends.sqlspec.extension import QUEUE_EXTENSION_NAME, configure_queue_migration_extension
from litestar_queues.backends.sqlspec.schema import (
    DEFAULT_TABLE_NAME,
    validate_column_map,
    validate_native_json_columns,
    validate_table_name,
)
from litestar_queues.backends.sqlspec.stores.factory import create_queue_store
from litestar_queues.exceptions import QueueConfigurationError
from litestar_queues.models import (
    EnqueueSpec,
    QueueBackendCapabilities,
    QueuedTaskRecord,
    QueueStatistics,
    StaleTaskRecoveryResult,
    TaskStatus,
)

if TYPE_CHECKING:
    from collections.abc import AsyncIterator, Iterator, Sequence

    from litestar_queues.config import QueueConfig

__all__ = ("SQLSpecQueueBackend",)

_DUE_STATUSES = ("pending", "scheduled")
_DURABLE_NOTIFICATION_BACKENDS = frozenset({"aq", "listen_notify_durable", "table_queue", "txeventq"})
_EVENT_EXTENSION_NAME = "events"
_QUEUE_SETTING_EVENT_SETTINGS = ("event_settings", "events")

_NOTIFY_TRANSPORT_POLLING = "polling"
# Adapter families that can push worker wakeups. Postgres-over-asyncpg ships the
# durable LISTEN/NOTIFY hybrid; psycopg/psqlpy fall back to the durable table
# queue until their LISTEN/NOTIFY path lands upstream. Everything else polls.
_NOTIFY_DURABLE_ADAPTERS = frozenset({"asyncpg"})
_NOTIFY_TABLE_QUEUE_ADAPTERS = frozenset({"psycopg", "psqlpy"})


def _adapter_notify_transport(adapter_name: "str | None") -> "str":
    """Return the default wakeup transport for a SQLSpec adapter.

    The wakeup transport is gated purely by adapter knowledge so backends only
    advertise push wakeups where the driver can deliver them.

    Returns:
        ``"listen_notify_durable"`` for asyncpg, ``"table_queue"`` for
        psycopg/psqlpy, otherwise ``"polling"``.
    """
    if adapter_name in _NOTIFY_DURABLE_ADAPTERS:
        return "listen_notify_durable"
    if adapter_name in _NOTIFY_TABLE_QUEUE_ADAPTERS:
        return "table_queue"
    return _NOTIFY_TRANSPORT_POLLING


[docs] class SQLSpecQueueBackend(BaseQueueBackend): """SQLSpec-backed queue backend.""" __slots__ = ( "_column_map", "_create_schema", "_event_backend", "_event_channel", "_event_poll_interval", "_event_queue_table", "_event_settings", "_heartbeat_pool_config", "_heartbeat_pool_enabled", "_heartbeat_pool_registered", "_manage_schema", "_native_json_columns", "_notification_backend", "_notification_channel", "_notifications_enabled", "_notifications_requested", "_notify_transport", "_opened", "_owns_event_channel", "_owns_sqlspec", "_queue_observability", "_run_migrations", "_sqlspec", "_sqlspec_config", "_store", "_table_name", )
[docs] def __init__( self, config: "QueueConfig | None" = None, *, backend_config: "SQLSpecBackendConfig | None" = None ) -> "None": super().__init__(config=config) backend_config = backend_config or SQLSpecBackendConfig() self._column_map = validate_column_map(backend_config.column_map) self._native_json_columns = validate_native_json_columns(frozenset(backend_config.native_json_columns)) self._manage_schema = backend_config.manage_schema self._sqlspec = backend_config.sqlspec self._sqlspec_config = backend_config.config self._heartbeat_pool_config = backend_config.heartbeat_pool_config self._heartbeat_pool_enabled = self._heartbeat_pool_config is not None self._heartbeat_pool_registered = False self._owns_sqlspec = self._sqlspec is None self._table_name = ( validate_table_name(backend_config.table_name) if backend_config.table_name is not None else None ) self._create_schema = backend_config.create_schema self._run_migrations = backend_config.run_migrations self._event_channel = backend_config.event_channel self._owns_event_channel = self._event_channel is None self._notifications_requested = backend_config.notifications self._notification_channel = backend_config.notification_channel self._notify_transport = backend_config.notify_transport self._event_backend = backend_config.event_backend self._event_queue_table = backend_config.event_queue_table self._event_poll_interval = backend_config.event_poll_interval self._event_settings = dict(backend_config.event_settings) self._queue_observability = backend_config.queue_observability self._notification_backend: "str | None" = getattr(self._event_channel, "_backend_name", None) self._notifications_enabled = self._event_channel is not None self._store: "Any | None" = None self._opened = False
[docs] async def open(self) -> "bool": """Open SQLSpec resources. Returns: True when SQLSpec resources are ready. """ if self._opened: return True self._get_or_create_sqlspec() self._resolve_table_name() self._configure_notifications() self._register_heartbeat_pool() self._opened = True if self._resolve_run_migrations(): await self.run_migrations() if self._resolve_create_schema(): await self.create_schema() return True
[docs] async def close(self) -> "None": """Close SQLSpec resources.""" await self._close_heartbeat_pool() if self._owns_event_channel and self._event_channel is not None: await self._event_channel.shutdown() self._event_channel = None if self._owns_sqlspec and self._sqlspec is not None: await self._sqlspec.close_all_pools() self._sqlspec = None self._opened = False
@property def capabilities(self) -> "QueueBackendCapabilities": """Backend behavior capabilities.""" notification_backend = self._notification_backend return QueueBackendCapabilities( supports_notifications=self._notifications_enabled, notification_backend=notification_backend, notifications_durable=notification_backend in _DURABLE_NOTIFICATION_BACKENDS, )
[docs] async def create_schema(self) -> "None": """Create the SQLSpec queue table and indexes.""" if self._manage_schema: async with self._session() as driver: for statement in await _create_schema_statements(self._get_store(), driver): await driver.execute_script(statement) await driver.commit()
[docs] async def run_migrations(self) -> "None": """Apply packaged SQLSpec migrations.""" if self._manage_schema: sqlspec_config = self._get_sqlspec_config() configure_queue_migration_extension(sqlspec_config, table_name=self._resolve_table_name()) await sqlspec_config.migrate_up(echo=False)
[docs] async def enqueue( self, task_name: "str", *, args: "tuple[Any, ...]" = (), kwargs: "dict[str, Any] | None" = None, queue: "str" = "default", priority: "int" = 0, max_retries: "int" = 0, scheduled_at: "datetime | None" = None, key: "str | None" = None, execution_backend: "str" = "local", execution_profile: "str | None" = None, metadata: "dict[str, Any] | None" = None, ) -> "QueuedTaskRecord": with self._observe_queue_operation("enqueue", queue=queue, task_name=task_name): async with self._session() as driver: await driver.begin() try: if key is not None: existing_row = await self._select_task_by_key(driver, key) if existing_row is not None: existing = self._record_from_row(existing_row) if not existing.is_terminal: await driver.rollback() return existing await self._clear_key(driver, existing.id) now = _utc_now() record = QueuedTaskRecord( task_name=task_name, args=args, kwargs=dict(kwargs or {}), queue=queue, execution_backend=execution_backend, execution_profile=execution_profile, status="scheduled" if scheduled_at is not None and scheduled_at > now else "pending", priority=priority, max_retries=max_retries, scheduled_at=scheduled_at, key=key, metadata=dict(metadata or {}), ) await driver.execute(self._get_store().insert_task(self._params_from_record(record))) await driver.commit() except Exception: with suppress(Exception): await driver.rollback() raise self._increment_queue_metric("enqueue") await self.notify_new_task(record) return record
[docs] async def enqueue_many(self, specs: "Sequence[EnqueueSpec]") -> "list[QueuedTaskRecord]": """Persist many tasks via the adapter's fastest bulk path. Resolves existing deduplication keys in one round trip, then inserts the remaining rows through the native Arrow ingest path (:meth:`load_from_records`) when the adapter supports it, otherwise via a batched ``execute_many``. Returns records in input order, with existing non-terminal keyed tasks returned as-is (no duplicate insert) to match the semantics of :meth:`enqueue`. Returns: Queue task records in the same order as ``specs``. """ if not specs: return [] store = self._get_store() now = _utc_now() keyed = [spec.key for spec in specs if spec.key is not None] with self._observe_queue_operation("enqueue", task_count=len(specs)): async with self._session() as driver: await driver.begin() try: existing_by_key = await self._existing_records_by_key(driver, store, keyed) results, to_insert, terminal_keys = self._plan_bulk_enqueue(specs, existing_by_key, now) for task_id in terminal_keys: await driver.execute(store.clear_key(task_id=str(task_id))) if to_insert: await self._bulk_insert(driver, store, to_insert) await driver.commit() except Exception: with suppress(Exception): await driver.rollback() raise self._increment_queue_metric("enqueue", float(len(to_insert))) for record in to_insert: await self.notify_new_task(record) return results
[docs] async def get_task(self, task_id: "UUID") -> "QueuedTaskRecord | None": async with self._session() as driver: row = await self._select_task(driver, task_id) return self._record_from_row(row) if row is not None else None
[docs] async def get_task_by_key(self, key: "str") -> "QueuedTaskRecord | None": async with self._session() as driver: row = await self._select_task_by_key(driver, key) return self._record_from_row(row) if row is not None else None
[docs] async def list_pending( self, *, limit: "int" = 1, queue: "str | None" = None, execution_backend: "str | None" = None ) -> "list[QueuedTaskRecord]": rows = await self._select_pending_rows(limit=limit, queue=queue, execution_backend=execution_backend) return [self._record_from_row(row) for row in rows]
[docs] async def claim_task(self, task_id: "UUID") -> "QueuedTaskRecord | None": with self._observe_queue_operation("claim", task_id=str(task_id)): async with self._session() as driver: await driver.begin() try: row = await self._select_task(driver, task_id) if row is None: await driver.rollback() return None record = self._record_from_row(row) if record.status not in _DUE_STATUSES or not record.is_due: await driver.rollback() return None now = _utc_now() result = await driver.execute( self._get_store().claim_task( task_id=str(task_id), due_at=self._serialize_datetime(now), heartbeat_at=self._serialize_datetime(now), started_at=self._serialize_datetime(now), ) ) if result.rows_affected == 0: await driver.rollback() return None updated_row = await self._select_task(driver, task_id) if updated_row is None or self._record_from_row(updated_row).status != "running": await driver.rollback() return None await driver.commit() except Exception: with suppress(Exception): await driver.rollback() raise claimed = self._record_from_row(updated_row) if updated_row is not None else None if claimed is not None: self._increment_queue_metric("claim") return claimed
[docs] async def claim_next( self, *, queue: "str | None" = None, execution_backend: "str | None" = None ) -> "QueuedTaskRecord | None": store = self._get_store() if store.supports_skip_locked: return await self._claim_next_skip_locked(store, queue=queue, execution_backend=execution_backend) rows = await self._select_pending_rows(limit=10, queue=queue, execution_backend=execution_backend) for row in rows: task_id = UUID(str(row["id"])) claimed = await self.claim_task(task_id) if claimed is not None: return claimed return None
async def _claim_next_skip_locked( self, store: "Any", *, queue: "str | None", execution_backend: "str | None" ) -> "QueuedTaskRecord | None": """Claim the next due task under ``SELECT ... FOR UPDATE SKIP LOCKED``. Locks a single due row and claims it inside one transaction so competing workers skip the locked row instead of colliding on the optimistic CAS claim. The v1 fenced-claim contract is preserved: a row that cannot be transitioned to ``running`` yields ``None``. Returns: The claimed task record, if a claim was available. """ with self._observe_queue_operation("claim", queue=queue, execution_backend=execution_backend): async with self._session() as driver: await driver.begin() try: now = _utc_now() statement = store.select_claimable( now=self._serialize_datetime(now), limit=1, queue=queue, execution_backend=execution_backend ) stream_chunk_size = cast("int | None", getattr(store, "claim_select_stream_chunk_size", None)) if stream_chunk_size is None: rows = await driver.select(statement) row = cast("dict[str, Any] | None", rows[0] if rows else None) else: row = None async for claimable_row in _select_stream(driver, statement, chunk_size=stream_chunk_size): row = cast("dict[str, Any]", claimable_row) break if row is None: await driver.rollback() return None record = self._record_from_row(row) result = await driver.execute( store.claim_task( task_id=str(record.id), due_at=self._serialize_datetime(now), heartbeat_at=self._serialize_datetime(now), started_at=self._serialize_datetime(now), ) ) if result.rows_affected == 0: await driver.rollback() return None updated_row = await self._select_task(driver, record.id) if updated_row is None or self._record_from_row(updated_row).status != "running": await driver.rollback() return None await driver.commit() except Exception: with suppress(Exception): await driver.rollback() raise claimed = self._record_from_row(updated_row) self._increment_queue_metric("claim") return claimed
[docs] async def complete_task( self, task_id: "UUID", *, result: "Any" = None, expected_retry_count: "int | None" = None ) -> "QueuedTaskRecord | None": now = _utc_now() store = self._get_store() with self._observe_queue_operation("complete", task_id=str(task_id)): async with self._session() as driver: await driver.begin() try: if expected_retry_count is not None: existing_row = await self._select_task(driver, task_id) if existing_row is None: await driver.rollback() return None existing = self._record_from_row(existing_row) if existing.status != "running" or existing.retry_count != expected_retry_count: await driver.rollback() self._increment_queue_metric("claim_lost") return None updated = await driver.execute( store.complete_task( task_id=str(task_id), completed_at=self._serialize_datetime(now), heartbeat_at=self._serialize_datetime(now), result_json=store.serialize_json("result_json", result), ) ) row = await self._select_task(driver, task_id) if updated.rows_affected else None await driver.commit() except Exception: with suppress(Exception): await driver.rollback() raise completed = self._record_from_row(row) if row is not None else None if completed is not None: self._increment_queue_metric("complete") return completed
[docs] async def fail_task( self, task_id: "UUID", error: "str", *, retry: "bool" = True, expected_retry_count: "int | None" = None ) -> "QueuedTaskRecord | None": with self._observe_queue_operation("fail", task_id=str(task_id), retry=retry): async with self._session() as driver: await driver.begin() try: row = await self._select_task(driver, task_id) if row is None: await driver.rollback() return None record = self._record_from_row(row) if expected_retry_count is not None and ( record.status != "running" or record.retry_count != expected_retry_count ): await driver.rollback() self._increment_queue_metric("claim_lost") return None metric = "fail" if retry and record.retry_count < record.max_retries: await driver.execute( self._get_store().retry_task( task_id=str(task_id), error=error, retry_count=record.retry_count + 1 ) ) metric = "retry" else: now = _utc_now() await driver.execute( self._get_store().fail_task( task_id=str(task_id), completed_at=self._serialize_datetime(now), heartbeat_at=self._serialize_datetime(now), error=error, ) ) updated_row = await self._select_task(driver, task_id) await driver.commit() except Exception: with suppress(Exception): await driver.rollback() raise updated_record = self._record_from_row(updated_row) if updated_row is not None else None if updated_record is not None: self._increment_queue_metric(metric) return updated_record
[docs] async def cancel_task(self, task_id: "UUID") -> "bool": async with self._session() as driver: await driver.begin() try: result = await driver.execute( self._get_store().cancel_task( task_id=str(task_id), completed_at=self._serialize_datetime(_utc_now()) ) ) await driver.commit() except Exception: with suppress(Exception): await driver.rollback() raise return int(result.rows_affected) == 1
[docs] async def touch_heartbeat(self, task_id: "UUID", *, expected_retry_count: "int | None" = None) -> "bool": async with self._heartbeat_session() as driver: await driver.begin() try: if expected_retry_count is not None: row = await self._select_task(driver, task_id) if row is None: await driver.rollback() return False record = self._record_from_row(row) if record.status != "running" or record.retry_count != expected_retry_count: await driver.rollback() return False result = await driver.execute( self._get_store().touch_heartbeat( task_id=str(task_id), heartbeat_at=self._serialize_datetime(_utc_now()) ) ) rows_affected = int(result.rows_affected) if rows_affected == 1: touched = True elif rows_affected == 0: touched = False else: touched_row = await self._select_task(driver, task_id) touched_record = self._record_from_row(touched_row) if touched_row is not None else None touched = ( touched_record is not None and touched_record.status == "running" and (expected_retry_count is None or touched_record.retry_count == expected_retry_count) ) await driver.commit() except Exception: with suppress(Exception): await driver.rollback() raise return touched
[docs] async def null_heartbeats(self, task_ids: "list[UUID]", *, expected_retry_count: "int | None" = None) -> "None": if not task_ids: return async with self._heartbeat_session() as driver: await driver.begin() try: filtered_task_ids = task_ids if expected_retry_count is not None: filtered_task_ids = [] for task_id in task_ids: row = await self._select_task(driver, task_id) if row is None: continue record = self._record_from_row(row) if record.retry_count == expected_retry_count: filtered_task_ids.append(task_id) if filtered_task_ids: await driver.execute( self._get_store().null_heartbeats(task_ids=[str(task_id) for task_id in filtered_task_ids]) ) await driver.commit() except Exception: with suppress(Exception): await driver.rollback() raise
[docs] async def requeue_stale_running(self, *, stale_after: "timedelta") -> "StaleTaskRecoveryResult": cutoff = _utc_now() - stale_after store = self._get_store() result = StaleTaskRecoveryResult() with self._observe_queue_operation("stale_recovered"): async with self._session() as driver: rows = await driver.select(store.list_stale_running(cutoff=self._serialize_datetime(cutoff))) stack = StatementStack() for row in cast("list[dict[str, Any]]", rows): record = self._record_from_row(row) requeue_on_stale = record.metadata.get("requeue_on_stale", True) is not False if requeue_on_stale and record.retry_count < record.max_retries: stack = stack.push_execute( store.retry_task( task_id=str(record.id), error="Task heartbeat stale", retry_count=record.retry_count + 1 ) ) result.requeued += 1 else: now = _utc_now() stack = stack.push_execute( store.fail_task( task_id=str(record.id), completed_at=self._serialize_datetime(now), heartbeat_at=self._serialize_datetime(now), error="Task heartbeat stale", ) ) result.failed += 1 result.failed_task_ids.append(record.id) if not requeue_on_stale: result.handler_needed += 1 result.handler_needed_task_ids.append(record.id) if stack: # Reset any implicit read transaction the SELECT may have opened so # SQLSpec's stack runner owns the write transaction. with suppress(Exception): await driver.rollback() try: await driver.execute_stack(stack) except Exception: with suppress(Exception): await driver.rollback() raise recovered = result.requeued + result.failed if recovered: self._increment_queue_metric("stale_recovered", float(recovered)) if result.requeued: self._increment_queue_metric("retry", float(result.requeued)) if result.failed: self._increment_queue_metric("stale_failed", float(result.failed)) return result
[docs] async def set_execution_ref( self, task_id: "UUID", execution_backend: "str", execution_ref: "str", *, execution_profile: "str | None" = None ) -> "QueuedTaskRecord | None": async with self._session() as driver: await driver.begin() try: result = await driver.execute( self._get_store().set_execution_ref( task_id=str(task_id), execution_backend=execution_backend, execution_profile=execution_profile, execution_ref=execution_ref, ) ) row = await self._select_task(driver, task_id) if result.rows_affected else None await driver.commit() except Exception: with suppress(Exception): await driver.rollback() raise return self._record_from_row(row) if row is not None else None
[docs] async def set_execution_backend( self, task_id: "UUID", execution_backend: "str", *, execution_profile: "str | None" = None ) -> "QueuedTaskRecord | None": async with self._session() as driver: await driver.begin() try: result = await driver.execute( self._get_store().set_execution_backend( task_id=str(task_id), execution_backend=execution_backend, execution_profile=execution_profile ) ) row = await self._select_task(driver, task_id) if result.rows_affected else None await driver.commit() except Exception: with suppress(Exception): await driver.rollback() raise record = self._record_from_row(row) if row is not None else None if record is not None: await self.notify_new_task(record) return record
[docs] async def list_running_external(self, *, limit: "int | None" = None) -> "list[QueuedTaskRecord]": async with self._session() as driver: rows = await driver.select(self._get_store().list_running_external(limit=limit)) return [self._record_from_row(row) for row in cast("list[dict[str, Any]]", rows)]
[docs] async def get_statistics(self) -> "QueueStatistics": statistics = QueueStatistics() async with self._session() as driver: async for row in _select_stream(driver, self._get_store().list_all()): status = _coerce_status(cast("dict[str, Any]", row)["status"]) setattr(statistics, status, getattr(statistics, status) + 1) return statistics
[docs] async def iter_all(self, *, chunk_size: "int" = 1000) -> "AsyncIterator[QueuedTaskRecord]": """Stream every queue record without materializing the full table. Uses SQLSpec ``select_stream`` so large administrative scans and exports consume rows in chunks of ``chunk_size`` rather than loading the entire result set into memory. The backend session stays open for the duration of iteration, so callers should consume the iterator promptly. Yields: Queue task records from the backing SQLSpec table. """ session = self._session() driver = await session.__aenter__() try: async for row in _select_stream(driver, self._get_store().list_all(), chunk_size=chunk_size): yield self._record_from_row(cast("dict[str, Any]", row)) except BaseException as exc: if not await session.__aexit__(type(exc), exc, exc.__traceback__): raise else: await session.__aexit__(None, None, None)
[docs] async def list_completed_by_task( self, task_name: "str", *, since: "datetime | None" = None, limit: "int" = 10 ) -> "list[QueuedTaskRecord]": async with self._session() as driver: rows = await driver.select( self._get_store().list_completed_by_task( task_name=task_name, since=self._serialize_datetime(since), limit=limit ) ) return [self._record_from_row(row) for row in cast("list[dict[str, Any]]", rows)]
[docs] async def cleanup_terminal(self, before: "datetime") -> "int": store = self._get_store() before_str = self._serialize_datetime(before) async with self._session() as driver: await driver.begin() try: # Some drivers (e.g. psqlpy) cannot reliably report # ``rows_affected`` for DELETE: they return ``-1`` as a # sentinel for "unknown". Count first inside the same # transaction so the cleanup count is always exact. count_row = await driver.select_one_or_none(store.count_terminal(before=before_str)) deleted = int(count_row["terminal_count"]) if count_row is not None else 0 if deleted > 0: await driver.execute(store.cleanup_terminal(before=before_str)) await driver.commit() except Exception: with suppress(Exception): await driver.rollback() raise return deleted
[docs] async def notify_new_task(self, record: "QueuedTaskRecord") -> "None": """Publish a SQLSpec event when configured queue work becomes available.""" if self._notifications_enabled and self._event_channel is not None and record.status in _DUE_STATUSES: with self._observe_queue_operation("notify", task_id=str(record.id), queue=record.queue): await self._event_channel.publish( self._resolve_notification_channel(), { "task_id": str(record.id), "task_name": record.task_name, "queue": record.queue, "execution_backend": record.execution_backend, }, {"event_type": "litestar_queues.task_available"}, ) self._increment_queue_metric("notify")
[docs] async def wait_for_notifications(self, timeout: "float | None" = None) -> "bool": """Wait for a SQLSpec event when queue notifications are configured. Returns: True when a notification was received. """ if not self._notifications_enabled or self._event_channel is None: return await super().wait_for_notifications(timeout=timeout) stream = self._event_channel.iter_events( self._resolve_notification_channel(), poll_interval=self._event_poll_interval if self._event_poll_interval is not None else timeout, ) try: if timeout is None: event = await anext(stream) else: event = await asyncio.wait_for(anext(stream), timeout=timeout) except asyncio.TimeoutError: return False finally: await cast("Any", stream).aclose() await self._event_channel.ack(event.event_id) return True
@staticmethod def _default_sqlspec_config() -> "Any": from sqlspec.adapters.aiosqlite import AiosqliteConfig return AiosqliteConfig() def _resolve_table_name(self) -> "str": if self._table_name is None: queue_settings = _queue_extension_settings(self._sqlspec_config) configured_table_name = _setting(queue_settings, "table_name") or DEFAULT_TABLE_NAME self._table_name = validate_table_name(str(configured_table_name)) return self._table_name def _resolve_create_schema(self) -> "bool": if not self._manage_schema: return False return _resolve_bool( self._create_schema, _queue_extension_settings(self._sqlspec_config), "create_schema", True ) def _resolve_run_migrations(self) -> "bool": if not self._manage_schema: return False return _resolve_bool( self._run_migrations, _queue_extension_settings(self._sqlspec_config), "run_migrations", False ) def _resolve_notification_channel(self) -> "str": if self._notification_channel is not None: self._notification_channel = _normalize_notification_channel(str(self._notification_channel)) else: queue_settings = _queue_extension_settings(self._sqlspec_config) configured_channel = _setting(queue_settings, "notification_channel") or DEFAULT_NOTIFICATION_CHANNEL self._notification_channel = _normalize_notification_channel(str(configured_channel)) return self._notification_channel def _configure_notifications(self) -> "None": sqlspec_config = self._get_sqlspec_config() queue_settings = _queue_extension_settings(sqlspec_config) events_settings = _events_extension_settings(sqlspec_config) notifications_requested = self._resolve_notifications_requested(queue_settings) transport = self._select_notify_transport(sqlspec_config, queue_settings, events_settings) if not self._notifications_should_enable( notifications_requested, transport, sqlspec_config, queue_settings, events_settings ): self._notifications_enabled = False self._notification_backend = None return self._notifications_enabled = True self._resolve_notification_channel() if self._event_channel is None: self._apply_event_settings(sqlspec_config, queue_settings, events_settings, transport) self._event_channel = self._get_or_create_sqlspec().event_channel(sqlspec_config) self._owns_event_channel = True else: # An injected channel already owns its backend; still resolve the # configured poll interval so wait_for_notifications honors it. self._resolve_event_poll_interval(queue_settings, events_settings) self._notification_backend = cast("str | None", getattr(self._event_channel, "_backend_name", None)) def _resolve_notifications_requested(self, queue_settings: "dict[str, Any]") -> "bool | None": notifications_requested = self._notifications_requested if notifications_requested is None and "notifications" in queue_settings: notifications_requested = bool(queue_settings["notifications"]) return notifications_requested def _select_notify_transport( self, sqlspec_config: "Any", queue_settings: "dict[str, Any]", events_settings: "dict[str, Any]" ) -> "str": """Resolve the effective wakeup transport. Explicit ``queue_backend_config`` selections win over ``extension_config`` defaults, which in turn win over the per-adapter capability gate. Returns: A wakeup transport name (``listen_notify``, ``listen_notify_durable``, ``table_queue``, or ``polling``), or an events backend name carried over from existing ``event_backend`` configuration. """ explicit = self._notify_transport or _setting(queue_settings, "notify_transport") if explicit is None: explicit = ( self._event_backend or _setting(queue_settings, "event_backend") or events_settings.get("backend") ) if explicit is not None: return str(explicit) return _adapter_notify_transport(resolve_adapter_name(sqlspec_config)) def _notifications_should_enable( self, notifications_requested: "bool | None", transport: "str", sqlspec_config: "Any", queue_settings: "dict[str, Any]", events_settings: "dict[str, Any]", ) -> "bool": """Decide whether push wakeups are active for the resolved transport. Notifications stay opt-in: a bare backend config never auto-enables an events channel (which keeps the frozen claim/lease contract and the zero-config polling default intact). When a signal is present but the adapter is gated to ``polling``, wakeups degrade to interval polling. Returns: True when an events channel should back worker wakeups. """ if notifications_requested is False: return False if self._event_channel is not None: return True events_present = _EVENT_EXTENSION_NAME in cast( "dict[str, Any]", getattr(sqlspec_config, "extension_config", {}) or {} ) explicit_signal = ( self._notify_transport is not None or "notify_transport" in queue_settings or self._event_backend is not None or "event_backend" in queue_settings or bool(events_settings) or events_present or notifications_requested is True ) return explicit_signal and transport != _NOTIFY_TRANSPORT_POLLING def _resolve_event_poll_interval( self, queue_settings: "dict[str, Any]", events_settings: "dict[str, Any]" ) -> "None": configured_poll_interval = self._event_poll_interval if configured_poll_interval is None: configured_poll_interval = _setting(queue_settings, "event_poll_interval") if configured_poll_interval is None and "poll_interval" in events_settings: configured_poll_interval = events_settings["poll_interval"] if configured_poll_interval is not None: self._event_poll_interval = float(configured_poll_interval) def _apply_event_settings( self, sqlspec_config: "Any", queue_settings: "dict[str, Any]", events_settings: "dict[str, Any]", transport: "str", ) -> "None": merged_event_settings = dict(events_settings) for name in _QUEUE_SETTING_EVENT_SETTINGS: configured_events = queue_settings.get(name) if isinstance(configured_events, dict): merged_event_settings.update(configured_events) merged_event_settings.update(self._event_settings) merged_event_settings["backend"] = transport configured_queue_table = self._event_queue_table or _setting(queue_settings, "event_queue_table") if configured_queue_table is not None: merged_event_settings["queue_table"] = str(configured_queue_table) self._resolve_event_poll_interval(queue_settings, merged_event_settings) if self._event_poll_interval is not None: merged_event_settings["poll_interval"] = self._event_poll_interval extension_config = dict(cast("dict[str, Any]", getattr(sqlspec_config, "extension_config", {}) or {})) extension_config[_EVENT_EXTENSION_NAME] = merged_event_settings sqlspec_config.extension_config = extension_config migration_config = dict(cast("dict[str, Any]", getattr(sqlspec_config, "migration_config", {}) or {})) sqlspec_config.set_migration_config(migration_config) def _get_or_create_sqlspec(self) -> "SQLSpec": if self._sqlspec is None: self._sqlspec = SQLSpec() return self._sqlspec def _get_sqlspec_config(self) -> "Any": if self._sqlspec_config is None: registered_configs = tuple(cast("dict[int, Any]", self._get_or_create_sqlspec().configs).values()) if len(registered_configs) == 1: self._sqlspec_config = registered_configs[0] elif len(registered_configs) > 1: msg = ( "SQLSpecQueueBackend received a SQLSpec manager with multiple configs; " "pass config to select the queue database." ) raise QueueConfigurationError(msg) else: self._sqlspec_config = self._default_sqlspec_config() return self._sqlspec_config def _get_store(self) -> "Any": if self._store is None: self._store = create_queue_store( self._get_sqlspec_config(), table_name=self._resolve_table_name(), column_map=self._column_map, native_json_columns=self._native_json_columns, manage_schema=self._manage_schema, ) return self._store @asynccontextmanager async def _session(self) -> "AsyncIterator[Any]": if not self._opened or self._sqlspec is None: msg = "SQLSpecQueueBackend.open() must be called before using the backend." raise RuntimeError(msg) sqlspec_config = self._get_sqlspec_config() async with _bridge_session(self._get_or_create_sqlspec(), sqlspec_config) as driver: yield driver @asynccontextmanager async def _heartbeat_session(self) -> "AsyncIterator[Any]": """Yield a driver bound to the dedicated heartbeat pool when configured. Falls back to the main pool when ``heartbeat_pool_config`` is not set, or when the dedicated pool failed to register at ``open()`` time. Yields: A SQLSpec driver bound to the heartbeat or main pool. Raises: RuntimeError: When ``open()`` has not been called on the backend. """ if not self._opened or self._sqlspec is None: msg = "SQLSpecQueueBackend.open() must be called before using the backend." raise RuntimeError(msg) if self._heartbeat_pool_enabled and self._heartbeat_pool_registered and self._heartbeat_pool_config is not None: async with _bridge_session(self._sqlspec, self._heartbeat_pool_config) as driver: yield driver else: async with self._session() as driver: yield driver def _register_heartbeat_pool(self) -> "None": """Register the dedicated heartbeat pool with the SQLSpec manager. Best effort. On failure the backend logs a warning and continues with the main pool for heartbeats. """ if ( self._heartbeat_pool_enabled and self._heartbeat_pool_config is not None and not self._heartbeat_pool_registered ): try: self._get_or_create_sqlspec().add_config(self._heartbeat_pool_config) except Exception: getLogger("litestar_queues").warning( "SQLSpecQueueBackend heartbeat pool registration failed; " "falling back to main pool for heartbeat writes.", exc_info=True, ) self._heartbeat_pool_enabled = False self._heartbeat_pool_registered = False else: self._heartbeat_pool_registered = True async def _close_heartbeat_pool(self) -> "None": """Close the dedicated heartbeat pool if the backend opened one.""" if self._heartbeat_pool_registered and self._heartbeat_pool_config is not None: try: close_result = self._heartbeat_pool_config.close_pool() if isawaitable(close_result): await close_result except Exception: getLogger("litestar_queues").debug("SQLSpecQueueBackend heartbeat pool close failed.", exc_info=True) self._heartbeat_pool_registered = False async def _select_pending_rows( self, *, limit: "int", queue: "str | None", execution_backend: "str | None" ) -> "list[dict[str, Any]]": async with self._session() as driver: rows = await driver.select( self._get_store().list_pending( now=self._serialize_datetime(_utc_now()), limit=limit, queue=queue, execution_backend=execution_backend, ) ) return cast("list[dict[str, Any]]", rows) async def _select_task(self, driver: "Any", task_id: "UUID") -> "dict[str, Any] | None": row = await driver.select_one_or_none(self._get_store().select_task(str(task_id))) return cast("dict[str, Any] | None", row) async def _select_task_by_key(self, driver: "Any", key: "str") -> "dict[str, Any] | None": row = await driver.select_one_or_none(self._get_store().select_task_by_key(key)) return cast("dict[str, Any] | None", row) async def _clear_key(self, driver: "Any", task_id: "UUID") -> "None": await driver.execute(self._get_store().clear_key(task_id=str(task_id))) def _get_observability_runtime(self) -> "Any | None": if not self._queue_observability: return None return self._get_sqlspec_config().get_observability_runtime() @contextmanager def _observe_queue_operation(self, operation: "str", **attributes: "Any") -> "Iterator[None]": runtime = self._get_observability_runtime() if runtime is None: yield return span_attributes = { "sqlspec.queue.operation": operation, **{f"litestar_queues.{key}": value for key, value in attributes.items() if value is not None}, } span = runtime.start_span(f"sqlspec.queue.{operation}", attributes=span_attributes) error: "Exception | None" = None try: yield except Exception as exc: error = exc raise finally: if span is not None: runtime.end_span(span, error=error) def _increment_queue_metric(self, name: "str", amount: "float" = 1.0) -> "None": runtime = self._get_observability_runtime() if runtime is not None and amount: runtime.increment_metric(f"queue.{name}", amount) async def _existing_records_by_key( self, driver: "Any", store: "Any", keys: "list[str]" ) -> "dict[str, QueuedTaskRecord]": """Return a map of deduplication key to existing record for the given keys.""" existing: "dict[str, QueuedTaskRecord]" = {} if not keys: return existing rows = await driver.select(store.select_tasks_by_keys(keys)) for row in cast("list[dict[str, Any]]", rows): record = self._record_from_row(row) if record.key is not None: existing[record.key] = record return existing def _plan_bulk_enqueue( self, specs: "Sequence[EnqueueSpec]", existing_by_key: "dict[str, QueuedTaskRecord]", now: "datetime" ) -> "tuple[list[QueuedTaskRecord], list[QueuedTaskRecord], list[UUID]]": """Resolve deduplication keys and build records, preserving input order. Returns the ordered result records, the subset that must be inserted, and the ids of terminal-key rows whose key must be cleared before insert. Active (non-terminal) keys, whether already persisted or earlier in the batch, reuse the existing record instead of inserting a duplicate. Returns: Ordered result records, records to insert, and terminal-key ids to clear. """ results: "list[QueuedTaskRecord]" = [] to_insert: "list[QueuedTaskRecord]" = [] terminal_keys_to_clear: "list[UUID]" = [] batch_new_by_key: "dict[str, QueuedTaskRecord]" = {} for spec in specs: key = spec.key if key is not None: reused = self._reuse_for_key(key, existing_by_key, batch_new_by_key, terminal_keys_to_clear) if reused is not None: results.append(reused) continue record = self._record_from_spec(spec, now) results.append(record) to_insert.append(record) if key is not None: batch_new_by_key[key] = record return results, to_insert, terminal_keys_to_clear @staticmethod def _reuse_for_key( key: "str", existing_by_key: "dict[str, QueuedTaskRecord]", batch_new_by_key: "dict[str, QueuedTaskRecord]", terminal_keys_to_clear: "list[UUID]", ) -> "QueuedTaskRecord | None": """Return the record to reuse for ``key``, or ``None`` if a new row is needed. Records a terminal key for clearing so its row can be replaced. """ active = existing_by_key.get(key) if active is not None and not active.is_terminal: return active earlier = batch_new_by_key.get(key) if earlier is not None: return earlier if active is not None: terminal_keys_to_clear.append(active.id) del existing_by_key[key] return None @staticmethod def _record_from_spec(spec: "EnqueueSpec", now: "datetime") -> "QueuedTaskRecord": return QueuedTaskRecord( task_name=spec.task_name, args=spec.args, kwargs=dict(spec.kwargs or {}), queue=spec.queue, execution_backend=spec.execution_backend, execution_profile=spec.execution_profile, status="scheduled" if spec.scheduled_at is not None and spec.scheduled_at > now else "pending", priority=spec.priority, max_retries=spec.max_retries, scheduled_at=spec.scheduled_at, key=spec.key, metadata=dict(spec.metadata or {}), ) async def _bulk_insert(self, driver: "Any", store: "Any", records: "list[QueuedTaskRecord]") -> "None": """Insert records using the adapter's fastest available bulk tier.""" values = store.bulk_values([self._params_from_record(record) for record in records]) if store.supports_native_bulk_ingest: await driver.load_from_records(store.table_name, values) else: await driver.execute_many(store.insert_tasks_template(), values) def _serialize_datetime(self, value: "datetime | None") -> "datetime | str | None": serialized = _serialize_datetime(value) if serialized is not None and self._get_store().bind_datetime_as_text: return serialized.isoformat() return serialized def _params_from_record(self, record: "QueuedTaskRecord") -> "dict[str, Any]": store = self._get_store() return { "args_json": store.serialize_json("args_json", list(record.args)), "completed_at": self._serialize_datetime(record.completed_at), "created_at": self._serialize_datetime(record.created_at), "error": record.error, "execution_backend": record.execution_backend, "execution_profile": record.execution_profile, "execution_ref": record.execution_ref, "heartbeat_at": self._serialize_datetime(record.heartbeat_at), "id": str(record.id), "kwargs_json": store.serialize_json("kwargs_json", record.kwargs), "max_retries": record.max_retries, "metadata_json": store.serialize_json("metadata_json", record.metadata), "priority": record.priority, "queue": record.queue, "result_json": store.serialize_json("result_json", record.result), "retry_count": record.retry_count, "scheduled_at": self._serialize_datetime(record.scheduled_at), "started_at": self._serialize_datetime(record.started_at), "status": record.status, "task_key": record.key, "task_name": record.task_name, } def _record_from_row(self, row: "dict[str, Any]") -> "QueuedTaskRecord": store = self._get_store() args = store.deserialize_json("args_json", row["args_json"]) kwargs = store.deserialize_json("kwargs_json", row["kwargs_json"]) metadata = store.deserialize_json("metadata_json", row["metadata_json"]) return QueuedTaskRecord( id=UUID(str(row["id"])), task_name=str(row["task_name"]), args=tuple(args), kwargs=kwargs, queue=str(row["queue"]), execution_backend=str(row["execution_backend"]), execution_profile=cast("str | None", row["execution_profile"]), execution_ref=cast("str | None", row["execution_ref"]), status=_coerce_status(row["status"]), priority=int(row["priority"]), max_retries=int(row["max_retries"]), retry_count=int(row["retry_count"]), scheduled_at=_deserialize_datetime(row["scheduled_at"]), created_at=cast("datetime", _deserialize_datetime(row["created_at"])), started_at=_deserialize_datetime(row["started_at"]), completed_at=_deserialize_datetime(row["completed_at"]), heartbeat_at=_deserialize_datetime(row["heartbeat_at"]), result=store.deserialize_json("result_json", row["result_json"]), error=cast("str | None", row["error"]), key=cast("str | None", row["task_key"]), metadata=metadata, )
class _ManagedAsyncDriver: """Expose sync SQLSpec driver methods through SQLSpec's managed async bridge.""" __slots__ = ("_driver",) def __init__(self, driver: "Any") -> "None": self._driver = driver def __getattr__(self, name: "str") -> "Any": attr = getattr(self._driver, name) if callable(attr): return async_(attr) return attr @asynccontextmanager async def _bridge_session(sqlspec_manager: "Any", sqlspec_config: "Any") -> "AsyncIterator[Any]": """Yield a SQLSpec driver regardless of sync/async config. Sync SQLSpec configs (``SqliteConfig``, ``DuckDBConfig``, ``MysqlConnectorSyncConfig``, etc.) return sync context managers and sync drivers. They are bridged with ``sqlspec.utils.sync_tools.async_`` so blocking operations use SQLSpec's managed executor and honor ``SQLSPEC_ASYNC_THREAD_LIMIT``. Yields: A SQLSpec driver whose methods can be awaited regardless of whether the underlying config is sync or async. """ session_cm = sqlspec_manager.provide_session(sqlspec_config) if sqlspec_config.is_async: async with session_cm as driver: yield driver else: driver = await async_(session_cm.__enter__)() try: yield _ManagedAsyncDriver(driver) except BaseException as exc: if not await async_(session_cm.__exit__)(type(exc), exc, exc.__traceback__): raise else: await async_(session_cm.__exit__)(None, None, None) async def _select_stream(driver: "Any", statement: "Any", *, chunk_size: "int | None" = None) -> "AsyncIterator[Any]": """Yield rows from SQLSpec async and sync stream implementations. Yields: Rows returned by the SQLSpec statement. """ if isinstance(driver, _ManagedAsyncDriver): rows = await driver.select(statement) for row in rows: yield row else: if chunk_size is None: stream = driver.select_stream(statement) else: stream = driver.select_stream(statement, chunk_size=chunk_size) if isawaitable(stream): stream = await stream async for row in stream: yield row def _utc_now() -> "datetime": return datetime.now(timezone.utc) def _extract_count(result: "Any") -> "int": """Pull a non-negative ``COUNT(*)`` value off a SQLSpec result. Returns: The count value from the first result row, or ``0`` when unavailable. """ rows = getattr(result, "data", None) or [] if rows: row = rows[0] if isinstance(row, dict): return int(next(iter(row.values()))) if isinstance(row, (list, tuple)): return int(row[0]) return 0 def _serialize_datetime(value: "datetime | None") -> "datetime | None": if value is None: return None if value.tzinfo is None: value = value.replace(tzinfo=timezone.utc) return value.astimezone(timezone.utc) def _deserialize_datetime(value: "Any") -> "datetime | None": if value is None: return None value_text = str(value) try: parsed = datetime.fromisoformat(value_text) except ValueError: parsed = datetime.strptime(value_text.upper(), "%d-%b-%y").replace(tzinfo=timezone.utc) if parsed.tzinfo is None: return parsed.replace(tzinfo=timezone.utc) return parsed.astimezone(timezone.utc) def _coerce_status(value: "Any") -> "TaskStatus": status = str(value) if status not in {"cancelled", "completed", "failed", "pending", "running", "scheduled"}: msg = f"Unknown queued task status from SQLSpec queue backend: {status!r}" raise ValueError(msg) return cast("TaskStatus", status) def _queue_extension_settings(sqlspec_config: "Any | None") -> "dict[str, Any]": if sqlspec_config is None: return {} extension_config = cast("dict[str, Any]", getattr(sqlspec_config, "extension_config", {}) or {}) return dict(cast("dict[str, Any]", extension_config.get(QUEUE_EXTENSION_NAME, {}) or {})) async def _create_schema_statements(store: "Any", driver: "Any") -> "list[str]": create_for_driver = getattr(store, "create_statements_for_driver", None) if callable(create_for_driver): result = create_for_driver(driver) if isawaitable(result): return cast("list[str]", await result) return cast("list[str]", result) return cast("list[str]", store.create_statements()) def _events_extension_settings(sqlspec_config: "Any | None") -> "dict[str, Any]": if sqlspec_config is None: return {} extension_config = cast("dict[str, Any]", getattr(sqlspec_config, "extension_config", {}) or {}) return dict(cast("dict[str, Any]", extension_config.get(_EVENT_EXTENSION_NAME, {}) or {})) def _setting(queue_settings: "dict[str, Any]", *names: "str") -> "Any": for name in names: if name in queue_settings: return queue_settings[name] return None def _resolve_bool(value: "bool | None", queue_settings: "dict[str, Any]", key: "str", default: "bool") -> "bool": if value is not None: return value if key in queue_settings: return bool(queue_settings[key]) return default def _normalize_notification_channel(channel: "str") -> "str": try: return str(normalize_event_channel_name(channel)) except Exception as exc: msg = f"Invalid SQLSpec queue notification channel: {channel!r}" raise QueueConfigurationError(msg) from exc