Skip to content
10 changes: 8 additions & 2 deletions dask_expr/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ def _meta(self):
def size(self):
return new_collection(self.expr.size)

def __len__(self):
_len = self.expr._len
if isinstance(_len, expr.Expr):
_len = new_collection(_len).compute()
return _len

def __reduce__(self):
return new_collection, (self._expr,)

Expand Down Expand Up @@ -548,7 +554,7 @@ def read_parquet(
index=None,
storage_options=None,
dtype_backend=None,
calculate_divisions=False,
gather_statistics=True,
ignore_metadata_file=False,
metadata_task_size=None,
split_row_groups="infer",
Expand All @@ -573,7 +579,7 @@ def read_parquet(
categories=categories,
index=index,
storage_options=storage_options,
calculate_divisions=calculate_divisions,
gather_statistics=gather_statistics,
ignore_metadata_file=ignore_metadata_file,
metadata_task_size=metadata_task_size,
split_row_groups=split_row_groups,
Expand Down
45 changes: 44 additions & 1 deletion dask_expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ def ndim(self):
except AttributeError:
return 0

@functools.cached_property
def _len(self):
statistics = self.statistics()
if "row_count" in statistics:
return statistics["row_count"].sum()
return Len(self)

def __str__(self):
s = ", ".join(
str(param) + "=" + str(operand)
Expand Down Expand Up @@ -206,6 +213,42 @@ def _layer(self) -> dict:

return {(self._name, i): self._task(i) for i in range(self.npartitions)}

def _statistics(self):
"""New statistics to add for this expression"""
from dask_expr.statistics import Statistics

# Assume statistics from dependencies
statistics = {}
for dep in self.dependencies():
for k, v in dep.statistics().items():
assert isinstance(v, Statistics)
if k not in statistics:
val = v.assume(self)
if val:
statistics[k] = val
return statistics

def statistics(self) -> dict:
"""Known statistics of an expression, like partition statistics

To define this on a class create a `._statistics` method that returns a
dictionary of new statistics known by that class. If nothing is known it
is ok to return None. Superclasses will also be consulted.

Examples
--------
>>> df.statistics()
{'row_count': RowCountStatistics(data=(1000000,))}
"""
out = {}
for typ in type(self).mro()[::-1]:
if not issubclass(typ, Expr):
continue
d = typ._statistics(self) or {}
if d:
out.update(d) # TODO: this is fragile
return out

def simplify(self):
"""Simplify expression

Expand Down Expand Up @@ -1337,4 +1380,4 @@ def _execute_task(graph, name, *deps):


from dask_expr.io import BlockwiseIO
from dask_expr.reductions import Count, Max, Mean, Min, Mode, Prod, Size, Sum
from dask_expr.reductions import Count, Len, Max, Mean, Min, Mode, Prod, Size, Sum
8 changes: 8 additions & 0 deletions dask_expr/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dask.dataframe.io.io import sorted_division_locations

from dask_expr.expr import Blockwise, Expr, PartitionsFiltered
from dask_expr.statistics import RowCountStatistics


class IO(Expr):
Expand Down Expand Up @@ -68,6 +69,13 @@ def _divisions_and_locations(self):
divisions = (None,) * len(locations)
return divisions, locations

def _statistics(self):
locations = self._locations()
row_counts = tuple(
offset - locations[i] for i, offset in enumerate(locations[1:])
)
return {"row_count": RowCountStatistics(row_counts)}

def _divisions(self):
return self._divisions_and_locations[0]

Expand Down
98 changes: 79 additions & 19 deletions dask_expr/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@

from dask.dataframe.io.parquet.core import (
ParquetFunctionWrapper,
aggregate_row_groups,
get_engine,
process_statistics,
set_index_columns,
sorted_columns,
)
from dask.dataframe.io.parquet.utils import _split_user_options
from dask.utils import natural_sort_key

from dask_expr.expr import EQ, GE, GT, LE, LT, NE, And, Expr, Filter, Or, Projection
from dask_expr.io import BlockwiseIO, PartitionsFiltered
from dask_expr.statistics import RowCountStatistics

NONE_LABEL = "__null_dask_index__"

Expand All @@ -28,6 +30,60 @@ def _list_columns(columns):
return columns


def _align_statistics(parts, statistics):
# Make sure parts and statistics are aligned
# (if statistics is not empty)
if statistics and len(parts) != len(statistics):
statistics = []
if statistics:
result = list(
zip(
*[
(part, stats)
for part, stats in zip(parts, statistics)
if stats["num-rows"] > 0
]
)
)
parts, statistics = result or [[], []]
return parts, statistics


def _aggregate_row_groups(parts, statistics, dataset_info):
# Aggregate parts/statistics if we are splitting by row-group
blocksize = (
dataset_info["blocksize"] if dataset_info["split_row_groups"] is True else None
)
split_row_groups = dataset_info["split_row_groups"]
fs = dataset_info["fs"]
aggregation_depth = dataset_info["aggregation_depth"]

if statistics:
if blocksize or (split_row_groups and int(split_row_groups) > 1):
parts, statistics = aggregate_row_groups(
parts, statistics, blocksize, split_row_groups, fs, aggregation_depth
)
return parts, statistics


def _calculate_divisions(statistics, dataset_info, npartitions):
# Use statistics to define divisions
divisions = None
if statistics:
calculate_divisions = dataset_info["kwargs"].get("calculate_divisions", None)
index = dataset_info["index"]
process_columns = index if index and len(index) == 1 else None
if (calculate_divisions is not False) and process_columns:
for sorted_column_info in sorted_columns(
statistics, columns=process_columns
):
if sorted_column_info["name"] in index:
divisions = sorted_column_info["divisions"]
break

return divisions or (None,) * (npartitions + 1)


class ReadParquet(PartitionsFiltered, BlockwiseIO):
"""Read a parquet dataset"""

Expand All @@ -38,7 +94,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO):
"categories",
"index",
"storage_options",
"calculate_divisions",
"gather_statistics",
"ignore_metadata_file",
"metadata_task_size",
"split_row_groups",
Expand All @@ -56,7 +112,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO):
"categories": None,
"index": None,
"storage_options": None,
"calculate_divisions": False,
"gather_statistics": True,
"ignore_metadata_file": False,
"metadata_task_size": None,
"split_row_groups": "infer",
Expand Down Expand Up @@ -152,7 +208,7 @@ def _dataset_info(self):
fs,
self.categories,
index,
self.calculate_divisions,
self.gather_statistics,
self.filters,
self.split_row_groups,
blocksize,
Expand All @@ -169,6 +225,7 @@ def _dataset_info(self):

# Infer meta, accounting for index and columns arguments.
meta = self.engine._create_dd_meta(dataset_info)
index = dataset_info["index"]
index = [index] if isinstance(index, str) else index
meta, index, columns = set_index_columns(
meta, index, self.operand("columns"), auto_index_allowed
Expand Down Expand Up @@ -196,21 +253,14 @@ def _plan(self):
dataset_info
)

# Parse dataset statistics from metadata (if available)
parts, divisions, _ = process_statistics(
parts,
stats,
dataset_info["filters"],
dataset_info["index"],
(
dataset_info["blocksize"]
if dataset_info["split_row_groups"] is True
else None
),
dataset_info["split_row_groups"],
dataset_info["fs"],
dataset_info["aggregation_depth"],
)
# Make sure parts and stats are aligned
parts, stats = _align_statistics(parts, stats)

# Use statistics to aggregate partitions
parts, stats = _aggregate_row_groups(parts, stats, dataset_info)

# Use statistics to calculate divisions
divisions = _calculate_divisions(stats, dataset_info, len(parts))

meta = dataset_info["meta"]
if len(divisions) < 2:
Expand All @@ -234,6 +284,7 @@ def _plan(self):
return {
"func": io_func,
"parts": parts,
"statistics": stats,
"divisions": divisions,
}

Expand All @@ -246,6 +297,15 @@ def _filtered_task(self, index: int):
return (operator.getitem, tsk, self.columns[0])
return tsk

def _statistics(self):
if self._pq_statistics and not self.filters:
row_count = tuple(stat["num-rows"] for stat in self._pq_statistics)
return {"row_count": RowCountStatistics(row_count)}

@property
def _pq_statistics(self):
return self._plan["statistics"]


#
# Filters
Expand Down
11 changes: 11 additions & 0 deletions dask_expr/io/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,14 @@ def test_parquet_complex_filters(tmpdir):

assert_eq(got, expect)
assert_eq(got.optimize(), expect)


def test_parquet_row_count_statistics(tmpdir):
# NOTE: We should no longer need to set `index`
# or `calculate_divisions` to gather row-count
# statistics after dask#10290
df = read_parquet(_make_file(tmpdir), index="a", calculate_divisions=True)
pdf = df.compute()

s = (df["b"] + 1).astype("Int32")
assert s.statistics().get("row_count").sum() == len(pdf)
74 changes: 74 additions & 0 deletions dask_expr/statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from __future__ import annotations
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could use some help understanding this module.

I think that in general we have yet to define what kinds of statistics we're going to capture, and how we plan to encode those. There are lots of options here.

I think what I'm seeing here is that your response is "we'll just make different classes for all the different kinds of things that people might want to encode". Is that correct? If so, I'm not totally bought into this just yet.

I think that the question of "how do we encode dataframe-level or partition-level statistics" is a big open one. I'm ok with us not having a clear answer on this before we move forward, but I want the level of sophistication of our solution to be correlated with our confidence. This feels like a somewhat sophisticated/specific solution (a few classes with some specific method APIs) but I don't have confidence that it's correct (or at least I don't know enough to be confident). Can you help me understand here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm. We may need to have a real-time chat about this one. My primary goal here was to keep things very simple, and so it worries me a bit that you see something sophisticated.

The general approach here is: “Adopt the same statistics approach suggested in #40, but use a simple data class as a container for the statistics so that we know if/how it should be passed from child to parent.” I only added the simple class structure to the mix after I started experimenting with row-count and min/max column statistics, and felt that there was unnecessary _statistics logic polluting several non-IO Expr classes. Since I know the statistics representation/framework is likely to evolve (or be replaced completely) in the future, I was hoping to keep the logic isolated. In the end, I decided to focus on the simple row-count case, and propose a class structure that I expect to be relevant to all statistics: We need hold some kind of statistics “data”, and we need to expose a mechanism to allow the passing of a specific kind of statistics between child and parent.

I suppose you are probably saying that that you would prefer not to introduce classes until we know that those classes will capture some of the other kinds of statistics we will want to track (e.g. min/max/null column statistics, and “shuffled-by” information)? This request is perfectly fair. I’ll admit that part of the reason I didn’t include min/max column statistics in this PR is that I hadn’t decided on the best way to represent partition-wise column statistics.

Aside: My favorite column-statistics approach I’ve played with so far is to track a ColumnStatistics(Statistics) object for each column, and for the data of that object to be a ColumnMaxima(PartitionStatistics) object where data is a tuple of {‘min’: …, ‘max’: …} dicts.



Another consideration is whether this design will allow us to push down “requests” for missing statistics into a ReadParquet expression at optimization time. I think the answer is “yes,” but this question is another reason I’d like to keep the statistics logic isolated in the meantime.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I don't like about the design in this PR is that it still uses the dict approach (as is) from #40 for tracking all known statistics. Whatever design we ultimately go with, we will probably need to enforce explicit rules for key names and collisions. I didn't bother to deal with this yet, but it was certainly on my mind.


from collections.abc import Iterable
from dataclasses import dataclass
from functools import singledispatchmethod
from typing import Any

from dask_expr.expr import Elemwise, Expr, Partitions


@dataclass(frozen=True)
class Statistics:
"""Abstract class for expression statistics

See Also
--------
PartitionStatistics
"""

data: Any

@singledispatchmethod
def assume(self, parent: Expr) -> Statistics | None:
"""Statistics that a "parent" Expr may assume

A return value of `None` (the default) means that
`parent` is not eligable to assume this kind of
statistics.
"""
return None


@dataclass(frozen=True)
class PartitionStatistics(Statistics):
"""Statistics containing a distinct value for every partition

See Also
--------
RowCountStatistics
"""

data: Iterable


@PartitionStatistics.assume.register
def _partitionstatistics_partitions(self, parent: Partitions):
# A `Partitions` expression may assume statistics
# from the selected partitions
return type(self)(
type(self.data)(
part for i, part in enumerate(self.data) if i in parent.partitions
)
)


#
# PartitionStatistics sub-classes
#


@dataclass(frozen=True)
class RowCountStatistics(PartitionStatistics):
"""Tracks the row count of each partition"""

def sum(self):
"""Return the total row-count of all partitions"""
return sum(self.data)


@RowCountStatistics.assume.register
def _rowcount_elemwise(self, parent: Elemwise):
# All Element-wise operations may assume
# row-count statistics
return self
17 changes: 17 additions & 0 deletions dask_expr/tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,3 +429,20 @@ def test_repartition_divisions(df, opt):
if len(part):
assert part.min() >= df2.divisions[p]
assert part.max() < df2.divisions[p + 1]


def test_len(df, pdf):
df2 = df[["x"]] + 1
assert len(df2) == len(pdf)

first = df2.partitions[0].compute()
assert len(df2.partitions[0]) == len(first)


def test_row_count_statistics(df, pdf):
df2 = df[["x"]] + 1
assert df2.statistics().get("row_count").sum() == len(pdf)
assert df[df.x > 5].statistics().get("row_count") is None
assert df2.partitions[0].statistics().get("row_count").sum() == len(
df2.partitions[0]
)