diff --git a/fastapi_utilities/repeat/repeat_at.py b/fastapi_utilities/repeat/repeat_at.py index 97cf77c..6775e27 100644 --- a/fastapi_utilities/repeat/repeat_at.py +++ b/fastapi_utilities/repeat/repeat_at.py @@ -25,6 +25,7 @@ def repeat_at( logger: logging.Logger = None, raise_exceptions: bool = False, max_repetitions: int = None, + on_startup = False, ) -> typing.Callable[[_FuncType], _FuncType]: """ Decorator to schedule a function's execution based on a cron expression. @@ -39,6 +40,8 @@ def repeat_at( Whether to raise exceptions or log them. max_repetitions: int (default None) Maximum number of times to repeat the function. If None, repeats indefinitely. + on_startup: bool (default False) + Whether execute on application startup """ def decorator(func: _FuncType) -> _FuncType: @@ -50,6 +53,10 @@ async def async_wrapper(*args, **kwargs): if not croniter.is_valid(cron): raise ValueError(f"Invalid cron expression: '{cron}'") + if on_startup: + await func(*args, **kwargs) + repetitions += 1 + while max_repetitions is None or repetitions < max_repetitions: try: sleep_time = get_delta(cron) @@ -70,6 +77,11 @@ def sync_wrapper(*args, **kwargs): async def loop(): nonlocal repetitions + + if on_startup: + await func(*args, **kwargs) + repetitions += 1 + while max_repetitions is None or repetitions < max_repetitions: try: sleep_time = get_delta(cron) diff --git a/tests/test_repeat_at.py b/tests/test_repeat_at.py index 0bbbe42..28a496f 100644 --- a/tests/test_repeat_at.py +++ b/tests/test_repeat_at.py @@ -115,3 +115,20 @@ def raise_exc(): out, err = capsys.readouterr() assert out == "" assert err == "" + +@pytest.mark.asyncio +async def test_repeat_at_on_startup(capsys: CaptureFixture[str]): + """ + Test repeat_at with on_startup=True function + """ + @repeat_at(cron="0 * * * *", max_repetitions=1, on_startup=True) + async def hello(): + print("executed") + + asyncio.create_task(hello()) + + await asyncio.sleep(0.1) + out, err = capsys.readouterr() + assert "executed" in out + assert err == "" +