From 13b828f3d87529d67f3fa44ce7b56ebe24ac574d Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Mon, 15 May 2023 08:54:08 -0500 Subject: [PATCH 01/31] start experimenting with parquet statistics --- dask_expr/collection.py | 10 ++++- dask_expr/expr.py | 12 ++++- dask_expr/io/parquet.py | 99 +++++++++++++++++++++++++++++++++-------- 3 files changed, 98 insertions(+), 23 deletions(-) diff --git a/dask_expr/collection.py b/dask_expr/collection.py index e1ab97b27..093043fa0 100644 --- a/dask_expr/collection.py +++ b/dask_expr/collection.py @@ -73,6 +73,12 @@ def _meta(self): def size(self): return new_collection(self.expr.size) + def __len__(self): + _len = self.expr._len + if isinstance(_len, expr.Expr): + _len = new_collection(_len).compute() + return _len + def __reduce__(self): return new_collection, (self._expr,) @@ -546,7 +552,7 @@ def read_parquet( index=None, storage_options=None, dtype_backend=None, - calculate_divisions=False, + gather_statistics=True, ignore_metadata_file=False, metadata_task_size=None, split_row_groups="infer", @@ -571,7 +577,7 @@ def read_parquet( categories=categories, index=index, storage_options=storage_options, - calculate_divisions=calculate_divisions, + gather_statistics=gather_statistics, ignore_metadata_file=ignore_metadata_file, metadata_task_size=metadata_task_size, split_row_groups=split_row_groups, diff --git a/dask_expr/expr.py b/dask_expr/expr.py index 192aee52a..0b6169984 100644 --- a/dask_expr/expr.py +++ b/dask_expr/expr.py @@ -57,6 +57,11 @@ def ndim(self): except AttributeError: return 0 + @functools.cached_property + def _len(self): + # TODO: Use single column + return Len(self) + def __str__(self): s = ", ".join( str(param) + "=" + str(operand) @@ -724,7 +729,10 @@ class Elemwise(Blockwise): optimizations, like `len` will care about which operations preserve length """ - pass + @property + def _len(self): + # Length must be equal to parent + return self.frame._len class AsType(Elemwise): @@ -1318,4 +1326,4 @@ def _execute_task(graph, name, *deps): from dask_expr.io import BlockwiseIO -from dask_expr.reductions import Count, Max, Mean, Min, Mode, Size, Sum +from dask_expr.reductions import Count, Len, Max, Mean, Min, Mode, Size, Sum diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index c5fc0510b..6e74a3063 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -3,11 +3,13 @@ import operator from functools import cached_property +import pandas as pd from dask.dataframe.io.parquet.core import ( ParquetFunctionWrapper, + aggregate_row_groups, get_engine, - process_statistics, set_index_columns, + sorted_columns, ) from dask.dataframe.io.parquet.utils import _split_user_options from dask.utils import natural_sort_key @@ -27,6 +29,60 @@ def _list_columns(columns): return columns +def _align_statistics(parts, statistics): + # Make sure parts and statistics are aligned + # (if statistics is not empty) + if statistics and len(parts) != len(statistics): + statistics = [] + if statistics: + result = list( + zip( + *[ + (part, stats) + for part, stats in zip(parts, statistics) + if stats["num-rows"] > 0 + ] + ) + ) + parts, statistics = result or [[], []] + return parts, statistics + + +def _aggregate_row_groups(parts, statistics, dataset_info): + # Aggregate parts/statistics if we are splitting by row-group + blocksize = ( + dataset_info["blocksize"] if dataset_info["split_row_groups"] is True else None + ) + split_row_groups = dataset_info["split_row_groups"] + fs = dataset_info["fs"] + aggregation_depth = dataset_info["aggregation_depth"] + + if statistics: + if blocksize or (split_row_groups and int(split_row_groups) > 1): + parts, statistics = aggregate_row_groups( + parts, statistics, blocksize, split_row_groups, fs, aggregation_depth + ) + return parts, statistics + + +def _calculate_divisions(statistics, dataset_info, npartitions): + # Use statistics to define divisions + divisions = None + if statistics: + calculate_divisions = dataset_info["kwargs"].get("calculate_divisions", None) + index = dataset_info["index"] + process_columns = index if index and len(index) == 1 else None + if (calculate_divisions is not False) and process_columns: + for sorted_column_info in sorted_columns( + statistics, columns=process_columns + ): + if sorted_column_info["name"] in index: + divisions = sorted_column_info["divisions"] + break + + return divisions or (None,) * (npartitions + 1) + + class ReadParquet(PartitionsFiltered, BlockwiseIO): """Read a parquet dataset""" @@ -37,7 +93,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO): "categories", "index", "storage_options", - "calculate_divisions", + "gather_statistics", "ignore_metadata_file", "metadata_task_size", "split_row_groups", @@ -55,7 +111,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO): "categories": None, "index": None, "storage_options": None, - "calculate_divisions": False, + "gather_statistics": True, "ignore_metadata_file": False, "metadata_task_size": None, "split_row_groups": "infer", @@ -172,7 +228,7 @@ def _dataset_info(self): fs, self.categories, index, - self.calculate_divisions, + self.gather_statistics, self.filters, self.split_row_groups, blocksize, @@ -189,6 +245,7 @@ def _dataset_info(self): # Infer meta, accounting for index and columns arguments. meta = self.engine._create_dd_meta(dataset_info) + index = dataset_info["index"] index = [index] if isinstance(index, str) else index meta, index, columns = set_index_columns( meta, index, self.operand("columns"), auto_index_allowed @@ -216,21 +273,14 @@ def _plan(self): dataset_info ) - # Parse dataset statistics from metadata (if available) - parts, divisions, _ = process_statistics( - parts, - stats, - dataset_info["filters"], - dataset_info["index"], - ( - dataset_info["blocksize"] - if dataset_info["split_row_groups"] is True - else None - ), - dataset_info["split_row_groups"], - dataset_info["fs"], - dataset_info["aggregation_depth"], - ) + # Make sure parts and stats are aligned + parts, stats = _align_statistics(parts, stats) + + # Use statistics to aggregate partitions + parts, stats = _aggregate_row_groups(parts, stats, dataset_info) + + # Use statistics to calculate divisions + divisions = _calculate_divisions(stats, dataset_info, len(parts)) meta = dataset_info["meta"] if len(divisions) < 2: @@ -254,6 +304,7 @@ def _plan(self): return { "func": io_func, "parts": parts, + "statistics": stats, "divisions": divisions, } @@ -265,3 +316,13 @@ def _filtered_task(self, index: int): if self._series: return (operator.getitem, tsk, self.columns[0]) return tsk + + @property + def _statistics(self): + return self._plan["statistics"] + + @property + def _len(self): + if self._statistics and not self.filters: + return pd.DataFrame(self._statistics)["num-rows"].sum() + return super()._len From 990ba4cb78a8dc606a0c2b887ce7f5be6cf76314 Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Mon, 15 May 2023 19:50:37 -0500 Subject: [PATCH 02/31] adopt parts of #40 --- dask_expr/expr.py | 40 ++++++++++++++++++++++++++++++++++++---- dask_expr/io/io.py | 7 +++++++ dask_expr/io/parquet.py | 12 +++++------- 3 files changed, 48 insertions(+), 11 deletions(-) diff --git a/dask_expr/expr.py b/dask_expr/expr.py index 4ccf2ec70..3fbb42a20 100644 --- a/dask_expr/expr.py +++ b/dask_expr/expr.py @@ -59,6 +59,9 @@ def ndim(self): @functools.cached_property def _len(self): + stats = self.statistics() + if "row_count" in stats: + return sum(stats["row_count"]) # TODO: Use single column return Len(self) @@ -211,6 +214,30 @@ def _layer(self) -> dict: return {(self._name, i): self._task(i) for i in range(self.npartitions)} + def _statistics(self): + return {} + + def statistics(self) -> dict: + """Known quantities of an expression, like length or min/max + + To define this on a class create a `._statistics` method that returns a + dictionary of new statistics known by that class. If nothing is known it + is ok to return None. Superclasses will also be consulted. + + Examples + -------- + >>> df.statistics() + {"length": 1000000} + """ + out = {} + for typ in type(self).mro()[::-1]: + if not issubclass(typ, Expr): + continue + d = typ._statistics(self) # TODO: maybe this should be cached + if d: + out.update(d) # TODO: this is fragile + return out + def simplify(self): """Simplify expression @@ -738,10 +765,10 @@ class Elemwise(Blockwise): optimizations, like `len` will care about which operations preserve length """ - @property - def _len(self): - # Length must be equal to parent - return self.frame._len + def _statistics(self): + for dep in self.dependencies(): + if "row_count" in dep.statistics(): + return {"row_count": dep.statistics()["row_count"]} class AsType(Elemwise): @@ -1052,6 +1079,11 @@ def _simplify_down(self): def _node_label_args(self): return [self.frame, self.partitions] + def _statistics(self): + if "row_count" in self.frame.statistics(): + row_counts = self.frame.statistics()["row_count"] + return {"row_count": tuple(row_counts[p] for p in self.partitions)} + class PartitionsFiltered(Expr): """Mixin class for partition filtering diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index 18927d8d8..40941002e 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -68,6 +68,13 @@ def _divisions_and_locations(self): divisions = (None,) * len(locations) return divisions, locations + def _statistics(self): + locations = self._locations() + row_counts = tuple( + offset - locations[i] for i, offset in enumerate(locations[1:]) + ) + return {"row_count": row_counts} + def _divisions(self): return self._divisions_and_locations[0] diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index 3c76c689d..f617b1c99 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -4,7 +4,6 @@ import operator from functools import cached_property -import pandas as pd from dask.dataframe.io.parquet.core import ( ParquetFunctionWrapper, aggregate_row_groups, @@ -297,15 +296,14 @@ def _filtered_task(self, index: int): return (operator.getitem, tsk, self.columns[0]) return tsk - @property def _statistics(self): - return self._plan["statistics"] + if self._pq_statistics and not self.filters: + row_count = tuple(stat["num-rows"] for stat in self._pq_statistics) + return {"row_count": row_count} @property - def _len(self): - if self._statistics and not self.filters: - return pd.DataFrame(self._statistics)["num-rows"].sum() - return super()._len + def _pq_statistics(self): + return self._plan["statistics"] # From 1c62f4c5e89f05a4ebbe79eb46f9b2731333ca94 Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Tue, 16 May 2023 14:18:22 -0500 Subject: [PATCH 03/31] experimenting with dedicated Metadata class structure --- dask_expr/expr.py | 49 ++++++++++++++++++++++------------------- dask_expr/io/io.py | 5 +++-- dask_expr/io/parquet.py | 5 +++-- 3 files changed, 32 insertions(+), 27 deletions(-) diff --git a/dask_expr/expr.py b/dask_expr/expr.py index 3fbb42a20..55e7495fa 100644 --- a/dask_expr/expr.py +++ b/dask_expr/expr.py @@ -59,10 +59,9 @@ def ndim(self): @functools.cached_property def _len(self): - stats = self.statistics() - if "row_count" in stats: - return sum(stats["row_count"]) - # TODO: Use single column + metadata = self.metadata() + if "row_count" in metadata: + return metadata["row_count"].sum() return Len(self) def __str__(self): @@ -214,26 +213,38 @@ def _layer(self) -> dict: return {(self._name, i): self._task(i) for i in range(self.npartitions)} - def _statistics(self): - return {} - - def statistics(self) -> dict: - """Known quantities of an expression, like length or min/max + def _metadata(self): + """New metadata to add for this expression""" + from dask_expr.metadata import Metadata - To define this on a class create a `._statistics` method that returns a - dictionary of new statistics known by that class. If nothing is known it + # Inherit metadata from dependencies + metadata = {} + for dep in self.dependencies(): + for k, v in dep.metadata().items(): + assert isinstance(v, Metadata) + if k not in metadata: + val = v.inherit(self) + if val: + metadata[k] = val + return metadata + + def metadata(self) -> dict: + """Known metadata of an expression, like partition statistics + + To define this on a class create a `._metadata` method that returns a + dictionary of new metadata known by that class. If nothing is known it is ok to return None. Superclasses will also be consulted. Examples -------- - >>> df.statistics() - {"length": 1000000} + >>> df.metadata() + {'row_count': RowCountMetadata(data=(1000000,))} """ out = {} for typ in type(self).mro()[::-1]: if not issubclass(typ, Expr): continue - d = typ._statistics(self) # TODO: maybe this should be cached + d = typ._metadata(self) or {} if d: out.update(d) # TODO: this is fragile return out @@ -765,10 +776,7 @@ class Elemwise(Blockwise): optimizations, like `len` will care about which operations preserve length """ - def _statistics(self): - for dep in self.dependencies(): - if "row_count" in dep.statistics(): - return {"row_count": dep.statistics()["row_count"]} + pass class AsType(Elemwise): @@ -1079,11 +1087,6 @@ def _simplify_down(self): def _node_label_args(self): return [self.frame, self.partitions] - def _statistics(self): - if "row_count" in self.frame.statistics(): - row_counts = self.frame.statistics()["row_count"] - return {"row_count": tuple(row_counts[p] for p in self.partitions)} - class PartitionsFiltered(Expr): """Mixin class for partition filtering diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index 40941002e..500c74f10 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -4,6 +4,7 @@ from dask.dataframe.io.io import sorted_division_locations from dask_expr.expr import Blockwise, Expr, PartitionsFiltered +from dask_expr.metadata import RowCountMetadata class IO(Expr): @@ -68,12 +69,12 @@ def _divisions_and_locations(self): divisions = (None,) * len(locations) return divisions, locations - def _statistics(self): + def _metadata(self): locations = self._locations() row_counts = tuple( offset - locations[i] for i, offset in enumerate(locations[1:]) ) - return {"row_count": row_counts} + return {"row_count": RowCountMetadata(row_counts)} def _divisions(self): return self._divisions_and_locations[0] diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index f617b1c99..cccfc9ff2 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -16,6 +16,7 @@ from dask_expr.expr import EQ, GE, GT, LE, LT, NE, And, Expr, Filter, Or, Projection from dask_expr.io import BlockwiseIO, PartitionsFiltered +from dask_expr.metadata import RowCountMetadata NONE_LABEL = "__null_dask_index__" @@ -296,10 +297,10 @@ def _filtered_task(self, index: int): return (operator.getitem, tsk, self.columns[0]) return tsk - def _statistics(self): + def _metadata(self): if self._pq_statistics and not self.filters: row_count = tuple(stat["num-rows"] for stat in self._pq_statistics) - return {"row_count": row_count} + return {"row_count": RowCountMetadata(row_count)} @property def _pq_statistics(self): From afd59d7e9fe7ce3fbd12eb575d7f58a0e31a924b Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Tue, 16 May 2023 14:25:32 -0500 Subject: [PATCH 04/31] add missing file --- dask_expr/metadata.py | 86 ++++++++++++++++++++++++++++++ dask_expr/tests/test_collection.py | 12 +++++ 2 files changed, 98 insertions(+) create mode 100644 dask_expr/metadata.py diff --git a/dask_expr/metadata.py b/dask_expr/metadata.py new file mode 100644 index 000000000..dd5f1f2a5 --- /dev/null +++ b/dask_expr/metadata.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from collections.abc import Iterable +from dataclasses import dataclass +from functools import singledispatchmethod +from typing import Any + +from dask_expr.expr import Elemwise, Expr, Partitions + + +@dataclass(frozen=True) +class Metadata: + """Abstract expression-metadata class + + See Also + -------- + StaticMetadata + PartitionMetadata + """ + + data: Any + + @singledispatchmethod + def inherit(self, child: Expr) -> Metadata | None: + """New `Metadata` object that a "child" Expr mayinherit + + A return value of `None` means that `type(Expr)` is + not eligable to inherit this kind of metadata. + """ + return None + + +@dataclass(frozen=True) +class StaticMetadata(Metadata): + """A static metadata object + + This metadata is not partition-specific, and can be + inherited by any child `Expr`. + """ + + def inherit(self, child: Expr) -> StaticMetadata: + return self + + +@dataclass(frozen=True) +class PartitionMetadata(Metadata): + """Metadata containing a distinct value for every partition + + See Also + -------- + RowCountMetadata + """ + + data: Iterable + + +@PartitionMetadata.inherit.register +def _partitionmetadata_partitions(self, child: Partitions): + # A `Partitions` expression may inherit metadata + # from the selected partitions + return type(self)( + type(self.data)( + part for i, part in enumerate(self.data) if i in child.partitions + ) + ) + + +# +# PartitionMetadata sub-classes +# + + +@dataclass(frozen=True) +class RowCountMetadata(PartitionMetadata): + """Tracks the row count of each partition""" + + def sum(self): + """Return the total row-count of all partitions""" + return sum(self.data) + + +@RowCountMetadata.inherit.register +def _rowcount_elemwise(self, child: Elemwise): + # All Element-wise operations may inherit + # row-count metadata "as is" + return self diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index 21dfe370a..cae821b7b 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -429,3 +429,15 @@ def test_repartition_divisions(df, opt): if len(part): assert part.min() >= df2.divisions[p] assert part.max() < df2.divisions[p + 1] + + +def test_statistics(df, pdf): + df2 = df[["x"]] + 1 + assert len(df2) == len(pdf) + assert df2.metadata().get("row_count").sum() == len(pdf) + assert df[df.x > 5].metadata().get("row_count") is None + + # Check `partitions` + first = df2.partitions[0].compute() + assert len(df2.partitions[0]) == len(first) + assert df2.partitions[0].metadata().get("row_count").sum() == len(first) From 830230587f3f55dc1107b65195d1a107bf4f5839 Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Tue, 16 May 2023 14:55:23 -0500 Subject: [PATCH 05/31] go back to and remove sub-class for now --- dask_expr/expr.py | 40 ++++++++++----------- dask_expr/io/io.py | 6 ++-- dask_expr/io/parquet.py | 6 ++-- dask_expr/{metadata.py => statistics.py} | 45 +++++++++--------------- dask_expr/tests/test_collection.py | 6 ++-- 5 files changed, 45 insertions(+), 58 deletions(-) rename dask_expr/{metadata.py => statistics.py} (50%) diff --git a/dask_expr/expr.py b/dask_expr/expr.py index 55e7495fa..c0c77b8f9 100644 --- a/dask_expr/expr.py +++ b/dask_expr/expr.py @@ -59,9 +59,9 @@ def ndim(self): @functools.cached_property def _len(self): - metadata = self.metadata() - if "row_count" in metadata: - return metadata["row_count"].sum() + statistics = self.statistics() + if "row_count" in statistics: + return statistics["row_count"].sum() return Len(self) def __str__(self): @@ -213,38 +213,38 @@ def _layer(self) -> dict: return {(self._name, i): self._task(i) for i in range(self.npartitions)} - def _metadata(self): - """New metadata to add for this expression""" - from dask_expr.metadata import Metadata + def _statistics(self): + """New statistics to add for this expression""" + from dask_expr.statistics import Statistics - # Inherit metadata from dependencies - metadata = {} + # Inherit statistics from dependencies + statistics = {} for dep in self.dependencies(): - for k, v in dep.metadata().items(): - assert isinstance(v, Metadata) - if k not in metadata: + for k, v in dep.statistics().items(): + assert isinstance(v, Statistics) + if k not in statistics: val = v.inherit(self) if val: - metadata[k] = val - return metadata + statistics[k] = val + return statistics - def metadata(self) -> dict: - """Known metadata of an expression, like partition statistics + def statistics(self) -> dict: + """Known statistics of an expression, like partition statistics - To define this on a class create a `._metadata` method that returns a - dictionary of new metadata known by that class. If nothing is known it + To define this on a class create a `._statistics` method that returns a + dictionary of new statistics known by that class. If nothing is known it is ok to return None. Superclasses will also be consulted. Examples -------- - >>> df.metadata() - {'row_count': RowCountMetadata(data=(1000000,))} + >>> df.statistics() + {'row_count': RowCountStatistics(data=(1000000,))} """ out = {} for typ in type(self).mro()[::-1]: if not issubclass(typ, Expr): continue - d = typ._metadata(self) or {} + d = typ._statistics(self) or {} if d: out.update(d) # TODO: this is fragile return out diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index 500c74f10..8c451efc6 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -4,7 +4,7 @@ from dask.dataframe.io.io import sorted_division_locations from dask_expr.expr import Blockwise, Expr, PartitionsFiltered -from dask_expr.metadata import RowCountMetadata +from dask_expr.statistics import RowCountStatistics class IO(Expr): @@ -69,12 +69,12 @@ def _divisions_and_locations(self): divisions = (None,) * len(locations) return divisions, locations - def _metadata(self): + def _statistics(self): locations = self._locations() row_counts = tuple( offset - locations[i] for i, offset in enumerate(locations[1:]) ) - return {"row_count": RowCountMetadata(row_counts)} + return {"row_count": RowCountStatistics(row_counts)} def _divisions(self): return self._divisions_and_locations[0] diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index cccfc9ff2..dd1c18627 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -16,7 +16,7 @@ from dask_expr.expr import EQ, GE, GT, LE, LT, NE, And, Expr, Filter, Or, Projection from dask_expr.io import BlockwiseIO, PartitionsFiltered -from dask_expr.metadata import RowCountMetadata +from dask_expr.statistics import RowCountStatistics NONE_LABEL = "__null_dask_index__" @@ -297,10 +297,10 @@ def _filtered_task(self, index: int): return (operator.getitem, tsk, self.columns[0]) return tsk - def _metadata(self): + def _statistics(self): if self._pq_statistics and not self.filters: row_count = tuple(stat["num-rows"] for stat in self._pq_statistics) - return {"row_count": RowCountMetadata(row_count)} + return {"row_count": RowCountStatistics(row_count)} @property def _pq_statistics(self): diff --git a/dask_expr/metadata.py b/dask_expr/statistics.py similarity index 50% rename from dask_expr/metadata.py rename to dask_expr/statistics.py index dd5f1f2a5..02f00b66d 100644 --- a/dask_expr/metadata.py +++ b/dask_expr/statistics.py @@ -9,54 +9,41 @@ @dataclass(frozen=True) -class Metadata: - """Abstract expression-metadata class +class Statistics: + """Abstract expression-statistics class See Also -------- - StaticMetadata - PartitionMetadata + PartitionStatistics """ data: Any @singledispatchmethod - def inherit(self, child: Expr) -> Metadata | None: - """New `Metadata` object that a "child" Expr mayinherit + def inherit(self, child: Expr) -> Statistics | None: + """New `Statistics` object that a "child" Expr mayinherit A return value of `None` means that `type(Expr)` is - not eligable to inherit this kind of metadata. + not eligable to inherit this kind of statistics. """ return None @dataclass(frozen=True) -class StaticMetadata(Metadata): - """A static metadata object - - This metadata is not partition-specific, and can be - inherited by any child `Expr`. - """ - - def inherit(self, child: Expr) -> StaticMetadata: - return self - - -@dataclass(frozen=True) -class PartitionMetadata(Metadata): - """Metadata containing a distinct value for every partition +class PartitionStatistics(Statistics): + """Statistics containing a distinct value for every partition See Also -------- - RowCountMetadata + RowCountStatistics """ data: Iterable -@PartitionMetadata.inherit.register -def _partitionmetadata_partitions(self, child: Partitions): - # A `Partitions` expression may inherit metadata +@PartitionStatistics.inherit.register +def _partitionstatistics_partitions(self, child: Partitions): + # A `Partitions` expression may inherit statistics # from the selected partitions return type(self)( type(self.data)( @@ -66,12 +53,12 @@ def _partitionmetadata_partitions(self, child: Partitions): # -# PartitionMetadata sub-classes +# PartitionStatistics sub-classes # @dataclass(frozen=True) -class RowCountMetadata(PartitionMetadata): +class RowCountStatistics(PartitionStatistics): """Tracks the row count of each partition""" def sum(self): @@ -79,8 +66,8 @@ def sum(self): return sum(self.data) -@RowCountMetadata.inherit.register +@RowCountStatistics.inherit.register def _rowcount_elemwise(self, child: Elemwise): # All Element-wise operations may inherit - # row-count metadata "as is" + # row-count statistics "as is" return self diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index cae821b7b..b77a431da 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -434,10 +434,10 @@ def test_repartition_divisions(df, opt): def test_statistics(df, pdf): df2 = df[["x"]] + 1 assert len(df2) == len(pdf) - assert df2.metadata().get("row_count").sum() == len(pdf) - assert df[df.x > 5].metadata().get("row_count") is None + assert df2.statistics().get("row_count").sum() == len(pdf) + assert df[df.x > 5].statistics().get("row_count") is None # Check `partitions` first = df2.partitions[0].compute() assert len(df2.partitions[0]) == len(first) - assert df2.partitions[0].metadata().get("row_count").sum() == len(first) + assert df2.partitions[0].statistics().get("row_count").sum() == len(first) From a3c5f2c353b8b424ec584a13ec8bd44f15c07802 Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Tue, 16 May 2023 15:09:16 -0500 Subject: [PATCH 06/31] add parquet test --- dask_expr/io/tests/test_io.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/dask_expr/io/tests/test_io.py b/dask_expr/io/tests/test_io.py index 60e85394f..0f149750d 100644 --- a/dask_expr/io/tests/test_io.py +++ b/dask_expr/io/tests/test_io.py @@ -204,3 +204,14 @@ def test_parquet_complex_filters(tmpdir): assert_eq(got, expect) assert_eq(got.optimize(), expect) + + +def test_parquet_row_count_statistics(tmpdir): + # NOTE: We should no longer need to set `index` + # or `calculate_divisions` to gather row-count + # statistics after dask#10290 + df = read_parquet(_make_file(tmpdir), index="a", calculate_divisions=True) + pdf = df.compute() + + s = (df["b"] + 1).astype("Int32") + assert s.statistics().get("row_count").sum() == len(pdf) From cbced80d005bd5e219e93389625a35123043a706 Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Tue, 16 May 2023 15:21:22 -0500 Subject: [PATCH 07/31] use assume vs inherit --- dask_expr/statistics.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/dask_expr/statistics.py b/dask_expr/statistics.py index 02f00b66d..0116ba5ff 100644 --- a/dask_expr/statistics.py +++ b/dask_expr/statistics.py @@ -20,11 +20,11 @@ class Statistics: data: Any @singledispatchmethod - def inherit(self, child: Expr) -> Statistics | None: - """New `Statistics` object that a "child" Expr mayinherit + def assume(self, parent: Expr) -> Statistics | None: + """Statistics that a "parent" Expr may assume A return value of `None` means that `type(Expr)` is - not eligable to inherit this kind of statistics. + not eligable to assume these kind of statistics. """ return None @@ -41,13 +41,13 @@ class PartitionStatistics(Statistics): data: Iterable -@PartitionStatistics.inherit.register -def _partitionstatistics_partitions(self, child: Partitions): - # A `Partitions` expression may inherit statistics +@PartitionStatistics.assume.register +def _partitionstatistics_partitions(self, parent: Partitions): + # A `Partitions` expression may assume statistics # from the selected partitions return type(self)( type(self.data)( - part for i, part in enumerate(self.data) if i in child.partitions + part for i, part in enumerate(self.data) if i in parent.partitions ) ) @@ -66,8 +66,8 @@ def sum(self): return sum(self.data) -@RowCountStatistics.inherit.register -def _rowcount_elemwise(self, child: Elemwise): - # All Element-wise operations may inherit - # row-count statistics "as is" +@RowCountStatistics.assume.register +def _rowcount_elemwise(self, parent: Elemwise): + # All Element-wise operations may assume + # row-count statistics return self From 5fe58629feb93aae8e39ba5f59946a9abf5d3174 Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Tue, 16 May 2023 15:21:57 -0500 Subject: [PATCH 08/31] use assume vs inherit --- dask_expr/expr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dask_expr/expr.py b/dask_expr/expr.py index c0c77b8f9..46851d188 100644 --- a/dask_expr/expr.py +++ b/dask_expr/expr.py @@ -223,7 +223,7 @@ def _statistics(self): for k, v in dep.statistics().items(): assert isinstance(v, Statistics) if k not in statistics: - val = v.inherit(self) + val = v.assume(self) if val: statistics[k] = val return statistics From b0946f8432ea327182218f10b9c482f85dce7e18 Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Tue, 16 May 2023 15:26:22 -0500 Subject: [PATCH 09/31] split test --- dask_expr/tests/test_collection.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index b77a431da..5c1ceb67e 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -431,13 +431,18 @@ def test_repartition_divisions(df, opt): assert part.max() < df2.divisions[p + 1] -def test_statistics(df, pdf): +def test_len(df, pdf): df2 = df[["x"]] + 1 assert len(df2) == len(pdf) - assert df2.statistics().get("row_count").sum() == len(pdf) - assert df[df.x > 5].statistics().get("row_count") is None - # Check `partitions` first = df2.partitions[0].compute() assert len(df2.partitions[0]) == len(first) - assert df2.partitions[0].statistics().get("row_count").sum() == len(first) + + +def test_row_count_statistics(df, pdf): + df2 = df[["x"]] + 1 + assert df2.statistics().get("row_count").sum() == len(pdf) + assert df[df.x > 5].statistics().get("row_count") is None + assert df2.partitions[0].statistics().get("row_count").sum() == len( + df2.partitions[0] + ) From bfd87105ede7b3bece84c21aae1e475c118ce8b8 Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Tue, 16 May 2023 15:36:14 -0500 Subject: [PATCH 10/31] fix doc-string --- dask_expr/statistics.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dask_expr/statistics.py b/dask_expr/statistics.py index 0116ba5ff..3fbba5b0e 100644 --- a/dask_expr/statistics.py +++ b/dask_expr/statistics.py @@ -10,7 +10,7 @@ @dataclass(frozen=True) class Statistics: - """Abstract expression-statistics class + """Abstract class for expression statistics See Also -------- @@ -23,8 +23,9 @@ class Statistics: def assume(self, parent: Expr) -> Statistics | None: """Statistics that a "parent" Expr may assume - A return value of `None` means that `type(Expr)` is - not eligable to assume these kind of statistics. + A return value of `None` (the default) means that + `type(Expr)` is not eligable to assume these kind + of statistics. """ return None From 2d343c71334fef84aaa381b2c45441c92f02acc1 Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Tue, 16 May 2023 15:50:21 -0500 Subject: [PATCH 11/31] fix typos --- dask_expr/expr.py | 2 +- dask_expr/statistics.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dask_expr/expr.py b/dask_expr/expr.py index 46851d188..7ea7d9268 100644 --- a/dask_expr/expr.py +++ b/dask_expr/expr.py @@ -217,7 +217,7 @@ def _statistics(self): """New statistics to add for this expression""" from dask_expr.statistics import Statistics - # Inherit statistics from dependencies + # Assume statistics from dependencies statistics = {} for dep in self.dependencies(): for k, v in dep.statistics().items(): diff --git a/dask_expr/statistics.py b/dask_expr/statistics.py index 3fbba5b0e..ae631fc6c 100644 --- a/dask_expr/statistics.py +++ b/dask_expr/statistics.py @@ -24,8 +24,8 @@ def assume(self, parent: Expr) -> Statistics | None: """Statistics that a "parent" Expr may assume A return value of `None` (the default) means that - `type(Expr)` is not eligable to assume these kind - of statistics. + `parent` is not eligable to assume this kind of + statistics. """ return None From 4ce604d639e2db827aa6e50a1cba8900202d304c Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Wed, 17 May 2023 18:03:50 -0500 Subject: [PATCH 12/31] use _lengths ILO statistics --- dask_expr/collection.py | 9 ++-- dask_expr/expr.py | 86 ++++++++++-------------------- dask_expr/io/io.py | 11 ++-- dask_expr/io/parquet.py | 12 +++-- dask_expr/io/tests/test_io.py | 4 +- dask_expr/reductions.py | 6 ++- dask_expr/statistics.py | 74 ------------------------- dask_expr/tests/test_collection.py | 10 ++-- 8 files changed, 58 insertions(+), 154 deletions(-) delete mode 100644 dask_expr/statistics.py diff --git a/dask_expr/collection.py b/dask_expr/collection.py index 0d757697d..b90888c3d 100644 --- a/dask_expr/collection.py +++ b/dask_expr/collection.py @@ -74,10 +74,11 @@ def size(self): return new_collection(self.expr.size) def __len__(self): - _len = self.expr._len - if isinstance(_len, expr.Expr): - _len = new_collection(_len).compute() - return _len + return self._len + + @functools.cached_property + def _len(self): + return new_collection(expr.Len(self.expr)).compute() def __reduce__(self): return new_collection, (self._expr,) diff --git a/dask_expr/expr.py b/dask_expr/expr.py index 3fd5a891e..470d663dc 100644 --- a/dask_expr/expr.py +++ b/dask_expr/expr.py @@ -19,6 +19,7 @@ is_dataframe_like, is_index_like, is_series_like, + make_meta, ) from dask.utils import M, apply, funcname, import_required, is_arraylike @@ -38,6 +39,7 @@ class Expr: associative = False _parameters = [] _defaults = {} + _lengths = None def __init__(self, *args, **kwargs): operands = list(args) @@ -57,13 +59,6 @@ def ndim(self): except AttributeError: return 0 - @functools.cached_property - def _len(self): - statistics = self.statistics() - if "row_count" in statistics: - return statistics["row_count"].sum() - return Len(self) - def __str__(self): s = ", ".join( str(param) + "=" + str(operand) @@ -213,42 +208,6 @@ def _layer(self) -> dict: return {(self._name, i): self._task(i) for i in range(self.npartitions)} - def _statistics(self): - """New statistics to add for this expression""" - from dask_expr.statistics import Statistics - - # Assume statistics from dependencies - statistics = {} - for dep in self.dependencies(): - for k, v in dep.statistics().items(): - assert isinstance(v, Statistics) - if k not in statistics: - val = v.assume(self) - if val: - statistics[k] = val - return statistics - - def statistics(self) -> dict: - """Known statistics of an expression, like partition statistics - - To define this on a class create a `._statistics` method that returns a - dictionary of new statistics known by that class. If nothing is known it - is ok to return None. Superclasses will also be consulted. - - Examples - -------- - >>> df.statistics() - {'row_count': RowCountStatistics(data=(1000000,))} - """ - out = {} - for typ in type(self).mro()[::-1]: - if not issubclass(typ, Expr): - continue - d = typ._statistics(self) or {} - if d: - out.update(d) # TODO: this is fragile - return out - def simplify(self): """Simplify expression @@ -634,6 +593,23 @@ def visualize(self, filename="dask-expr.svg", format=None, **kwargs): return g +class Literal(Expr): + """Represent a literal (known) value as an `Expr`""" + + _parameters = ["value"] + + def _divisions(self): + return (None, None) + + @property + def _meta(self): + return make_meta(self.value) + + def _task(self, index: int): + assert index == 0 + return self.value + + class Blockwise(Expr): """Super-class for block-wise operations @@ -782,7 +758,9 @@ class Elemwise(Blockwise): optimizations, like `len` will care about which operations preserve length """ - pass + @property + def _lengths(self): + return self.dependencies()[0]._lengths class AsType(Elemwise): @@ -1093,6 +1071,12 @@ def _simplify_down(self): def _node_label_args(self): return [self.frame, self.partitions] + @property + def _lengths(self): + lengths = self.frame._lengths + if lengths: + return tuple(lengths[i] for i in self.partitions) + class PartitionsFiltered(Expr): """Mixin class for partition filtering @@ -1386,16 +1370,4 @@ def _execute_task(graph, name, *deps): from dask_expr.io import BlockwiseIO -from dask_expr.reductions import ( - All, - Any, - Count, - Len, - Max, - Mean, - Min, - Mode, - Prod, - Size, - Sum, -) +from dask_expr.reductions import All, Any, Count, Max, Mean, Min, Mode, Prod, Size, Sum diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index 8c451efc6..545829eee 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -4,7 +4,6 @@ from dask.dataframe.io.io import sorted_division_locations from dask_expr.expr import Blockwise, Expr, PartitionsFiltered -from dask_expr.statistics import RowCountStatistics class IO(Expr): @@ -69,12 +68,14 @@ def _divisions_and_locations(self): divisions = (None,) * len(locations) return divisions, locations - def _statistics(self): + @functools.cached_property + def _lengths(self): locations = self._locations() - row_counts = tuple( - offset - locations[i] for i, offset in enumerate(locations[1:]) + return tuple( + offset - locations[i] + for i, offset in enumerate(locations[1:]) + if not self._filtered or i in self._partitions ) - return {"row_count": RowCountStatistics(row_counts)} def _divisions(self): return self._divisions_and_locations[0] diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index dd1c18627..b6b044f91 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -16,7 +16,6 @@ from dask_expr.expr import EQ, GE, GT, LE, LT, NE, And, Expr, Filter, Or, Projection from dask_expr.io import BlockwiseIO, PartitionsFiltered -from dask_expr.statistics import RowCountStatistics NONE_LABEL = "__null_dask_index__" @@ -297,10 +296,15 @@ def _filtered_task(self, index: int): return (operator.getitem, tsk, self.columns[0]) return tsk - def _statistics(self): + @cached_property + def _lengths(self): if self._pq_statistics and not self.filters: - row_count = tuple(stat["num-rows"] for stat in self._pq_statistics) - return {"row_count": RowCountStatistics(row_count)} + row_count = tuple( + stat["num-rows"] + for i, stat in enumerate(self._pq_statistics) + if not self._filtered or i in self._partitions + ) + return row_count @property def _pq_statistics(self): diff --git a/dask_expr/io/tests/test_io.py b/dask_expr/io/tests/test_io.py index 0f149750d..389ab86d0 100644 --- a/dask_expr/io/tests/test_io.py +++ b/dask_expr/io/tests/test_io.py @@ -206,7 +206,7 @@ def test_parquet_complex_filters(tmpdir): assert_eq(got.optimize(), expect) -def test_parquet_row_count_statistics(tmpdir): +def test_parquet_lengths(tmpdir): # NOTE: We should no longer need to set `index` # or `calculate_divisions` to gather row-count # statistics after dask#10290 @@ -214,4 +214,4 @@ def test_parquet_row_count_statistics(tmpdir): pdf = df.compute() s = (df["b"] + 1).astype("Int32") - assert s.statistics().get("row_count").sum() == len(pdf) + assert sum(s._lengths) == len(pdf) diff --git a/dask_expr/reductions.py b/dask_expr/reductions.py index 3300f2ec0..0f7a27965 100644 --- a/dask_expr/reductions.py +++ b/dask_expr/reductions.py @@ -9,7 +9,7 @@ ) from dask.utils import M, apply -from dask_expr.expr import Elemwise, Expr, Projection +from dask_expr.expr import Elemwise, Expr, Literal, Projection class ApplyConcatApply(Expr): @@ -244,7 +244,9 @@ class Len(Reduction): reduction_aggregate = sum def _simplify_down(self): - if isinstance(self.frame, Elemwise): + if self.frame._lengths: + return Literal(sum(self.frame._lengths)) + elif isinstance(self.frame, Elemwise): child = max(self.frame.dependencies(), key=lambda expr: expr.npartitions) return Len(child) diff --git a/dask_expr/statistics.py b/dask_expr/statistics.py deleted file mode 100644 index ae631fc6c..000000000 --- a/dask_expr/statistics.py +++ /dev/null @@ -1,74 +0,0 @@ -from __future__ import annotations - -from collections.abc import Iterable -from dataclasses import dataclass -from functools import singledispatchmethod -from typing import Any - -from dask_expr.expr import Elemwise, Expr, Partitions - - -@dataclass(frozen=True) -class Statistics: - """Abstract class for expression statistics - - See Also - -------- - PartitionStatistics - """ - - data: Any - - @singledispatchmethod - def assume(self, parent: Expr) -> Statistics | None: - """Statistics that a "parent" Expr may assume - - A return value of `None` (the default) means that - `parent` is not eligable to assume this kind of - statistics. - """ - return None - - -@dataclass(frozen=True) -class PartitionStatistics(Statistics): - """Statistics containing a distinct value for every partition - - See Also - -------- - RowCountStatistics - """ - - data: Iterable - - -@PartitionStatistics.assume.register -def _partitionstatistics_partitions(self, parent: Partitions): - # A `Partitions` expression may assume statistics - # from the selected partitions - return type(self)( - type(self.data)( - part for i, part in enumerate(self.data) if i in parent.partitions - ) - ) - - -# -# PartitionStatistics sub-classes -# - - -@dataclass(frozen=True) -class RowCountStatistics(PartitionStatistics): - """Tracks the row count of each partition""" - - def sum(self): - """Return the total row-count of all partitions""" - return sum(self.data) - - -@RowCountStatistics.assume.register -def _rowcount_elemwise(self, parent: Elemwise): - # All Element-wise operations may assume - # row-count statistics - return self diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index 5ccb88bf2..f59a57d2d 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -441,10 +441,8 @@ def test_len(df, pdf): assert len(df2.partitions[0]) == len(first) -def test_row_count_statistics(df, pdf): +def test_lengths(df, pdf): df2 = df[["x"]] + 1 - assert df2.statistics().get("row_count").sum() == len(pdf) - assert df[df.x > 5].statistics().get("row_count") is None - assert df2.partitions[0].statistics().get("row_count").sum() == len( - df2.partitions[0] - ) + assert sum(df2._lengths) == len(pdf) + assert df[df.x > 5]._lengths is None + assert sum(df2.partitions[0]._lengths) == len(df2.partitions[0]) From 4ad6fb2e57eca15f534b5a37680cb6a76b91bd78 Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Thu, 18 May 2023 07:42:58 -0500 Subject: [PATCH 13/31] start pushing on _column_statistics --- dask_expr/collection.py | 3 ++- dask_expr/expr.py | 3 +++ dask_expr/io/io.py | 17 +++++++++++++++++ 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/dask_expr/collection.py b/dask_expr/collection.py index b90888c3d..c740f8293 100644 --- a/dask_expr/collection.py +++ b/dask_expr/collection.py @@ -17,6 +17,7 @@ from dask_expr import expr from dask_expr.expr import no_default from dask_expr.merge import Merge +from dask_expr.reductions import Len from dask_expr.repartition import Repartition # @@ -78,7 +79,7 @@ def __len__(self): @functools.cached_property def _len(self): - return new_collection(expr.Len(self.expr)).compute() + return new_collection(Len(self.expr)).compute() def __reduce__(self): return new_collection, (self._expr,) diff --git a/dask_expr/expr.py b/dask_expr/expr.py index 470d663dc..13f060995 100644 --- a/dask_expr/expr.py +++ b/dask_expr/expr.py @@ -59,6 +59,9 @@ def ndim(self): except AttributeError: return 0 + def _column_statistics(self, columns: list | None = None): + return None + def __str__(self): s = ", ".join( str(param) + "=" + str(operand) diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index 545829eee..1c676d00b 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -77,6 +77,23 @@ def _lengths(self): if not self._filtered or i in self._partitions ) + def _column_statistics(self, columns: list | None = None): + columns = columns or self.columns + maxes = [] + mins = [] + nulls = [] + for i in range(self.npartitions): + df = self._task(i)[columns] + maxes.append(df.max().to_dict()) + mins.append(df.min().to_dict()) + nulls.append(df.isna().sum().to_dict()) + if maxes: + maxes = type(df)(maxes) + mins = type(df)(mins) + nulls = type(df)(nulls) + # TODO: return something like: + # pd.concat([maxes, mins, nulls], axis=1, keys=["max", "min", "null"]) + def _divisions(self): return self._divisions_and_locations[0] From d5e93a4b70b919f2f6146043969611095a599d2e Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Thu, 18 May 2023 14:08:48 -0500 Subject: [PATCH 14/31] add _collect_statistics machinery to ReadParquet --- dask_expr/collection.py | 4 +- dask_expr/io/io.py | 17 -- dask_expr/io/parquet.py | 339 +++++++++++++++++++++++++++++++--------- 3 files changed, 264 insertions(+), 96 deletions(-) diff --git a/dask_expr/collection.py b/dask_expr/collection.py index c740f8293..57fddbcd0 100644 --- a/dask_expr/collection.py +++ b/dask_expr/collection.py @@ -556,7 +556,7 @@ def read_parquet( index=None, storage_options=None, dtype_backend=None, - gather_statistics=True, + calculate_divisions=True, ignore_metadata_file=False, metadata_task_size=None, split_row_groups="infer", @@ -581,7 +581,7 @@ def read_parquet( categories=categories, index=index, storage_options=storage_options, - gather_statistics=gather_statistics, + calculate_divisions=calculate_divisions, ignore_metadata_file=ignore_metadata_file, metadata_task_size=metadata_task_size, split_row_groups=split_row_groups, diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index 1c676d00b..545829eee 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -77,23 +77,6 @@ def _lengths(self): if not self._filtered or i in self._partitions ) - def _column_statistics(self, columns: list | None = None): - columns = columns or self.columns - maxes = [] - mins = [] - nulls = [] - for i in range(self.npartitions): - df = self._task(i)[columns] - maxes.append(df.max().to_dict()) - mins.append(df.min().to_dict()) - nulls.append(df.isna().sum().to_dict()) - if maxes: - maxes = type(df)(maxes) - mins = type(df)(mins) - nulls = type(df)(nulls) - # TODO: return something like: - # pd.concat([maxes, mins, nulls], axis=1, keys=["max", "min", "null"]) - def _divisions(self): return self._divisions_and_locations[0] diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index b6b044f91..5fbb4ca3e 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -2,8 +2,11 @@ import itertools import operator +from collections import defaultdict from functools import cached_property +import dask +import pyarrow.parquet as pq from dask.dataframe.io.parquet.core import ( ParquetFunctionWrapper, aggregate_row_groups, @@ -12,6 +15,8 @@ sorted_columns, ) from dask.dataframe.io.parquet.utils import _split_user_options +from dask.dataframe.io.utils import _is_local_fs +from dask.delayed import delayed from dask.utils import natural_sort_key from dask_expr.expr import EQ, GE, GT, LE, LT, NE, And, Expr, Filter, Or, Projection @@ -20,69 +25,6 @@ NONE_LABEL = "__null_dask_index__" -def _list_columns(columns): - # Simple utility to convert columns to list - if isinstance(columns, (str, int)): - columns = [columns] - elif isinstance(columns, tuple): - columns = list(columns) - return columns - - -def _align_statistics(parts, statistics): - # Make sure parts and statistics are aligned - # (if statistics is not empty) - if statistics and len(parts) != len(statistics): - statistics = [] - if statistics: - result = list( - zip( - *[ - (part, stats) - for part, stats in zip(parts, statistics) - if stats["num-rows"] > 0 - ] - ) - ) - parts, statistics = result or [[], []] - return parts, statistics - - -def _aggregate_row_groups(parts, statistics, dataset_info): - # Aggregate parts/statistics if we are splitting by row-group - blocksize = ( - dataset_info["blocksize"] if dataset_info["split_row_groups"] is True else None - ) - split_row_groups = dataset_info["split_row_groups"] - fs = dataset_info["fs"] - aggregation_depth = dataset_info["aggregation_depth"] - - if statistics: - if blocksize or (split_row_groups and int(split_row_groups) > 1): - parts, statistics = aggregate_row_groups( - parts, statistics, blocksize, split_row_groups, fs, aggregation_depth - ) - return parts, statistics - - -def _calculate_divisions(statistics, dataset_info, npartitions): - # Use statistics to define divisions - divisions = None - if statistics: - calculate_divisions = dataset_info["kwargs"].get("calculate_divisions", None) - index = dataset_info["index"] - process_columns = index if index and len(index) == 1 else None - if (calculate_divisions is not False) and process_columns: - for sorted_column_info in sorted_columns( - statistics, columns=process_columns - ): - if sorted_column_info["name"] in index: - divisions = sorted_column_info["divisions"] - break - - return divisions or (None,) * (npartitions + 1) - - class ReadParquet(PartitionsFiltered, BlockwiseIO): """Read a parquet dataset""" @@ -93,7 +35,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO): "categories", "index", "storage_options", - "gather_statistics", + "calculate_divisions", "ignore_metadata_file", "metadata_task_size", "split_row_groups", @@ -111,7 +53,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO): "categories": None, "index": None, "storage_options": None, - "gather_statistics": True, + "calculate_divisions": True, "ignore_metadata_file": False, "metadata_task_size": None, "split_row_groups": "infer", @@ -123,6 +65,9 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO): "_partitions": None, "_series": False, } + _length_pq_stats = None + _min_pq_stats = None + _max_pq_stats = None @property def engine(self): @@ -207,7 +152,7 @@ def _dataset_info(self): fs, self.categories, index, - self.gather_statistics, + self.calculate_divisions, self.filters, self.split_row_groups, blocksize, @@ -298,24 +243,206 @@ def _filtered_task(self, index: int): @cached_property def _lengths(self): - if self._pq_statistics and not self.filters: - row_count = tuple( - stat["num-rows"] - for i, stat in enumerate(self._pq_statistics) + if not self.filters: + self._update_length_statistics() + return self._length_pq_stats + + def _update_length_statistics(self): + """Ensure that partition-length statistics are up to date""" + + if not self._length_pq_stats: + if self._plan["statistics"]: + # Already have statistics from original API call + self._length_pq_stats = tuple( + stat["num-rows"] + for i, stat in enumerate(self._plan["statistics"]) + if not self._filtered or i in self._partitions + ) + else: + # Need to go back and collect statistics + self._length_pq_stats = tuple( + stat["num-rows"] for stat in self._collect_statistics() + ) + + def _update_column_statistics(self, columns: list | None = None): + """Ensure that min/max column statistics are up to date""" + + def _record_column_statistics(all_stats, columns: list | None = None): + # Helper function to translate List[Dict] to Tuple or Dict + lengths = [] + columns = columns or list(self.columns) + column_mins = {} + column_maxes = {} + for stats in all_stats: + lengths.append(stats["num-rows"]) + for col_stats in stats.get("columns", []): + name = col_stats.get("name") + if name in columns: + # Min + if name not in column_mins: + column_mins[name] = [] + column_mins[name].append(col_stats.get("min")) + # Max + if name not in column_maxes: + column_maxes[name] = [] + column_maxes[name].append(col_stats.get("max")) + return lengths, column_mins, column_maxes + + # First, try to use the statistics we already have + self._min_pq_stats = self._min_pq_stats or {} + self._max_pq_stats = self._max_pq_stats or {} + if not self._min_pq_stats and self._plan["statistics"]: + stats = [ + stat + for i, stat in enumerate(self._plan["statistics"]) if not self._filtered or i in self._partitions - ) - return row_count - - @property - def _pq_statistics(self): - return self._plan["statistics"] + ] + lengths, column_mins, column_maxes = _record_column_statistics(stats) + if not self._length_pq_stats: + self._length_pq_stats = lengths + self._min_pq_stats = column_mins + self._max_pq_stats = column_maxes + + # Find which column statistics are missing, and collect them + columns = [ + col + for col in (columns or list(self.columns)) + if col not in self._min_pq_stats + ] + if columns: + ( + lengths, + column_mins, + column_maxes, + ) = _record_column_statistics(self._collect_statistics(columns), columns) + self._length_pq_stats = lengths + self._min_pq_stats.update(column_mins) + self._max_pq_stats.update(column_maxes) + + def _collect_statistics(self, columns: list | None = None) -> list[dict] | None: + """Collect Parquet statistic for dataset paths""" + + # Be strict about columns argument + if columns: + if not isinstance(columns, list): + raise ValueError(f"Expected columns to be a list, got {type(columns)}.") + elif not set(columns).issubset(set(self.columns)): + raise ValueError( + f"columns={columns} must be a subset of {self.columns}" + ) + + # Collect statistics using layer information + fs = self._plan["func"].fs + parts = [ + part + for i, part in enumerate(self._plan["parts"]) + if not self._filtered or i in self._partitions + ] + + # Execute with delayed for large and remote datasets + parallel = int(False if _is_local_fs(fs) else 16) + if parallel: + # Group parts corresponding to the same file. + # A single task should always parse statistics + # for all these parts at once (since they will + # all be in the same footer) + groups = defaultdict(list) + for part in parts: + for p in [part] if isinstance(part, dict) else part: + path = p.get("piece")[0] + groups[path].append(p) + group_keys = list(groups.keys()) + + # Compute and return flattened result + func = delayed(_read_partition_stats_group) + result = dask.compute( + [ + func( + list( + itertools.chain( + *[groups[k] for k in group_keys[i : i + parallel]] + ) + ), + fs, + columns=columns, + ) + for i in range(0, len(group_keys), parallel) + ] + )[0] + return list(itertools.chain(*result)) + else: + # Serial computation on client + return _read_partition_stats_group(parts, fs, columns=columns) # -# Filters +# Helper utilities # +def _list_columns(columns): + # Simple utility to convert columns to list + if isinstance(columns, (str, int)): + columns = [columns] + elif isinstance(columns, tuple): + columns = list(columns) + return columns + + +def _align_statistics(parts, statistics): + # Make sure parts and statistics are aligned + # (if statistics is not empty) + if statistics and len(parts) != len(statistics): + statistics = [] + if statistics: + result = list( + zip( + *[ + (part, stats) + for part, stats in zip(parts, statistics) + if stats["num-rows"] > 0 + ] + ) + ) + parts, statistics = result or [[], []] + return parts, statistics + + +def _aggregate_row_groups(parts, statistics, dataset_info): + # Aggregate parts/statistics if we are splitting by row-group + blocksize = ( + dataset_info["blocksize"] if dataset_info["split_row_groups"] is True else None + ) + split_row_groups = dataset_info["split_row_groups"] + fs = dataset_info["fs"] + aggregation_depth = dataset_info["aggregation_depth"] + + if statistics: + if blocksize or (split_row_groups and int(split_row_groups) > 1): + parts, statistics = aggregate_row_groups( + parts, statistics, blocksize, split_row_groups, fs, aggregation_depth + ) + return parts, statistics + + +def _calculate_divisions(statistics, dataset_info, npartitions): + # Use statistics to define divisions + divisions = None + if statistics: + calculate_divisions = dataset_info["kwargs"].get("calculate_divisions", None) + index = dataset_info["index"] + process_columns = index if index and len(index) == 1 else None + if (calculate_divisions is not False) and process_columns: + for sorted_column_info in sorted_columns( + statistics, columns=process_columns + ): + if sorted_column_info["name"] in index: + divisions = sorted_column_info["divisions"] + break + + return divisions or (None,) * (npartitions + 1) + + class _DNF: """Manage filters in Disjunctive Normal Form (DNF)""" @@ -429,3 +556,61 @@ def extract_pq_filters(cls, pq_expr: ReadParquet, predicate_expr: Expr) -> _DNF: _filters = cls._Or([left, right]) return _DNF(_filters) + + +def _read_partition_stats_group(parts, fs, columns=None): + """Parse the statistics for a group of files""" + + def _read_partition_stats(part, fs, columns=None): + # Helper function to read Parquet-metadata + # statistics for a single partition + + if not isinstance(part, list): + part = [part] + + column_stats = {} + num_rows = 0 + columns = columns or [] + for p in part: + piece = p["piece"] + path = piece[0] + row_groups = None if piece[1] == [None] else piece[1] + with fs.open(path, default_cache="none") as f: + md = pq.ParquetFile(f).metadata + if row_groups is None: + row_groups = list(range(md.num_row_groups)) + for rg in row_groups: + row_group = md.row_group(rg) + num_rows += row_group.num_rows + for i in range(row_group.num_columns): + col = row_group.column(i) + name = col.path_in_schema + if name in columns: + if col.statistics and col.statistics.has_min_max: + if name in column_stats: + column_stats[name]["min"] = min( + column_stats[name]["min"], col.statistics.min + ) + column_stats[name]["max"] = max( + column_stats[name]["max"], col.statistics.max + ) + else: + column_stats[name] = { + "min": col.statistics.min, + "max": col.statistics.max, + } + + # Convert dict-of-dict to list-of-dict to be consistent + # with current `dd.read_parquet` convention (for now) + column_stats_list = [ + { + "name": name, + "min": column_stats[name]["min"], + "max": column_stats[name]["max"], + } + for name in column_stats.keys() + ] + return {"num-rows": num_rows, "columns": column_stats_list} + + # Helper function used by _extract_statistics + return [_read_partition_stats(part, fs, columns=columns) for part in parts] From 7b137c5f9c4c138ccc58f361bff1f29e33cee98a Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Thu, 18 May 2023 14:22:21 -0500 Subject: [PATCH 15/31] move utilities out of class body --- dask_expr/io/parquet.py | 269 +++++++++++++++++++++------------------- 1 file changed, 140 insertions(+), 129 deletions(-) diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index 5fbb4ca3e..89d54b70a 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -244,139 +244,12 @@ def _filtered_task(self, index: int): @cached_property def _lengths(self): if not self.filters: - self._update_length_statistics() + _update_length_statistics(self) return self._length_pq_stats - def _update_length_statistics(self): - """Ensure that partition-length statistics are up to date""" - - if not self._length_pq_stats: - if self._plan["statistics"]: - # Already have statistics from original API call - self._length_pq_stats = tuple( - stat["num-rows"] - for i, stat in enumerate(self._plan["statistics"]) - if not self._filtered or i in self._partitions - ) - else: - # Need to go back and collect statistics - self._length_pq_stats = tuple( - stat["num-rows"] for stat in self._collect_statistics() - ) - - def _update_column_statistics(self, columns: list | None = None): - """Ensure that min/max column statistics are up to date""" - - def _record_column_statistics(all_stats, columns: list | None = None): - # Helper function to translate List[Dict] to Tuple or Dict - lengths = [] - columns = columns or list(self.columns) - column_mins = {} - column_maxes = {} - for stats in all_stats: - lengths.append(stats["num-rows"]) - for col_stats in stats.get("columns", []): - name = col_stats.get("name") - if name in columns: - # Min - if name not in column_mins: - column_mins[name] = [] - column_mins[name].append(col_stats.get("min")) - # Max - if name not in column_maxes: - column_maxes[name] = [] - column_maxes[name].append(col_stats.get("max")) - return lengths, column_mins, column_maxes - - # First, try to use the statistics we already have - self._min_pq_stats = self._min_pq_stats or {} - self._max_pq_stats = self._max_pq_stats or {} - if not self._min_pq_stats and self._plan["statistics"]: - stats = [ - stat - for i, stat in enumerate(self._plan["statistics"]) - if not self._filtered or i in self._partitions - ] - lengths, column_mins, column_maxes = _record_column_statistics(stats) - if not self._length_pq_stats: - self._length_pq_stats = lengths - self._min_pq_stats = column_mins - self._max_pq_stats = column_maxes - - # Find which column statistics are missing, and collect them - columns = [ - col - for col in (columns or list(self.columns)) - if col not in self._min_pq_stats - ] - if columns: - ( - lengths, - column_mins, - column_maxes, - ) = _record_column_statistics(self._collect_statistics(columns), columns) - self._length_pq_stats = lengths - self._min_pq_stats.update(column_mins) - self._max_pq_stats.update(column_maxes) - - def _collect_statistics(self, columns: list | None = None) -> list[dict] | None: - """Collect Parquet statistic for dataset paths""" - - # Be strict about columns argument - if columns: - if not isinstance(columns, list): - raise ValueError(f"Expected columns to be a list, got {type(columns)}.") - elif not set(columns).issubset(set(self.columns)): - raise ValueError( - f"columns={columns} must be a subset of {self.columns}" - ) - - # Collect statistics using layer information - fs = self._plan["func"].fs - parts = [ - part - for i, part in enumerate(self._plan["parts"]) - if not self._filtered or i in self._partitions - ] - - # Execute with delayed for large and remote datasets - parallel = int(False if _is_local_fs(fs) else 16) - if parallel: - # Group parts corresponding to the same file. - # A single task should always parse statistics - # for all these parts at once (since they will - # all be in the same footer) - groups = defaultdict(list) - for part in parts: - for p in [part] if isinstance(part, dict) else part: - path = p.get("piece")[0] - groups[path].append(p) - group_keys = list(groups.keys()) - - # Compute and return flattened result - func = delayed(_read_partition_stats_group) - result = dask.compute( - [ - func( - list( - itertools.chain( - *[groups[k] for k in group_keys[i : i + parallel]] - ) - ), - fs, - columns=columns, - ) - for i in range(0, len(group_keys), parallel) - ] - )[0] - return list(itertools.chain(*result)) - else: - # Serial computation on client - return _read_partition_stats_group(parts, fs, columns=columns) - # -# Helper utilities +# Helper functions # @@ -443,6 +316,11 @@ def _calculate_divisions(statistics, dataset_info, npartitions): return divisions or (None,) * (npartitions + 1) +# +# Filtering logic +# + + class _DNF: """Manage filters in Disjunctive Normal Form (DNF)""" @@ -558,6 +436,139 @@ def extract_pq_filters(cls, pq_expr: ReadParquet, predicate_expr: Expr) -> _DNF: return _DNF(_filters) +# +# Parquet-statistics handling +# + + +def _update_length_statistics(expr: ReadParquet): + """Ensure that partition-length statistics are up to date""" + + if not expr._length_pq_stats: + if expr._plan["statistics"]: + # Already have statistics from original API call + expr._length_pq_stats = tuple( + stat["num-rows"] + for i, stat in enumerate(expr._plan["statistics"]) + if not expr._filtered or i in expr._partitions + ) + else: + # Need to go back and collect statistics + expr._length_pq_stats = tuple( + stat["num-rows"] for stat in expr._collect_statistics() + ) + + +def _update_column_statistics(expr: ReadParquet, columns: list | None = None): + """Ensure that min/max column statistics are up to date""" + + def _record_column_statistics(all_stats, columns: list | None = None): + # Helper function to translate List[Dict] to Tuple or Dict + lengths = [] + columns = columns or list(expr.columns) + column_mins = {} + column_maxes = {} + for stats in all_stats: + lengths.append(stats["num-rows"]) + for col_stats in stats.get("columns", []): + name = col_stats.get("name") + if name in columns: + # Min + if name not in column_mins: + column_mins[name] = [] + column_mins[name].append(col_stats.get("min")) + # Max + if name not in column_maxes: + column_maxes[name] = [] + column_maxes[name].append(col_stats.get("max")) + return lengths, column_mins, column_maxes + + # First, try to use the statistics we already have + expr._min_pq_stats = expr._min_pq_stats or {} + expr._max_pq_stats = expr._max_pq_stats or {} + if not expr._min_pq_stats and expr._plan["statistics"]: + stats = [ + stat + for i, stat in enumerate(expr._plan["statistics"]) + if not expr._filtered or i in expr._partitions + ] + lengths, column_mins, column_maxes = _record_column_statistics(stats) + if not expr._length_pq_stats: + expr._length_pq_stats = lengths + expr._min_pq_stats = column_mins + expr._max_pq_stats = column_maxes + + # Find which column statistics are missing, and collect them + columns = [ + col for col in (columns or list(expr.columns)) if col not in expr._min_pq_stats + ] + if columns: + ( + lengths, + column_mins, + column_maxes, + ) = _record_column_statistics(expr._collect_statistics(columns), columns) + expr._length_pq_stats = lengths + expr._min_pq_stats.update(column_mins) + expr._max_pq_stats.update(column_maxes) + + +def _collect_statistics( + expr: ReadParquet, columns: list | None = None +) -> list[dict] | None: + """Collect Parquet statistic for dataset paths""" + + # Be strict about columns argument + if columns: + if not isinstance(columns, list): + raise ValueError(f"Expected columns to be a list, got {type(columns)}.") + elif not set(columns).issubset(set(expr.columns)): + raise ValueError(f"columns={columns} must be a subset of {expr.columns}") + + # Collect statistics using layer information + fs = expr._plan["func"].fs + parts = [ + part + for i, part in enumerate(expr._plan["parts"]) + if not expr._filtered or i in expr._partitions + ] + + # Execute with delayed for large and remote datasets + parallel = int(False if _is_local_fs(fs) else 16) + if parallel: + # Group parts corresponding to the same file. + # A single task should always parse statistics + # for all these parts at once (since they will + # all be in the same footer) + groups = defaultdict(list) + for part in parts: + for p in [part] if isinstance(part, dict) else part: + path = p.get("piece")[0] + groups[path].append(p) + group_keys = list(groups.keys()) + + # Compute and return flattened result + func = delayed(_read_partition_stats_group) + result = dask.compute( + [ + func( + list( + itertools.chain( + *[groups[k] for k in group_keys[i : i + parallel]] + ) + ), + fs, + columns=columns, + ) + for i in range(0, len(group_keys), parallel) + ] + )[0] + return list(itertools.chain(*result)) + else: + # Serial computation on client + return _read_partition_stats_group(parts, fs, columns=columns) + + def _read_partition_stats_group(parts, fs, columns=None): """Parse the statistics for a group of files""" From f6823d1c30ca3b85ab4e2477d20a551d0de96cbd Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Fri, 19 May 2023 12:02:31 -0500 Subject: [PATCH 16/31] introduce _partitioning --- dask_expr/expr.py | 38 ++++++++++++++++++++++++++++ dask_expr/io/parquet.py | 56 ++++++++++++++++++++++++++++++++++++++--- dask_expr/shuffle.py | 8 ++++++ 3 files changed, 98 insertions(+), 4 deletions(-) diff --git a/dask_expr/expr.py b/dask_expr/expr.py index 13f060995..813edbecb 100644 --- a/dask_expr/expr.py +++ b/dask_expr/expr.py @@ -595,6 +595,44 @@ def visualize(self, filename="dask-expr.svg", format=None, **kwargs): graphviz_to_file(g, filename, format) return g + def _partitioning(self, columns: list) -> dict: + """Known partitioning information + + Return known-partitioning information for the specified + list of columns. This information should be formatted + as a dict containing "columns" and "how" keys, where + the `"columns"` value should be a tuple of column names, + and the `"how"` value should be a tuple that uniquely + identifies the partitioning. + + If no partitioning information is known, an empty + dictionary will be returned. + + Examples of `"how"`: + + - Sorted data: `("increasing", )` + - Reverse-sorted data: `("decreasing", )` + - Shuffled data: `("hash", )` + + Note that un-named index columns must be specified as + `"__index__"` (`None` is not supported). + + Return + ------ + partitioning: dict + """ + assert isinstance(columns, list), "columns must be list" + + # By default, we only know about partitioning from known divisions + index_name = self._meta.index.name or "__index__" + if self.known_divisions and columns[0] == index_name: + return { + "columns": (index_name,), + "how": self.divisions, + } + + return {} + class Literal(Expr): """Represent a literal (known) value as an `Expr`""" diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index 89d54b70a..27349d848 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -243,10 +243,33 @@ def _filtered_task(self, index: int): @cached_property def _lengths(self): + """Return known partition lengths using parquet statistics""" if not self.filters: _update_length_statistics(self) return self._length_pq_stats + def _partitioning(self, columns: list): + """Return known partitioning details using parquet statistics""" + # Check if we can already use divisions + divisions_partitioning = super()._partitioning(columns) + if divisions_partitioning: + return divisions_partitioning + # TODO: Account for directory partitioning? + + # Use parquet statistics to check if we are ordered + # by the first column in `columns` + _update_column_statistics(self, columns) + column = columns[0] + if column in self._min_pq_stats and column in self._max_pq_stats: + ordering = _known_ordering( + self._min_pq_stats[column], + self._max_pq_stats[column], + ) + if ordering: + return {"columns": (column,), "how": ordering} + + return {} + # # Helper functions @@ -441,6 +464,30 @@ def extract_pq_filters(cls, pq_expr: ReadParquet, predicate_expr: Expr) -> _DNF: # +def _known_ordering(mins, maxes) -> tuple: + # Check for increasing order + divisions = [mins[0]] + for max_i, min_ip1 in zip(maxes[:-1], mins[1:]): + if max_i < min_ip1: + divisions.append(max_i) + else: + divisions = [] + break + if divisions: + return "increasing", tuple(divisions + [maxes[-1]]) + + # Check for decreasing order + divisions = [maxes[0]] + for min_i, max_ip1 in zip(mins[:-1], maxes[1:]): + if min_i > max_ip1: + divisions.append(min_i) + else: + divisions = [] + break + if divisions: + return "decreasing", tuple(divisions + [mins[-1]]) + + def _update_length_statistics(expr: ReadParquet): """Ensure that partition-length statistics are up to date""" @@ -455,7 +502,7 @@ def _update_length_statistics(expr: ReadParquet): else: # Need to go back and collect statistics expr._length_pq_stats = tuple( - stat["num-rows"] for stat in expr._collect_statistics() + stat["num-rows"] for stat in _collect_statistics(expr) ) @@ -507,7 +554,7 @@ def _record_column_statistics(all_stats, columns: list | None = None): lengths, column_mins, column_maxes, - ) = _record_column_statistics(expr._collect_statistics(columns), columns) + ) = _record_column_statistics(_collect_statistics(expr, columns), columns) expr._length_pq_stats = lengths expr._min_pq_stats.update(column_mins) expr._max_pq_stats.update(column_maxes) @@ -522,8 +569,9 @@ def _collect_statistics( if columns: if not isinstance(columns, list): raise ValueError(f"Expected columns to be a list, got {type(columns)}.") - elif not set(columns).issubset(set(expr.columns)): - raise ValueError(f"columns={columns} must be a subset of {expr.columns}") + allowed = {expr._meta.index.name} | set(expr.columns) + if not set(columns).issubset(allowed): + raise ValueError(f"columns={columns} must be a subset of {allowed}") # Collect statistics using layer information fs = expr._plan["func"].fs diff --git a/dask_expr/shuffle.py b/dask_expr/shuffle.py index c66ef3671..c3f15d99f 100644 --- a/dask_expr/shuffle.py +++ b/dask_expr/shuffle.py @@ -115,6 +115,14 @@ def _meta(self): def _divisions(self): return (None,) * (self.npartitions_out + 1) + def _partitioning(self, columns: list): + """Return known partitioning details using parquet statistics""" + by = self.partitioning_index + by = [by] if isinstance(by, (str, int)) else by + if columns == by[: len(columns)]: + return {"columns": tuple(columns), "how": ("hash", self.npartitions)} + return {} + # # ShuffleBackend Implementations From e600ea12416e3f8c68665f36dafb35fe80bb690f Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Fri, 19 May 2023 13:36:03 -0500 Subject: [PATCH 17/31] add simple test coverage for _partitions --- dask_expr/expr.py | 3 +++ dask_expr/io/tests/test_io.py | 19 ++++++++++++++++++- dask_expr/tests/test_shuffle.py | 9 +++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/dask_expr/expr.py b/dask_expr/expr.py index 813edbecb..d186c86d2 100644 --- a/dask_expr/expr.py +++ b/dask_expr/expr.py @@ -888,6 +888,9 @@ def __str__(self): base = "(" + base + ")" return f"{base}[{repr(self.columns)}]" + def _partitioning(self, columns: list) -> dict: + return self.frame._partitioning(columns) + def _simplify_down(self): if isinstance(self.frame, Projection): # df[a][b] diff --git a/dask_expr/io/tests/test_io.py b/dask_expr/io/tests/test_io.py index 389ab86d0..ed280a7f9 100644 --- a/dask_expr/io/tests/test_io.py +++ b/dask_expr/io/tests/test_io.py @@ -160,7 +160,7 @@ def test_io_culling(tmpdir, fmt): if fmt == "parquet": dd.from_pandas(pdf, 2).to_parquet(tmpdir) df = read_parquet(tmpdir) - elif fmt == "parquet": + elif fmt == "csv": dd.from_pandas(pdf, 2).to_csv(tmpdir) df = read_csv(tmpdir + "/*") else: @@ -215,3 +215,20 @@ def test_parquet_lengths(tmpdir): s = (df["b"] + 1).astype("Int32") assert sum(s._lengths) == len(pdf) + + +@pytest.mark.parametrize("order", ["increasing", "decreasing"]) +def test_parquet_ordered_partitioning_info(tmpdir, order): + pdf = pd.DataFrame( + { + "a": range(10, 0, -1) if order == "decreasing" else range(10), + "b": range(10), + "c": range(10), + } + ) + dd.from_pandas(pdf, 2).to_parquet(tmpdir) + df = read_parquet(tmpdir) + + _partitioning = df._partitioning(["a", "b"]) + assert _partitioning["columns"] == ("a",) + assert _partitioning["how"][0] == order diff --git a/dask_expr/tests/test_shuffle.py b/dask_expr/tests/test_shuffle.py index db249e278..842846c6b 100644 --- a/dask_expr/tests/test_shuffle.py +++ b/dask_expr/tests/test_shuffle.py @@ -114,3 +114,12 @@ def test_shuffle_column_projection(): df2 = df.shuffle("x")[["x"]].simplify() assert "y" not in df2.expr.operands[0].columns + + +def test_shuffle_partitioning_info(): + pdf = pd.DataFrame({"x": list(range(20)) * 5, "y": range(100)}) + df = from_pandas(pdf, npartitions=10) + df2 = df.shuffle("x")[["x"]] + + assert df2._partitioning(["x"])["columns"] == ("x",) + assert df2._partitioning(["x"])["how"] == ("hash", df.npartitions) From 1dbfb1806fcfdb431e075aeb833e8978e50d0e6a Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Fri, 19 May 2023 13:45:59 -0500 Subject: [PATCH 18/31] improve test and fix bug --- dask_expr/shuffle.py | 2 +- dask_expr/tests/test_shuffle.py | 14 ++++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/dask_expr/shuffle.py b/dask_expr/shuffle.py index c3f15d99f..98835901b 100644 --- a/dask_expr/shuffle.py +++ b/dask_expr/shuffle.py @@ -119,7 +119,7 @@ def _partitioning(self, columns: list): """Return known partitioning details using parquet statistics""" by = self.partitioning_index by = [by] if isinstance(by, (str, int)) else by - if columns == by[: len(columns)]: + if columns[: len(by)] == by: return {"columns": tuple(columns), "how": ("hash", self.npartitions)} return {} diff --git a/dask_expr/tests/test_shuffle.py b/dask_expr/tests/test_shuffle.py index 842846c6b..1fd1abf91 100644 --- a/dask_expr/tests/test_shuffle.py +++ b/dask_expr/tests/test_shuffle.py @@ -117,9 +117,15 @@ def test_shuffle_column_projection(): def test_shuffle_partitioning_info(): - pdf = pd.DataFrame({"x": list(range(20)) * 5, "y": range(100)}) + pdf = pd.DataFrame({"x": list(range(20)) * 5, "y": range(100), "z": range(100)}) df = from_pandas(pdf, npartitions=10) - df2 = df.shuffle("x")[["x"]] + df2 = df.shuffle(["x", "y"])[["x", "y"]] + + # We should have ["x", "y"] partitioning info + assert df2._partitioning(["x", "y"])["columns"] == ("x", "y") + assert df2._partitioning(["x", "y"])["how"] == ("hash", df.npartitions) - assert df2._partitioning(["x"])["columns"] == ("x",) - assert df2._partitioning(["x"])["how"] == ("hash", df.npartitions) + # We should not have partitioning info for ["x"], + # because we are only guarenteed to have unique + # ["x", "y"] combinations in each partition + assert df2._partitioning(["x"]) == {} From 423cfcbf0d12d48353c19d1b732b28f4c8f279a2 Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Fri, 19 May 2023 13:52:59 -0500 Subject: [PATCH 19/31] remove leftover --- dask_expr/expr.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/dask_expr/expr.py b/dask_expr/expr.py index 393fa638c..51080f710 100644 --- a/dask_expr/expr.py +++ b/dask_expr/expr.py @@ -60,9 +60,6 @@ def ndim(self): except AttributeError: return 0 - def _column_statistics(self, columns: list | None = None): - return None - def __str__(self): s = ", ".join( str(param) + "=" + str(operand) From 58ebf5a33f97d341f38739ca453398523d179a29 Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Fri, 19 May 2023 13:54:29 -0500 Subject: [PATCH 20/31] fix parquet len test --- dask_expr/io/tests/test_io.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/dask_expr/io/tests/test_io.py b/dask_expr/io/tests/test_io.py index ed280a7f9..c1a2db68c 100644 --- a/dask_expr/io/tests/test_io.py +++ b/dask_expr/io/tests/test_io.py @@ -207,10 +207,7 @@ def test_parquet_complex_filters(tmpdir): def test_parquet_lengths(tmpdir): - # NOTE: We should no longer need to set `index` - # or `calculate_divisions` to gather row-count - # statistics after dask#10290 - df = read_parquet(_make_file(tmpdir), index="a", calculate_divisions=True) + df = read_parquet(_make_file(tmpdir)) pdf = df.compute() s = (df["b"] + 1).astype("Int32") From 5790fb18913ba81fe36cb81fb360da5bf436fbb7 Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Mon, 22 May 2023 08:29:09 -0500 Subject: [PATCH 21/31] fix calculate_divisions default --- dask_expr/io/parquet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index 27349d848..b3c9f4b7c 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -53,7 +53,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO): "categories": None, "index": None, "storage_options": None, - "calculate_divisions": True, + "calculate_divisions": False, "ignore_metadata_file": False, "metadata_task_size": None, "split_row_groups": "infer", From cc01ebbc755c37c90a100ce83e0fe7bd509723f6 Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Tue, 23 May 2023 10:20:17 -0500 Subject: [PATCH 22/31] strip out _partitioning changes --- dask_expr/expr.py | 41 --------- dask_expr/io/parquet.py | 145 +++++--------------------------- dask_expr/io/tests/test_io.py | 17 ---- dask_expr/shuffle.py | 8 -- dask_expr/tests/test_shuffle.py | 15 ---- 5 files changed, 21 insertions(+), 205 deletions(-) diff --git a/dask_expr/expr.py b/dask_expr/expr.py index fb062ac56..ce91ccdc6 100644 --- a/dask_expr/expr.py +++ b/dask_expr/expr.py @@ -612,44 +612,6 @@ def visualize(self, filename="dask-expr.svg", format=None, **kwargs): graphviz_to_file(g, filename, format) return g - def _partitioning(self, columns: list) -> dict: - """Known partitioning information - - Return known-partitioning information for the specified - list of columns. This information should be formatted - as a dict containing "columns" and "how" keys, where - the `"columns"` value should be a tuple of column names, - and the `"how"` value should be a tuple that uniquely - identifies the partitioning. - - If no partitioning information is known, an empty - dictionary will be returned. - - Examples of `"how"`: - - - Sorted data: `("increasing", )` - - Reverse-sorted data: `("decreasing", )` - - Shuffled data: `("hash", )` - - Note that un-named index columns must be specified as - `"__index__"` (`None` is not supported). - - Return - ------ - partitioning: dict - """ - assert isinstance(columns, list), "columns must be list" - - # By default, we only know about partitioning from known divisions - index_name = self._meta.index.name or "__index__" - if self.known_divisions and columns[0] == index_name: - return { - "columns": (index_name,), - "how": self.divisions, - } - - return {} - class Literal(Expr): """Represent a literal (known) value as an `Expr`""" @@ -975,9 +937,6 @@ def __str__(self): base = "(" + base + ")" return f"{base}[{repr(self.columns)}]" - def _partitioning(self, columns: list) -> dict: - return self.frame._partitioning(columns) - def _simplify_down(self): if isinstance(self.frame, Projection): # df[a][b] diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index b3c9f4b7c..d18790596 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -65,9 +65,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO): "_partitions": None, "_series": False, } - _length_pq_stats = None - _min_pq_stats = None - _max_pq_stats = None + _pq_length_stats = None @property def engine(self): @@ -245,30 +243,25 @@ def _filtered_task(self, index: int): def _lengths(self): """Return known partition lengths using parquet statistics""" if not self.filters: - _update_length_statistics(self) - return self._length_pq_stats - - def _partitioning(self, columns: list): - """Return known partitioning details using parquet statistics""" - # Check if we can already use divisions - divisions_partitioning = super()._partitioning(columns) - if divisions_partitioning: - return divisions_partitioning - # TODO: Account for directory partitioning? - - # Use parquet statistics to check if we are ordered - # by the first column in `columns` - _update_column_statistics(self, columns) - column = columns[0] - if column in self._min_pq_stats and column in self._max_pq_stats: - ordering = _known_ordering( - self._min_pq_stats[column], - self._max_pq_stats[column], - ) - if ordering: - return {"columns": (column,), "how": ordering} - - return {} + self._update_length_statistics() + return self._pq_length_stats + + def _update_length_statistics(self): + """Ensure that partition-length statistics are up to date""" + + if not self._pq_length_stats: + if self._plan["statistics"]: + # Already have statistics from original API call + self._pq_length_stats = tuple( + stat["num-rows"] + for i, stat in enumerate(self._plan["statistics"]) + if not self._filtered or i in self._partitions + ) + else: + # Need to go back and collect statistics + self._pq_length_stats = tuple( + stat["num-rows"] for stat in _collect_pq_statistics(self) + ) # @@ -464,103 +457,7 @@ def extract_pq_filters(cls, pq_expr: ReadParquet, predicate_expr: Expr) -> _DNF: # -def _known_ordering(mins, maxes) -> tuple: - # Check for increasing order - divisions = [mins[0]] - for max_i, min_ip1 in zip(maxes[:-1], mins[1:]): - if max_i < min_ip1: - divisions.append(max_i) - else: - divisions = [] - break - if divisions: - return "increasing", tuple(divisions + [maxes[-1]]) - - # Check for decreasing order - divisions = [maxes[0]] - for min_i, max_ip1 in zip(mins[:-1], maxes[1:]): - if min_i > max_ip1: - divisions.append(min_i) - else: - divisions = [] - break - if divisions: - return "decreasing", tuple(divisions + [mins[-1]]) - - -def _update_length_statistics(expr: ReadParquet): - """Ensure that partition-length statistics are up to date""" - - if not expr._length_pq_stats: - if expr._plan["statistics"]: - # Already have statistics from original API call - expr._length_pq_stats = tuple( - stat["num-rows"] - for i, stat in enumerate(expr._plan["statistics"]) - if not expr._filtered or i in expr._partitions - ) - else: - # Need to go back and collect statistics - expr._length_pq_stats = tuple( - stat["num-rows"] for stat in _collect_statistics(expr) - ) - - -def _update_column_statistics(expr: ReadParquet, columns: list | None = None): - """Ensure that min/max column statistics are up to date""" - - def _record_column_statistics(all_stats, columns: list | None = None): - # Helper function to translate List[Dict] to Tuple or Dict - lengths = [] - columns = columns or list(expr.columns) - column_mins = {} - column_maxes = {} - for stats in all_stats: - lengths.append(stats["num-rows"]) - for col_stats in stats.get("columns", []): - name = col_stats.get("name") - if name in columns: - # Min - if name not in column_mins: - column_mins[name] = [] - column_mins[name].append(col_stats.get("min")) - # Max - if name not in column_maxes: - column_maxes[name] = [] - column_maxes[name].append(col_stats.get("max")) - return lengths, column_mins, column_maxes - - # First, try to use the statistics we already have - expr._min_pq_stats = expr._min_pq_stats or {} - expr._max_pq_stats = expr._max_pq_stats or {} - if not expr._min_pq_stats and expr._plan["statistics"]: - stats = [ - stat - for i, stat in enumerate(expr._plan["statistics"]) - if not expr._filtered or i in expr._partitions - ] - lengths, column_mins, column_maxes = _record_column_statistics(stats) - if not expr._length_pq_stats: - expr._length_pq_stats = lengths - expr._min_pq_stats = column_mins - expr._max_pq_stats = column_maxes - - # Find which column statistics are missing, and collect them - columns = [ - col for col in (columns or list(expr.columns)) if col not in expr._min_pq_stats - ] - if columns: - ( - lengths, - column_mins, - column_maxes, - ) = _record_column_statistics(_collect_statistics(expr, columns), columns) - expr._length_pq_stats = lengths - expr._min_pq_stats.update(column_mins) - expr._max_pq_stats.update(column_maxes) - - -def _collect_statistics( +def _collect_pq_statistics( expr: ReadParquet, columns: list | None = None ) -> list[dict] | None: """Collect Parquet statistic for dataset paths""" diff --git a/dask_expr/io/tests/test_io.py b/dask_expr/io/tests/test_io.py index c1a2db68c..d3b32199d 100644 --- a/dask_expr/io/tests/test_io.py +++ b/dask_expr/io/tests/test_io.py @@ -212,20 +212,3 @@ def test_parquet_lengths(tmpdir): s = (df["b"] + 1).astype("Int32") assert sum(s._lengths) == len(pdf) - - -@pytest.mark.parametrize("order", ["increasing", "decreasing"]) -def test_parquet_ordered_partitioning_info(tmpdir, order): - pdf = pd.DataFrame( - { - "a": range(10, 0, -1) if order == "decreasing" else range(10), - "b": range(10), - "c": range(10), - } - ) - dd.from_pandas(pdf, 2).to_parquet(tmpdir) - df = read_parquet(tmpdir) - - _partitioning = df._partitioning(["a", "b"]) - assert _partitioning["columns"] == ("a",) - assert _partitioning["how"][0] == order diff --git a/dask_expr/shuffle.py b/dask_expr/shuffle.py index 98835901b..c66ef3671 100644 --- a/dask_expr/shuffle.py +++ b/dask_expr/shuffle.py @@ -115,14 +115,6 @@ def _meta(self): def _divisions(self): return (None,) * (self.npartitions_out + 1) - def _partitioning(self, columns: list): - """Return known partitioning details using parquet statistics""" - by = self.partitioning_index - by = [by] if isinstance(by, (str, int)) else by - if columns[: len(by)] == by: - return {"columns": tuple(columns), "how": ("hash", self.npartitions)} - return {} - # # ShuffleBackend Implementations diff --git a/dask_expr/tests/test_shuffle.py b/dask_expr/tests/test_shuffle.py index 1fd1abf91..db249e278 100644 --- a/dask_expr/tests/test_shuffle.py +++ b/dask_expr/tests/test_shuffle.py @@ -114,18 +114,3 @@ def test_shuffle_column_projection(): df2 = df.shuffle("x")[["x"]].simplify() assert "y" not in df2.expr.operands[0].columns - - -def test_shuffle_partitioning_info(): - pdf = pd.DataFrame({"x": list(range(20)) * 5, "y": range(100), "z": range(100)}) - df = from_pandas(pdf, npartitions=10) - df2 = df.shuffle(["x", "y"])[["x", "y"]] - - # We should have ["x", "y"] partitioning info - assert df2._partitioning(["x", "y"])["columns"] == ("x", "y") - assert df2._partitioning(["x", "y"])["how"] == ("hash", df.npartitions) - - # We should not have partitioning info for ["x"], - # because we are only guarenteed to have unique - # ["x", "y"] combinations in each partition - assert df2._partitioning(["x"]) == {} From 0345d196fec709ceaa0bf9c86423dc8a1e397195 Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Tue, 23 May 2023 10:21:49 -0500 Subject: [PATCH 23/31] missing calculate_divisions default --- dask_expr/collection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dask_expr/collection.py b/dask_expr/collection.py index 658d0e332..1027356de 100644 --- a/dask_expr/collection.py +++ b/dask_expr/collection.py @@ -634,7 +634,7 @@ def read_parquet( index=None, storage_options=None, dtype_backend=None, - calculate_divisions=True, + calculate_divisions=False, ignore_metadata_file=False, metadata_task_size=None, split_row_groups="infer", From 7052a26fa531977a122f1bb711704a255c9097d6 Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Tue, 23 May 2023 10:44:10 -0500 Subject: [PATCH 24/31] move _lengths to a method with force option --- dask_expr/expr.py | 23 ++++++++++++++++------- dask_expr/io/io.py | 3 +-- dask_expr/io/parquet.py | 6 +++--- dask_expr/io/tests/test_io.py | 2 +- dask_expr/reductions.py | 5 +++-- dask_expr/tests/test_collection.py | 6 +++--- 6 files changed, 27 insertions(+), 18 deletions(-) diff --git a/dask_expr/expr.py b/dask_expr/expr.py index ce91ccdc6..c239ae620 100644 --- a/dask_expr/expr.py +++ b/dask_expr/expr.py @@ -40,7 +40,6 @@ class Expr: associative = False _parameters = [] _defaults = {} - _lengths = None def __init__(self, *args, **kwargs): operands = list(args) @@ -60,6 +59,18 @@ def ndim(self): except AttributeError: return 0 + def _lengths(self, force: bool = False) -> tuple | None: + """Return a tuple of known partition lengths + + Parameters + ---------- + force: + Whether to attempt to collect missing length + statistics manually if they are missing. + Defaults to `False`. + """ + return None + def __str__(self): s = ", ".join( str(param) + "=" + str(operand) @@ -805,9 +816,8 @@ class Elemwise(Blockwise): optimizations, like `len` will care about which operations preserve length """ - @property - def _lengths(self): - return self.dependencies()[0]._lengths + def _lengths(self, force: bool = False) -> tuple | None: + return self.dependencies()[0]._lengths(force=force) class ToTimestamp(Elemwise): @@ -1165,9 +1175,8 @@ def _simplify_down(self): def _node_label_args(self): return [self.frame, self.partitions] - @property - def _lengths(self): - lengths = self.frame._lengths + def _lengths(self, force: bool = False) -> tuple | None: + lengths = self.frame._lengths(force=force) if lengths: return tuple(lengths[i] for i in self.partitions) diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index 545829eee..952fdaf7c 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -68,8 +68,7 @@ def _divisions_and_locations(self): divisions = (None,) * len(locations) return divisions, locations - @functools.cached_property - def _lengths(self): + def _lengths(self, force: bool = False) -> tuple | None: locations = self._locations() return tuple( offset - locations[i] diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index d18790596..64faf0a95 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -239,11 +239,11 @@ def _filtered_task(self, index: int): return (operator.getitem, tsk, self.columns[0]) return tsk - @cached_property - def _lengths(self): + def _lengths(self, force: bool = False) -> tuple | None: """Return known partition lengths using parquet statistics""" if not self.filters: - self._update_length_statistics() + if force: + self._update_length_statistics() return self._pq_length_stats def _update_length_statistics(self): diff --git a/dask_expr/io/tests/test_io.py b/dask_expr/io/tests/test_io.py index d3b32199d..7dd3e71af 100644 --- a/dask_expr/io/tests/test_io.py +++ b/dask_expr/io/tests/test_io.py @@ -211,4 +211,4 @@ def test_parquet_lengths(tmpdir): pdf = df.compute() s = (df["b"] + 1).astype("Int32") - assert sum(s._lengths) == len(pdf) + assert sum(s._lengths(force=True)) == len(pdf) diff --git a/dask_expr/reductions.py b/dask_expr/reductions.py index 463c92f1f..1ee6de897 100644 --- a/dask_expr/reductions.py +++ b/dask_expr/reductions.py @@ -334,8 +334,9 @@ class Len(Reduction): reduction_aggregate = sum def _simplify_down(self): - if self.frame._lengths: - return Literal(sum(self.frame._lengths)) + _lengths = self.frame._lengths(force=True) + if _lengths: + return Literal(sum(_lengths)) elif isinstance(self.frame, Elemwise): child = max(self.frame.dependencies(), key=lambda expr: expr.npartitions) return Len(child) diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index a24188dbf..1715092c9 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -519,9 +519,9 @@ def test_len(df, pdf): def test_lengths(df, pdf): df2 = df[["x"]] + 1 - assert sum(df2._lengths) == len(pdf) - assert df[df.x > 5]._lengths is None - assert sum(df2.partitions[0]._lengths) == len(df2.partitions[0]) + assert sum(df2._lengths()) == len(pdf) + assert df[df.x > 5]._lengths() is None + assert sum(df2.partitions[0]._lengths()) == len(df2.partitions[0]) def test_drop_duplicates(df, pdf): From e26d6cd353e6210aafcbc211ce8dc6c62b81da25 Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Tue, 23 May 2023 10:47:25 -0500 Subject: [PATCH 25/31] cache pd lengths --- dask_expr/io/io.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index 952fdaf7c..0c8b5a143 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -44,6 +44,7 @@ class FromPandas(PartitionsFiltered, BlockwiseIO): _parameters = ["frame", "npartitions", "sort", "_partitions"] _defaults = {"npartitions": 1, "sort": True, "_partitions": None} + _pd_length_stats = None @property def _meta(self): @@ -69,12 +70,14 @@ def _divisions_and_locations(self): return divisions, locations def _lengths(self, force: bool = False) -> tuple | None: - locations = self._locations() - return tuple( - offset - locations[i] - for i, offset in enumerate(locations[1:]) - if not self._filtered or i in self._partitions - ) + if self._pd_length_stats is None: + locations = self._locations() + self._pd_length_stats = tuple( + offset - locations[i] + for i, offset in enumerate(locations[1:]) + if not self._filtered or i in self._partitions + ) + return self._pd_length_stats def _divisions(self): return self._divisions_and_locations[0] From 5c376b9ad2a43027ab0723abbfcbed45887bb9eb Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Tue, 23 May 2023 13:26:34 -0500 Subject: [PATCH 26/31] missing annotations import --- dask_expr/io/io.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index 0c8b5a143..27166d3f6 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import functools import math From 253cfeb15c11e42e18128c2e87b7a0bd0a9fa2d6 Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Tue, 30 May 2023 12:00:57 -0500 Subject: [PATCH 27/31] use Lengths --- dask_expr/collection.py | 3 +- dask_expr/expr.py | 46 ++++++++++++++++++------------ dask_expr/io/io.py | 10 +++++-- dask_expr/io/parquet.py | 32 +++++++++++++++++---- dask_expr/io/tests/test_io.py | 6 ++-- dask_expr/reductions.py | 8 +++--- dask_expr/tests/test_collection.py | 9 ++---- 7 files changed, 73 insertions(+), 41 deletions(-) diff --git a/dask_expr/collection.py b/dask_expr/collection.py index b3776de78..55a9fd1da 100644 --- a/dask_expr/collection.py +++ b/dask_expr/collection.py @@ -22,7 +22,6 @@ from dask_expr.merge import Merge from dask_expr.reductions import ( DropDuplicates, - Len, MemoryUsageFrame, MemoryUsageIndex, NLargest, @@ -91,7 +90,7 @@ def __len__(self): @functools.cached_property def _len(self): - return new_collection(Len(self.expr)).compute() + return sum(new_collection(expr.Lengths(self.expr)).compute()) @property def nbytes(self): diff --git a/dask_expr/expr.py b/dask_expr/expr.py index c8ce0ffe5..9807bc43e 100644 --- a/dask_expr/expr.py +++ b/dask_expr/expr.py @@ -60,18 +60,6 @@ def ndim(self): except AttributeError: return 0 - def _lengths(self, force: bool = False) -> tuple | None: - """Return a tuple of known partition lengths - - Parameters - ---------- - force: - Whether to attempt to collect missing length - statistics manually if they are missing. - Defaults to `False`. - """ - return None - def __str__(self): s = ", ".join( str(param) + "=" + str(operand) @@ -874,8 +862,7 @@ class Elemwise(Blockwise): optimizations, like `len` will care about which operations preserve length """ - def _lengths(self, force: bool = False) -> tuple | None: - return self.dependencies()[0]._lengths(force=force) + pass class Clip(Elemwise): @@ -1062,6 +1049,32 @@ def _task(self, index: int): ) +class Lengths(Expr): + """Returns a tuple of partition lengths""" + + _parameters = ["frame"] + + @property + def _meta(self): + return tuple() + + def _divisions(self): + return (None, None) + + def _simplify_down(self): + if isinstance(self.frame, Elemwise): + return Lengths(self.frame.operands[0]) + + def _layer(self): + name = "part-" + self._name + dsk = { + (name, i): (len, (self.frame._name, i)) + for i in range(self.frame.npartitions) + } + dsk[(self._name, 0)] = (tuple, list(dsk.keys())) + return dsk + + class Head(Expr): """Take the first `n` rows of the first partition""" @@ -1301,11 +1314,6 @@ def _simplify_down(self): def _node_label_args(self): return [self.frame, self.partitions] - def _lengths(self, force: bool = False) -> tuple | None: - lengths = self.frame._lengths(force=force) - if lengths: - return tuple(lengths[i] for i in self.partitions) - class PartitionsFiltered(Expr): """Mixin class for partition filtering diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index 8b4f70810..bb7351ad6 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -5,7 +5,7 @@ from dask.dataframe.io.io import sorted_division_locations -from dask_expr.expr import Blockwise, Expr, PartitionsFiltered +from dask_expr.expr import Blockwise, Expr, Lengths, Literal, PartitionsFiltered class IO(Expr): @@ -71,7 +71,7 @@ def _divisions_and_locations(self): divisions = (None,) * len(locations) return divisions, locations - def _lengths(self, force: bool = False) -> tuple | None: + def _get_lengths(self) -> tuple | None: if self._pd_length_stats is None: locations = self._locations() self._pd_length_stats = tuple( @@ -81,6 +81,12 @@ def _lengths(self, force: bool = False) -> tuple | None: ) return self._pd_length_stats + def _simplify_up(self, parent): + if isinstance(parent, Lengths): + _lengths = self._get_lengths() + if _lengths: + return Literal(_lengths) + def _divisions(self): return self._divisions_and_locations[0] diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index 64faf0a95..aece8fdc6 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -19,7 +19,21 @@ from dask.delayed import delayed from dask.utils import natural_sort_key -from dask_expr.expr import EQ, GE, GT, LE, LT, NE, And, Expr, Filter, Or, Projection +from dask_expr.expr import ( + EQ, + GE, + GT, + LE, + LT, + NE, + And, + Expr, + Filter, + Lengths, + Literal, + Or, + Projection, +) from dask_expr.io import BlockwiseIO, PartitionsFiltered NONE_LABEL = "__null_dask_index__" @@ -102,6 +116,11 @@ def _simplify_up(self, parent): kwargs["filters"] = filters.combine(kwargs["filters"]).to_list_tuple() return ReadParquet(**kwargs) + if isinstance(parent, Lengths): + _lengths = self._get_lengths() + if _lengths: + return Literal(_lengths) + @cached_property def _dataset_info(self): # Process and split user options @@ -239,12 +258,15 @@ def _filtered_task(self, index: int): return (operator.getitem, tsk, self.columns[0]) return tsk - def _lengths(self, force: bool = False) -> tuple | None: + def _get_lengths(self) -> tuple | None: """Return known partition lengths using parquet statistics""" if not self.filters: - if force: - self._update_length_statistics() - return self._pq_length_stats + self._update_length_statistics() + return tuple( + length + for i, length in enumerate(self._pq_length_stats) + if not self._filtered or i in self._partitions + ) def _update_length_statistics(self): """Ensure that partition-length statistics are up to date""" diff --git a/dask_expr/io/tests/test_io.py b/dask_expr/io/tests/test_io.py index f06952240..b28740852 100644 --- a/dask_expr/io/tests/test_io.py +++ b/dask_expr/io/tests/test_io.py @@ -207,12 +207,14 @@ def test_parquet_complex_filters(tmpdir): assert_eq(got.optimize(), expect) -def test_parquet_lengths(tmpdir): +def test_parquet_len(tmpdir): df = read_parquet(_make_file(tmpdir)) pdf = df.compute() + assert len(df[df.a > 5]) == len(pdf[pdf.a > 5]) + s = (df["b"] + 1).astype("Int32") - assert sum(s._lengths(force=True)) == len(pdf) + assert len(s) == len(pdf) @pytest.mark.parametrize("optimize", [True, False]) diff --git a/dask_expr/reductions.py b/dask_expr/reductions.py index f2d49619a..4ec69bb7c 100644 --- a/dask_expr/reductions.py +++ b/dask_expr/reductions.py @@ -13,7 +13,7 @@ ) from dask.utils import M, apply -from dask_expr.expr import Elemwise, Expr, Literal, Projection +from dask_expr.expr import Elemwise, Expr, Lengths, Literal, Projection class ApplyConcatApply(Expr): @@ -346,9 +346,9 @@ class Len(Reduction): reduction_aggregate = sum def _simplify_down(self): - _lengths = self.frame._lengths(force=True) - if _lengths: - return Literal(sum(_lengths)) + _lengths = Lengths(self.frame).optimize() + if isinstance(_lengths, Literal): + return Literal(sum(_lengths.value)) elif isinstance(self.frame, Elemwise): child = max(self.frame.dependencies(), key=lambda expr: expr.npartitions) return Len(child) diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index 3fe446548..3acb84042 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -594,17 +594,12 @@ def test_len(df, pdf): df2 = df[["x"]] + 1 assert len(df2) == len(pdf) + assert len(df[df.x > 5]) == len(pdf[pdf.x > 5]) + first = df2.partitions[0].compute() assert len(df2.partitions[0]) == len(first) -def test_lengths(df, pdf): - df2 = df[["x"]] + 1 - assert sum(df2._lengths()) == len(pdf) - assert df[df.x > 5]._lengths() is None - assert sum(df2.partitions[0]._lengths()) == len(df2.partitions[0]) - - def test_drop_duplicates(df, pdf): assert_eq(df.drop_duplicates(), pdf.drop_duplicates()) assert_eq( From 32e4f943ac4dca8ad2d8d82e03f4e7bec675beb4 Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Tue, 30 May 2023 17:13:17 -0500 Subject: [PATCH 28/31] partial fixup --- dask_expr/collection.py | 5 +++-- dask_expr/expr.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/dask_expr/collection.py b/dask_expr/collection.py index 4ab80d5be..9a57c5cc0 100644 --- a/dask_expr/collection.py +++ b/dask_expr/collection.py @@ -22,6 +22,7 @@ from dask_expr.merge import Merge from dask_expr.reductions import ( DropDuplicates, + Len, MemoryUsageFrame, MemoryUsageIndex, NLargest, @@ -88,9 +89,9 @@ def size(self): def __len__(self): return self._len - @functools.cached_property + @property def _len(self): - return sum(new_collection(expr.Lengths(self.expr)).compute()) + return new_collection(Len(self.expr)).compute() @property def nbytes(self): diff --git a/dask_expr/expr.py b/dask_expr/expr.py index 046d9ae24..dad8e7343 100644 --- a/dask_expr/expr.py +++ b/dask_expr/expr.py @@ -1063,7 +1063,8 @@ def _divisions(self): def _simplify_down(self): if isinstance(self.frame, Elemwise): - return Lengths(self.frame.operands[0]) + child = max(self.frame.dependencies(), key=lambda expr: expr.npartitions) + return Lengths(child) def _layer(self): name = "part-" + self._name From bd5395ad741e6f460e115b398f87ef2f969366c8 Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Tue, 30 May 2023 17:32:47 -0500 Subject: [PATCH 29/31] improve testing --- dask_expr/io/tests/test_io.py | 6 +++++- dask_expr/tests/test_collection.py | 3 +++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/dask_expr/io/tests/test_io.py b/dask_expr/io/tests/test_io.py index b28740852..cd65ea7f8 100644 --- a/dask_expr/io/tests/test_io.py +++ b/dask_expr/io/tests/test_io.py @@ -6,8 +6,9 @@ from dask.dataframe.utils import assert_eq from dask_expr import from_dask_dataframe, from_pandas, optimize, read_csv, read_parquet -from dask_expr.expr import Expr +from dask_expr.expr import Expr, Literal from dask_expr.io import ReadParquet +from dask_expr.reductions import Len def _make_file(dir, format="parquet", df=None): @@ -216,6 +217,9 @@ def test_parquet_len(tmpdir): s = (df["b"] + 1).astype("Int32") assert len(s) == len(pdf) + expr = Len(s.expr) + assert isinstance(expr.optimize(), Literal) + @pytest.mark.parametrize("optimize", [True, False]) def test_from_dask_dataframe(optimize): diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index dc900e91a..60a3e787c 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -12,6 +12,7 @@ from dask_expr import expr, from_pandas, optimize from dask_expr.datasets import timeseries +from dask_expr.reductions import Len @pytest.fixture @@ -607,6 +608,8 @@ def test_len(df, pdf): first = df2.partitions[0].compute() assert len(df2.partitions[0]) == len(first) + assert isinstance(Len(df2.expr).optimize(), expr.Literal) + def test_drop_duplicates(df, pdf): assert_eq(df.drop_duplicates(), pdf.drop_duplicates()) From be4af189cb3f4907f58874b8c76ed88489263e06 Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Wed, 31 May 2023 09:28:33 -0500 Subject: [PATCH 30/31] cleanup --- dask_expr/io/io.py | 6 ++++++ dask_expr/io/parquet.py | 6 ++++++ dask_expr/io/tests/test_io.py | 6 +++--- dask_expr/reductions.py | 7 ++----- dask_expr/tests/test_collection.py | 1 + dask_expr/tests/test_datasets.py | 7 +++++++ 6 files changed, 25 insertions(+), 8 deletions(-) diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index bb7351ad6..ebee423e9 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -6,6 +6,7 @@ from dask.dataframe.io.io import sorted_division_locations from dask_expr.expr import Blockwise, Expr, Lengths, Literal, PartitionsFiltered +from dask_expr.reductions import Len class IO(Expr): @@ -87,6 +88,11 @@ def _simplify_up(self, parent): if _lengths: return Literal(_lengths) + if isinstance(parent, Len): + _lengths = self._get_lengths() + if _lengths: + return Literal(sum(_lengths)) + def _divisions(self): return self._divisions_and_locations[0] diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index aece8fdc6..dd7bea178 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -35,6 +35,7 @@ Projection, ) from dask_expr.io import BlockwiseIO, PartitionsFiltered +from dask_expr.reductions import Len NONE_LABEL = "__null_dask_index__" @@ -121,6 +122,11 @@ def _simplify_up(self, parent): if _lengths: return Literal(_lengths) + if isinstance(parent, Len): + _lengths = self._get_lengths() + if _lengths: + return Literal(sum(_lengths)) + @cached_property def _dataset_info(self): # Process and split user options diff --git a/dask_expr/io/tests/test_io.py b/dask_expr/io/tests/test_io.py index cd65ea7f8..b2a7036fa 100644 --- a/dask_expr/io/tests/test_io.py +++ b/dask_expr/io/tests/test_io.py @@ -6,7 +6,7 @@ from dask.dataframe.utils import assert_eq from dask_expr import from_dask_dataframe, from_pandas, optimize, read_csv, read_parquet -from dask_expr.expr import Expr, Literal +from dask_expr.expr import Expr, Lengths, Literal from dask_expr.io import ReadParquet from dask_expr.reductions import Len @@ -217,8 +217,8 @@ def test_parquet_len(tmpdir): s = (df["b"] + 1).astype("Int32") assert len(s) == len(pdf) - expr = Len(s.expr) - assert isinstance(expr.optimize(), Literal) + assert isinstance(Len(s.expr).optimize(), Literal) + assert isinstance(Lengths(s.expr).optimize(), Literal) @pytest.mark.parametrize("optimize", [True, False]) diff --git a/dask_expr/reductions.py b/dask_expr/reductions.py index 4ec69bb7c..b21df3d12 100644 --- a/dask_expr/reductions.py +++ b/dask_expr/reductions.py @@ -13,7 +13,7 @@ ) from dask.utils import M, apply -from dask_expr.expr import Elemwise, Expr, Lengths, Literal, Projection +from dask_expr.expr import Elemwise, Expr, Projection class ApplyConcatApply(Expr): @@ -346,10 +346,7 @@ class Len(Reduction): reduction_aggregate = sum def _simplify_down(self): - _lengths = Lengths(self.frame).optimize() - if isinstance(_lengths, Literal): - return Literal(sum(_lengths.value)) - elif isinstance(self.frame, Elemwise): + if isinstance(self.frame, Elemwise): child = max(self.frame.dependencies(), key=lambda expr: expr.npartitions) return Len(child) diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index 60a3e787c..6c32da8c3 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -609,6 +609,7 @@ def test_len(df, pdf): assert len(df2.partitions[0]) == len(first) assert isinstance(Len(df2.expr).optimize(), expr.Literal) + assert isinstance(expr.Lengths(df2.expr).optimize(), expr.Literal) def test_drop_duplicates(df, pdf): diff --git a/dask_expr/tests/test_datasets.py b/dask_expr/tests/test_datasets.py index ff2ac9d77..1a8532ccb 100644 --- a/dask_expr/tests/test_datasets.py +++ b/dask_expr/tests/test_datasets.py @@ -1,6 +1,8 @@ from dask.dataframe.utils import assert_eq +from dask_expr import new_collection from dask_expr.datasets import timeseries +from dask_expr.expr import Lengths def test_timeseries(): @@ -48,3 +50,8 @@ def test_persist(): assert_eq(a, b) assert len(a.dask) > len(b.dask) assert len(b.dask) == b.npartitions + + +def test_lengths(): + df = timeseries(freq="1H", start="2000-01-01", end="2000-01-03", seed=123) + assert len(df) == sum(new_collection(Lengths(df.expr).optimize()).compute()) From 47ff1d3fa7027aa55f3cd0e651f8352f69d34be7 Mon Sep 17 00:00:00 2001 From: Rick Zamora Date: Wed, 31 May 2023 09:34:46 -0500 Subject: [PATCH 31/31] remove _len for now --- dask_expr/collection.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/dask_expr/collection.py b/dask_expr/collection.py index 9a57c5cc0..da0c48960 100644 --- a/dask_expr/collection.py +++ b/dask_expr/collection.py @@ -87,10 +87,6 @@ def size(self): return new_collection(self.expr.size) def __len__(self): - return self._len - - @property - def _len(self): return new_collection(Len(self.expr)).compute() @property