Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions docs/utility_methods/lifespan_tasks.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ Lifespan tasks are defined as async coroutines or async contextmanagers. To avoi
blocking the event thread, never use `time.sleep` or perform non-async I/O within
a lifespan task.

Tasks execute in the order they are registered.

In dev mode, lifespan tasks will stop and restart when a hot-reload occurs.

## Tasks
Expand Down Expand Up @@ -38,14 +40,23 @@ async def long_running_task(foo, bar):
To register a lifespan task, use `app.register_lifespan_task(coro_func, **kwargs)`.
Any keyword arguments specified during registration will be passed to the task.

If the task accepts the special argument, `app`, it will be an instance of the `FastAPI` object
associated with the app.
If the task accepts the special argument, `app`, it will be passed the `Starlette`
application instance.

```python
app = rx.App()
app.register_lifespan_task(long_running_task, foo=42, bar=os.environ["BAR_PARAM"])
```

All tasks must be registered before the app starts. Calling
`register_lifespan_task` after the lifespan has begun (for example, from an
event handler or from within another lifespan task) will raise a `RuntimeError`.

### Inspecting Registered Tasks

To get the currently registered lifespan tasks, use `app.get_lifespan_tasks()`,
which returns a `tuple` of tasks in registration order.

## Context Managers

Lifespan tasks can also be defined as async contextmanagers. This is useful for
Expand All @@ -55,9 +66,6 @@ protocol.
Code up to the first `yield` will run when the backend comes up. As the backend
is shutting down, the code after the `yield` will run to clean up.

Here is an example borrowed from the FastAPI docs and modified to work with this
interface.

