import asyncio
import contextvars
import inspect
import pkgutil
import random
import sys
import zoneinfo
from collections.abc import Awaitable, Callable, Mapping
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from functools import partial
from importlib import import_module, reload
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, overload
from typing_extensions import ParamSpec, Self
if TYPE_CHECKING:
from concurrent.futures import Executor
from types import ModuleType
from uuid import UUID
from litestar_queues.events import TaskExecutionContext
from litestar_queues.models import QueuedTaskRecord, TaskStatus
from litestar_queues.service import QueueService
__all__ = (
"ScheduleConfig",
"Task",
"TaskResult",
"clear_task_registry",
"discover_tasks",
"get_default_service",
"get_scheduled_tasks",
"get_task_registry",
"load_task_modules",
"set_default_service",
"task",
)
P = ParamSpec("P")
T = TypeVar("T")
TaskCallable = Callable[P, T | Awaitable[T]]
AnyTaskCallable = Callable[..., Any]
CRON_FIELD_COUNT = 5
CRON_SEARCH_YEARS = 8
SUNDAY_CRON_VALUE = 7
_task_registry: 'dict[str, "Task[Any, Any]"]' = {}
_schedule_registry: 'dict[str, "ScheduleConfig"]' = {}
_loaded_modules: "set[str]" = set()
_RANDOM = random.SystemRandom()
_default_service_holder: 'list["QueueService | None"]' = [None]
@dataclass(frozen=True, slots=True)
class _ParsedCron:
minutes: "set[int]"
hours: "set[int]"
days: "set[int]"
months: "set[int]"
weekdays: "set[int]"
day_of_month_restricted: "bool"
day_of_week_restricted: "bool"
[docs]
@dataclass(frozen=True, slots=True)
class ScheduleConfig:
"""Configuration for a recurring task schedule."""
task_name: "str"
cron: "str | None" = None
interval: "timedelta | int | float | None" = None
timezone: "str" = "UTC"
initial_delay: "timedelta | int | float" = 0
jitter: "timedelta | int | float" = 0
max_instances: "int" = 1
timeout: "float | None" = None
def __post_init__(self) -> "None":
if self.cron is not None and self.interval is not None:
msg = "Cannot specify both cron and interval"
raise ValueError(msg)
interval = _coerce_interval(self.interval)
initial_delay = _coerce_interval(self.initial_delay) or timedelta()
jitter = _coerce_interval(self.jitter) or timedelta()
if interval is not None and interval <= timedelta():
msg = "Schedule interval must be positive"
raise ValueError(msg)
if initial_delay < timedelta():
msg = "Schedule initial_delay cannot be negative"
raise ValueError(msg)
if jitter < timedelta():
msg = "Schedule jitter cannot be negative"
raise ValueError(msg)
_get_timezone(self.timezone)
if self.cron is not None:
self._parse_cron()
object.__setattr__(self, "interval", interval)
object.__setattr__(self, "initial_delay", initial_delay)
object.__setattr__(self, "jitter", jitter)
[docs]
def get_next_run(self, after: "datetime | None" = None, *, use_initial_delay: "bool" = False) -> "datetime":
"""Calculate the next scheduled run time.
Returns:
The next run time in UTC.
Raises:
ValueError: If no interval or cron expression is configured.
"""
base = _ensure_utc(after or datetime.now(timezone.utc))
initial_delay = cast("timedelta", self.initial_delay)
interval = cast("timedelta | None", self.interval)
if use_initial_delay and initial_delay:
return self._apply_jitter(base + initial_delay)
if interval is not None:
return self._apply_jitter(base + interval)
if self.cron is None:
msg = "Schedule must have either cron or interval"
raise ValueError(msg)
return self._apply_jitter(self._get_next_cron_run(base))
def _parse_cron(self) -> "_ParsedCron":
if self.cron is None:
msg = "Cron expression is not configured"
raise ValueError(msg)
aliases = {
"@annually": "0 0 1 1 *",
"@daily": "0 0 * * *",
"@hourly": "0 * * * *",
"@midnight": "0 0 * * *",
"@monthly": "0 0 1 * *",
"@weekly": "0 0 * * 0",
"@yearly": "0 0 1 1 *",
}
expression = aliases.get(self.cron, self.cron)
parts = expression.split()
if len(parts) != CRON_FIELD_COUNT:
msg = f"Invalid cron expression: {self.cron}"
raise ValueError(msg)
day_field = parts[2]
weekday_field = parts[4]
if day_field == "?" and weekday_field == "?":
msg = f"Invalid cron expression: {self.cron}"
raise ValueError(msg)
month_names = {
"APR": 4,
"AUG": 8,
"DEC": 12,
"FEB": 2,
"JAN": 1,
"JUL": 7,
"JUN": 6,
"MAR": 3,
"MAY": 5,
"NOV": 11,
"OCT": 10,
"SEP": 9,
}
weekday_names = {"FRI": 5, "MON": 1, "SAT": 6, "SUN": 0, "THU": 4, "TUE": 2, "WED": 3}
try:
return _ParsedCron(
minutes=_expand_cron_field(parts[0], minimum=0, maximum=59),
hours=_expand_cron_field(parts[1], minimum=0, maximum=23),
days=_expand_cron_field(day_field, minimum=1, maximum=31, allow_question=True),
months=_expand_cron_field(parts[3], minimum=1, maximum=12, names=month_names),
weekdays=_expand_cron_field(
weekday_field, minimum=0, maximum=7, names=weekday_names, normalize_sunday=True, allow_question=True
),
day_of_month_restricted=day_field not in {"*", "?"},
day_of_week_restricted=weekday_field not in {"*", "?"},
)
except (KeyError, TypeError, ValueError) as exc:
msg = f"Invalid cron expression: {self.cron}"
raise ValueError(msg) from exc
def _get_next_cron_run(self, after: "datetime") -> "datetime":
parsed = self._parse_cron()
tz = _get_timezone(self.timezone)
candidate = after.astimezone(tz).replace(second=0, microsecond=0) + timedelta(minutes=1)
max_attempts = CRON_SEARCH_YEARS * 366 * 24 * 60
for _ in range(max_attempts):
cron_weekday = (candidate.weekday() + 1) % 7
if (
candidate.minute in parsed.minutes
and candidate.hour in parsed.hours
and candidate.month in parsed.months
and _cron_day_matches(parsed, day=candidate.day, weekday=cron_weekday)
):
return candidate.astimezone(timezone.utc)
candidate += timedelta(minutes=1)
msg = f"No matching run found for cron expression: {self.cron}"
raise ValueError(msg)
def _apply_jitter(self, value: "datetime") -> "datetime":
jitter = cast("timedelta", self.jitter)
jitter_seconds = jitter.total_seconds()
if jitter_seconds <= 0:
return value
return value + timedelta(seconds=_RANDOM.uniform(0, jitter_seconds))
[docs]
class TaskResult:
"""Handle to a queued task result."""
__slots__ = ("_cached_record", "_service", "_task_id", "_task_name")
[docs]
def __init__(
self,
task_id: "UUID",
task_name: "str",
*,
service: "QueueService | None" = None,
record: "QueuedTaskRecord | None" = None,
) -> "None":
self._task_id = task_id
self._task_name = task_name
self._service = service
self._cached_record = record
@property
def id(self) -> "UUID":
"""Queue record ID."""
return self._task_id
@property
def task_name(self) -> "str":
"""Registered task name."""
return self._task_name
@property
def status(self) -> "TaskStatus | None":
"""Cached task status."""
return self._cached_record.status if self._cached_record is not None else None
@property
def result(self) -> "Any":
"""Cached task result."""
return self._cached_record.result if self._cached_record is not None else None
@property
def error(self) -> "str | None":
"""Cached task error."""
return self._cached_record.error if self._cached_record is not None else None
@property
def record(self) -> "QueuedTaskRecord | None":
"""Cached queue record."""
return self._cached_record
[docs]
async def refresh(self) -> "Self":
"""Refresh this handle from its queue service.
Returns:
The refreshed result handle.
Raises:
RuntimeError: If the result has no associated service.
"""
if self._service is None:
msg = "TaskResult.refresh() requires an associated QueueService."
raise RuntimeError(msg)
self._cached_record = await self._service.get_task(self._task_id)
return self
[docs]
async def wait(self, *, timeout: "float | None" = None, poll_interval: "float" = 0.1) -> "Self":
"""Wait until the task reaches a terminal status.
Returns:
The completed result handle.
Raises:
TimeoutError: If the timeout elapses before a terminal status.
"""
start = asyncio.get_running_loop().time()
while self.status not in {"cancelled", "completed", "failed"}:
await self.refresh()
if self.status in {"cancelled", "completed", "failed"}:
break
if timeout is not None and asyncio.get_running_loop().time() - start >= timeout:
msg = f"Task {self._task_id} did not complete within {timeout}s"
raise TimeoutError(msg)
await asyncio.sleep(poll_interval)
return self
[docs]
class Task(Generic[P, T]):
"""Registered task wrapper with direct call and enqueue APIs."""
__slots__ = (
"__dict__",
"_description",
"_execution_backend",
"_execution_profile",
"_func",
"_key",
"_log_level",
"_name",
"_priority",
"_queue",
"_quiet_success",
"_requeue_on_stale",
"_retries",
"_run_after",
"_timeout",
)
[docs]
def __init__(
self,
func: "TaskCallable[P, T]",
*,
name: "str",
queue: "str" = "default",
priority: "int" = 0,
retries: "int" = 0,
timeout: "float | None" = None,
execution_backend: "str | None" = None,
execution_profile: "str | None" = None,
key: "str | None" = None,
run_after: "float | timedelta | None" = None,
description: "str | None" = None,
log_level: "str | None" = None,
quiet_success: "bool | None" = None,
requeue_on_stale: "bool | None" = None,
) -> "None":
self._func = func
self._name = name
self._queue = queue
self._priority = priority
self._retries = retries
self._timeout = timeout
self._execution_backend = execution_backend
self._execution_profile = execution_profile
self._key = key
self._run_after = _coerce_interval(run_after)
self._description = description
self._log_level = log_level
self._quiet_success = quiet_success
self._requeue_on_stale = requeue_on_stale
@property
def name(self) -> "str":
"""Registered task name."""
return self._name
@property
def queue(self) -> "str":
"""Default queue name."""
return self._queue
@property
def priority(self) -> "int":
"""Default priority."""
return self._priority
@property
def retries(self) -> "int":
"""Maximum retry count."""
return self._retries
@property
def timeout(self) -> "float | None":
"""Execution timeout."""
return self._timeout
@property
def execution_backend(self) -> "str | None":
"""Task-specific execution backend override."""
return self._execution_backend
@property
def execution_profile(self) -> "str | None":
"""Task-specific execution profile override."""
return self._execution_profile
@property
def key(self) -> "str | None":
"""Default deduplication key."""
return self._key
@property
def run_after(self) -> "timedelta | None":
"""Relative delay for enqueue operations."""
return self._run_after
@property
def description(self) -> "str | None":
"""Task description metadata."""
return self._description
@property
def log_level(self) -> "str | None":
"""Task log level metadata."""
return self._log_level
@property
def quiet_success(self) -> "bool | None":
"""Whether successful completion logging should be quiet."""
return self._quiet_success
@property
def requeue_on_stale(self) -> "bool":
"""Whether stale running records should be requeued when retries remain."""
return self._requeue_on_stale is not False
@property
def function(self) -> "TaskCallable[P, T]":
"""Wrapped callable."""
return self._func
async def __call__(self, *args: "P.args", **kwargs: "P.kwargs") -> "T":
"""Execute the wrapped callable directly.
Returns:
The wrapped callable result.
"""
result = self._func(*args, **kwargs)
if inspect.isawaitable(result):
return await result
return result
[docs]
async def execute_record(
self,
record: "QueuedTaskRecord",
*,
task_context: "TaskExecutionContext | None" = None,
extra_kwargs: "Mapping[str, object] | None" = None,
sync_executor: "Executor | None" = None,
) -> "T":
"""Execute this task for a queued record in worker context.
Returns:
The wrapped callable result.
"""
kwargs = dict(record.kwargs)
if extra_kwargs:
kwargs.update(extra_kwargs)
if self._accepts_job_id():
kwargs["_job_id"] = str(record.id)
if task_context is not None and self._accepts_task_context():
kwargs["_task_context"] = task_context
if inspect.iscoroutinefunction(self._func):
coroutine_func = cast("Callable[..., Awaitable[T]]", self._func)
return await coroutine_func(*record.args, **kwargs)
sync_func = cast("Callable[..., T]", self._func)
return await _run_sync_callable(sync_func, record.args, kwargs, sync_executor=sync_executor)
[docs]
def using(
self,
*,
queue: "str | None" = None,
priority: "int | None" = None,
retries: "int | None" = None,
timeout: "float | None" = None,
execution_backend: "str | None" = None,
execution_profile: "str | None" = None,
key: "str | None" = None,
run_after: "float | timedelta | None" = None,
description: "str | None" = None,
log_level: "str | None" = None,
quiet_success: "bool | None" = None,
requeue_on_stale: "bool | None" = None,
) -> "Task[P, T]":
"""Return a configured copy with enqueue overrides."""
return Task(
self._func,
name=self._name,
queue=queue if queue is not None else self._queue,
priority=priority if priority is not None else self._priority,
retries=retries if retries is not None else self._retries,
timeout=timeout if timeout is not None else self._timeout,
execution_backend=execution_backend if execution_backend is not None else self._execution_backend,
execution_profile=execution_profile if execution_profile is not None else self._execution_profile,
key=key if key is not None else self._key,
run_after=run_after if run_after is not None else self._run_after,
description=description if description is not None else self._description,
log_level=log_level if log_level is not None else self._log_level,
quiet_success=quiet_success if quiet_success is not None else self._quiet_success,
requeue_on_stale=requeue_on_stale if requeue_on_stale is not None else self._requeue_on_stale,
)
[docs]
async def enqueue(self, *args: "P.args", **kwargs: "P.kwargs") -> "TaskResult":
"""Enqueue this task using the configured default service or fall back to an immediate service.
Returns:
A result handle for the queued record.
"""
enqueue_kwargs = cast("dict[str, Any]", kwargs)
service = get_default_service()
if service is not None:
return await service.enqueue(cast("Task[Any, Any]", self), *args, **enqueue_kwargs)
from litestar_queues.config import QueueConfig
from litestar_queues.service import QueueService
async with QueueService(QueueConfig(execution_backend="immediate")) as service:
return await service.enqueue(cast("Task[Any, Any]", self), *args, **enqueue_kwargs)
def _accepts_job_id(self) -> "bool":
signature = inspect.signature(self._func)
parameters = signature.parameters
return "_job_id" in parameters or any(
param.kind == inspect.Parameter.VAR_KEYWORD for param in parameters.values()
)
def _accepts_task_context(self) -> "bool":
signature = inspect.signature(self._func)
parameters = signature.parameters
return "_task_context" in parameters or any(
param.kind == inspect.Parameter.VAR_KEYWORD for param in parameters.values()
)
[docs]
def get_task_registry() -> "dict[str, Task[Any, Any]]":
"""Return the global task registry."""
return _task_registry
[docs]
def get_scheduled_tasks() -> "dict[str, ScheduleConfig]":
"""Return the global scheduled task registry."""
return _schedule_registry
[docs]
def get_default_service() -> "QueueService | None":
"""Return the global default QueueService instance."""
return _default_service_holder[0]
[docs]
def set_default_service(service: "QueueService | None") -> "None":
"""Set the global default QueueService instance."""
_default_service_holder[0] = service
[docs]
def clear_task_registry() -> "None":
"""Clear task and schedule registries."""
_task_registry.clear()
_schedule_registry.clear()
_loaded_modules.clear()
_default_service_holder[0] = None
[docs]
def load_task_modules(modules: "tuple[str, ...] | list[str]", *, force_reload: "bool" = False) -> "int":
"""Import configured task modules so decorators register tasks.
Returns:
Number of imported modules.
"""
loaded = 0
for module_name in modules:
if module_name in _loaded_modules and not force_reload:
continue
_loaded_modules.add(module_name)
if force_reload or module_name in sys.modules:
module = reload(sys.modules[module_name])
else:
module = import_module(module_name)
loaded += 1
loaded += _load_child_modules(module, force_reload=force_reload)
return loaded
[docs]
def discover_tasks(package: "str", subpackage: "str" = "jobs", *, force_reload: "bool" = False) -> "tuple[str, ...]":
"""Walk ``package`` and import every ``<package>.<...>.<subpackage>.<...>`` module.
Adopters with ``app.domain.<x>.jobs/`` layouts can call this once at
startup so ``@task``-decorated callables register without having to
enumerate ``QueueConfig.task_modules`` by hand.
Args:
package: Dotted package name to walk (e.g. ``"app.domain"``).
subpackage: Path segment that marks task modules. Any module whose
dotted path (excluding the root) contains this segment is
imported. Defaults to ``"jobs"``.
force_reload: Re-import modules already in ``sys.modules``.
Returns:
Sorted, deduplicated tuple of task names registered after the walk.
Raises:
ModuleNotFoundError: If ``package`` cannot be imported, or if it
resolves to a plain module rather than a package.
"""
root = reload(sys.modules[package]) if force_reload and package in sys.modules else import_module(package)
if not hasattr(root, "__path__"):
msg = f"discover_tasks requires a package; {package!r} is a module"
raise ModuleNotFoundError(msg)
matched: "list[str]" = []
for _, module_name, _is_package in pkgutil.walk_packages(cast("Any", root).__path__, prefix=f"{root.__name__}."):
if subpackage not in module_name.split(".")[1:]:
continue
matched.append(module_name)
for module_name in matched:
if module_name in _loaded_modules and not force_reload:
continue
if force_reload and module_name in sys.modules:
reload(sys.modules[module_name])
else:
import_module(module_name)
_loaded_modules.add(module_name)
return tuple(sorted(_task_registry.keys()))
@overload
def task(func: "Callable[P, Awaitable[T]]", /) -> "Task[P, T]": ...
@overload
def task(func: "Callable[P, T]", /) -> "Task[P, T]": ...
@overload
def task(
name: "str | None" = None,
/,
*,
queue: "str" = "default",
priority: "int" = 0,
retries: "int" = 0,
timeout: "float | None" = None,
execution_backend: "str | None" = None,
execution_profile: "str | None" = None,
key: "str | None" = None,
run_after: "float | timedelta | None" = None,
description: "str | None" = None,
log_level: "str | None" = None,
quiet_success: "bool | None" = None,
requeue_on_stale: "bool | None" = None,
cron: "str | None" = None,
interval: "float | timedelta | None" = None,
timezone: "str" = "UTC",
initial_delay: "float | timedelta" = 0,
jitter: "float | timedelta" = 0,
max_instances: "int" = 1,
) -> "Callable[[AnyTaskCallable], Task[Any, Any]]": ...
[docs]
def task(
func_or_name: "AnyTaskCallable | str | None" = None,
/,
*,
queue: "str" = "default",
priority: "int" = 0,
retries: "int" = 0,
timeout: "float | None" = None,
execution_backend: "str | None" = None,
execution_profile: "str | None" = None,
key: "str | None" = None,
run_after: "float | timedelta | None" = None,
description: "str | None" = None,
log_level: "str | None" = None,
quiet_success: "bool | None" = None,
requeue_on_stale: "bool | None" = None,
cron: "str | None" = None,
interval: "float | timedelta | None" = None,
timezone: "str" = "UTC",
initial_delay: "float | timedelta" = 0,
jitter: "float | timedelta" = 0,
max_instances: "int" = 1,
) -> "Task[Any, Any] | Callable[[AnyTaskCallable], Task[Any, Any]]":
"""Register a callable as a queue task.
Returns:
A task wrapper when used bare, otherwise a decorator.
Raises:
ValueError: If both cron and interval are configured.
"""
if cron is not None and interval is not None:
msg = "Cannot specify both cron and interval"
raise ValueError(msg)
explicit_name = func_or_name if isinstance(func_or_name, str) else None
schedule = (
ScheduleConfig(
task_name=explicit_name or "",
cron=cron,
initial_delay=initial_delay,
interval=interval,
jitter=jitter,
max_instances=max_instances,
timeout=timeout,
timezone=timezone,
)
if cron is not None or interval is not None
else None
)
def decorator(func: "AnyTaskCallable") -> "Task[Any, Any]":
task_name = explicit_name or func.__name__
task_obj: "Task[Any, Any]" = Task(
cast("TaskCallable[..., Any]", func),
name=task_name,
queue=queue,
priority=priority,
retries=retries,
timeout=timeout,
execution_backend=execution_backend,
execution_profile=execution_profile,
key=key,
run_after=run_after,
description=description,
log_level=log_level,
quiet_success=quiet_success,
requeue_on_stale=requeue_on_stale,
)
_task_registry[task_name] = task_obj
if schedule is not None:
_schedule_registry[task_name] = ScheduleConfig(
task_name=task_name,
cron=schedule.cron,
initial_delay=cast("timedelta", schedule.initial_delay),
interval=cast("timedelta | None", schedule.interval),
jitter=cast("timedelta", schedule.jitter),
max_instances=schedule.max_instances,
timeout=schedule.timeout,
timezone=schedule.timezone,
)
return task_obj
if callable(func_or_name) and not isinstance(func_or_name, str):
return decorator(func_or_name)
return decorator
def _ensure_utc(value: "datetime") -> "datetime":
if value.tzinfo is None:
return value.replace(tzinfo=timezone.utc)
return value.astimezone(timezone.utc)
def _coerce_interval(value: "float | timedelta | None") -> "timedelta | None":
if value is None:
return None
if isinstance(value, timedelta):
return value
return timedelta(seconds=value)
def _get_timezone(name: "str") -> "zoneinfo.ZoneInfo":
try:
return zoneinfo.ZoneInfo(name)
except zoneinfo.ZoneInfoNotFoundError as exc:
msg = f"Invalid timezone: {name}"
raise ValueError(msg) from exc
def _cron_day_matches(parsed: "_ParsedCron", *, day: "int", weekday: "int") -> "bool":
day_matches = day in parsed.days
weekday_matches = weekday in parsed.weekdays
if parsed.day_of_month_restricted and parsed.day_of_week_restricted:
return day_matches or weekday_matches
return day_matches and weekday_matches
async def _run_sync_callable(
func: "Callable[..., T]", args: "tuple[Any, ...]", kwargs: "dict[str, Any]", *, sync_executor: "Executor | None"
) -> "T":
if sync_executor is None:
return await asyncio.to_thread(func, *args, **kwargs)
context = contextvars.copy_context()
call = partial(context.run, func, *args, **kwargs)
return await asyncio.get_running_loop().run_in_executor(sync_executor, call)
def _parse_cron_value(value: "str", names: "dict[str, int]") -> "int":
normalized = value.upper()
if normalized in names:
return names[normalized]
return int(value)
def _expand_cron_field(
field: "str",
*,
minimum: "int",
maximum: "int",
names: "dict[str, int] | None" = None,
normalize_sunday: "bool" = False,
allow_question: "bool" = False,
) -> "set[int]":
names = names or {}
if allow_question and field == "?":
return set(range(minimum, maximum + 1))
values: "set[int]" = set()
for raw_part in field.split(","):
part = raw_part.strip()
if not part:
msg = "Cron fields cannot be empty"
raise ValueError(msg)
if "/" in part:
range_part, step_part = part.split("/", 1)
step = int(step_part)
if step <= 0:
msg = "Cron step values must be positive"
raise ValueError(msg)
else:
range_part = part
step = 1
if range_part == "*":
start = minimum
end = maximum
elif "-" in range_part:
start_part, end_part = range_part.split("-", 1)
start = _parse_cron_value(start_part, names)
end = _parse_cron_value(end_part, names)
else:
start = _parse_cron_value(range_part, names)
end = maximum if "/" in part else start
if start > end:
msg = f"Invalid cron range: {raw_part}"
raise ValueError(msg)
if not minimum <= start <= maximum or not minimum <= end <= maximum:
msg = f"Cron value out of range: {raw_part}"
raise ValueError(msg)
values.update(range(start, end + 1, step))
if normalize_sunday and SUNDAY_CRON_VALUE in values:
values.remove(SUNDAY_CRON_VALUE)
values.add(0)
return values
def _load_child_modules(module: "ModuleType", *, force_reload: "bool") -> "int":
if not hasattr(module, "__path__"):
return 0
loaded = 0
module_paths = cast("Any", module).__path__
for _, module_name, is_package in pkgutil.walk_packages(module_paths, prefix=f"{module.__name__}."):
if is_package or (module_name in _loaded_modules and not force_reload):
continue
if force_reload and module_name in sys.modules:
reload(sys.modules[module_name])
else:
import_module(module_name)
_loaded_modules.add(module_name)
loaded += 1
return loaded