Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions docs/utility_methods/lifespan_tasks.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,11 @@ async def long_running_task(foo, bar):
To register a lifespan task, use `app.register_lifespan_task(coro_func, **kwargs)`.
Any keyword arguments specified during registration will be passed to the task.

If the task accepts the special argument, `app`, it will be passed the `Starlette`
application instance.
If the task accepts the special argument, `app`, it will be passed the Reflex app
instance (`rx.App`/`LifespanMixin`).

If the task accepts the special argument, `starlette_app`, it will be passed the
underlying `Starlette` application instance.

```python
app = rx.App()
Expand Down
6 changes: 4 additions & 2 deletions reflex/app_mixins/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def get_lifespan_tasks(self) -> tuple[asyncio.Task | Callable, ...]:
return tuple(self._lifespan_tasks)

@contextlib.asynccontextmanager
async def _run_lifespan_tasks(self, app: Starlette):
async def _run_lifespan_tasks(self, starlette_app: Starlette):
self._lifespan_tasks_started = True
running_tasks = []
try:
Expand All @@ -100,7 +100,9 @@ async def _run_lifespan_tasks(self, app: Starlette):
else:
signature = inspect.signature(task)
if "app" in signature.parameters:
task = functools.partial(task, app=app)
task = functools.partial(task, app=self)
if "starlette_app" in signature.parameters:
task = functools.partial(task, starlette_app=starlette_app)
t_ = task()
if isinstance(t_, contextlib._AsyncGeneratorContextManager):
await stack.enter_async_context(t_)
Expand Down
81 changes: 81 additions & 0 deletions tests/units/app_mixins/test_lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import pytest
from reflex_base.utils.exceptions import InvalidLifespanTaskTypeError
from starlette.applications import Starlette

from reflex.app_mixins.lifespan import LifespanMixin

Expand Down Expand Up @@ -38,6 +39,7 @@ def check_for_updates(timeout: int) -> int:
assert registered_task() == 10


@pytest.mark.asyncio
async def test_register_lifespan_task_rejects_kwargs_for_asyncio_task():
"""Registering kwargs against an asyncio.Task raises a clear error."""
mixin = LifespanMixin()
Expand All @@ -53,3 +55,82 @@ async def test_register_lifespan_task_rejects_kwargs_for_asyncio_task():
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task


@pytest.mark.asyncio
async def test_lifespan_task_app_param_receives_reflex_app_instance():
"""Lifespan tasks should receive the Reflex app instance, not Starlette."""

class DummyApp(LifespanMixin):
"""Minimal test app based on the lifespan mixin."""

app = DummyApp()
received: dict[str, object] = {}

def lifespan_task(app):
"""Record the app argument injected by the lifespan runner."""
received["app"] = app

app.register_lifespan_task(lifespan_task)

async with app._run_lifespan_tasks(Starlette()):
await asyncio.sleep(0)

assert received["app"] is app
Comment on lines +60 to +79
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing test coverage for starlette_app injection

The PR introduces starlette_app as a new injectable parameter for lifespan tasks, but the test suite only verifies the app (Reflex instance) injection path. A complementary test covering starlette_app injection — and ideally a task that declares both parameters simultaneously — would complete coverage for the new feature.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lemme add



@pytest.mark.asyncio
async def test_lifespan_task_starlette_app_param_receives_starlette_instance():
"""Lifespan tasks should receive the Starlette app when requested."""

class DummyApp(LifespanMixin):
"""Minimal test app based on the lifespan mixin."""

app = DummyApp()
received: dict[str, object] = {}
starlette_app = Starlette()

def lifespan_task(starlette_app):
"""Record the Starlette app argument injected by the lifespan runner.

Args:
starlette_app: Starlette app object injected by the lifespan runner.
"""
received["starlette_app"] = starlette_app

app.register_lifespan_task(lifespan_task)

async with app._run_lifespan_tasks(starlette_app):
await asyncio.sleep(0)

assert received["starlette_app"] is starlette_app


@pytest.mark.asyncio
async def test_lifespan_task_both_app_and_starlette_app_params_are_injected():
"""Lifespan tasks should receive both app and starlette_app when declared."""

class DummyApp(LifespanMixin):
"""Minimal test app based on the lifespan mixin."""

app = DummyApp()
received: dict[str, object] = {}
starlette_app = Starlette()

def lifespan_task(app, starlette_app):
"""Record both injected app objects from the lifespan runner.

Args:
app: Reflex app object injected by the lifespan runner.
starlette_app: Starlette app object injected by the lifespan runner.
"""
received["app"] = app
received["starlette_app"] = starlette_app

app.register_lifespan_task(lifespan_task)

async with app._run_lifespan_tasks(starlette_app):
await asyncio.sleep(0)

assert received["app"] is app
assert received["starlette_app"] is starlette_app
Loading