Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions fastapi_utilities/repeat/repeat_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def repeat_at(
logger: logging.Logger = None,
raise_exceptions: bool = False,
max_repetitions: int = None,
on_startup = False,
) -> typing.Callable[[_FuncType], _FuncType]:
"""
Decorator to schedule a function's execution based on a cron expression.
Expand All @@ -39,6 +40,8 @@ def repeat_at(
Whether to raise exceptions or log them.
max_repetitions: int (default None)
Maximum number of times to repeat the function. If None, repeats indefinitely.
on_startup: bool (default False)
Whether execute on application startup
"""

def decorator(func: _FuncType) -> _FuncType:
Expand All @@ -50,6 +53,10 @@ async def async_wrapper(*args, **kwargs):
if not croniter.is_valid(cron):
raise ValueError(f"Invalid cron expression: '{cron}'")

if on_startup:
await func(*args, **kwargs)
repetitions += 1

while max_repetitions is None or repetitions < max_repetitions:
try:
sleep_time = get_delta(cron)
Expand All @@ -70,6 +77,11 @@ def sync_wrapper(*args, **kwargs):

async def loop():
nonlocal repetitions

if on_startup:
await func(*args, **kwargs)
repetitions += 1

while max_repetitions is None or repetitions < max_repetitions:
try:
sleep_time = get_delta(cron)
Expand Down
17 changes: 17 additions & 0 deletions tests/test_repeat_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,20 @@ def raise_exc():
out, err = capsys.readouterr()
assert out == ""
assert err == ""

@pytest.mark.asyncio
async def test_repeat_at_on_startup(capsys: CaptureFixture[str]):
"""
Test repeat_at with on_startup=True function
"""
@repeat_at(cron="0 * * * *", max_repetitions=1, on_startup=True)
async def hello():
print("executed")

asyncio.create_task(hello())

await asyncio.sleep(0.1)
out, err = capsys.readouterr()
assert "executed" in out
assert err == ""