Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/dask_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ jobs:
cache-environment-key: environment-${{ steps.date.outputs.date }}-0

- name: Install current main versions of dask
run: python -m pip install git+https://github.com/dask/dask
run: python -m pip install git+https://github.com/dask/dask@6758d14f9049dbc491368125cb8a4ccc6b4e2d84

- name: Install current main versions of distributed
run: python -m pip install git+https://github.com/dask/distributed
Expand Down
12 changes: 10 additions & 2 deletions dask_expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
M,
derived_from,
get_meta_library,
key_split,
maybe_pluralize,
memory_repr,
put_lines,
Expand Down Expand Up @@ -444,7 +445,12 @@ def __dask_postcompute__(self):

def __dask_postpersist__(self):
state = new_collection(self.expr.lower_completely())
return from_graph, (state._meta, state.divisions, state._name)
return from_graph, (
state._meta,
state.divisions,
state.__dask_keys__(),
key_split(state._name),
)

def __getattr__(self, key):
try:
Expand Down Expand Up @@ -4473,7 +4479,9 @@ def from_dask_dataframe(ddf: _Frame, optimize: bool = True) -> FrameBase:
graph = ddf.dask
if optimize:
graph = ddf.__dask_optimize__(graph, ddf.__dask_keys__())
return from_graph(graph, ddf._meta, ddf.divisions, ddf._name)
return from_graph(
graph, ddf._meta, ddf.divisions, ddf.__dask_keys__(), key_split(ddf._name)
)


def from_dask_array(x, columns=None, index=None, meta=None):
Expand Down
61 changes: 39 additions & 22 deletions dask_expr/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,36 @@ def _unpack_collections(o):
class Expr:
_parameters = []
_defaults = {}
_instances = weakref.WeakValueDictionary()

def __init__(self, *args, **kwargs):
def __new__(cls, *args, **kwargs):
operands = list(args)
for parameter in type(self)._parameters[len(operands) :]:
for parameter in cls._parameters[len(operands) :]:
try:
operands.append(kwargs.pop(parameter))
except KeyError:
operands.append(type(self)._defaults[parameter])
operands.append(cls._defaults[parameter])
assert not kwargs, kwargs
operands = [_unpack_collections(o) for o in operands]
self.operands = operands
inst = object.__new__(cls)
inst.operands = [_unpack_collections(o) for o in operands]
_name = inst._name
if _name in Expr._instances:
return Expr._instances[_name]

Expr._instances[_name] = inst
return inst

def _tune_down(self):
return None

def _tune_up(self, parent):
return None

def _cull_down(self):
return None

def _cull_up(self, parent):
return None

