Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/reference/task-config.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# `TaskConfig` model reference

::: sheppy.models.Config
::: sheppy.models.TaskConfig
2 changes: 1 addition & 1 deletion docs/reference/task-spec.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# `TaskSpec` model reference

::: sheppy.models.Spec
::: sheppy.models.TaskSpec
4 changes: 2 additions & 2 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ nav:
- CLI reference: reference/cli.md
- Models:
- Task: reference/task.md
- Spec: reference/task-spec.md
- Config: reference/task-config.md
- TaskSpec: reference/task-spec.md
- TaskConfig: reference/task-config.md
- TaskCron: reference/task-cron.md
- Backends:
- Backend class: reference/backends/backend.md
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ sheppy = "sheppy.cli.cli:app"

[tool.mypy]
strict = true
plugins = ['pydantic.mypy']

[tool.ruff.lint]
select = [
Expand Down
4 changes: 2 additions & 2 deletions src/sheppy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .backend import Backend, BackendError, MemoryBackend, RedisBackend
from .models import Task
from .models import CURRENT_TASK, Task
from .queue import Queue
from .task_factory import task
from .testqueue import TestQueue
Expand All @@ -12,7 +12,7 @@
# fastapi
"Depends",
# task
"task", "Task",
"task", "Task", "CURRENT_TASK",
# queue
"Queue",
# testqueue
Expand Down
31 changes: 17 additions & 14 deletions src/sheppy/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@

TASK_CRON_NS = UUID('7005b432-c135-4131-b19e-d3dc89703a9a')

# sentinel object for current task injection (def my_task(x: int, task: Task = CURRENT_TASK): ...)
CURRENT_TASK = object()


def cron_expression_validator(value: str) -> str:
if not croniter.is_valid(value):
Expand All @@ -36,7 +39,7 @@ def cron_expression_validator(value: str) -> str:
CronExpression = Annotated[str, AfterValidator(cron_expression_validator)]


class Spec(BaseModel):
class TaskSpec(BaseModel):
"""Task specification.

Attributes:
Expand All @@ -47,7 +50,7 @@ class Spec(BaseModel):
middleware (list[str]|None): List of fully qualified middleware function names to be applied to the task, e.g. `['my_module.submodule:my_middleware']`. Middleware will be applied in the order they are listed.

Note:
- You should not create Spec instances directly. Instead, use the `@task` decorator to define a task function, and then call that function to create a Task instance.
- You should not create TaskSpec instances directly. Instead, use the `@task` decorator to define a task function, and then call that function to create a Task instance.
- `args` and `kwargs` must be JSON serializable.

Example:
Expand All @@ -65,7 +68,7 @@ def my_task(x: int, y: str) -> str:
print(t.spec.return_type) # "builtins.str"
```
"""
model_config = ConfigDict(frozen=True)
model_config = ConfigDict(frozen=True, extra="forbid")

func: str
"""str: Fully qualified function name, e.g. `my_module.my_submodule:my_function`"""
Expand All @@ -79,15 +82,15 @@ def my_task(x: int, y: str) -> str:
"""list[str]|None: List of fully qualified middleware function names to be applied to the task, e.g. `['my_module.submodule:my_middleware']`. Middleware will be applied in the order they are listed."""


class Config(BaseModel):
class TaskConfig(BaseModel):
"""Task configuration

Attributes:
retry (int): Number of times to retry the task if it fails. Default is 0 (no retries).
retry_delay (float|list[float]): Delay between retries in seconds. If a single float is provided, it will be used for all retries. If a list is provided, it will be used for each retry attempt respectively (exponential backoff). Default is 1.0 seconds.

Note:
- You should not create Config instances directly. Instead, use the `@task` decorator to define a task function, and then call that function to create a Task instance.
- You should not create TaskConfig instances directly. Instead, use the `@task` decorator to define a task function, and then call that function to create a Task instance.
- `retry` must be a non-negative integer.
- `retry_delay` must be a positive float or a list of positive floats.

Expand All @@ -104,7 +107,7 @@ def my_task():
print(t.config.retry_delay) # [1.0, 2.0, 3.0]
```
"""
model_config = ConfigDict(frozen=True)
model_config = ConfigDict(frozen=True, extra="forbid")

retry: int = Field(default=0, ge=0)
"""int: Number of times to retry the task if it fails. Default is 0 (no retries)."""
Expand All @@ -130,8 +133,8 @@ class Task(BaseModel):
completed (bool): A completion flag that is set to True only if task finished successfully.
error (str|None): Error message if the task failed. None if the task succeeded or is not yet executed.
result (Any): The result of the task execution. If the task failed, this will be None.
spec (sheppy.models.Spec): Task specification
config (sheppy.models.Config): Task configuration
spec (sheppy.models.TaskSpec): Task specification
config (sheppy.models.TaskConfig): Task configuration
created_at (datetime): Timestamp when the task was created.
finished_at (datetime|None): Timestamp when the task was finished. None if the task is not yet finished.
scheduled_at (datetime|None): Timestamp when the task is scheduled to run. None if the task is not scheduled.
Expand Down Expand Up @@ -171,9 +174,9 @@ def add(x: int, y: int) -> int:
result: Any = None
"""Any: The result of the task execution. This will be None if the task failed or is not yet executed."""

spec: Spec
spec: TaskSpec
"""Task specification"""
config: Config = Field(default_factory=Config)
config: TaskConfig = Field(default_factory=TaskConfig)
"""Task configuration"""

created_at: AwareDatetime = Field(default_factory=lambda: datetime.now(timezone.utc))
Expand Down Expand Up @@ -244,8 +247,8 @@ class TaskCron(BaseModel):
Attributes:
id (UUID): Unique identifier for the cron definition.
expression (str): Cron expression defining the schedule, e.g. "*/5 * * * *" for every 5 minutes.
spec (sheppy.models.Spec): Task specification
config (sheppy.models.Config): Task configuration
spec (sheppy.models.TaskSpec): Task specification
config (sheppy.models.TaskConfig): Task configuration

Note:
- You should not create TaskCron instances directly. Instead, use the `add_cron` method of the Queue class to create a cron definition.
Expand Down Expand Up @@ -283,9 +286,9 @@ def say_hello(to: str) -> str:
expression: CronExpression
"""str: Cron expression defining the schedule, e.g. "*/5 * * * *" for every 5 minutes."""

spec: Spec
spec: TaskSpec
"""Task specification"""
config: Config
config: TaskConfig
"""Task configuration"""

# enabled: bool = True
Expand Down
2 changes: 1 addition & 1 deletion src/sheppy/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ async def get_crons(self) -> list[TaskCron]:
crons = await q.get_crons()

for cron in crons:
print(f"Cron ID: {cron.id}, Expression: {cron.expression}, Task Spec: {cron.spec}")
print(f"Cron ID: {cron.id}, Expression: {cron.expression}, TaskSpec: {cron.spec}")
```
"""
await self.__ensure_backend_is_connected()
Expand Down
6 changes: 3 additions & 3 deletions src/sheppy/task_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
overload,
)

from .models import Config, Spec, Task, TaskCron
from .models import Task, TaskConfig, TaskCron, TaskSpec
from .utils.validation import validate_input

P = ParamSpec('P')
Expand Down Expand Up @@ -86,14 +86,14 @@ def create_task(func: Callable[..., Any],
stringified_middlewares.append(TaskFactory._stringify_function(m))

_task = Task(
spec=Spec(
spec=TaskSpec(
func=func_string,
args=args,
kwargs=kwargs,
return_type=return_type,
middleware=stringified_middlewares
),
config=Config(**task_config)
config=TaskConfig(**task_config)
)

return _task
Expand Down
2 changes: 1 addition & 1 deletion src/sheppy/testqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def get_crons(self) -> list[TaskCron]:
crons = q.get_crons()

for cron in crons:
print(f"Cron ID: {cron.id}, Expression: {cron.expression}, Task Spec: {cron.spec}")
print(f"Cron ID: {cron.id}, Expression: {cron.expression}, TaskSpec: {cron.spec}")
```
"""
return asyncio.run(self._queue.get_crons())
Expand Down
18 changes: 8 additions & 10 deletions src/sheppy/utils/task_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import anyio
from pydantic import ConfigDict, PydanticSchemaGenerationError, TypeAdapter

from ..models import Task
from ..models import CURRENT_TASK, Task
from .fastapi import Depends

cache_main_module: str | None = None
Expand Down Expand Up @@ -214,9 +214,13 @@ async def process_function_parameters(
remaining_args = list(args)

for param_name, param in list(signature.parameters.items()):
# Task injection (self: Task)
# current Task injection (current: Task = CURRENT_TASK)
if task and TaskProcessor._is_task_injection(param):
final_args.append(task)
# inject positionally for positional params to maintain correct order
if param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD):
final_args.append(task)
else:
final_kwargs[param_name] = task
continue

# validate positional args
Expand All @@ -237,13 +241,7 @@ async def process_function_parameters(

@staticmethod
def _is_task_injection(param: inspect.Parameter) -> bool:
if param.name != 'self':
return False

if param.annotation == inspect.Parameter.empty:
return False

return param.annotation is Task or param.annotation == 'Task'
return param.default is CURRENT_TASK


@staticmethod
Expand Down
12 changes: 3 additions & 9 deletions src/sheppy/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from pydantic import TypeAdapter, ValidationError

from sheppy.models import Task
from sheppy.models import CURRENT_TASK

from .fastapi import Depends

Expand Down Expand Up @@ -39,11 +39,8 @@ def validate_input(

for param_name, param in signature.parameters.items():

# task self injection
if _is_task_injection(param):
if param.default != inspect.Parameter.empty:
raise ValidationError.from_exception_data(
f"Task injection parameter '{param_name}' cannot have a default value", line_errors=[]
)
if param_name in remaining_kwargs:
raise ValidationError.from_exception_data(
f"Cannot provide value for Task injection parameter '{param_name}'", line_errors=[]
Expand Down Expand Up @@ -109,10 +106,7 @@ def validate_input(


def _is_task_injection(param: inspect.Parameter) -> bool:
if param.name != 'self':
return False

return param.annotation is Task or param.annotation == 'Task'
return param.default is CURRENT_TASK


def _is_depends_parameter(param: inspect.Parameter) -> bool:
Expand Down
10 changes: 5 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import pytest_asyncio

from sheppy import Queue, Task, Worker, task
from sheppy import CURRENT_TASK, Queue, Task, Worker, task
from sheppy.backend import Backend, MemoryBackend, RedisBackend

TEST_QUEUE_NAME = "pytest"
Expand Down Expand Up @@ -67,15 +67,15 @@ async def queue(backend: Backend) -> Queue:


@task(retry=2, retry_delay=0.1)
async def async_fail_once(self: Task) -> str:
if self.retry_count == 0:
async def async_fail_once(current: Task = CURRENT_TASK) -> str:
if current.retry_count == 0:
raise Exception("transient error")
return "ok"


@task(retry=2, retry_delay=0)
def sync_fail_once(self: Task) -> str:
if self.retry_count == 0:
def sync_fail_once(current: Task = CURRENT_TASK) -> str:
if current.retry_count == 0:
raise Exception("transient error")
return "ok"

Expand Down
4 changes: 2 additions & 2 deletions tests/contract/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic import ValidationError

from sheppy import Task, task
from sheppy.models import Config
from sheppy.models import TaskConfig
from tests.dependencies import (
Status,
failing_task,
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_task_is_frozen(task_fn):
task.config.retry = 5.5

with pytest.raises(ValidationError, match="Instance is frozen"):
task.config = Config()
task.config = TaskConfig()

with pytest.raises(TypeError, match="does not support item assignment"):
task.spec.args[0] = 5
Expand Down
6 changes: 3 additions & 3 deletions tests/contract/test_taskcron.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from pydantic import ValidationError

from sheppy.models import Config, Spec, TaskCron
from sheppy.models import TaskConfig, TaskCron, TaskSpec


@pytest.mark.parametrize("cron_expression", [
Expand All @@ -26,7 +26,7 @@
])
def test_valid_cron_expressions(cron_expression):

TaskCron(expression=cron_expression, spec=Spec(func=""), config=Config())
TaskCron(expression=cron_expression, spec=TaskSpec(func=""), config=TaskConfig())


@pytest.mark.parametrize("cron_expression", [
Expand All @@ -50,7 +50,7 @@ def test_valid_cron_expressions(cron_expression):
def test_invalid_cron_expressions(cron_expression):

with pytest.raises(ValidationError):
TaskCron(expression=cron_expression, spec=Spec(func=""), config=Config())
TaskCron(expression=cron_expression, spec=TaskSpec(func=""), config=TaskConfig())


@pytest.mark.parametrize("expression,spec,config", [
Expand Down
20 changes: 10 additions & 10 deletions tests/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from pydantic import BaseModel

from sheppy import Depends, Task, task
from sheppy import CURRENT_TASK, Depends, Task, task


class User(BaseModel):
Expand Down Expand Up @@ -139,18 +139,18 @@ async def async_task_with_pydantic_model(user: User) -> Status:


@task
def task_with_self(self: Task, x: int, y: int) -> dict[str, Any]:
return {"task_id": self.id, "result": x + y}
def task_with_current_task(current: Task = CURRENT_TASK, x: int = 5, y: int = 6) -> dict[str, Any]:
return {"task_id": current.id, "result": x + y}


@task
def task_with_self_middle(x: int, self: Task, y: int) -> dict[str, Any]:
return {"task_id": self.id, "result": x + y}
def task_with_current_task_middle(x: int, current: Task = CURRENT_TASK, y: int = 7) -> dict[str, Any]:
return {"task_id": current.id, "result": x + y}


@task
def task_with_self_end(x: int, y: int, self: Task) -> dict[str, Any]:
return {"task_id": self.id, "result": x + y}
def task_with_current_task_end(x: int, y: int, current: Task = CURRENT_TASK) -> dict[str, Any]:
return {"task_id": current.id, "result": x + y}


@task
Expand Down Expand Up @@ -482,9 +482,9 @@ def deep_recursion_tasks() -> list[TaskTestCase]:
def self_referencing_tasks() -> list[TaskTestCase]:
"""Tasks that should fail."""
return [
TaskTestCase("task_with_self", task_with_self, (22, 33), expected_result=55),
TaskTestCase("task_with_self_middle", task_with_self_middle, (22, 33), expected_result=55),
TaskTestCase("task_with_self_end", task_with_self_end, (22, 33), expected_result=55),
TaskTestCase("task_with_current_task", task_with_current_task, (), {"x": 22, "y": 33}, expected_result=55),
TaskTestCase("task_with_current_task_middle", task_with_current_task_middle, (22, ), {"y": 33}, expected_result=55),
TaskTestCase("task_with_current_task_end", task_with_current_task_end, (22, 33), expected_result=55),
]

@staticmethod
Expand Down