diff --git a/dask_expr/collection.py b/dask_expr/collection.py index 1ddd6b87f..0d757697d 100644 --- a/dask_expr/collection.py +++ b/dask_expr/collection.py @@ -73,6 +73,12 @@ def _meta(self): def size(self): return new_collection(self.expr.size) + def __len__(self): + _len = self.expr._len + if isinstance(_len, expr.Expr): + _len = new_collection(_len).compute() + return _len + def __reduce__(self): return new_collection, (self._expr,) @@ -548,7 +554,7 @@ def read_parquet( index=None, storage_options=None, dtype_backend=None, - calculate_divisions=False, + gather_statistics=True, ignore_metadata_file=False, metadata_task_size=None, split_row_groups="infer", @@ -573,7 +579,7 @@ def read_parquet( categories=categories, index=index, storage_options=storage_options, - calculate_divisions=calculate_divisions, + gather_statistics=gather_statistics, ignore_metadata_file=ignore_metadata_file, metadata_task_size=metadata_task_size, split_row_groups=split_row_groups, diff --git a/dask_expr/expr.py b/dask_expr/expr.py index 248b46222..7ea7d9268 100644 --- a/dask_expr/expr.py +++ b/dask_expr/expr.py @@ -57,6 +57,13 @@ def ndim(self): except AttributeError: return 0 + @functools.cached_property + def _len(self): + statistics = self.statistics() + if "row_count" in statistics: + return statistics["row_count"].sum() + return Len(self) + def __str__(self): s = ", ".join( str(param) + "=" + str(operand) @@ -206,6 +213,42 @@ def _layer(self) -> dict: return {(self._name, i): self._task(i) for i in range(self.npartitions)} + def _statistics(self): + """New statistics to add for this expression""" + from dask_expr.statistics import Statistics + + # Assume statistics from dependencies + statistics = {} + for dep in self.dependencies(): + for k, v in dep.statistics().items(): + assert isinstance(v, Statistics) + if k not in statistics: + val = v.assume(self) + if val: + statistics[k] = val + return statistics + + def statistics(self) -> dict: + """Known statistics of an expression, like partition statistics + + To define this on a class create a `._statistics` method that returns a + dictionary of new statistics known by that class. If nothing is known it + is ok to return None. Superclasses will also be consulted. + + Examples + -------- + >>> df.statistics() + {'row_count': RowCountStatistics(data=(1000000,))} + """ + out = {} + for typ in type(self).mro()[::-1]: + if not issubclass(typ, Expr): + continue + d = typ._statistics(self) or {} + if d: + out.update(d) # TODO: this is fragile + return out + def simplify(self): """Simplify expression @@ -1337,4 +1380,4 @@ def _execute_task(graph, name, *deps): from dask_expr.io import BlockwiseIO -from dask_expr.reductions import Count, Max, Mean, Min, Mode, Prod, Size, Sum +from dask_expr.reductions import Count, Len, Max, Mean, Min, Mode, Prod, Size, Sum diff --git a/dask_expr/io/io.py b/dask_expr/io/io.py index 18927d8d8..8c451efc6 100644 --- a/dask_expr/io/io.py +++ b/dask_expr/io/io.py @@ -4,6 +4,7 @@ from dask.dataframe.io.io import sorted_division_locations from dask_expr.expr import Blockwise, Expr, PartitionsFiltered +from dask_expr.statistics import RowCountStatistics class IO(Expr): @@ -68,6 +69,13 @@ def _divisions_and_locations(self): divisions = (None,) * len(locations) return divisions, locations + def _statistics(self): + locations = self._locations() + row_counts = tuple( + offset - locations[i] for i, offset in enumerate(locations[1:]) + ) + return {"row_count": RowCountStatistics(row_counts)} + def _divisions(self): return self._divisions_and_locations[0] diff --git a/dask_expr/io/parquet.py b/dask_expr/io/parquet.py index 388432878..dd1c18627 100644 --- a/dask_expr/io/parquet.py +++ b/dask_expr/io/parquet.py @@ -6,15 +6,17 @@ from dask.dataframe.io.parquet.core import ( ParquetFunctionWrapper, + aggregate_row_groups, get_engine, - process_statistics, set_index_columns, + sorted_columns, ) from dask.dataframe.io.parquet.utils import _split_user_options from dask.utils import natural_sort_key from dask_expr.expr import EQ, GE, GT, LE, LT, NE, And, Expr, Filter, Or, Projection from dask_expr.io import BlockwiseIO, PartitionsFiltered +from dask_expr.statistics import RowCountStatistics NONE_LABEL = "__null_dask_index__" @@ -28,6 +30,60 @@ def _list_columns(columns): return columns +def _align_statistics(parts, statistics): + # Make sure parts and statistics are aligned + # (if statistics is not empty) + if statistics and len(parts) != len(statistics): + statistics = [] + if statistics: + result = list( + zip( + *[ + (part, stats) + for part, stats in zip(parts, statistics) + if stats["num-rows"] > 0 + ] + ) + ) + parts, statistics = result or [[], []] + return parts, statistics + + +def _aggregate_row_groups(parts, statistics, dataset_info): + # Aggregate parts/statistics if we are splitting by row-group + blocksize = ( + dataset_info["blocksize"] if dataset_info["split_row_groups"] is True else None + ) + split_row_groups = dataset_info["split_row_groups"] + fs = dataset_info["fs"] + aggregation_depth = dataset_info["aggregation_depth"] + + if statistics: + if blocksize or (split_row_groups and int(split_row_groups) > 1): + parts, statistics = aggregate_row_groups( + parts, statistics, blocksize, split_row_groups, fs, aggregation_depth + ) + return parts, statistics + + +def _calculate_divisions(statistics, dataset_info, npartitions): + # Use statistics to define divisions + divisions = None + if statistics: + calculate_divisions = dataset_info["kwargs"].get("calculate_divisions", None) + index = dataset_info["index"] + process_columns = index if index and len(index) == 1 else None + if (calculate_divisions is not False) and process_columns: + for sorted_column_info in sorted_columns( + statistics, columns=process_columns + ): + if sorted_column_info["name"] in index: + divisions = sorted_column_info["divisions"] + break + + return divisions or (None,) * (npartitions + 1) + + class ReadParquet(PartitionsFiltered, BlockwiseIO): """Read a parquet dataset""" @@ -38,7 +94,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO): "categories", "index", "storage_options", - "calculate_divisions", + "gather_statistics", "ignore_metadata_file", "metadata_task_size", "split_row_groups", @@ -56,7 +112,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO): "categories": None, "index": None, "storage_options": None, - "calculate_divisions": False, + "gather_statistics": True, "ignore_metadata_file": False, "metadata_task_size": None, "split_row_groups": "infer", @@ -152,7 +208,7 @@ def _dataset_info(self): fs, self.categories, index, - self.calculate_divisions, + self.gather_statistics, self.filters, self.split_row_groups, blocksize, @@ -169,6 +225,7 @@ def _dataset_info(self): # Infer meta, accounting for index and columns arguments. meta = self.engine._create_dd_meta(dataset_info) + index = dataset_info["index"] index = [index] if isinstance(index, str) else index meta, index, columns = set_index_columns( meta, index, self.operand("columns"), auto_index_allowed @@ -196,21 +253,14 @@ def _plan(self): dataset_info ) - # Parse dataset statistics from metadata (if available) - parts, divisions, _ = process_statistics( - parts, - stats, - dataset_info["filters"], - dataset_info["index"], - ( - dataset_info["blocksize"] - if dataset_info["split_row_groups"] is True - else None - ), - dataset_info["split_row_groups"], - dataset_info["fs"], - dataset_info["aggregation_depth"], - ) + # Make sure parts and stats are aligned + parts, stats = _align_statistics(parts, stats) + + # Use statistics to aggregate partitions + parts, stats = _aggregate_row_groups(parts, stats, dataset_info) + + # Use statistics to calculate divisions + divisions = _calculate_divisions(stats, dataset_info, len(parts)) meta = dataset_info["meta"] if len(divisions) < 2: @@ -234,6 +284,7 @@ def _plan(self): return { "func": io_func, "parts": parts, + "statistics": stats, "divisions": divisions, } @@ -246,6 +297,15 @@ def _filtered_task(self, index: int): return (operator.getitem, tsk, self.columns[0]) return tsk + def _statistics(self): + if self._pq_statistics and not self.filters: + row_count = tuple(stat["num-rows"] for stat in self._pq_statistics) + return {"row_count": RowCountStatistics(row_count)} + + @property + def _pq_statistics(self): + return self._plan["statistics"] + # # Filters diff --git a/dask_expr/io/tests/test_io.py b/dask_expr/io/tests/test_io.py index 60e85394f..0f149750d 100644 --- a/dask_expr/io/tests/test_io.py +++ b/dask_expr/io/tests/test_io.py @@ -204,3 +204,14 @@ def test_parquet_complex_filters(tmpdir): assert_eq(got, expect) assert_eq(got.optimize(), expect) + + +def test_parquet_row_count_statistics(tmpdir): + # NOTE: We should no longer need to set `index` + # or `calculate_divisions` to gather row-count + # statistics after dask#10290 + df = read_parquet(_make_file(tmpdir), index="a", calculate_divisions=True) + pdf = df.compute() + + s = (df["b"] + 1).astype("Int32") + assert s.statistics().get("row_count").sum() == len(pdf) diff --git a/dask_expr/statistics.py b/dask_expr/statistics.py new file mode 100644 index 000000000..ae631fc6c --- /dev/null +++ b/dask_expr/statistics.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from collections.abc import Iterable +from dataclasses import dataclass +from functools import singledispatchmethod +from typing import Any + +from dask_expr.expr import Elemwise, Expr, Partitions + + +@dataclass(frozen=True) +class Statistics: + """Abstract class for expression statistics + + See Also + -------- + PartitionStatistics + """ + + data: Any + + @singledispatchmethod + def assume(self, parent: Expr) -> Statistics | None: + """Statistics that a "parent" Expr may assume + + A return value of `None` (the default) means that + `parent` is not eligable to assume this kind of + statistics. + """ + return None + + +@dataclass(frozen=True) +class PartitionStatistics(Statistics): + """Statistics containing a distinct value for every partition + + See Also + -------- + RowCountStatistics + """ + + data: Iterable + + +@PartitionStatistics.assume.register +def _partitionstatistics_partitions(self, parent: Partitions): + # A `Partitions` expression may assume statistics + # from the selected partitions + return type(self)( + type(self.data)( + part for i, part in enumerate(self.data) if i in parent.partitions + ) + ) + + +# +# PartitionStatistics sub-classes +# + + +@dataclass(frozen=True) +class RowCountStatistics(PartitionStatistics): + """Tracks the row count of each partition""" + + def sum(self): + """Return the total row-count of all partitions""" + return sum(self.data) + + +@RowCountStatistics.assume.register +def _rowcount_elemwise(self, parent: Elemwise): + # All Element-wise operations may assume + # row-count statistics + return self diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index 21dfe370a..5c1ceb67e 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -429,3 +429,20 @@ def test_repartition_divisions(df, opt): if len(part): assert part.min() >= df2.divisions[p] assert part.max() < df2.divisions[p + 1] + + +def test_len(df, pdf): + df2 = df[["x"]] + 1 + assert len(df2) == len(pdf) + + first = df2.partitions[0].compute() + assert len(df2.partitions[0]) == len(first) + + +def test_row_count_statistics(df, pdf): + df2 = df[["x"]] + 1 + assert df2.statistics().get("row_count").sum() == len(pdf) + assert df[df.x > 5].statistics().get("row_count") is None + assert df2.partitions[0].statistics().get("row_count").sum() == len( + df2.partitions[0] + )