-
Notifications
You must be signed in to change notification settings - Fork 1.7k
ENG-9348: Lifespan tasks execute in registration order #6334
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
51449c5
a5f8ca1
7d92f45
f804a88
06d1040
1be80e8
9918f4d
3cea108
05d405f
7b68106
0948f18
4da2dd1
e437699
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,37 +9,93 @@ | |
| import inspect | ||
| import time | ||
| from collections.abc import Callable, Coroutine | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| from reflex_base.utils import console | ||
| from reflex_base.utils.exceptions import InvalidLifespanTaskTypeError | ||
| from starlette.applications import Starlette | ||
|
|
||
| from .mixin import AppMixin | ||
|
|
||
| if TYPE_CHECKING: | ||
| from typing_extensions import deprecated | ||
|
|
||
|
|
||
| def _get_task_name(task: asyncio.Task | Callable) -> str: | ||
| """Get a display name for a lifespan task. | ||
|
|
||
| Args: | ||
| task: The task to get the name for. | ||
|
|
||
| Returns: | ||
| The name of the task. | ||
| """ | ||
| if isinstance(task, asyncio.Task): | ||
| return task.get_name() | ||
| return task.__name__ # pyright: ignore[reportAttributeAccessIssue] | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class LifespanMixin(AppMixin): | ||
| """A Mixin that allow tasks to run during the whole app lifespan. | ||
|
|
||
| Attributes: | ||
| lifespan_tasks: Lifespan tasks that are planned to run. | ||
| lifespan_tasks: Set of lifespan tasks that are planned to run (deprecated). | ||
| """ | ||
|
|
||
| lifespan_tasks: set[asyncio.Task | Callable] = dataclasses.field( | ||
| default_factory=set | ||
| _lifespan_tasks: dict[asyncio.Task | Callable, None] = dataclasses.field( | ||
| default_factory=dict, init=False, repr=False | ||
| ) | ||
| _lifespan_tasks_started: bool = dataclasses.field( | ||
| default=False, init=False, repr=False | ||
| ) | ||
|
|
||
| if TYPE_CHECKING: | ||
| # Static deprecation warning for IDE/type checkers. | ||
| @property | ||
| @deprecated("Use get_lifespan_tasks method instead.") | ||
| def lifespan_tasks(self) -> frozenset[asyncio.Task | Callable]: | ||
| """Get a copy of registered lifespan tasks (deprecated).""" | ||
| ... | ||
|
|
||
| else: | ||
|
|
||
| @property | ||
| def lifespan_tasks(self) -> frozenset[asyncio.Task | Callable]: | ||
| """Get a copy of registered lifespan tasks. | ||
|
|
||
| Returns: | ||
| A frozenset of registered lifespan tasks. | ||
| """ | ||
| # Runtime deprecation warning prints to the console when accessed. | ||
| console.deprecate( | ||
| feature_name="LifespanMixin.lifespan_tasks", | ||
| reason="Use get_lifespan_tasks method instead to get a copy of registered lifespan tasks.", | ||
| deprecation_version="0.9.0", | ||
| removal_version="1.0", | ||
| ) | ||
| return frozenset(self._lifespan_tasks) | ||
|
|
||
| def get_lifespan_tasks(self) -> tuple[asyncio.Task | Callable, ...]: | ||
| """Get a copy of currently registered lifespan tasks. | ||
|
|
||
| Returns: | ||
| A tuple of registered lifespan tasks. | ||
| """ | ||
| return tuple(self._lifespan_tasks) | ||
|
|
||
| @contextlib.asynccontextmanager | ||
| async def _run_lifespan_tasks(self, app: Starlette): | ||
| self._lifespan_tasks_started = True | ||
| running_tasks = [] | ||
| try: | ||
| async with contextlib.AsyncExitStack() as stack: | ||
| for task in self.lifespan_tasks: | ||
| run_msg = f"Started lifespan task: {task.__name__} as {{type}}" # pyright: ignore [reportAttributeAccessIssue] | ||
| for task in self._lifespan_tasks: | ||
| task_name = _get_task_name(task) | ||
| run_msg = f"Started lifespan task: {task_name} as {{type}}" | ||
| if isinstance(task, asyncio.Task): | ||
|
Comment on lines
92
to
96
|
||
| running_tasks.append(task) | ||
|
masenf marked this conversation as resolved.
|
||
| else: | ||
| task_name = task.__name__ | ||
| signature = inspect.signature(task) | ||
| if "app" in signature.parameters: | ||
| task = functools.partial(task, app=app) | ||
|
|
@@ -90,15 +146,22 @@ def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs): | |
|
|
||
| Raises: | ||
| InvalidLifespanTaskTypeError: If the task is a generator function. | ||
| RuntimeError: If lifespan tasks are already running. | ||
| """ | ||
| if self._lifespan_tasks_started: | ||
| msg = ( | ||
| f"Cannot register lifespan task {_get_task_name(task)!r} after " | ||
| "lifespan has started. Register all tasks before the app starts." | ||
| ) | ||
| raise RuntimeError(msg) | ||
| if inspect.isgeneratorfunction(task) or inspect.isasyncgenfunction(task): | ||
| msg = f"Task {task.__name__} of type generator must be decorated with contextlib.asynccontextmanager." | ||
| raise InvalidLifespanTaskTypeError(msg) | ||
|
|
||
| task_name = task.__name__ # pyright: ignore [reportAttributeAccessIssue] | ||
| task_name = _get_task_name(task) | ||
| if task_kwargs: | ||
| original_task = task | ||
| task = functools.partial(task, **task_kwargs) # pyright: ignore [reportArgumentType] | ||
| functools.update_wrapper(task, original_task) # pyright: ignore [reportArgumentType] | ||
| self.lifespan_tasks.add(task) | ||
| self._lifespan_tasks[task] = None | ||
| console.debug(f"Registered lifespan task: {task_name}") | ||
Uh oh!
There was an error while loading. Please reload this page.