import asyncio
import contextlib
import os
from typing import TYPE_CHECKING
from litestar_queues.config import execution_backend_name
from litestar_queues.execution import get_execution_backend
if TYPE_CHECKING:
from datetime import timedelta
from uuid import UUID
from litestar_queues.models import QueuedTaskRecord
from litestar_queues.service import QueueService
__all__ = ("Worker",)
[docs]
class Worker:
"""Local in-process queue worker."""
__slots__ = (
"_batch_size",
"_final_cancel_timeout",
"_graceful_shutdown_timeout",
"_heartbeat_interval",
"_is_running",
"_last_reconcile_at",
"_last_stale_check_at",
"_max_concurrency",
"_poll_interval",
"_queues",
"_reconcile_interval",
"_running_tasks",
"_service",
"_stale_after",
"_stale_check_interval",
"_stop_event",
"_worker_id",
)
[docs]
def __init__(
self,
service: "QueueService",
*,
batch_size: "int" = 10,
poll_interval: "float" = 0.1,
max_concurrency: "int" = 1,
heartbeat_interval: "float" = 30,
reconcile_interval: "float" = 30,
stale_after: "timedelta | None" = None,
stale_check_interval: "float" = 60.0,
graceful_shutdown_timeout: "float" = 30,
final_cancel_timeout: "float" = 5,
worker_id: "str | None" = None,
queues: "tuple[str, ...]" = (),
) -> "None":
"""Initialize the worker."""
self._service = service
self._batch_size = batch_size
self._poll_interval = poll_interval
self._max_concurrency = max(1, max_concurrency)
self._heartbeat_interval = heartbeat_interval
self._reconcile_interval = reconcile_interval
self._stale_after = stale_after
self._stale_check_interval = stale_check_interval
self._graceful_shutdown_timeout = graceful_shutdown_timeout
self._final_cancel_timeout = final_cancel_timeout
self._worker_id = worker_id if worker_id is not None else f"worker-{os.getpid()}"
self._queues = queues
self._running_tasks: "set[asyncio.Task[None]]" = set()
self._stop_event = asyncio.Event()
self._is_running = False
self._last_reconcile_at = -float("inf")
self._last_stale_check_at = -float("inf")
@property
def is_running(self) -> "bool":
"""Whether the worker loop is active."""
return self._is_running
@property
def worker_id(self) -> "str":
"""Worker identity used for events and logs."""
return self._worker_id
[docs]
async def start(self) -> "None":
"""Run the worker loop until stopped or cancelled."""
self._is_running = True
self._stop_event.clear()
try:
while not self._stop_event.is_set():
await self._maybe_requeue_stale()
await self._maybe_reconcile_external()
processed = await self.run_once()
if processed == 0:
await self._wait_for_work()
finally:
self._is_running = False
[docs]
async def stop(self, *, force: "bool" = False) -> "None":
"""Stop the worker loop and drain or cancel in-flight work."""
self._stop_event.set()
if force:
await self._cancel_running()
else:
await self._drain_running()
[docs]
async def run_once(self) -> "int":
"""Process one batch of due tasks.
Returns:
Number of claimed task records.
"""
queue_backend = self._service.get_queue_backend()
execution_backend = self._service.get_execution_backend()
available = min(self._batch_size, max(0, self._max_concurrency - len(self._running_tasks)))
if available <= 0:
return 0
records = await self._list_pending(limit=available)
if execution_backend.is_external:
return await self._dispatch_external(records)
claimed_records: 'list["QueuedTaskRecord"]' = []
for record in records:
claimed = await queue_backend.claim_task(record.id)
if claimed is None:
continue
claimed_records.append(claimed)
if not claimed_records:
return 0
tasks = [self._track_execution(record) for record in claimed_records]
await asyncio.gather(*tasks, return_exceptions=True)
return len(claimed_records)
async def _list_pending(self, *, limit: "int") -> "list[QueuedTaskRecord]":
queue_backend = self._service.get_queue_backend()
execution_backend_name_ = execution_backend_name(self._service.config.execution_backend)
if not self._queues:
return await queue_backend.list_pending(limit=limit, execution_backend=execution_backend_name_)
records: 'list["QueuedTaskRecord"]' = []
seen: "set[object]" = set()
for queue in self._queues:
if len(records) >= limit:
break
queue_records = await queue_backend.list_pending(
limit=limit - len(records), queue=queue, execution_backend=execution_backend_name_
)
for record in queue_records:
if record.id in seen:
continue
seen.add(record.id)
records.append(record)
if len(records) >= limit:
break
return records
def _track_execution(self, record: "QueuedTaskRecord") -> "asyncio.Task[None]":
task = asyncio.create_task(self._execute_claimed(record))
self._running_tasks.add(task)
task.add_done_callback(self._running_tasks.discard)
return task
[docs]
async def reconcile_external(self, *, limit: "int | None" = None) -> "int":
"""Reconcile externally dispatched records.
Returns:
Number of records that reached a terminal queue status.
"""
queue_backend = self._service.get_queue_backend()
records = await queue_backend.list_running_external(limit=limit)
reconciled = 0
current_backend = self._service.get_execution_backend()
for record in records:
if record.execution_ref is None:
continue
execution_backend = (
current_backend
if record.execution_backend == execution_backend_name(self._service.config.execution_backend)
else get_execution_backend(record.execution_backend, config=self._service.config)
)
updated = await execution_backend.reconcile(self._service, record)
if updated is not None and updated.is_terminal:
reconciled += 1
return reconciled
async def _execute_claimed(self, record: "QueuedTaskRecord") -> "None":
heartbeat_task = asyncio.create_task(self._heartbeat(record.id, expected_retry_count=record.retry_count))
try:
await self._service.get_execution_backend().execute(self._service, record, worker_id=self._worker_id)
finally:
heartbeat_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await heartbeat_task
await self._service.get_queue_backend().null_heartbeats(
[record.id], expected_retry_count=record.retry_count
)
async def _dispatch_external(self, records: "list[QueuedTaskRecord]") -> "int":
execution_backend = self._service.get_execution_backend()
dispatched = 0
for record in records:
if record.execution_ref is not None:
continue
await execution_backend.dispatch(self._service, record)
dispatched += 1
return dispatched
async def _maybe_requeue_stale(self) -> "None":
if self._stale_after is None:
return
now = asyncio.get_running_loop().time()
if now - self._last_stale_check_at < self._stale_check_interval:
return
self._last_stale_check_at = now
await self._service.recover_stale_tasks(stale_after=self._stale_after, worker_id=self._worker_id)
async def _maybe_reconcile_external(self) -> "None":
if self._reconcile_interval <= 0:
await self.reconcile_external()
return
now = asyncio.get_running_loop().time()
if now - self._last_reconcile_at < self._reconcile_interval:
return
self._last_reconcile_at = now
await self.reconcile_external()
async def _heartbeat(self, task_id: "UUID", expected_retry_count: "int | None" = None) -> "None":
while True:
await asyncio.sleep(self._heartbeat_interval)
await self._service.get_queue_backend().touch_heartbeat(task_id, expected_retry_count=expected_retry_count)
async def _drain_running(self) -> "None":
if not self._running_tasks:
return
try:
await asyncio.wait_for(
asyncio.gather(*tuple(self._running_tasks), return_exceptions=True),
timeout=self._graceful_shutdown_timeout,
)
except asyncio.TimeoutError:
await self._cancel_running()
async def _cancel_running(self) -> "None":
tasks = tuple(self._running_tasks)
if not tasks:
return
for task in tasks:
task.cancel()
with contextlib.suppress(asyncio.TimeoutError):
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=self._final_cancel_timeout)
async def _wait_for_work(self) -> "None":
queue_backend = self._service.get_queue_backend()
notification_task = asyncio.create_task(queue_backend.wait_for_notifications(timeout=self._poll_interval))
stop_task = asyncio.create_task(self._stop_event.wait())
done, pending = await asyncio.wait({notification_task, stop_task}, return_when=asyncio.FIRST_COMPLETED)
for task in pending:
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
for task in done:
with contextlib.suppress(asyncio.TimeoutError):
task.result()