Skip to content

Commit d1f2343

Browse files
committed
add default comparison
1 parent 7fa450d commit d1f2343

File tree

5 files changed

+54
-8
lines changed

5 files changed

+54
-8
lines changed

python/pydantic_core/core_schema.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2186,6 +2186,7 @@ def with_default_schema(
21862186
*,
21872187
default: Any = PydanticUndefined,
21882188
default_factory: Callable[[], Any] | None = None,
2189+
default_comparison: Callable[[Any, Any], bool] | None = None,
21892190
on_error: Literal['raise', 'omit', 'default'] | None = None,
21902191
validate_default: bool | None = None,
21912192
strict: bool | None = None,
@@ -2211,6 +2212,7 @@ def with_default_schema(
22112212
schema: The schema to add a default value to
22122213
default: The default value to use
22132214
default_factory: A function that returns the default value to use
2215+
default_comparison: A function to compare the default value with any other given
22142216
on_error: What to do if the schema validation fails. One of 'raise', 'omit', 'default'
22152217
validate_default: Whether the default value should be validated
22162218
strict: Whether the underlying schema should be validated with strict mode
@@ -2222,6 +2224,7 @@ def with_default_schema(
22222224
type='default',
22232225
schema=schema,
22242226
default_factory=default_factory,
2227+
default_comparison=default_comparison,
22252228
on_error=on_error,
22262229
validate_default=validate_default,
22272230
strict=strict,

src/serializers/fields.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,9 @@ impl SerField {
6868
}
6969

7070
fn exclude_default(value: &PyAny, extra: &Extra, serializer: &CombinedSerializer) -> PyResult<bool> {
71+
let py = value.py();
7172
if extra.exclude_defaults {
72-
if let Some(default) = serializer.get_default(value.py())? {
73-
if value.eq(default)? {
74-
return Ok(true);
75-
}
76-
}
73+
return serializer.compare_with_default(py, value);
7774
}
7875
Ok(false)
7976
}

src/serializers/shared.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,10 @@ pub(crate) trait TypeSerializer: Send + Sync + Clone + Debug {
301301
fn get_default(&self, _py: Python) -> PyResult<Option<PyObject>> {
302302
Ok(None)
303303
}
304+
305+
fn compare_with_default(&self, _py: Python, _value: &PyAny) -> PyResult<bool> {
306+
Ok(false)
307+
}
304308
}
305309

306310
pub(crate) struct PydanticSerializer<'py> {

src/serializers/type_serializers/with_default.rs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use super::{BuildSerializer, CombinedSerializer, Extra, TypeSerializer};
1313
#[derive(Debug, Clone)]
1414
pub struct WithDefaultSerializer {
1515
default: DefaultType,
16+
default_comparison: Option<PyObject>,
1617
serializer: Box<CombinedSerializer>,
1718
}
1819

@@ -26,11 +27,16 @@ impl BuildSerializer for WithDefaultSerializer {
2627
) -> PyResult<CombinedSerializer> {
2728
let py = schema.py();
2829
let default = DefaultType::new(schema)?;
29-
30+
let default_comparison = schema.get_as(intern!(py, "default_comparison"))?;
3031
let sub_schema: &PyDict = schema.get_as_req(intern!(py, "schema"))?;
3132
let serializer = Box::new(CombinedSerializer::build(sub_schema, config, definitions)?);
3233

33-
Ok(Self { default, serializer }.into())
34+
Ok(Self {
35+
default,
36+
default_comparison,
37+
serializer,
38+
}
39+
.into())
3440
}
3541
}
3642

@@ -74,4 +80,16 @@ impl TypeSerializer for WithDefaultSerializer {
7480
fn get_default(&self, py: Python) -> PyResult<Option<PyObject>> {
7581
self.default.default_value(py)
7682
}
83+
84+
fn compare_with_default(&self, py: Python, value: &PyAny) -> PyResult<bool> {
85+
if let Some(default) = self.get_default(py)? {
86+
if let Some(default_comparison) = &self.default_comparison {
87+
return default_comparison.call(py, (value, default), None)?.extract::<bool>(py);
88+
} else if value.eq(default)? {
89+
return Ok(true);
90+
}
91+
}
92+
93+
Ok(false)
94+
}
7795
}

tests/serializers/test_typed_dict.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,27 @@ def test_exclude_none():
150150

151151

152152
def test_exclude_default():
153+
class TestComparison:
154+
def __init__(self, val: Any):
155+
self.val = val
156+
157+
def __eq__(self, other):
158+
return self.val == other.val
159+
153160
v = SchemaSerializer(
154161
core_schema.typed_dict_schema(
155162
{
156163
'foo': core_schema.typed_dict_field(core_schema.nullable_schema(core_schema.int_schema())),
157164
'bar': core_schema.typed_dict_field(
158165
core_schema.with_default_schema(core_schema.bytes_schema(), default=b'[default]')
159166
),
167+
'foobar': core_schema.typed_dict_field(
168+
core_schema.with_default_schema(
169+
core_schema.any_schema(),
170+
default=TestComparison(val=1),
171+
default_comparison=lambda value, default: value.val == -1 * default.val,
172+
)
173+
),
160174
}
161175
)
162176
)
@@ -165,9 +179,19 @@ def test_exclude_default():
165179
assert v.to_python({'foo': 1, 'bar': b'[default]'}, exclude_defaults=True) == {'foo': 1}
166180
assert v.to_python({'foo': 1, 'bar': b'[default]'}, mode='json') == {'foo': 1, 'bar': '[default]'}
167181
assert v.to_python({'foo': 1, 'bar': b'[default]'}, exclude_defaults=True, mode='json') == {'foo': 1}
168-
169182
assert v.to_json({'foo': 1, 'bar': b'[default]'}) == b'{"foo":1,"bar":"[default]"}'
170183
assert v.to_json({'foo': 1, 'bar': b'[default]'}, exclude_defaults=True) == b'{"foo":1}'
184+
# Note that due to the custom comparison operator foobar must be excluded
185+
assert v.to_python({'foo': 1, 'bar': b'x', 'foobar': TestComparison(val=-1)}, exclude_defaults=True) == {
186+
'foo': 1,
187+
'bar': b'x',
188+
}
189+
# foobar here must be included
190+
assert v.to_python({'foo': 1, 'bar': b'x', 'foobar': TestComparison(val=1)}, exclude_defaults=True) == {
191+
'foo': 1,
192+
'bar': b'x',
193+
'foobar': TestComparison(val=1),
194+
}
171195

172196

173197
def test_function_plain_field_serializer_to_python():

0 commit comments

Comments
 (0)