Skip to content

Commit 8c1a250

Browse files
committed
Allow immutable json primitives recursively
Signed-off-by: liamhuber <liamhuber@greyhavensolutions.com>
1 parent b354427 commit 8c1a250

File tree

2 files changed

+67
-2
lines changed

2 files changed

+67
-2
lines changed

src/python_workflow_definition/models.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
from pathlib import Path
2-
from typing import List, Union, Optional, Literal, Any, Annotated, Type, TypeVar
2+
from typing import (
3+
List,
4+
Union,
5+
Optional,
6+
Literal,
7+
Any,
8+
Annotated,
9+
Type,
10+
TypeAliasType,
11+
TypeVar,
12+
)
313
from pydantic import BaseModel, Field, field_validator, field_serializer
414
from pydantic import ValidationError
515
import json
@@ -18,6 +28,11 @@
1828
"PythonWorkflowDefinitionWorkflow",
1929
)
2030

31+
JsonPrimitive = Union[str, int, float, bool, None]
32+
AllowableDefaults = TypeAliasType(
33+
"AllowableDefaults", "Union[JsonPrimitive, tuple[AllowableDefaults, ...]]"
34+
)
35+
2136

2237
class PythonWorkflowDefinitionBaseNode(BaseModel):
2338
"""Base model for all node types, containing common fields."""
@@ -33,7 +48,7 @@ class PythonWorkflowDefinitionInputNode(PythonWorkflowDefinitionBaseNode):
3348

3449
type: Literal["input"]
3550
name: str
36-
value: Optional[Any] = None
51+
value: Optional[AllowableDefaults] = None
3752

3853

3954
class PythonWorkflowDefinitionOutputNode(PythonWorkflowDefinitionBaseNode):

tests/test_models.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
INTERNAL_DEFAULT_HANDLE,
1313
)
1414

15+
16+
class _NoTrivialSerialization:
17+
pass
18+
19+
1520
class TestModels(unittest.TestCase):
1621
def setUp(self):
1722
self.valid_workflow_dict = {
@@ -40,6 +45,51 @@ def test_input_node(self):
4045
)
4146
self.assertEqual(node_with_value.value, 42)
4247

48+
def test_input_node_valid_values(self):
49+
good_values = (
50+
1,
51+
1.1,
52+
"string",
53+
True,
54+
None,
55+
(1, 2),
56+
(("recursive", "tuple"), (True, False)),
57+
)
58+
for value in good_values:
59+
with self.subTest(value=value):
60+
model = PythonWorkflowDefinitionInputNode.model_validate(
61+
{
62+
"id": 0,
63+
"type": "input",
64+
"name": "x",
65+
"value": value,
66+
}
67+
)
68+
self.assertEqual(
69+
value,
70+
PythonWorkflowDefinitionInputNode.model_validate(
71+
model.model_dump(mode="json")
72+
).value
73+
)
74+
75+
76+
def test_input_node_invalid_value_raises(self):
77+
bad_values = (
78+
{"mutable": "thing"},
79+
_NoTrivialSerialization(),
80+
)
81+
for value in bad_values:
82+
with self.subTest(value=value):
83+
with self.assertRaises(ValidationError):
84+
PythonWorkflowDefinitionInputNode.model_validate(
85+
{
86+
"id": 0,
87+
"type": "input",
88+
"name": "x",
89+
"value": value,
90+
}
91+
)
92+
4393
def test_output_node(self):
4494
node = PythonWorkflowDefinitionOutputNode(id=1, type="output", name="test_output")
4595
self.assertEqual(node.id, 1)

0 commit comments

Comments
 (0)