Skip to content

Commit 66ee2ac

Browse files
Restrict input values (#166)
* Restrict input values To a subset of JSONifiable types Signed-off-by: liamhuber <liamhuber@greyhavensolutions.com> * Allow dictionaries and lists of primitive types * Update test_models.py * Update models.py * Update models.py * Update test_models.py * Update models.py * extend tests --------- Signed-off-by: liamhuber <liamhuber@greyhavensolutions.com> Co-authored-by: Jan Janssen <jan-janssen@users.noreply.github.com>
1 parent b354427 commit 66ee2ac

File tree

2 files changed

+69
-2
lines changed

2 files changed

+69
-2
lines changed

src/python_workflow_definition/models.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
from pathlib import Path
2-
from typing import List, Union, Optional, Literal, Any, Annotated, Type, TypeVar
2+
from typing import (
3+
List,
4+
Union,
5+
Optional,
6+
Literal,
7+
Any,
8+
Annotated,
9+
Type,
10+
TypeVar,
11+
)
12+
from typing_extensions import TypeAliasType
313
from pydantic import BaseModel, Field, field_validator, field_serializer
414
from pydantic import ValidationError
515
import json
@@ -19,6 +29,13 @@
1929
)
2030

2131

32+
JsonPrimitive = Union[str, int, float, bool, None]
33+
AllowableDefaults = TypeAliasType(
34+
"AllowableDefaults",
35+
"Union[JsonPrimitive, dict[str, AllowableDefaults], list[AllowableDefaults]]",
36+
)
37+
38+
2239
class PythonWorkflowDefinitionBaseNode(BaseModel):
2340
"""Base model for all node types, containing common fields."""
2441

@@ -33,7 +50,7 @@ class PythonWorkflowDefinitionInputNode(PythonWorkflowDefinitionBaseNode):
3350

3451
type: Literal["input"]
3552
name: str
36-
value: Optional[Any] = None
53+
value: Optional[AllowableDefaults] = None
3754

3855

3956
class PythonWorkflowDefinitionOutputNode(PythonWorkflowDefinitionBaseNode):

tests/test_models.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from unittest import mock
55
from pydantic import ValidationError
66
from python_workflow_definition.models import (
7+
JsonPrimitive,
78
PythonWorkflowDefinitionInputNode,
89
PythonWorkflowDefinitionOutputNode,
910
PythonWorkflowDefinitionFunctionNode,
@@ -12,6 +13,11 @@
1213
INTERNAL_DEFAULT_HANDLE,
1314
)
1415

16+
17+
class _NoTrivialSerialization:
18+
pass
19+
20+
1521
class TestModels(unittest.TestCase):
1622
def setUp(self):
1723
self.valid_workflow_dict = {
@@ -40,6 +46,50 @@ def test_input_node(self):
4046
)
4147
self.assertEqual(node_with_value.value, 42)
4248

49+
def test_input_node_valid_values(self):
50+
good_values = (
51+
1,
52+
1.1,
53+
"string",
54+
True,
55+
None,
56+
[1, 2],
57+
[["recursive", "tuple"], [True, False]],
58+
)
59+
for value in good_values:
60+
with self.subTest(value=value):
61+
model = PythonWorkflowDefinitionInputNode.model_validate(
62+
{
63+
"id": 0,
64+
"type": "input",
65+
"name": "x",
66+
"value": value,
67+
}
68+
)
69+
self.assertEqual(
70+
value,
71+
PythonWorkflowDefinitionInputNode.model_validate(
72+
model.model_dump(mode="json")
73+
).value
74+
)
75+
76+
def test_input_node_invalid_value_raises(self):
77+
bad_values = (
78+
{1: 2},
79+
_NoTrivialSerialization(),
80+
)
81+
for value in bad_values:
82+
with self.subTest(value=value):
83+
with self.assertRaises(ValidationError):
84+
PythonWorkflowDefinitionInputNode.model_validate(
85+
{
86+
"id": 0,
87+
"type": "input",
88+
"name": "x",
89+
"value": value,
90+
}
91+
)
92+
4393
def test_output_node(self):
4494
node = PythonWorkflowDefinitionOutputNode(id=1, type="output", name="test_output")
4595
self.assertEqual(node.id, 1)

0 commit comments

Comments
 (0)