def __str__(self):
s = ", ".join(
Expand Down Expand Up @@ -204,28 +223,26 @@ def rewrite(self, kind: str):
_continue = False

# Rewrite this node
if down_name in expr.__dir__():
out = getattr(expr, down_name)()
out = getattr(expr, down_name)()
if out is None:
out = expr
if not isinstance(out, Expr):
return out
if out._name != expr._name:
expr = out
continue

# Allow children to rewrite their parents
for child in expr.dependencies():
out = getattr(child, up_name)(expr)
if out is None:
out = expr
if not isinstance(out, Expr):
return out
if out._name != expr._name:
if out is not expr and out._name != expr._name:
expr = out
continue

# Allow children to rewrite their parents
for child in expr.dependencies():
if up_name in child.__dir__():
out = getattr(child, up_name)(expr)
if out is None:
out = expr
if not isinstance(out, Expr):
return out
if out is not expr and out._name != expr._name:
expr = out
_continue = True
break
_continue = True
break

if _continue:
continue
Expand Down
1 change: 1 addition & 0 deletions dask_expr/io/_delayed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class _DelayedExpr(Expr):
# Wraps a Delayed object to make it an Expr for now. This is hacky and we should
# integrate this properly...
# TODO
_parameters = ["obj"]

def __init__(self, obj):
self.obj = obj
Expand Down
15 changes: 11 additions & 4 deletions dask_expr/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class FromGraph(IO):
conversion from legacy dataframes.
"""

_parameters = ["layer", "_meta", "divisions", "_name"]
_parameters = ["layer", "_meta", "divisions", "keys", "name_prefix"]

@property
def _meta(self):
Expand All @@ -45,12 +45,19 @@ def _meta(self):
def _divisions(self):
return self.operand("divisions")

@property
@functools.cached_property
def _name(self):
return self.operand("_name")
return (
self.operand("name_prefix") + "-" + _tokenize_deterministic(*self.operands)
)

def _layer(self):
return dict(self.operand("layer"))
dsk = dict(self.operand("layer"))
# The name may not actually match the layers name therefore rewrite this
# using an alias
for part, k in enumerate(self.operand("keys")):
dsk[(self._name, part)] = k
return dsk


class BlockwiseIO(Blockwise, IO):
Expand Down
79 changes: 50 additions & 29 deletions dask_expr/io/parquet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import contextlib
import functools
import itertools
import operator
import warnings
Expand All @@ -26,8 +25,9 @@
from dask.dataframe.io.parquet.utils import _split_user_options
from dask.dataframe.io.utils import _is_local_fs
from dask.delayed import delayed
from dask.utils import apply, natural_sort_key, typename
from dask.utils import apply, funcname, natural_sort_key, typename
from fsspec.utils import stringify_path
from toolz import identity

from dask_expr._expr import (
EQ,
Expand All @@ -47,26 +47,15 @@
determine_column_projection,
)
from dask_expr._reductions import Len
from dask_expr._util import _convert_to_list
from dask_expr._util import _convert_to_list, _tokenize_deterministic
from dask_expr.io import BlockwiseIO, PartitionsFiltered

NONE_LABEL = "__null_dask_index__"

_cached_dataset_info = {}
_CACHED_DATASET_SIZE = 10
_CACHED_PLAN_SIZE = 10
_cached_plan = {}


def _control_cached_dataset_info(key):
if (
len(_cached_dataset_info) > _CACHED_DATASET_SIZE
and key not in _cached_dataset_info
):
key_to_pop = list(_cached_dataset_info.keys())[0]
_cached_dataset_info.pop(key_to_pop)


def _control_cached_plan(key):
if len(_cached_plan) > _CACHED_PLAN_SIZE and key not in _cached_plan:
key_to_pop = list(_cached_plan.keys())[0]
Expand Down Expand Up @@ -121,7 +110,7 @@ def _lower(self):
class ToParquetData(Blockwise):
_parameters = ToParquet._parameters

@cached_property
@property
def io_func(self):
return ToParquetFunctionWrapper(
self.engine,
Expand Down Expand Up @@ -257,7 +246,6 @@ def to_parquet(

# Clear read_parquet caches in case we are
# also reading from the overwritten path
_cached_dataset_info.clear()
_cached_plan.clear()

# Always skip divisions checks if divisions are unknown
Expand Down Expand Up @@ -413,6 +401,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO):
"kwargs",
"_partitions",
"_series",
"_dataset_info_cache",
]
_defaults = {
"columns": None,
Expand All @@ -432,6 +421,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO):
"kwargs": None,
"_partitions": None,
"_series": False,
"_dataset_info_cache": None,
}
_pq_length_stats = None
_absorb_projections = True
Expand Down Expand Up @@ -474,7 +464,21 @@ def _simplify_up(self, parent, dependents):
return Literal(sum(_lengths))

@cached_property
def _name(self):
return (
funcname(type(self)).lower()
+ "-"
+ _tokenize_deterministic(self.checksum, *self.operands)
)

@property
def checksum(self):
return self._dataset_info["checksum"]

@property
def _dataset_info(self):
if rv := self.operand("_dataset_info_cache"):
return rv
# Process and split user options
(
dataset_options,
Expand Down Expand Up @@ -536,13 +540,25 @@ def _dataset_info(self):
**other_options,
},
)
dataset_token = tokenize(*args)
if dataset_token not in _cached_dataset_info:
_control_cached_dataset_info(dataset_token)
_cached_dataset_info[dataset_token] = self.engine._collect_dataset_info(
*args
)
dataset_info = _cached_dataset_info[dataset_token].copy()
dataset_info = self.engine._collect_dataset_info(*args)
checksum = []
files_for_checksum = []
if dataset_info["has_metadata_file"]:
if isinstance(self.path, list):
files_for_checksum = [
next(path for path in self.path if path.endswith("_metadata"))
]
else:
files_for_checksum = [self.path + fs.sep + "_metadata"]
else:
files_for_checksum = dataset_info["ds"].files

for file in files_for_checksum:
# The checksum / file info is usually already cached by the fsspec
# FileSystem dir_cache since this info was already asked for in
# _collect_dataset_info
checksum.append(fs.checksum(file))
dataset_info["checksum"] = tokenize(checksum)

# Infer meta, accounting for index and columns arguments.
meta = self.engine._create_dd_meta(dataset_info)
Expand All @@ -558,6 +574,9 @@ def _dataset_info(self):
dataset_info["all_columns"] = all_columns
dataset_info["calculate_divisions"] = self.calculate_divisions

self.operands[
type(self)._parameters.index("_dataset_info_cache")
] = dataset_info
return dataset_info

@property
Expand All @@ -571,10 +590,10 @@ def _meta(self):
return meta[columns]
return meta

@cached_property
@property
def _io_func(self):
if self._plan["empty"]:
return lambda x: x
return identity
dataset_info = self._dataset_info
return ParquetFunctionWrapper(
self.engine,
Expand Down Expand Up @@ -662,7 +681,7 @@ def _update_length_statistics(self):
stat["num-rows"] for stat in _collect_pq_statistics(self)
)

@functools.cached_property
@property
def _fusion_compression_factor(self):
if self.operand("columns") is None:
return 1
Expand Down Expand Up @@ -767,9 +786,11 @@ def _maybe_list(val):
return [val]

return [
_maybe_list(val.to_list_tuple())
if hasattr(val, "to_list_tuple")
else _maybe_list(val)
(
_maybe_list(val.to_list_tuple())
if hasattr(val, "to_list_tuple")
else _maybe_list(val)
)
for val in self
]

Expand Down
8 changes: 5 additions & 3 deletions dask_expr/tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import dask.array as da
import numpy as np
import pytest
from dask.dataframe._compat import PANDAS_GE_210
from dask.dataframe._compat import PANDAS_GE_210, PANDAS_GE_220
from dask.dataframe.utils import UNKNOWN_CATEGORIES
from dask.utils import M

Expand Down Expand Up @@ -983,14 +983,15 @@ def test_broadcast(pdf, df):

def test_persist(pdf, df):
a = df + 2
a *= 2
b = a.persist()

assert_eq(a, b)
assert len(a.__dask_graph__()) > len(b.__dask_graph__())

assert len(b.__dask_graph__()) == b.npartitions
assert len(b.__dask_graph__()) == 2 * b.npartitions

assert_eq(b.y.sum(), (pdf + 2).y.sum())
assert_eq(b.y.sum(), ((pdf + 2) * 2).y.sum())


def test_index(pdf, df):
Expand Down Expand Up @@ -1035,6 +1036,7 @@ def test_head_down(df):
assert not isinstance(optimized.expr, expr.Head)


@pytest.mark.skipif(not PANDAS_GE_220, reason="not implemented")
def test_case_when(pdf, df):
result = df.x.case_when([(df.x.eq(1), 1), (df.y == 10, 2.5)])
expected = pdf.x.case_when([(pdf.x.eq(1), 1), (pdf.y == 10, 2.5)])
Expand Down
2 changes: 1 addition & 1 deletion dask_expr/tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_persist():
b = a.persist()

assert_eq(a, b)
assert len(b.dask) == b.npartitions
assert len(b.dask) == 2 * b.npartitions


def test_lengths():
Expand Down