309 lines
8.4 KiB
Python
309 lines
8.4 KiB
Python
from dataclasses import dataclass, field, replace
|
|
from datetime import datetime, timedelta
|
|
from inspect import iscoroutinefunction
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Callable,
|
|
Generic,
|
|
Optional,
|
|
TypeVar,
|
|
Union,
|
|
cast,
|
|
overload,
|
|
)
|
|
|
|
from asgiref.sync import async_to_sync, sync_to_async
|
|
from django.db.models.enums import TextChoices
|
|
from django.utils.translation import gettext_lazy as _
|
|
from typing_extensions import ParamSpec, Self
|
|
|
|
from .exceptions import ResultDoesNotExist
|
|
from .utils import (
|
|
get_module_path,
|
|
json_normalize,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from .backends.base import BaseTaskBackend
|
|
|
|
DEFAULT_TASK_BACKEND_ALIAS = "default"
|
|
DEFAULT_QUEUE_NAME = "default"
|
|
MIN_PRIORITY = -100
|
|
MAX_PRIORITY = 100
|
|
DEFAULT_PRIORITY = 0
|
|
|
|
TASK_REFRESH_ATTRS = {
|
|
"_exception_class",
|
|
"_traceback",
|
|
"_return_value",
|
|
"finished_at",
|
|
"started_at",
|
|
"status",
|
|
"enqueued_at",
|
|
}
|
|
|
|
|
|
class ResultStatus(TextChoices):
|
|
NEW = ("NEW", _("New"))
|
|
RUNNING = ("RUNNING", _("Running"))
|
|
FAILED = ("FAILED", _("Failed"))
|
|
SUCCEEDED = ("SUCCEEDED", _("Succeeded"))
|
|
|
|
|
|
T = TypeVar("T")
|
|
P = ParamSpec("P")
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Task(Generic[P, T]):
|
|
priority: int
|
|
"""The priority of the task"""
|
|
|
|
func: Callable[P, T]
|
|
"""The task function"""
|
|
|
|
backend: str
|
|
"""The name of the backend the task will run on"""
|
|
|
|
queue_name: str = DEFAULT_QUEUE_NAME
|
|
"""The name of the queue the task will run on"""
|
|
|
|
run_after: Optional[datetime] = None
|
|
"""The earliest this task will run"""
|
|
|
|
enqueue_on_commit: Optional[bool] = None
|
|
"""
|
|
Whether the task will be enqueued when the current transaction commits,
|
|
immediately, or whatever the backend decides
|
|
"""
|
|
|
|
def __post_init__(self) -> None:
|
|
self.get_backend().validate_task(self)
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
"""
|
|
An identifier for the task
|
|
"""
|
|
return self.func.__name__
|
|
|
|
def using(
|
|
self,
|
|
*,
|
|
priority: Optional[int] = None,
|
|
queue_name: Optional[str] = None,
|
|
run_after: Optional[Union[datetime, timedelta]] = None,
|
|
backend: Optional[str] = None,
|
|
) -> Self:
|
|
"""
|
|
Create a new task with modified defaults
|
|
"""
|
|
|
|
changes: dict[str, Any] = {}
|
|
|
|
if priority is not None:
|
|
changes["priority"] = priority
|
|
if queue_name is not None:
|
|
changes["queue_name"] = queue_name
|
|
if run_after is not None:
|
|
changes["run_after"] = run_after
|
|
if backend is not None:
|
|
changes["backend"] = backend
|
|
|
|
return replace(self, **changes)
|
|
|
|
def enqueue(self, *args: P.args, **kwargs: P.kwargs) -> "TaskResult[T]":
|
|
"""
|
|
Queue up the task to be executed
|
|
"""
|
|
return self.get_backend().enqueue(
|
|
self, json_normalize(args), json_normalize(kwargs)
|
|
)
|
|
|
|
async def aenqueue(self, *args: P.args, **kwargs: P.kwargs) -> "TaskResult[T]":
|
|
"""
|
|
Queue up a task function (or coroutine) to be executed
|
|
"""
|
|
return await self.get_backend().aenqueue(
|
|
self, json_normalize(args), json_normalize(kwargs)
|
|
)
|
|
|
|
def get_result(self, result_id: str) -> "TaskResult[T]":
|
|
"""
|
|
Retrieve the result for a task of this type by its id (if one exists).
|
|
If one doesn't, or is the wrong type, raises ResultDoesNotExist.
|
|
"""
|
|
result = self.get_backend().get_result(result_id)
|
|
|
|
if result.task.func != self.func:
|
|
raise ResultDoesNotExist
|
|
|
|
return result
|
|
|
|
async def aget_result(self, result_id: str) -> "TaskResult[T]":
|
|
"""
|
|
Retrieve the result for a task of this type by its id (if one exists).
|
|
If one doesn't, or is the wrong type, raises ResultDoesNotExist.
|
|
"""
|
|
result = await self.get_backend().aget_result(result_id)
|
|
|
|
if result.task.func != self.func:
|
|
raise ResultDoesNotExist
|
|
|
|
return result
|
|
|
|
def call(self, *args: P.args, **kwargs: P.kwargs) -> T:
|
|
if iscoroutinefunction(self.func):
|
|
return async_to_sync(self.func)(*args, **kwargs) # type:ignore[no-any-return]
|
|
return self.func(*args, **kwargs)
|
|
|
|
async def acall(self, *args: P.args, **kwargs: P.kwargs) -> T:
|
|
if iscoroutinefunction(self.func):
|
|
return await self.func(*args, **kwargs) # type:ignore[no-any-return]
|
|
return await sync_to_async(self.func)(*args, **kwargs)
|
|
|
|
def get_backend(self) -> "BaseTaskBackend":
|
|
from . import tasks
|
|
|
|
return tasks[self.backend]
|
|
|
|
@property
|
|
def module_path(self) -> str:
|
|
return get_module_path(self.func)
|
|
|
|
|
|
# Bare decorator usage
|
|
# e.g. @task
|
|
@overload
|
|
def task(function: Callable[P, T], /) -> Task[P, T]: ...
|
|
|
|
|
|
# Decorator with arguments
|
|
# e.g. @task() or @task(priority=1, ...)
|
|
@overload
|
|
def task(
|
|
*,
|
|
priority: int = DEFAULT_PRIORITY,
|
|
queue_name: str = DEFAULT_QUEUE_NAME,
|
|
backend: str = DEFAULT_TASK_BACKEND_ALIAS,
|
|
enqueue_on_commit: Optional[bool] = None,
|
|
) -> Callable[[Callable[P, T]], Task[P, T]]: ...
|
|
|
|
|
|
# Implementation
|
|
def task(
|
|
function: Optional[Callable[P, T]] = None,
|
|
*,
|
|
priority: int = DEFAULT_PRIORITY,
|
|
queue_name: str = DEFAULT_QUEUE_NAME,
|
|
backend: str = DEFAULT_TASK_BACKEND_ALIAS,
|
|
enqueue_on_commit: Optional[bool] = None,
|
|
) -> Union[Task[P, T], Callable[[Callable[P, T]], Task[P, T]]]:
|
|
"""
|
|
A decorator used to create a task.
|
|
"""
|
|
from . import tasks
|
|
|
|
def wrapper(f: Callable[P, T]) -> Task[P, T]:
|
|
return tasks[backend].task_class(
|
|
priority=priority,
|
|
func=f,
|
|
queue_name=queue_name,
|
|
backend=backend,
|
|
enqueue_on_commit=enqueue_on_commit,
|
|
)
|
|
|
|
if function:
|
|
return wrapper(function)
|
|
|
|
return wrapper
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class TaskResult(Generic[T]):
|
|
task: Task
|
|
"""The task for which this is a result"""
|
|
|
|
id: str
|
|
"""A unique identifier for the task result"""
|
|
|
|
status: ResultStatus
|
|
"""The status of the running task"""
|
|
|
|
enqueued_at: Optional[datetime]
|
|
"""The time this task was enqueued"""
|
|
|
|
started_at: Optional[datetime]
|
|
"""The time this task was started"""
|
|
|
|
finished_at: Optional[datetime]
|
|
"""The time this task was finished"""
|
|
|
|
args: list
|
|
"""The arguments to pass to the task function"""
|
|
|
|
kwargs: dict[str, Any]
|
|
"""The keyword arguments to pass to the task function"""
|
|
|
|
backend: str
|
|
"""The name of the backend the task will run on"""
|
|
|
|
_exception_class: Optional[type[BaseException]] = field(init=False, default=None)
|
|
_traceback: Optional[str] = field(init=False, default=None)
|
|
|
|
_return_value: Optional[T] = field(init=False, default=None)
|
|
|
|
@property
|
|
def return_value(self) -> Optional[T]:
|
|
"""
|
|
Get the return value of the task.
|
|
|
|
If the task didn't succeed, an exception is raised.
|
|
This is to distinguish against the task returning None.
|
|
"""
|
|
if self.status == ResultStatus.SUCCEEDED:
|
|
return cast(T, self._return_value)
|
|
elif self.status == ResultStatus.FAILED:
|
|
raise ValueError("Task failed")
|
|
else:
|
|
raise ValueError("Task has not finished yet")
|
|
|
|
@property
|
|
def exception_class(self) -> Optional[type[BaseException]]:
|
|
"""The exception raised by the task function"""
|
|
if not self.is_finished:
|
|
raise ValueError("Task has not finished yet")
|
|
|
|
return self._exception_class
|
|
|
|
@property
|
|
def traceback(self) -> Optional[str]:
|
|
"""The traceback of the exception if the task failed"""
|
|
if not self.is_finished:
|
|
raise ValueError("Task has not finished yet")
|
|
|
|
return self._traceback
|
|
|
|
@property
|
|
def is_finished(self) -> bool:
|
|
"""Has the task finished?"""
|
|
return self.status in {ResultStatus.FAILED, ResultStatus.SUCCEEDED}
|
|
|
|
def refresh(self) -> None:
|
|
"""
|
|
Reload the cached task data from the task store
|
|
"""
|
|
refreshed_task = self.task.get_backend().get_result(self.id)
|
|
|
|
for attr in TASK_REFRESH_ATTRS:
|
|
object.__setattr__(self, attr, getattr(refreshed_task, attr))
|
|
|
|
async def arefresh(self) -> None:
|
|
"""
|
|
Reload the cached task data from the task store
|
|
"""
|
|
refreshed_task = await self.task.get_backend().aget_result(self.id)
|
|
|
|
for attr in TASK_REFRESH_ATTRS:
|
|
object.__setattr__(self, attr, getattr(refreshed_task, attr))
|