diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..209a865 --- /dev/null +++ b/Makefile @@ -0,0 +1,4 @@ +.PHONY: test + +test: + uv run --dev pytest diff --git a/README.md b/README.md index 2928d33..22fa668 100644 --- a/README.md +++ b/README.md @@ -85,6 +85,10 @@ headers of the SSE response yourself. A datastar response consists of 0..N datastar events. There are response classes included to make this easy in all of the supported frameworks. +Each framework also exposes a `@datastar_response` decorator that will wrap +return values (including generators) into the right response class while +preserving sync handlers as sync so frameworks can keep them in their +threadpools. The following examples will work across all supported frameworks when the response class is imported from the appropriate framework package. diff --git a/pyproject.toml b/pyproject.toml index 3ff0506..e7f4c24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,12 +50,14 @@ urls.GitHub = "https://github.com/starfederation/datastar-python" dev = [ "django>=4.2.23", "fastapi>=0.116.1", + "httpx>=0.27", "litestar>=2.17", "pre-commit>=4.2", "python-fasthtml>=0.12.25; python_full_version>='3.10'", "quart>=0.20", "sanic>=25.3", "starlette>=0.47.3", + "uvicorn>=0.30", ] [tool.ruff] diff --git a/src/datastar_py/django.py b/src/datastar_py/django.py index 1c14b92..629da56 100644 --- a/src/datastar_py/django.py +++ b/src/datastar_py/django.py @@ -2,6 +2,7 @@ from collections.abc import Awaitable, Mapping from functools import wraps +from inspect import isawaitable from typing import Any, Callable, ParamSpec from django.http import HttpRequest @@ -54,7 +55,17 @@ def datastar_response( @wraps(func) async def wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse: r = func(*args, **kwargs) - if isinstance(r, Awaitable): + + if hasattr(r, "__aiter__"): + raise NotImplementedError( + "Async generators/iterables are not yet supported by the Django adapter; " + "use a sync generator or return a single value/awaitable instead." + ) + + if hasattr(r, "__iter__") and not isinstance(r, (str, bytes)): + return DatastarResponse(r) + + if isawaitable(r): return DatastarResponse(await r) return DatastarResponse(r) diff --git a/src/datastar_py/litestar.py b/src/datastar_py/litestar.py index 97a3a5f..3cc24ce 100644 --- a/src/datastar_py/litestar.py +++ b/src/datastar_py/litestar.py @@ -2,6 +2,7 @@ from collections.abc import Awaitable, Mapping from functools import wraps +from inspect import isawaitable from typing import ( TYPE_CHECKING, Any, @@ -64,17 +65,28 @@ def __init__( def datastar_response( func: Callable[P, Awaitable[DatastarEvents] | DatastarEvents], -) -> Callable[P, Awaitable[DatastarResponse]]: +) -> Callable[P, DatastarResponse]: """A decorator which wraps a function result in DatastarResponse. Can be used on a sync or async function or generator function. """ @wraps(func) - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse: + def wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse: r = func(*args, **kwargs) - if isinstance(r, Awaitable): - return DatastarResponse(await r) + + if hasattr(r, "__aiter__"): + return DatastarResponse(r) + + if hasattr(r, "__iter__") and not isinstance(r, (str, bytes)): + return DatastarResponse(r) + + if isawaitable(r): + async def await_and_yield(): + yield await r + + return DatastarResponse(await_and_yield()) + return DatastarResponse(r) wrapper.__annotations__["return"] = DatastarResponse diff --git a/src/datastar_py/sanic.py b/src/datastar_py/sanic.py index 283465a..8122a9e 100644 --- a/src/datastar_py/sanic.py +++ b/src/datastar_py/sanic.py @@ -4,6 +4,7 @@ from contextlib import aclosing, closing from functools import wraps from inspect import isasyncgen, isgenerator +from inspect import isawaitable from typing import Any, Callable, ParamSpec, Union from sanic import HTTPResponse, Request @@ -70,7 +71,7 @@ def datastar_response( @wraps(func) async def wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse | None: r = func(*args, **kwargs) - if isinstance(r, Awaitable): + if isawaitable(r): return DatastarResponse(await r) if isasyncgen(r): request = args[0] diff --git a/src/datastar_py/starlette.py b/src/datastar_py/starlette.py index 53989af..f1da760 100644 --- a/src/datastar_py/starlette.py +++ b/src/datastar_py/starlette.py @@ -2,6 +2,7 @@ from collections.abc import Awaitable, Mapping from functools import wraps +from inspect import isawaitable from typing import ( TYPE_CHECKING, Any, @@ -54,17 +55,33 @@ def __init__( def datastar_response( func: Callable[P, Awaitable[DatastarEvents] | DatastarEvents], -) -> Callable[P, Awaitable[DatastarResponse]]: +) -> Callable[P, DatastarResponse]: """A decorator which wraps a function result in DatastarResponse. Can be used on a sync or async function or generator function. """ @wraps(func) - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse: + def wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse: r = func(*args, **kwargs) - if isinstance(r, Awaitable): - return DatastarResponse(await r) + + # Check for async generator/iterator first (most specific case) + if hasattr(r, "__aiter__"): + return DatastarResponse(r) + + # Check for sync generator/iterator (before Awaitable to avoid false positives) + if hasattr(r, "__iter__") and not isinstance(r, (str, bytes)): + return DatastarResponse(r) + + # Check for coroutines/tasks (but NOT async generators, already handled above) + if isawaitable(r): + # Wrap awaitable in an async generator that yields the result + async def await_and_yield(): + yield await r + + return DatastarResponse(await_and_yield()) + + # Default case: single value or unknown type return DatastarResponse(r) wrapper.__annotations__["return"] = DatastarResponse diff --git a/tests/test_datastar_decorator_runtime.py b/tests/test_datastar_decorator_runtime.py new file mode 100644 index 0000000..2ecd905 --- /dev/null +++ b/tests/test_datastar_decorator_runtime.py @@ -0,0 +1,127 @@ +"""Runtime-focused tests for datastar_response decorators.""" + +from __future__ import annotations + +import importlib +import inspect +import threading +import time +from typing import Any + +import anyio +import httpx +import pytest +import uvicorn +from starlette.applications import Starlette +from starlette.responses import PlainTextResponse +from starlette.routing import Route + +from datastar_py.sse import ServerSentEventGenerator as SSE + + +@pytest.fixture +def anyio_backend() -> str: + """Limit anyio plugin to asyncio backend for these tests.""" + return "asyncio" + + +@pytest.mark.parametrize("module_path", ["datastar_py.starlette", "datastar_py.fasthtml"]) +@pytest.mark.parametrize( + "variant", + [ + "sync_value", + "sync_generator", + "async_value", + "async_generator", + ], +) +def test_decorator_returns_response_objects(module_path: str, variant: str) -> None: + """Decorated handlers should stay sync-callable and return DatastarResponse immediately.""" + + mod = importlib.import_module(module_path) + datastar_response = mod.datastar_response + DatastarResponse = mod.DatastarResponse + + if variant == "sync_value": + @datastar_response + def handler() -> Any: + return SSE.patch_signals({"ok": True}) + elif variant == "sync_generator": + @datastar_response + def handler() -> Any: + yield SSE.patch_signals({"ok": True}) + elif variant == "async_value": + @datastar_response + async def handler() -> Any: + return SSE.patch_signals({"ok": True}) + else: + @datastar_response + async def handler() -> Any: + yield SSE.patch_signals({"ok": True}) + + result = handler() + if inspect.iscoroutine(result): + result.close() # avoid "coroutine was never awaited" warnings + + assert not inspect.iscoroutinefunction(handler), "Decorator should preserve sync callable semantics" + assert isinstance(result, DatastarResponse) + + +async def _fetch( + client: httpx.AsyncClient, path: str, timings: dict[str, float], key: str +) -> None: + start = time.perf_counter() + resp = await client.get(path, timeout=5.0) + timings[key] = time.perf_counter() - start + resp.raise_for_status() + + +@pytest.mark.anyio("asyncio") +async def test_sync_handler_runs_off_event_loop() -> None: + """Sync routes should stay in the threadpool; otherwise they block the event loop.""" + + entered = threading.Event() + + from datastar_py.starlette import datastar_response + + @datastar_response + def slow(request) -> Any: # noqa: ANN001 + entered.set() + time.sleep(1.0) # if run on the event loop, this blocks other requests + return SSE.patch_signals({"slow": True}) + + async def ping(request) -> PlainTextResponse: # noqa: ANN001 + return PlainTextResponse("pong") + + app = Starlette(routes=[Route("/slow", slow), Route("/ping", ping)]) + + config = uvicorn.Config(app, host="127.0.0.1", port=0, log_level="warning", lifespan="off") + server = uvicorn.Server(config) + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + + try: + # Wait for server to start and expose sockets + for _ in range(50): + if server.started and getattr(server, "servers", None): + break + await anyio.sleep(0.05) + else: + pytest.fail("Server did not start") + + sock = server.servers[0].sockets[0] + host, port = sock.getsockname()[:2] + base_url = f"http://{host}:{port}" + + async with httpx.AsyncClient(base_url=base_url) as client: + timings: dict[str, float] = {} + async with anyio.create_task_group() as tg: + tg.start_soon(_fetch, client, "/slow", timings, "slow") + await anyio.to_thread.run_sync(entered.wait, 1.0) + tg.start_soon(_fetch, client, "/ping", timings, "ping") + + assert timings["slow"] >= 0.9 + assert timings["ping"] < 0.3, "Ping should not be blocked by slow sync handler" + finally: + server.should_exit = True + thread.join(timeout=2) diff --git a/tests/test_decorator_matrix.py b/tests/test_decorator_matrix.py new file mode 100644 index 0000000..852ed14 --- /dev/null +++ b/tests/test_decorator_matrix.py @@ -0,0 +1,109 @@ +"""Matrix tests for datastar_response across frameworks and callable types.""" + +from __future__ import annotations + +import importlib +import inspect +from typing import Any, Iterable + +import pytest + +from datastar_py.sse import ServerSentEventGenerator as SSE + + +FRAMEWORKS = [ + # name, module path, iterator attribute on response (None means use response directly) + ("starlette", "datastar_py.starlette", "body_iterator"), + ("fasthtml", "datastar_py.fasthtml", "body_iterator"), + ("fastapi", "datastar_py.fastapi", "body_iterator"), + ("litestar", "datastar_py.litestar", "iterator"), + ("django", "datastar_py.django", None), + # Quart and Sanic need full request contexts; covered elsewhere + ("quart", "datastar_py.quart", None), + ("sanic", "datastar_py.sanic", None), +] + + +@pytest.fixture +def anyio_backend() -> str: + """Limit anyio plugin to asyncio backend for these tests.""" + return "asyncio" + + +def _require_module(module_path: str) -> Any: + if not importlib.util.find_spec(module_path): + pytest.skip(f"{module_path} not installed") + return importlib.import_module(module_path) + + +async def _collect_events(resp: Any, iterator_attr: str | None) -> list[Any]: + """Gather events from response regardless of iterator style.""" + iterator = getattr(resp, iterator_attr) if iterator_attr else resp + events: list[Any] = [] + + if hasattr(iterator, "__aiter__"): + async for event in iterator: # type: ignore[has-type] + events.append(event) + elif isinstance(iterator, Iterable): + for event in iterator: + events.append(event) + else: + raise TypeError(f"Cannot iterate response events for {type(resp)}") + + return events + + +@pytest.mark.anyio +@pytest.mark.parametrize("framework_name,module_path,iterator_attr", FRAMEWORKS) +@pytest.mark.parametrize( + "variant", + ["sync_value", "sync_generator", "async_value", "async_generator"], +) +async def test_datastar_response_matrix( + framework_name: str, module_path: str, iterator_attr: str | None, variant: str +) -> None: + """Ensure decorator works for sync/async and generator/non-generator functions.""" + + if framework_name in {"quart", "sanic"}: + pytest.skip(f"{framework_name} decorator requires full request context to exercise") + if framework_name == "django": + from django.conf import settings + + if not settings.configured: + settings.configure(DEFAULT_CHARSET="utf-8") + if variant == "async_generator": + pytest.skip("Django adapter does not support async generators yet") + + mod = _require_module(module_path) + datastar_response = mod.datastar_response + DatastarResponse = mod.DatastarResponse + + if variant == "sync_value": + @datastar_response + def handler() -> Any: + return SSE.patch_signals({"ok": True}) + elif variant == "sync_generator": + @datastar_response + def handler() -> Any: + yield SSE.patch_signals({"ok": True}) + elif variant == "async_value": + @datastar_response + async def handler() -> Any: + return SSE.patch_signals({"ok": True}) + else: + @datastar_response + async def handler() -> Any: + yield SSE.patch_signals({"ok": True}) + + result = handler() + try: + if inspect.isawaitable(result): + result = await result + + assert isinstance(result, DatastarResponse) + events = await _collect_events(result, iterator_attr) + assert events, "Expected at least one event from response iterator" + finally: + # Avoid "coroutine was never awaited" warnings when assertions fail + if inspect.iscoroutine(result): + result.close() diff --git a/tests/test_fastapi_decorator_integration.py b/tests/test_fastapi_decorator_integration.py new file mode 100644 index 0000000..0ee3440 --- /dev/null +++ b/tests/test_fastapi_decorator_integration.py @@ -0,0 +1,120 @@ +"""Integration test: datastar_response within a live FastAPI app.""" + +from __future__ import annotations + +import threading +import time +from typing import Any + +import anyio +import httpx +import pytest +import uvicorn +from fastapi import FastAPI +from starlette.responses import PlainTextResponse + +from datastar_py.sse import ServerSentEventGenerator as SSE +from datastar_py.fastapi import datastar_response + + +@pytest.fixture +def anyio_backend() -> str: + return "asyncio" + + +async def _fetch(client: httpx.AsyncClient, path: str) -> httpx.Response: + resp = await client.get(path, timeout=5.0) + resp.raise_for_status() + return resp + + +@pytest.mark.anyio +async def test_fastapi_handlers_cover_matrix() -> None: + """Ensure FastAPI handlers across sync/async and gen/value work end-to-end.""" + + entered = threading.Event() + app = FastAPI() + + @app.get("/sync-value") + @datastar_response + def sync_value() -> Any: + entered.set() + time.sleep(0.2) # should run in threadpool + return SSE.patch_signals({"src": "sync_value"}) + + @app.get("/sync-generator") + @datastar_response + def sync_gen() -> Any: + yield SSE.patch_signals({"src": "sync_generator", "idx": 1}) + yield SSE.patch_signals({"src": "sync_generator", "idx": 2}) + + @app.get("/async-value") + @datastar_response + async def async_value() -> Any: + return SSE.patch_signals({"src": "async_value"}) + + @app.get("/async-generator") + @datastar_response + async def async_gen() -> Any: + yield SSE.patch_signals({"src": "async_generator", "idx": 1}) + yield SSE.patch_signals({"src": "async_generator", "idx": 2}) + + @app.get("/ping") + async def ping() -> PlainTextResponse: + return PlainTextResponse("pong") + + config = uvicorn.Config(app, host="127.0.0.1", port=0, log_level="warning", lifespan="off") + server = uvicorn.Server(config) + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + + try: + for _ in range(50): + if server.started and getattr(server, "servers", None): + break + await anyio.sleep(0.05) + else: + pytest.fail("Server did not start") + + sock = server.servers[0].sockets[0] + host, port = sock.getsockname()[:2] + base_url = f"http://{host}:{port}" + + async with httpx.AsyncClient(base_url=base_url) as client: + # Concurrency sanity: sync_value should not stall ping + async with anyio.create_task_group() as tg: + slow_resp: httpx.Response | None = None + ping_resp: httpx.Response | None = None + + async def hit_slow(): + nonlocal slow_resp + slow_resp = await _fetch(client, "/sync-value") + + async def hit_ping(): + nonlocal ping_resp + await anyio.to_thread.run_sync(entered.wait, 1.0) + ping_resp = await _fetch(client, "/ping") + + tg.start_soon(hit_slow) + tg.start_soon(hit_ping) + + assert slow_resp is not None and slow_resp.status_code == 200 + assert ping_resp is not None and ping_resp.status_code == 200 + assert float(ping_resp.elapsed.total_seconds()) < 0.35 + + sync_value_body = (await _fetch(client, "/sync-value")).text + assert '"src":"sync_value"' in sync_value_body + + sync_gen_body = (await _fetch(client, "/sync-generator")).text + assert '"src":"sync_generator"' in sync_gen_body + assert '"idx":1' in sync_gen_body and '"idx":2' in sync_gen_body + + async_value_body = (await _fetch(client, "/async-value")).text + assert '"src":"async_value"' in async_value_body + + async_gen_body = (await _fetch(client, "/async-generator")).text + assert '"src":"async_generator"' in async_gen_body + assert '"idx":1' in async_gen_body and '"idx":2' in async_gen_body + finally: + server.should_exit = True + thread.join(timeout=2) diff --git a/tests/test_fasthtml_decorator_integration.py b/tests/test_fasthtml_decorator_integration.py new file mode 100644 index 0000000..3645c48 --- /dev/null +++ b/tests/test_fasthtml_decorator_integration.py @@ -0,0 +1,118 @@ +"""Integration test: datastar_response within a live FastHTML app.""" + +from __future__ import annotations + +import threading +import time +from typing import Any + +import anyio +import httpx +import pytest +import uvicorn +from fasthtml.common import fast_app +from starlette.responses import PlainTextResponse + +from datastar_py.sse import ServerSentEventGenerator as SSE +from datastar_py.fasthtml import datastar_response + + +@pytest.fixture +def anyio_backend() -> str: + return "asyncio" + + +async def _fetch(client: httpx.AsyncClient, path: str) -> httpx.Response: + resp = await client.get(path, timeout=5.0) + resp.raise_for_status() + return resp + + +@pytest.mark.anyio +async def test_fasthtml_sync_and_streaming_handlers() -> None: + """Ensure FastHTML routes across sync/async and gen/value work end-to-end.""" + + entered = threading.Event() + + app, rt = fast_app(htmx=False, live=False) + + @rt("/slow") + @datastar_response + def slow(request) -> Any: # noqa: ANN001 + entered.set() + time.sleep(0.2) # should not block event loop for other requests + return SSE.patch_signals({"src": "sync_value"}) + + @rt("/sync-generator") + @datastar_response + def sync_gen(request) -> Any: # noqa: ANN001 + yield SSE.patch_signals({"src": "sync_generator", "idx": 1}) + yield SSE.patch_signals({"src": "sync_generator", "idx": 2}) + + @rt("/stream") + @datastar_response + async def async_gen(request) -> Any: # noqa: ANN001 + yield SSE.patch_signals({"src": "async_generator", "idx": 1}) + yield SSE.patch_signals({"src": "async_generator", "idx": 2}) + + @rt("/async-value") + @datastar_response + async def async_value(request) -> Any: # noqa: ANN001 + return SSE.patch_signals({"src": "async_value"}) + + @rt("/ping") + async def ping(request) -> PlainTextResponse: # noqa: ANN001 + return PlainTextResponse("pong") + + config = uvicorn.Config(app, host="127.0.0.1", port=0, log_level="warning", lifespan="off") + server = uvicorn.Server(config) + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + + try: + for _ in range(50): + if server.started and getattr(server, "servers", None): + break + await anyio.sleep(0.05) + else: + pytest.fail("Server did not start") + + sock = server.servers[0].sockets[0] + host, port = sock.getsockname()[:2] + base_url = f"http://{host}:{port}" + + async with httpx.AsyncClient(base_url=base_url) as client: + async with anyio.create_task_group() as tg: + slow_resp: httpx.Response | None = None + ping_resp: httpx.Response | None = None + + async def hit_slow(): + nonlocal slow_resp + slow_resp = await _fetch(client, "/slow") + + async def hit_ping(): + nonlocal ping_resp + await anyio.to_thread.run_sync(entered.wait, 1.0) + ping_resp = await _fetch(client, "/ping") + + tg.start_soon(hit_slow) + tg.start_soon(hit_ping) + + assert slow_resp is not None and slow_resp.status_code == 200 + assert ping_resp is not None and ping_resp.status_code == 200 + assert float(ping_resp.elapsed.total_seconds()) < 0.3 + + sync_gen_body = (await _fetch(client, "/sync-generator")).text + assert '"src":"sync_generator"' in sync_gen_body + assert '"idx":1' in sync_gen_body and '"idx":2' in sync_gen_body + + async_value_body = (await _fetch(client, "/async-value")).text + assert '"src":"async_value"' in async_value_body + + stream = await _fetch(client, "/stream") + body = stream.text + assert '"src":"async_generator"' in body + assert '"idx":1' in body and '"idx":2' in body + finally: + server.should_exit = True + thread.join(timeout=2) diff --git a/tests/test_litestar_decorator_integration.py b/tests/test_litestar_decorator_integration.py new file mode 100644 index 0000000..1529d7f --- /dev/null +++ b/tests/test_litestar_decorator_integration.py @@ -0,0 +1,120 @@ +"""Integration test: datastar_response within a live Litestar app.""" + +from __future__ import annotations + +import threading +import time +from typing import Any + +import anyio +import httpx +import pytest +import uvicorn +from litestar import Litestar, get +from starlette.responses import PlainTextResponse + +from datastar_py.sse import ServerSentEventGenerator as SSE +from datastar_py.litestar import datastar_response + + +@pytest.fixture +def anyio_backend() -> str: + return "asyncio" + + +async def _fetch(client: httpx.AsyncClient, path: str) -> httpx.Response: + resp = await client.get(path, timeout=5.0) + resp.raise_for_status() + return resp + + +@pytest.mark.anyio +async def test_litestar_handlers_cover_matrix() -> None: + """Ensure Litestar handlers across sync/async and gen/value work end-to-end.""" + + entered = threading.Event() + + @get("/sync-value") + @datastar_response + def sync_value() -> Any: + entered.set() + time.sleep(0.2) + return SSE.patch_signals({"src": "sync_value"}) + + @get("/sync-generator") + @datastar_response + def sync_gen() -> Any: + yield SSE.patch_signals({"src": "sync_generator", "idx": 1}) + yield SSE.patch_signals({"src": "sync_generator", "idx": 2}) + + @get("/async-value") + @datastar_response + async def async_value() -> Any: + return SSE.patch_signals({"src": "async_value"}) + + @get("/async-generator") + @datastar_response + async def async_gen() -> Any: + yield SSE.patch_signals({"src": "async_generator", "idx": 1}) + yield SSE.patch_signals({"src": "async_generator", "idx": 2}) + + @get("/ping") + async def ping() -> PlainTextResponse: + return PlainTextResponse("pong") + + app = Litestar(route_handlers=[sync_value, sync_gen, async_value, async_gen, ping]) + + config = uvicorn.Config(app, host="127.0.0.1", port=0, log_level="warning", lifespan="off") + server = uvicorn.Server(config) + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + + try: + for _ in range(50): + if server.started and getattr(server, "servers", None): + break + await anyio.sleep(0.05) + else: + pytest.fail("Server did not start") + + sock = server.servers[0].sockets[0] + host, port = sock.getsockname()[:2] + base_url = f"http://{host}:{port}" + + async with httpx.AsyncClient(base_url=base_url) as client: + async with anyio.create_task_group() as tg: + slow_resp: httpx.Response | None = None + ping_resp: httpx.Response | None = None + + async def hit_slow(): + nonlocal slow_resp + slow_resp = await _fetch(client, "/sync-value") + + async def hit_ping(): + nonlocal ping_resp + await anyio.to_thread.run_sync(entered.wait, 1.0) + ping_resp = await _fetch(client, "/ping") + + tg.start_soon(hit_slow) + tg.start_soon(hit_ping) + + assert slow_resp is not None and slow_resp.status_code == 200 + assert ping_resp is not None and ping_resp.status_code == 200 + assert float(ping_resp.elapsed.total_seconds()) < 0.35 + + sync_value_body = (await _fetch(client, "/sync-value")).text + assert '"src":"sync_value"' in sync_value_body + + sync_gen_body = (await _fetch(client, "/sync-generator")).text + assert '"src":"sync_generator"' in sync_gen_body + assert '"idx":1' in sync_gen_body and '"idx":2' in sync_gen_body + + async_value_body = (await _fetch(client, "/async-value")).text + assert '"src":"async_value"' in async_value_body + + async_gen_body = (await _fetch(client, "/async-generator")).text + assert '"src":"async_generator"' in async_gen_body + assert '"idx":1' in async_gen_body and '"idx":2' in async_gen_body + finally: + server.should_exit = True + thread.join(timeout=2) diff --git a/tests/test_starlette_decorator_integration.py b/tests/test_starlette_decorator_integration.py new file mode 100644 index 0000000..fccdb94 --- /dev/null +++ b/tests/test_starlette_decorator_integration.py @@ -0,0 +1,127 @@ +"""Integration test: datastar_response within a live Starlette app.""" + +from __future__ import annotations + +import threading +import time +from typing import Any + +import anyio +import httpx +import pytest +import uvicorn +from starlette.applications import Starlette +from starlette.responses import PlainTextResponse +from starlette.routing import Route + +from datastar_py.sse import ServerSentEventGenerator as SSE +from datastar_py.starlette import datastar_response + + +@pytest.fixture +def anyio_backend() -> str: + return "asyncio" + + +async def _fetch(client: httpx.AsyncClient, path: str) -> httpx.Response: + resp = await client.get(path, timeout=5.0) + resp.raise_for_status() + return resp + + +@pytest.mark.anyio +async def test_starlette_sync_handler_runs_in_threadpool_and_streams() -> None: + """Ensure all handler shapes work end-to-end and sync stays in threadpool.""" + + entered = threading.Event() + + @datastar_response + def sync_value(request) -> Any: # noqa: ANN001 + entered.set() + time.sleep(0.2) # should not block event loop + return SSE.patch_signals({"src": "sync_value"}) + + @datastar_response + def sync_gen(request) -> Any: # noqa: ANN001 + yield SSE.patch_signals({"src": "sync_generator", "idx": 1}) + yield SSE.patch_signals({"src": "sync_generator", "idx": 2}) + + @datastar_response + async def async_value(request) -> Any: # noqa: ANN001 + return SSE.patch_signals({"src": "async_value"}) + + @datastar_response + async def async_gen(request) -> Any: # noqa: ANN001 + yield SSE.patch_signals({"src": "async_generator", "idx": 1}) + yield SSE.patch_signals({"src": "async_generator", "idx": 2}) + + async def ping(request) -> PlainTextResponse: # noqa: ANN001 + return PlainTextResponse("pong") + + app = Starlette( + routes=[ + Route("/sync-value", sync_value), + Route("/sync-generator", sync_gen), + Route("/async-value", async_value), + Route("/async-generator", async_gen), + Route("/ping", ping), + ] + ) + + config = uvicorn.Config(app, host="127.0.0.1", port=0, log_level="warning", lifespan="off") + server = uvicorn.Server(config) + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + + try: + for _ in range(50): + if server.started and getattr(server, "servers", None): + break + await anyio.sleep(0.05) + else: + pytest.fail("Server did not start") + + sock = server.servers[0].sockets[0] + host, port = sock.getsockname()[:2] + base_url = f"http://{host}:{port}" + + async with httpx.AsyncClient(base_url=base_url) as client: + # Verify blocking sync handler doesn't stall other requests + # Concurrency sanity: sync_value blocks 0.2s but should not stall ping + async with anyio.create_task_group() as tg: + slow_resp: httpx.Response | None = None + ping_resp: httpx.Response | None = None + + async def hit_slow(): + nonlocal slow_resp + slow_resp = await _fetch(client, "/sync-value") + + async def hit_ping(): + nonlocal ping_resp + await anyio.to_thread.run_sync(entered.wait, 1.0) + ping_resp = await _fetch(client, "/ping") + + tg.start_soon(hit_slow) + tg.start_soon(hit_ping) + + assert slow_resp is not None and slow_resp.status_code == 200 + assert ping_resp is not None and ping_resp.status_code == 200 + assert float(ping_resp.elapsed.total_seconds()) < 0.35 + + # Verify content of each endpoint + sync_value_body = (await _fetch(client, "/sync-value")).text + assert '"src":"sync_value"' in sync_value_body + + sync_gen_body = (await _fetch(client, "/sync-generator")).text + assert '"src":"sync_generator"' in sync_gen_body + assert '"idx":1' in sync_gen_body and '"idx":2' in sync_gen_body + + async_value_body = (await _fetch(client, "/async-value")).text + assert '"src":"async_value"' in async_value_body + + async_gen_body = (await _fetch(client, "/async-generator")).text + assert '"src":"async_generator"' in async_gen_body + assert '"idx":1' in async_gen_body and '"idx":2' in async_gen_body + finally: + server.should_exit = True + thread.join(timeout=2)