From 3e0ba57e2ecd0748c70a625792fea64b84897fb2 Mon Sep 17 00:00:00 2001 From: Owen Kephart Date: Fri, 7 Nov 2025 15:45:29 -0800 Subject: [PATCH] [rfc][rip] dedup serdes --- .../partitions/subset/time_window.py | 1 + .../general_tests/test_serdes_dedup.py | 84 +++++++++++ .../dagster_shared/serdes/serdes.py | 135 ++++++++++++++++-- 3 files changed, 211 insertions(+), 9 deletions(-) create mode 100644 python_modules/dagster/dagster_tests/general_tests/test_serdes_dedup.py diff --git a/python_modules/dagster/dagster/_core/definitions/partitions/subset/time_window.py b/python_modules/dagster/dagster/_core/definitions/partitions/subset/time_window.py index e23e788729e18..875ea396a5c8f 100644 --- a/python_modules/dagster/dagster/_core/definitions/partitions/subset/time_window.py +++ b/python_modules/dagster/dagster/_core/definitions/partitions/subset/time_window.py @@ -55,6 +55,7 @@ def before_pack(self, value: "TimeWindowPartitionsSubset") -> "TimeWindowPartiti # value.num_partitions will calculate the number of partitions if the field is None # We want to check if the field is None and replace the value with the calculated value # for serialization + if value._asdict()["num_partitions"] is None: return TimeWindowPartitionsSubset( partitions_def=value.partitions_def, diff --git a/python_modules/dagster/dagster_tests/general_tests/test_serdes_dedup.py b/python_modules/dagster/dagster_tests/general_tests/test_serdes_dedup.py new file mode 100644 index 0000000000000..225c7a1d6662b --- /dev/null +++ b/python_modules/dagster/dagster_tests/general_tests/test_serdes_dedup.py @@ -0,0 +1,84 @@ +import dagster as dg +from dagster import DagsterInstance +from dagster._core.definitions.asset_daemon_cursor import AssetDaemonCursor +from dagster_shared.record import record +from dagster_shared.serdes import whitelist_for_serdes +from dagster_shared.serdes.serdes import deserialize_value_with_dedup, serialize_value_with_dedup + + +@whitelist_for_serdes +@record +class Inner: + number: float + + +@whitelist_for_serdes +@record +class Foo: + name: str + value: int + inner: Inner + + +@whitelist_for_serdes +@record +class Bar: + name: str + single: Foo + multiple: list[Foo] + + +def test_dedup(): + # same object, different ids + f1 = Foo(name="f1", value=1, inner=Inner(number=1.0)) + f1_same = Foo(name="f1", value=1, inner=Inner(number=1.0)) + + f2 = Foo(name="f2", value=2, inner=Inner(number=2.0)) + + bar = Bar(name="bar", single=f1, multiple=[f1, f1, f1_same, f1_same, f2]) + + serialized = serialize_value_with_dedup(bar) + assert "__dedup_mapping__" in serialized + assert "__dedup_ref__" in serialized + deserialized = deserialize_value_with_dedup(serialized, as_type=Bar) + assert deserialized == bar + + +def test_cursor(): + daily_partitions = dg.DailyPartitionsDefinition(start_date="2024-01-01") + + @dg.asset(partitions_def=daily_partitions) + def upstream_1() -> None: ... + + @dg.asset(partitions_def=daily_partitions) + def upstream_2() -> None: ... + + @dg.asset(partitions_def=daily_partitions) + def upstream_3() -> None: ... + + @dg.asset(partitions_def=daily_partitions) + def upstream_4() -> None: ... + + @dg.asset(partitions_def=daily_partitions) + def upstream_5() -> None: ... + + @dg.asset( + deps=[upstream_1, upstream_2, upstream_3, upstream_4, upstream_5], + automation_condition=dg.AutomationCondition.on_cron(cron_schedule="0 * * * *"), + ) + def downstream() -> None: ... + + defs = dg.Definitions( + assets=[upstream_1, upstream_2, upstream_3, upstream_4, upstream_5, downstream] + ) + instance = DagsterInstance.ephemeral() + + result = dg.evaluate_automation_conditions(defs=defs, instance=instance) + cursor = result.cursor + assert isinstance(cursor, AssetDaemonCursor) + + serialized = serialize_value_with_dedup(cursor) + assert "__dedup_mapping__" in serialized + assert "__dedup_ref__" in serialized + deserialized = deserialize_value_with_dedup(serialized, as_type=AssetDaemonCursor) + assert deserialized == cursor diff --git a/python_modules/libraries/dagster-shared/dagster_shared/serdes/serdes.py b/python_modules/libraries/dagster-shared/dagster_shared/serdes/serdes.py index 9222dc401495a..bcb03c05cc88e 100644 --- a/python_modules/libraries/dagster-shared/dagster_shared/serdes/serdes.py +++ b/python_modules/libraries/dagster-shared/dagster_shared/serdes/serdes.py @@ -470,8 +470,10 @@ def is_whitelisted_for_serdes_object( class UnpackContext: """values are unpacked bottom up.""" - def __init__(self): + def __init__(self, obj_mapping: dict): self.observed_unknown_serdes_values: set[UnknownSerdesValue] = set() + self.obj_mapping = obj_mapping + self.obj_mapping_unpacked = {} def assert_no_unknown_values(self, obj: UnpackedValue) -> PackableValue: if isinstance(obj, UnknownSerdesValue): @@ -639,8 +641,11 @@ def pack_items( self, value: T, whitelist_map: WhitelistMap, - object_handler: Callable[[SerializableObject, WhitelistMap, str], JsonSerializableValue], + object_handler: Callable[ + [SerializableObject, WhitelistMap, str, Optional["DedupContext"]], JsonSerializableValue + ], descent_path: str, + dedup_context: Optional["DedupContext"], ) -> Iterator[tuple[str, JsonSerializableValue]]: yield "__class__", self.get_storage_name() for key, inner_value in self.object_as_mapping(self.before_pack(value)).items(): @@ -667,6 +672,7 @@ def pack_items( whitelist_map=whitelist_map, object_handler=object_handler, descent_path=f"{descent_path}.{key}", + dedup_context=dedup_context, ), ) for key, default in self.old_fields.items(): @@ -839,7 +845,29 @@ def serialize_value( whitelist_map=whitelist_map, object_handler=_wrap_object, descent_path=_root(val), + dedup_context=None, + ) + return seven.json.dumps(serializable_value, **json_kwargs) + + +def serialize_value_with_dedup( + val: PackableValue, + whitelist_map: WhitelistMap = _WHITELIST_MAP, + **json_kwargs: Any, +) -> str: + ctx = DedupContext() + serializable_value = _transform_for_serialization( + val, + whitelist_map=whitelist_map, + object_handler=_wrap_object, + descent_path=_root(val), + dedup_context=ctx, ) + if ctx.obj_mapping: + serializable_value = { + "__dedup_mapping__": ctx.obj_mapping, + "value": serializable_value, + } return seven.json.dumps(serializable_value, **json_kwargs) @@ -902,15 +930,51 @@ def pack_value( whitelist_map=whitelist_map, descent_path=descent_path, object_handler=_pack_object, + dedup_context=DedupContext(), ) +class DedupContext: + def __init__(self): + self.obj_ids = set() + self.obj_mapping = {} + + def _transform_for_serialization( val: PackableValue, whitelist_map: WhitelistMap, - object_handler: Callable[[SerializableObject, WhitelistMap, str], JsonSerializableValue], + object_handler: Callable[ + [SerializableObject, WhitelistMap, str, Optional[DedupContext]], JsonSerializableValue + ], descent_path: str, + dedup_context: Optional[DedupContext], ) -> JsonSerializableValue: + if dedup_context is not None and is_whitelisted_for_serdes_object(val): + try: + obj_id = hash((type(val), val)) + except TypeError: + # unhashable object + obj_id = None + # we've seen this object before, so return a reference to it + if obj_id in dedup_context.obj_ids: + # make sure we add the object to the mapping + if obj_id not in dedup_context.obj_mapping: + value = _pack_object( + val, # type: ignore + whitelist_map, + descent_path, + dedup_context, + ) + dedup_context.obj_mapping[obj_id] = value + return {"__dedup_ref__": obj_id} + elif obj_id is not None: + dedup_context.obj_ids.add(obj_id) + return _pack_object( + val, # type: ignore + whitelist_map, + descent_path, + dedup_context, + ) # this is a hot code path so we handle the common base cases without isinstance tval = type(val) if tval in (int, float, str, bool) or val is None: @@ -922,6 +986,7 @@ def _transform_for_serialization( whitelist_map, object_handler, f"{descent_path}[{idx}]", + dedup_context, ) for idx, item in enumerate(cast("list", val)) ] @@ -932,6 +997,7 @@ def _transform_for_serialization( whitelist_map, object_handler, f"{descent_path}.{key}", + dedup_context, ) for key, value in cast("dict", val).items() } @@ -944,12 +1010,14 @@ def _transform_for_serialization( whitelist_map, object_handler, f"{descent_path}.{k}", + dedup_context, ), _transform_for_serialization( v, whitelist_map, object_handler, f"{descent_path}.{k}", + dedup_context, ), ] for k, v in cast("dict", val).items() @@ -984,6 +1052,7 @@ def _transform_for_serialization( cast("SerializableObject", val), whitelist_map, descent_path, + dedup_context, ) if isinstance(val, set): set_path = descent_path + "{}" @@ -994,6 +1063,7 @@ def _transform_for_serialization( whitelist_map, object_handler, set_path, + dedup_context, ) for item in sorted(list(val), key=str) ] @@ -1007,6 +1077,7 @@ def _transform_for_serialization( whitelist_map, object_handler, frz_set_path, + dedup_context, ) for item in sorted(list(val), key=str) ] @@ -1024,6 +1095,7 @@ def _transform_for_serialization( whitelist_map, object_handler, f"{descent_path}.{key}", + dedup_context, ) for key, value in val.items() } @@ -1034,6 +1106,7 @@ def _transform_for_serialization( whitelist_map, object_handler, f"{descent_path}[{idx}]", + dedup_context, ) for idx, item in enumerate(val) ] @@ -1042,13 +1115,18 @@ def _transform_for_serialization( def _pack_object( - obj: SerializableObject, whitelist_map: WhitelistMap, descent_path: str + obj: SerializableObject, + whitelist_map: WhitelistMap, + descent_path: str, + dedup_context: Optional[DedupContext], ) -> Mapping[str, JsonSerializableValue]: # the object_handler for _transform_for_serialization to produce dicts for objects klass_name = obj.__class__.__name__ serializer = whitelist_map.object_serializers[klass_name] - return dict(serializer.pack_items(obj, whitelist_map, _pack_object, descent_path)) + return dict( + serializer.pack_items(obj, whitelist_map, _pack_object, descent_path, dedup_context) + ) class _LazySerializationWrapper(dict): @@ -1058,6 +1136,7 @@ class _LazySerializationWrapper(dict): """ __slots__ = [ + "_dedup_context", "_descent_path", "_obj", "_whitelist_map", @@ -1068,10 +1147,12 @@ def __init__( obj: SerializableObject, whitelist_map: WhitelistMap, descent_path: str, + dedup_context: Optional[DedupContext], ): self._obj = obj self._whitelist_map = whitelist_map self._descent_path = descent_path + self._dedup_context = dedup_context # populate backing native dict to work around c fast path check # https://github.com/python/cpython/blob/0fb18b02c8ad56299d6a2910be0bab8ad601ef24/Modules/_json.c#L1542 @@ -1081,7 +1162,11 @@ def items(self) -> Iterator[tuple[str, JsonSerializableValue]]: # pyright: igno klass_name = self._obj.__class__.__name__ serializer = self._whitelist_map.object_serializers[klass_name] yield from serializer.pack_items( - self._obj, self._whitelist_map, _wrap_object, self._descent_path + self._obj, + self._whitelist_map, + _wrap_object, + self._descent_path, + self._dedup_context, ) @@ -1089,9 +1174,10 @@ def _wrap_object( obj: SerializableObject, whitelist_map: WhitelistMap, descent_path: str, + dedup_context: Optional[DedupContext], ) -> "_LazySerializationWrapper": # the object_handler for _transform_for_serialization to use in conjunction with json.dumps for iterative serialization - return _LazySerializationWrapper(obj, whitelist_map, descent_path) + return _LazySerializationWrapper(obj, whitelist_map, descent_path, dedup_context) ################################################################################################### @@ -1185,7 +1271,7 @@ def deserialize_values( ): unpacked_values = [] for val in vals: - context = UnpackContext() + context = UnpackContext(obj_mapping={}) unpacked_value = seven.json.loads( val, object_hook=partial(_unpack_object, whitelist_map=whitelist_map, context=context), @@ -1205,6 +1291,29 @@ def deserialize_values( return unpacked_values +def deserialize_value_with_dedup( + val: str, + as_type: Optional[ + Union[type[T_PackableValue], tuple[type[T_PackableValue], type[U_PackableValue]]] + ] = None, + whitelist_map: WhitelistMap = _WHITELIST_MAP, +) -> Union[PackableValue, T_PackableValue, Union[T_PackableValue, U_PackableValue]]: + parsed = seven.json.loads(val) + if "__dedup_mapping__" in parsed: + context = UnpackContext(obj_mapping=parsed["__dedup_mapping__"]) + value = parsed["value"] + else: + context = UnpackContext(obj_mapping={}) + value = parsed + unpacked_value = _unpack_value(value, whitelist_map, context) + unpacked_value = context.finalize_unpack(unpacked_value) + if as_type and not isinstance(unpacked_value, as_type): + raise DeserializationError( + f"Unpacked object was not expected type {as_type}, got {type(unpacked_value)}" + ) + return unpacked_value + + class UnknownSerdesValue: def __init__(self, message: str, value: Mapping[str, UnpackedValue]): self.message = message @@ -1212,6 +1321,14 @@ def __init__(self, message: str, value: Mapping[str, UnpackedValue]): def _unpack_object(val: dict, whitelist_map: WhitelistMap, context: UnpackContext) -> UnpackedValue: + if "__dedup_ref__" in val: + obj_id = val["__dedup_ref__"] + if obj_id not in context.obj_mapping_unpacked: + context.obj_mapping_unpacked[obj_id] = _unpack_value( + context.obj_mapping[str(obj_id)], whitelist_map, context + ) + return context.obj_mapping_unpacked[obj_id] + if "__class__" in val: klass_name = val["__class__"] if klass_name not in whitelist_map.object_deserializers: @@ -1303,7 +1420,7 @@ def unpack_value( - {"__class__": "", ...}: becomes an instance of the class, where `class` is a NamedTuple, dataclass or pydantic model """ - context = UnpackContext() if context is None else context + context = UnpackContext(obj_mapping={}) if context is None else context unpacked_value = _unpack_value( val, whitelist_map,