Skip to content

Commit ef63147

Browse files
committed
refactor: replace 'self: Task' magic with 'CURRENT_TASK' to fix issues with IDE typing
1 parent 54ed32d commit ef63147

File tree

6 files changed

+31
-36
lines changed

6 files changed

+31
-36
lines changed

src/sheppy/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .backend import Backend, BackendError, MemoryBackend, RedisBackend
2-
from .models import Task
2+
from .models import CURRENT_TASK, Task
33
from .queue import Queue
44
from .task_factory import task
55
from .testqueue import TestQueue
@@ -12,7 +12,7 @@
1212
# fastapi
1313
"Depends",
1414
# task
15-
"task", "Task",
15+
"task", "Task", "CURRENT_TASK",
1616
# queue
1717
"Queue",
1818
# testqueue

src/sheppy/models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626

2727
TASK_CRON_NS = UUID('7005b432-c135-4131-b19e-d3dc89703a9a')
2828

29+
# sentinel object for current task injection (def my_task(x: int, task: Task = CURRENT_TASK): ...)
30+
CURRENT_TASK = object()
31+
2932

3033
def cron_expression_validator(value: str) -> str:
3134
if not croniter.is_valid(value):

src/sheppy/utils/task_execution.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import anyio
1616
from pydantic import ConfigDict, PydanticSchemaGenerationError, TypeAdapter
1717

18-
from ..models import Task
18+
from ..models import CURRENT_TASK, Task
1919
from .fastapi import Depends
2020

2121
cache_main_module: str | None = None
@@ -214,9 +214,13 @@ async def process_function_parameters(
214214
remaining_args = list(args)
215215

216216
for param_name, param in list(signature.parameters.items()):
217-
# Task injection (self: Task)
217+
# current Task injection (current: Task = CURRENT_TASK)
218218
if task and TaskProcessor._is_task_injection(param):
219-
final_args.append(task)
219+
# inject positionally for positional params to maintain correct order
220+
if param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD):
221+
final_args.append(task)
222+
else:
223+
final_kwargs[param_name] = task
220224
continue
221225

222226
# validate positional args
@@ -237,13 +241,7 @@ async def process_function_parameters(
237241

238242
@staticmethod
239243
def _is_task_injection(param: inspect.Parameter) -> bool:
240-
if param.name != 'self':
241-
return False
242-
243-
if param.annotation == inspect.Parameter.empty:
244-
return False
245-
246-
return param.annotation is Task or param.annotation == 'Task'
244+
return param.default is CURRENT_TASK
247245

248246

249247
@staticmethod

src/sheppy/utils/validation.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from pydantic import TypeAdapter, ValidationError
1010

11-
from sheppy.models import Task
11+
from sheppy.models import CURRENT_TASK
1212

1313
from .fastapi import Depends
1414

@@ -39,11 +39,8 @@ def validate_input(
3939

4040
for param_name, param in signature.parameters.items():
4141

42+
# task self injection
4243
if _is_task_injection(param):
43-
if param.default != inspect.Parameter.empty:
44-
raise ValidationError.from_exception_data(
45-
f"Task injection parameter '{param_name}' cannot have a default value", line_errors=[]
46-
)
4744
if param_name in remaining_kwargs:
4845
raise ValidationError.from_exception_data(
4946
f"Cannot provide value for Task injection parameter '{param_name}'", line_errors=[]
@@ -109,10 +106,7 @@ def validate_input(
109106

110107

111108
def _is_task_injection(param: inspect.Parameter) -> bool:
112-
if param.name != 'self':
113-
return False
114-
115-
return param.annotation is Task or param.annotation == 'Task'
109+
return param.default is CURRENT_TASK
116110

117111

118112
def _is_depends_parameter(param: inspect.Parameter) -> bool:

tests/conftest.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55
import pytest_asyncio
66

7-
from sheppy import Queue, Task, Worker, task
7+
from sheppy import CURRENT_TASK, Queue, Task, Worker, task
88
from sheppy.backend import Backend, MemoryBackend, RedisBackend
99

1010
TEST_QUEUE_NAME = "pytest"
@@ -67,15 +67,15 @@ async def queue(backend: Backend) -> Queue:
6767

6868

6969
@task(retry=2, retry_delay=0.1)
70-
async def async_fail_once(self: Task) -> str:
71-
if self.retry_count == 0:
70+
async def async_fail_once(current: Task = CURRENT_TASK) -> str:
71+
if current.retry_count == 0:
7272
raise Exception("transient error")
7373
return "ok"
7474

7575

7676
@task(retry=2, retry_delay=0)
77-
def sync_fail_once(self: Task) -> str:
78-
if self.retry_count == 0:
77+
def sync_fail_once(current: Task = CURRENT_TASK) -> str:
78+
if current.retry_count == 0:
7979
raise Exception("transient error")
8080
return "ok"
8181

tests/dependencies.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from pydantic import BaseModel
77

8-
from sheppy import Depends, Task, task
8+
from sheppy import CURRENT_TASK, Depends, Task, task
99

1010

1111
class User(BaseModel):
@@ -139,18 +139,18 @@ async def async_task_with_pydantic_model(user: User) -> Status:
139139

140140

141141
@task
142-
def task_with_self(self: Task, x: int, y: int) -> dict[str, Any]:
143-
return {"task_id": self.id, "result": x + y}
142+
def task_with_current_task(current: Task = CURRENT_TASK, x: int = 5, y: int = 6) -> dict[str, Any]:
143+
return {"task_id": current.id, "result": x + y}
144144

145145

146146
@task
147-
def task_with_self_middle(x: int, self: Task, y: int) -> dict[str, Any]:
148-
return {"task_id": self.id, "result": x + y}
147+
def task_with_current_task_middle(x: int, current: Task = CURRENT_TASK, y: int = 7) -> dict[str, Any]:
148+
return {"task_id": current.id, "result": x + y}
149149

150150

151151
@task
152-
def task_with_self_end(x: int, y: int, self: Task) -> dict[str, Any]:
153-
return {"task_id": self.id, "result": x + y}
152+
def task_with_current_task_end(x: int, y: int, current: Task = CURRENT_TASK) -> dict[str, Any]:
153+
return {"task_id": current.id, "result": x + y}
154154

155155

156156
@task
@@ -482,9 +482,9 @@ def deep_recursion_tasks() -> list[TaskTestCase]:
482482
def self_referencing_tasks() -> list[TaskTestCase]:
483483
"""Tasks that should fail."""
484484
return [
485-
TaskTestCase("task_with_self", task_with_self, (22, 33), expected_result=55),
486-
TaskTestCase("task_with_self_middle", task_with_self_middle, (22, 33), expected_result=55),
487-
TaskTestCase("task_with_self_end", task_with_self_end, (22, 33), expected_result=55),
485+
TaskTestCase("task_with_current_task", task_with_current_task, (), {"x": 22, "y": 33}, expected_result=55),
486+
TaskTestCase("task_with_current_task_middle", task_with_current_task_middle, (22, ), {"y": 33}, expected_result=55),
487+
TaskTestCase("task_with_current_task_end", task_with_current_task_end, (22, 33), expected_result=55),
488488
]
489489

490490
@staticmethod

0 commit comments

Comments
 (0)