```python
from contextlib import asynccontextmanager

Expand All @@ -70,7 +78,7 @@ ml_models = \{}


@asynccontextmanager
async def setup_model(app: FastAPI):
async def setup_model(app):
# Load the ML model
ml_models["answer_to_everything"] = fake_answer_to_everything_ml_model
yield
Expand Down
79 changes: 71 additions & 8 deletions reflex/app_mixins/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,37 +9,93 @@
import inspect
import time
from collections.abc import Callable, Coroutine
from typing import TYPE_CHECKING

from reflex_base.utils import console
from reflex_base.utils.exceptions import InvalidLifespanTaskTypeError
from starlette.applications import Starlette

from .mixin import AppMixin

if TYPE_CHECKING:
from typing_extensions import deprecated


def _get_task_name(task: asyncio.Task | Callable) -> str:
"""Get a display name for a lifespan task.

Args:
task: The task to get the name for.

Returns:
The name of the task.
"""
if isinstance(task, asyncio.Task):
return task.get_name()
return task.__name__ # pyright: ignore[reportAttributeAccessIssue]


@dataclasses.dataclass
class LifespanMixin(AppMixin):
"""A Mixin that allow tasks to run during the whole app lifespan.

Attributes:
lifespan_tasks: Lifespan tasks that are planned to run.
lifespan_tasks: Set of lifespan tasks that are planned to run (deprecated).
"""

lifespan_tasks: set[asyncio.Task | Callable] = dataclasses.field(
default_factory=set
_lifespan_tasks: dict[asyncio.Task | Callable, None] = dataclasses.field(
default_factory=dict, init=False, repr=False
)
_lifespan_tasks_started: bool = dataclasses.field(
default=False, init=False, repr=False
)

if TYPE_CHECKING:
# Static deprecation warning for IDE/type checkers.
@property
@deprecated("Use get_lifespan_tasks method instead.")
def lifespan_tasks(self) -> frozenset[asyncio.Task | Callable]:
"""Get a copy of registered lifespan tasks (deprecated)."""
...

else:

@property
def lifespan_tasks(self) -> frozenset[asyncio.Task | Callable]:
"""Get a copy of registered lifespan tasks.

Returns:
A frozenset of registered lifespan tasks.
"""
# Runtime deprecation warning prints to the console when accessed.
console.deprecate(
feature_name="LifespanMixin.lifespan_tasks",
reason="Use get_lifespan_tasks method instead to get a copy of registered lifespan tasks.",
deprecation_version="0.9.0",
removal_version="1.0",
)
return frozenset(self._lifespan_tasks)
Comment thread
masenf marked this conversation as resolved.

def get_lifespan_tasks(self) -> tuple[asyncio.Task | Callable, ...]:
"""Get a copy of currently registered lifespan tasks.

Returns:
A tuple of registered lifespan tasks.
"""
return tuple(self._lifespan_tasks)

@contextlib.asynccontextmanager
async def _run_lifespan_tasks(self, app: Starlette):
self._lifespan_tasks_started = True
running_tasks = []
try:
async with contextlib.AsyncExitStack() as stack:
for task in self.lifespan_tasks:
run_msg = f"Started lifespan task: {task.__name__} as {{type}}" # pyright: ignore [reportAttributeAccessIssue]
for task in self._lifespan_tasks:
task_name = _get_task_name(task)
run_msg = f"Started lifespan task: {task_name} as {{type}}"
if isinstance(task, asyncio.Task):
Comment on lines 92 to 96
Copy link

Copilot AI Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Iterating directly over self._lifespan_tasks will raise RuntimeError: dictionary changed size during iteration if a lifespan task registers additional lifespan tasks while startup is in progress (e.g., an async contextmanager that calls register_lifespan_task). Even if it doesn’t crash, tasks added mid-iteration won’t be visited and therefore won’t be tracked in running_tasks for cancellation. Consider iterating over a snapshot and/or using an index-based loop that can safely pick up newly-registered tasks (preserving registration order) so dynamically-registered asyncio.Tasks are also cancelled on shutdown.

Copilot uses AI. Check for mistakes.
running_tasks.append(task)
Comment thread
masenf marked this conversation as resolved.
else:
task_name = task.__name__
signature = inspect.signature(task)
if "app" in signature.parameters:
task = functools.partial(task, app=app)
Expand Down Expand Up @@ -90,15 +146,22 @@ def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):

Raises:
InvalidLifespanTaskTypeError: If the task is a generator function.
RuntimeError: If lifespan tasks are already running.
"""
if self._lifespan_tasks_started:
msg = (
f"Cannot register lifespan task {_get_task_name(task)!r} after "
"lifespan has started. Register all tasks before the app starts."
)
raise RuntimeError(msg)
if inspect.isgeneratorfunction(task) or inspect.isasyncgenfunction(task):
msg = f"Task {task.__name__} of type generator must be decorated with contextlib.asynccontextmanager."
raise InvalidLifespanTaskTypeError(msg)

task_name = task.__name__ # pyright: ignore [reportAttributeAccessIssue]
task_name = _get_task_name(task)
if task_kwargs:
original_task = task
task = functools.partial(task, **task_kwargs) # pyright: ignore [reportArgumentType]
functools.update_wrapper(task, original_task) # pyright: ignore [reportArgumentType]
self.lifespan_tasks.add(task)
self._lifespan_tasks[task] = None
console.debug(f"Registered lifespan task: {task_name}")
126 changes: 126 additions & 0 deletions tests/integration/test_lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@ def LifespanApp(
from contextlib import asynccontextmanager

import reflex as rx
from reflex.istate.manager.token import BaseStateToken

lifespan_task_global = 0
lifespan_context_global = 0
raw_asyncio_task_global = 0
connected_tokens: set[str] = set()

@asynccontextmanager
async def lifespan_context(app, inc: int = 1): # noqa: RUF029
Expand All @@ -52,13 +55,47 @@ async def lifespan_task(inc: int = 1):
print(f"Lifespan global cancelled: {ce}.")
lifespan_task_global = 0

async def raw_asyncio_task_coro():
global raw_asyncio_task_global
print("Raw asyncio task started.")
try:
while True:
raw_asyncio_task_global += 1 # pyright: ignore[reportUnboundVariable, reportPossiblyUnboundVariable]
await asyncio.sleep(0.1)
except asyncio.CancelledError as ce:
print(f"Raw asyncio task cancelled: {ce}.")
raw_asyncio_task_global = 0

@asynccontextmanager
async def assert_register_blocked_during_lifespan(app): # noqa: RUF029
"""Negative test: registering a task after lifespan has started must raise."""
from reflex.utils.prerequisites import get_app

reflex_app = get_app().app
task = asyncio.create_task(raw_asyncio_task_coro(), name="raw_asyncio_task")
try:
reflex_app.register_lifespan_task(task)
except RuntimeError as exc:
print(f"Expected RuntimeError: {exc}")
else:
msg = "register_lifespan_task should have raised RuntimeError"
raise AssertionError(msg)
finally:
task.cancel()
yield

class LifespanState(rx.State):
interval: int = 100
modify_count: int = 0

@rx.event
def set_interval(self, interval: int):
self.interval = interval

@rx.event
def register_token(self):
connected_tokens.add(self.router.session.client_token)

@rx.var(cache=False)
def task_global(self) -> int:
return lifespan_task_global
Expand All @@ -67,14 +104,36 @@ def task_global(self) -> int:
def context_global(self) -> int:
return lifespan_context_global

@rx.var(cache=False)
def asyncio_task_global(self) -> int:
return raw_asyncio_task_global

@rx.event
def tick(self, date):
pass

async def modify_state_task():
from reflex.utils.prerequisites import get_app

reflex_app = get_app().app
try:
while True:
for token in list(connected_tokens):
async with reflex_app.modify_state(
BaseStateToken(ident=token, cls=LifespanState)
) as state:
lifespan_state = await state.get_state(LifespanState)
lifespan_state.modify_count += 1
await asyncio.sleep(0.1)
except asyncio.CancelledError:
print("modify_state_task cancelled.")

def index():
return rx.vstack(
rx.text(LifespanState.task_global, id="task_global"),
rx.text(LifespanState.context_global, id="context_global"),
rx.text(LifespanState.modify_count, id="modify_count"),
rx.text(LifespanState.asyncio_task_global, id="asyncio_task_global"),
rx.button(
rx.moment(
interval=LifespanState.interval, on_change=LifespanState.tick
Expand All @@ -84,6 +143,7 @@ def index():
),
id="toggle-tick",
),
on_mount=LifespanState.register_token,
)

from fastapi import FastAPI
Expand All @@ -95,6 +155,9 @@ def index():

app.register_lifespan_task(lifespan_task)
app.register_lifespan_task(lifespan_context, inc=2)
app.register_lifespan_task(raw_asyncio_task_coro)
app.register_lifespan_task(assert_register_blocked_during_lifespan)
app.register_lifespan_task(modify_state_task)
app.add_page(index)


Expand Down Expand Up @@ -160,6 +223,63 @@ def lifespan_app(
yield harness


def test_lifespan_modify_state(lifespan_app: AppHarness):
"""Test that a lifespan task can use app.modify_state to push state updates.

Args:
lifespan_app: harness for LifespanApp app
"""
assert lifespan_app.app_module is not None, "app module is not found"
assert lifespan_app.app_instance is not None, "app is not running"
driver = lifespan_app.frontend()
Comment thread
masenf marked this conversation as resolved.

ss = SessionStorage(driver)
assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found"

modify_count = driver.find_element(By.ID, "modify_count")

# Wait for modify_count to become non-zero (lifespan task is pushing updates)
assert lifespan_app.poll_for_content(modify_count, exp_not_equal="0")

# Verify it continues to increase
first_value = modify_count.text
next_value = lifespan_app.poll_for_content(modify_count, exp_not_equal=first_value)
assert int(next_value) > int(first_value)


def test_lifespan_raw_asyncio_task(lifespan_app: AppHarness):
"""Test that a coroutine function registered as a lifespan task runs as an asyncio.Task.

Args:
lifespan_app: harness for LifespanApp app
"""
assert lifespan_app.app_module is not None, "app module is not found"
assert lifespan_app.app_instance is not None, "app is not running"
driver = lifespan_app.frontend()

ss = SessionStorage(driver)
assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found"

asyncio_task_global = driver.find_element(By.ID, "asyncio_task_global")

# Wait for asyncio_task_global to become non-zero
assert lifespan_app.poll_for_content(asyncio_task_global, exp_not_equal="0")

# Verify it continues to increase
first_value = asyncio_task_global.text
next_value = lifespan_app.poll_for_content(
asyncio_task_global, exp_not_equal=first_value
)
assert int(next_value) > int(first_value)
assert lifespan_app.app_module.raw_asyncio_task_global > 0


# --- test_lifespan MUST be the last test in this file. ---
# It shuts down the backend and asserts cancellation of lifespan tasks.
# The lifespan_app fixture is session-scoped (expensive to rebuild), so all
# other tests that need a running backend must be defined ABOVE this point.


def test_lifespan(lifespan_app: AppHarness):
"""Test the lifespan integration.

Expand Down Expand Up @@ -195,3 +315,9 @@ def test_lifespan(lifespan_app: AppHarness):
# Check that the lifespan tasks have been cancelled
assert lifespan_app.app_module.lifespan_task_global == 0
assert lifespan_app.app_module.lifespan_context_global == 4
assert lifespan_app.app_module.raw_asyncio_task_global == 0


# --- Do NOT add new test cases below this line. ---
# test_lifespan (above) kills the backend; any test defined after it will
# find the harness in a stopped state and fail.
Loading