From 0a9617f6e75e5d0bc1a047ea3bb4b14b296f2e09 Mon Sep 17 00:00:00 2001 From: fjetter Date: Tue, 23 Jan 2024 12:22:49 +0100 Subject: [PATCH 1/8] Expr as singleton --- dask_expr/_core.py | 19 +++++++++++++------ dask_expr/io/_delayed.py | 1 + 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 72e8cc1ba..a2c6125f4 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -28,17 +28,24 @@ def _unpack_collections(o): class Expr: _parameters = [] _defaults = {} + _instances = weakref.WeakValueDictionary() - def __init__(self, *args, **kwargs): + def __new__(cls, *args, **kwargs): operands = list(args) - for parameter in type(self)._parameters[len(operands) :]: + for parameter in cls._parameters[len(operands) :]: try: operands.append(kwargs.pop(parameter)) except KeyError: - operands.append(type(self)._defaults[parameter]) + operands.append(cls._defaults[parameter]) assert not kwargs, kwargs - operands = [_unpack_collections(o) for o in operands] - self.operands = operands + inst = object.__new__(cls) + inst.operands = [_unpack_collections(o) for o in operands] + _name = inst._name + if _name in Expr._instances: + return Expr._instances[_name] + + Expr._instances[_name] = inst + return inst def __str__(self): s = ", ".join( @@ -129,7 +136,7 @@ def operand(self, key): def dependencies(self): # Dependencies are `Expr` operands only - return [operand for operand in self.operands if isinstance(operand, Expr)] + return [operand for operand in operands if isinstance(operand, Expr)] def _task(self, index: int): """The task for the i'th partition diff --git a/dask_expr/io/_delayed.py b/dask_expr/io/_delayed.py index cc2aa4ed3..a0f864f7c 100644 --- a/dask_expr/io/_delayed.py +++ b/dask_expr/io/_delayed.py @@ -20,6 +20,7 @@ class _DelayedExpr(Expr): # Wraps a Delayed object to make it an Expr for now. This is hacky and we should # integrate this properly... # TODO + _parameters = ["obj"] def __init__(self, obj): self.obj = obj From 8433deb2de09332cf93857704c34cfff8d7f2e43 Mon Sep 17 00:00:00 2001 From: fjetter Date: Tue, 23 Jan 2024 12:25:07 +0100 Subject: [PATCH 2/8] Ensure rewrite methods always exist --- dask_expr/_core.py | 44 +++++++++++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index a2c6125f4..99fc87c88 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -47,6 +47,18 @@ def __new__(cls, *args, **kwargs): Expr._instances[_name] = inst return inst + def _tune_down(self): + return None + + def _tune_up(self, parent): + return None + + def _cull_down(self): + return None + + def _cull_up(self, parent): + return None + def __str__(self): s = ", ".join( str(param) + "=" + str(operand) @@ -211,28 +223,26 @@ def rewrite(self, kind: str): _continue = False # Rewrite this node - if down_name in expr.__dir__(): - out = getattr(expr, down_name)() + out = getattr(expr, down_name)() + if out is None: + out = expr + if not isinstance(out, Expr): + return out + if out._name != expr._name: + expr = out + continue + + # Allow children to rewrite their parents + for child in expr.dependencies(): + out = getattr(child, up_name)(expr) if out is None: out = expr if not isinstance(out, Expr): return out - if out._name != expr._name: + if out is not expr and out._name != expr._name: expr = out - continue - - # Allow children to rewrite their parents - for child in expr.dependencies(): - if up_name in child.__dir__(): - out = getattr(child, up_name)(expr) - if out is None: - out = expr - if not isinstance(out, Expr): - return out - if out is not expr and out._name != expr._name: - expr = out - _continue = True - break + _continue = True + break if _continue: continue From ed93aa71253a1525ece8fbc84efd5b0c8d8ede66 Mon Sep 17 00:00:00 2001 From: fjetter Date: Thu, 1 Feb 2024 12:03:28 +0100 Subject: [PATCH 3/8] calc checksum for datasets and tie cache to expr ancestry --- dask_expr/_core.py | 5 ++- dask_expr/io/parquet.py | 77 +++++++++++++++++++++++------------------ 2 files changed, 47 insertions(+), 35 deletions(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 99fc87c88..62f648b81 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -36,7 +36,10 @@ def __new__(cls, *args, **kwargs): try: operands.append(kwargs.pop(parameter)) except KeyError: - operands.append(cls._defaults[parameter]) + default = cls._defaults[parameter] + if callable(default): + default = default() + operands.append(default) assert not kwargs, kwargs inst = object.__new__(cls) inst.operands = [_unpack_collections(o) for o in operands] diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index 6d348b252..c93703369 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -1,7 +1,6 @@ from __future__ import annotations import contextlib -import functools import itertools import operator import warnings @@ -26,8 +25,9 @@ from dask.dataframe.io.parquet.utils import _split_user_options from dask.dataframe.io.utils import _is_local_fs from dask.delayed import delayed -from dask.utils import apply, natural_sort_key, typename +from dask.utils import apply, funcname, natural_sort_key, typename from fsspec.utils import stringify_path +from toolz import identity from dask_expr._expr import ( EQ, @@ -47,26 +47,15 @@ determine_column_projection, ) from dask_expr._reductions import Len -from dask_expr._util import _convert_to_list +from dask_expr._util import _convert_to_list, _tokenize_deterministic from dask_expr.io import BlockwiseIO, PartitionsFiltered NONE_LABEL = "__null_dask_index__" -_cached_dataset_info = {} -_CACHED_DATASET_SIZE = 10 _CACHED_PLAN_SIZE = 10 _cached_plan = {} -def _control_cached_dataset_info(key): - if ( - len(_cached_dataset_info) > _CACHED_DATASET_SIZE - and key not in _cached_dataset_info - ): - key_to_pop = list(_cached_dataset_info.keys())[0] - _cached_dataset_info.pop(key_to_pop) - - def _control_cached_plan(key): if len(_cached_plan) > _CACHED_PLAN_SIZE and key not in _cached_plan: key_to_pop = list(_cached_plan.keys())[0] @@ -121,7 +110,7 @@ def _lower(self): class ToParquetData(Blockwise): _parameters = ToParquet._parameters - @cached_property + @property def io_func(self): return ToParquetFunctionWrapper( self.engine, @@ -257,7 +246,6 @@ def to_parquet( # Clear read_parquet caches in case we are # also reading from the overwritten path - _cached_dataset_info.clear() _cached_plan.clear() # Always skip divisions checks if divisions are unknown @@ -383,11 +371,6 @@ def to_parquet( if compute: out = out.compute(**compute_kwargs) - # Invalidate the filesystem listing cache for the output path after write. - # We do this before returning, even if `compute=False`. This helps ensure - # that reading files that were just written succeeds. - fs.invalidate_cache(path) - return out @@ -413,6 +396,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO): "kwargs", "_partitions", "_series", + "_dataset_info_cache", ] _defaults = { "columns": None, @@ -432,6 +416,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO): "kwargs": None, "_partitions": None, "_series": False, + "_dataset_info_cache": list, } _pq_length_stats = None _absorb_projections = True @@ -474,7 +459,21 @@ def _simplify_up(self, parent, dependents): return Literal(sum(_lengths)) @cached_property + def _name(self): + return ( + funcname(type(self)).lower() + + "-" + + _tokenize_deterministic(self.checksum, *self.operands) + ) + + @property + def checksum(self): + return self._dataset_info["checksum"] + + @property def _dataset_info(self): + if rv := self.operand("_dataset_info_cache"): + return rv[0] # Process and split user options ( dataset_options, @@ -536,13 +535,20 @@ def _dataset_info(self): **other_options, }, ) - dataset_token = tokenize(*args) - if dataset_token not in _cached_dataset_info: - _control_cached_dataset_info(dataset_token) - _cached_dataset_info[dataset_token] = self.engine._collect_dataset_info( - *args - ) - dataset_info = _cached_dataset_info[dataset_token].copy() + dataset_info = self.engine._collect_dataset_info(*args) + checksum = [] + files_for_checksum = [] + if dataset_info["has_metadata_file"]: + files_for_checksum = [self.path + fs.sep + "_metadata"] + else: + files_for_checksum = dataset_info["ds"].files + + for file in files_for_checksum: + # The checksum / file info is usually already cached by the fsspec + # FileSystem dir_cache since this info was already asked for in + # _collect_dataset_info + checksum.append(fs.checksum(file)) + dataset_info["checksum"] = tokenize(checksum) # Infer meta, accounting for index and columns arguments. meta = self.engine._create_dd_meta(dataset_info) @@ -558,6 +564,7 @@ def _dataset_info(self): dataset_info["all_columns"] = all_columns dataset_info["calculate_divisions"] = self.calculate_divisions + self._dataset_info_cache.append(dataset_info) return dataset_info @property @@ -571,10 +578,10 @@ def _meta(self): return meta[columns] return meta - @cached_property + @property def _io_func(self): if self._plan["empty"]: - return lambda x: x + return identity dataset_info = self._dataset_info return ParquetFunctionWrapper( self.engine, @@ -662,7 +669,7 @@ def _update_length_statistics(self): stat["num-rows"] for stat in _collect_pq_statistics(self) ) - @functools.cached_property + @property def _fusion_compression_factor(self): if self.operand("columns") is None: return 1 @@ -767,9 +774,11 @@ def _maybe_list(val): return [val] return [ - _maybe_list(val.to_list_tuple()) - if hasattr(val, "to_list_tuple") - else _maybe_list(val) + ( + _maybe_list(val.to_list_tuple()) + if hasattr(val, "to_list_tuple") + else _maybe_list(val) + ) for val in self ] From a24c6c94bf4beb1287e59eb730e14a106cfd46a7 Mon Sep 17 00:00:00 2001 From: fjetter Date: Thu, 1 Feb 2024 12:10:42 +0100 Subject: [PATCH 4/8] fix merge conflicts --- dask_expr/_core.py | 2 +- dask_expr/tests/test_collection.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 62f648b81..76148faf5 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -151,7 +151,7 @@ def operand(self, key): def dependencies(self): # Dependencies are `Expr` operands only - return [operand for operand in operands if isinstance(operand, Expr)] + return [operand for operand in self.operands if isinstance(operand, Expr)] def _task(self, index: int): """The task for the i'th partition diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index 4d869ca32..4d43e8d70 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -11,7 +11,7 @@ import dask.array as da import numpy as np import pytest -from dask.dataframe._compat import PANDAS_GE_210 +from dask.dataframe._compat import PANDAS_GE_210, PANDAS_GE_220 from dask.dataframe.utils import UNKNOWN_CATEGORIES from dask.utils import M @@ -1035,6 +1035,7 @@ def test_head_down(df): assert not isinstance(optimized.expr, expr.Head) +@pytest.mark.skipif(not PANDAS_GE_220, reason="not implemented") def test_case_when(pdf, df): result = df.x.case_when([(df.x.eq(1), 1), (df.y == 10, 2.5)]) expected = pdf.x.case_when([(pdf.x.eq(1), 1), (pdf.y == 10, 2.5)]) From f8b9f4e8649fd747d04c32ffc4e740df500e3fc7 Mon Sep 17 00:00:00 2001 From: fjetter Date: Thu, 1 Feb 2024 14:01:28 +0100 Subject: [PATCH 5/8] deal with weird paths --- dask_expr/_core.py | 5 +---- dask_expr/io/parquet.py | 24 ++++++++++++++++++++---- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 76148faf5..4e75e0a47 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -36,10 +36,7 @@ def __new__(cls, *args, **kwargs): try: operands.append(kwargs.pop(parameter)) except KeyError: - default = cls._defaults[parameter] - if callable(default): - default = default() - operands.append(default) + operands.append(cls._defaults[parameter]) assert not kwargs, kwargs inst = object.__new__(cls) inst.operands = [_unpack_collections(o) for o in operands] diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index c93703369..ad311e189 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -56,6 +56,15 @@ _cached_plan = {} +class _DatasetInfoCache(dict): + ... + + +@normalize_token.register(_DatasetInfoCache) +def _tokeniz_dataset_info_cache(x): + return x["checksum"] + + def _control_cached_plan(key): if len(_cached_plan) > _CACHED_PLAN_SIZE and key not in _cached_plan: key_to_pop = list(_cached_plan.keys())[0] @@ -416,7 +425,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO): "kwargs": None, "_partitions": None, "_series": False, - "_dataset_info_cache": list, + "_dataset_info_cache": None, } _pq_length_stats = None _absorb_projections = True @@ -473,7 +482,7 @@ def checksum(self): @property def _dataset_info(self): if rv := self.operand("_dataset_info_cache"): - return rv[0] + return rv # Process and split user options ( dataset_options, @@ -539,7 +548,12 @@ def _dataset_info(self): checksum = [] files_for_checksum = [] if dataset_info["has_metadata_file"]: - files_for_checksum = [self.path + fs.sep + "_metadata"] + if isinstance(self.path, list): + files_for_checksum = [ + next(path for path in self.path if path.endswith("_metadata")) + ] + else: + files_for_checksum = [self.path + fs.sep + "_metadata"] else: files_for_checksum = dataset_info["ds"].files @@ -564,7 +578,9 @@ def _dataset_info(self): dataset_info["all_columns"] = all_columns dataset_info["calculate_divisions"] = self.calculate_divisions - self._dataset_info_cache.append(dataset_info) + self.operands[ + type(self)._parameters.index("_dataset_info_cache") + ] = _DatasetInfoCache(dataset_info) return dataset_info @property From 129aa7a4eeca3714309250933aa04c9e51a3c014 Mon Sep 17 00:00:00 2001 From: fjetter Date: Thu, 1 Feb 2024 14:24:01 +0100 Subject: [PATCH 6/8] restore invalidate cache --- dask_expr/io/parquet.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index ad311e189..90c2aa53a 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -56,15 +56,6 @@ _cached_plan = {} -class _DatasetInfoCache(dict): - ... - - -@normalize_token.register(_DatasetInfoCache) -def _tokeniz_dataset_info_cache(x): - return x["checksum"] - - def _control_cached_plan(key): if len(_cached_plan) > _CACHED_PLAN_SIZE and key not in _cached_plan: key_to_pop = list(_cached_plan.keys())[0] @@ -380,6 +371,11 @@ def to_parquet( if compute: out = out.compute(**compute_kwargs) + # Invalidate the filesystem listing cache for the output path after write. + # We do this before returning, even if `compute=False`. This helps ensure + # that reading files that were just written succeeds. + fs.invalidate_cache(path) + return out @@ -580,7 +576,7 @@ def _dataset_info(self): self.operands[ type(self)._parameters.index("_dataset_info_cache") - ] = _DatasetInfoCache(dataset_info) + ] = dataset_info return dataset_info @property From cb85ffc2b7023a3fd70b66db125b37eb3efb6124 Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 2 Feb 2024 15:35:05 +0100 Subject: [PATCH 7/8] refactor FromGraph to tokenize operands --- dask_expr/_collection.py | 4 ++-- dask_expr/io/io.py | 13 +++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/dask_expr/_collection.py b/dask_expr/_collection.py index a76202629..c40f5b74d 100644 --- a/dask_expr/_collection.py +++ b/dask_expr/_collection.py @@ -444,7 +444,7 @@ def __dask_postcompute__(self): def __dask_postpersist__(self): state = new_collection(self.expr.lower_completely()) - return from_graph, (state._meta, state.divisions, state._name) + return from_graph, (state._meta, state.divisions, state.__dask_keys__()) def __getattr__(self, key): try: @@ -4473,7 +4473,7 @@ def from_dask_dataframe(ddf: _Frame, optimize: bool = True) -> FrameBase: graph = ddf.dask if optimize: graph = ddf.__dask_optimize__(graph, ddf.__dask_keys__()) - return from_graph(graph, ddf._meta, ddf.divisions, ddf._name) + return from_graph(graph, ddf._meta, ddf.divisions, ddf.__dask_keys__()) def from_dask_array(x, columns=None, index=None, meta=None): diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index 3d8b020cd..eec6451da 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -36,7 +36,7 @@ class FromGraph(IO): conversion from legacy dataframes. """ - _parameters = ["layer", "_meta", "divisions", "_name"] + _parameters = ["layer", "_meta", "divisions", "_dask_keys"] @property def _meta(self): @@ -45,12 +45,13 @@ def _meta(self): def _divisions(self): return self.operand("divisions") - @property - def _name(self): - return self.operand("_name") - def _layer(self): - return dict(self.operand("layer")) + dsk = dict(self.operand("layer")) + # The name may not actually match the layers name therefore rewrite this + # using an alias + for part, k in enumerate(self.operand("_dask_keys")): + dsk[(self._name, part)] = k + return dsk class BlockwiseIO(Blockwise, IO): From 4f8e0e3ef3dc11c5205cae59bcd623b7cffc6784 Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 2 Feb 2024 16:20:19 +0100 Subject: [PATCH 8/8] ensure name prefix stays the same --- dask_expr/_collection.py | 12 ++++++++++-- dask_expr/io/io.py | 10 ++++++++-- dask_expr/tests/test_collection.py | 5 +++-- dask_expr/tests/test_datasets.py | 2 +- 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/dask_expr/_collection.py b/dask_expr/_collection.py index c40f5b74d..6f9a0f4ff 100644 --- a/dask_expr/_collection.py +++ b/dask_expr/_collection.py @@ -48,6 +48,7 @@ M, derived_from, get_meta_library, + key_split, maybe_pluralize, memory_repr, put_lines, @@ -444,7 +445,12 @@ def __dask_postcompute__(self): def __dask_postpersist__(self): state = new_collection(self.expr.lower_completely()) - return from_graph, (state._meta, state.divisions, state.__dask_keys__()) + return from_graph, ( + state._meta, + state.divisions, + state.__dask_keys__(), + key_split(state._name), + ) def __getattr__(self, key): try: @@ -4473,7 +4479,9 @@ def from_dask_dataframe(ddf: _Frame, optimize: bool = True) -> FrameBase: graph = ddf.dask if optimize: graph = ddf.__dask_optimize__(graph, ddf.__dask_keys__()) - return from_graph(graph, ddf._meta, ddf.divisions, ddf.__dask_keys__()) + return from_graph( + graph, ddf._meta, ddf.divisions, ddf.__dask_keys__(), key_split(ddf._name) + ) def from_dask_array(x, columns=None, index=None, meta=None): diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index eec6451da..25ef1abc2 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -36,7 +36,7 @@ class FromGraph(IO): conversion from legacy dataframes. """ - _parameters = ["layer", "_meta", "divisions", "_dask_keys"] + _parameters = ["layer", "_meta", "divisions", "keys", "name_prefix"] @property def _meta(self): @@ -45,11 +45,17 @@ def _meta(self): def _divisions(self): return self.operand("divisions") + @functools.cached_property + def _name(self): + return ( + self.operand("name_prefix") + "-" + _tokenize_deterministic(*self.operands) + ) + def _layer(self): dsk = dict(self.operand("layer")) # The name may not actually match the layers name therefore rewrite this # using an alias - for part, k in enumerate(self.operand("_dask_keys")): + for part, k in enumerate(self.operand("keys")): dsk[(self._name, part)] = k return dsk diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index 4d43e8d70..9501c512a 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -983,14 +983,15 @@ def test_broadcast(pdf, df): def test_persist(pdf, df): a = df + 2 + a *= 2 b = a.persist() assert_eq(a, b) assert len(a.__dask_graph__()) > len(b.__dask_graph__()) - assert len(b.__dask_graph__()) == b.npartitions + assert len(b.__dask_graph__()) == 2 * b.npartitions - assert_eq(b.y.sum(), (pdf + 2).y.sum()) + assert_eq(b.y.sum(), ((pdf + 2) * 2).y.sum()) def test_index(pdf, df): diff --git a/dask_expr/tests/test_datasets.py b/dask_expr/tests/test_datasets.py index 1541beb29..aa210dce7 100644 --- a/dask_expr/tests/test_datasets.py +++ b/dask_expr/tests/test_datasets.py @@ -53,7 +53,7 @@ def test_persist(): b = a.persist() assert_eq(a, b) - assert len(b.dask) == b.npartitions + assert len(b.dask) == 2 * b.npartitions def test_lengths():