Skip to content

Commit 422bc49

Browse files
committed
Add predicate filtering
1 parent 19c126a commit 422bc49

File tree

10 files changed

+119
-125
lines changed

10 files changed

+119
-125
lines changed

plateau/core/common_metadata.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -166,15 +166,11 @@ def gen_metadata(schema: SchemaWrapper) -> dict[str, Any]:
166166
{
167167
"name": field.name,
168168
"field_name": field.name,
169-
# the following fields are NOT accessed when resorting the columns
170-
# "pandas_type": str(field.type), # optional
171-
# "numpy_type": str(field.type), # optional
172-
# "metadata": field.metadata, # optional: decode if needed
169+
# other are NOT accessed when resorting the columns
173170
}
174171
)
175172

176173
return pandas_metadata
177-
# > {'columns': [{'name': 'A', 'field_name': 'A'}, {'name': 'B', 'field_name': 'B'}, {'name': 'C', 'field_name': 'C'}, {'name': 'D', 'field_name': 'D'}, {'name': 'E', 'field_name': 'E'}, {'name': 'F', 'field_name': 'F'}], 'index_columns': [], 'pandas_version': '2.2.3'}
178174

179175

180176
def normalize_column_order(schema, partition_keys=None):
@@ -284,10 +280,6 @@ def make_meta(obj, origin, partition_keys=None):
284280
elif isinstance(obj, pa.Table):
285281
return obj.schema
286282

287-
# normalize_column_order(
288-
# SchemaWrapper(obj.schema, origin), partition_keys=partition_keys
289-
# )
290-
291283
if not isinstance(obj, pd.DataFrame):
292284
raise ValueError("Input must be a pyarrow schema, or a pandas dataframe")
293285

plateau/io/duckdb/dataframe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
def read_table_as_ddb(
2424
uuid: str,
2525
store: KeyValueStore,
26-
table: str,
26+
as_table: str,
2727
predicates: list[list[tuple[str, str, Any]]] | None = None,
2828
**kwargs, # support for everything else
2929
) -> duckdb.DuckDBPyConnection:
@@ -37,7 +37,7 @@ def read_table_as_ddb(
3737

3838
table_obj = read_table_as_arrow(uuid, store=store, predicates=predicates, **kwargs)
3939
con = duckdb.connect()
40-
con.register(table, table_obj)
40+
con.register(as_table, table_obj)
4141
return con
4242

4343

plateau/io/duckdb/helper.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,7 @@ def align_categories(tables: list[pa.Table], categoricals: list[str]) -> list[pa
5454
if not categoricals:
5555
return tables
5656

57-
# Process each categorical column
5857
for column in categoricals:
59-
all_types = [table[column].type for table in tables]
6058

6159
union_values = set()
6260
baseline_categories = None
@@ -69,29 +67,22 @@ def align_categories(tables: list[pa.Table], categoricals: list[str]) -> list[pa
6967
continue
7068

7169
col = table[column]
72-
# Ensure the column is dictionary encoded.
73-
# if not pa.types.is_dictionary(col.type):
74-
# col = pc.dictionary_encode(col)
75-
76-
# Combine chunks to get a single Array (if needed)
7770
col_combined = (
7871
col.combine_chunks() if isinstance(col, pa.ChunkedArray) else col
7972
)
80-
# Extract the dictionary as a Python list.
8173
cats = col_combined.dictionary.to_pylist()
8274
union_values.update(cats)
8375
if table.num_rows > baseline_num_rows:
8476
baseline_num_rows = table.num_rows
8577
baseline_categories = cats
8678

8779
if baseline_categories is None:
88-
# No table contained this column.
8980
continue
9081

9182
# Build the new dictionary order: use the baseline order then add any additional values
83+
# stay consistent with the utils:align_categories function
9284
extra = union_values - set(baseline_categories)
9385
new_dictionary = baseline_categories + sorted(extra)
94-
# Build a lookup map for quick conversion: value -> new index
9586
union_map = {val: idx for idx, val in enumerate(new_dictionary)}
9687

9788
# Second pass: recast the column in every table to use the new dictionary
@@ -104,24 +95,20 @@ def align_categories(tables: list[pa.Table], categoricals: list[str]) -> list[pa
10495
col = table[column]
10596
if not pa.types.is_dictionary(col.type):
10697
col = pc.dictionary_encode(col)
107-
# Decode the column to its raw values (as a Python list)
10898
col_combined = (
10999
col.combine_chunks() if isinstance(col, pa.ChunkedArray) else col
110100
)
111101
decoded = col_combined.to_pylist()
112-
# Map each value to the new dictionary index (preserving nulls)
113102
new_indices = [
114103
union_map[val] if val is not None else None for val in decoded
115104
]
116105
new_indices_array = pa.array(new_indices, type=pa.int32())
117-
# Create a new dictionary array with the new dictionary
118106
new_dict_array = pa.DictionaryArray.from_arrays(
119107
new_indices_array, pa.array(new_dictionary, type=col.type.value_type)
120108
)
121-
# Replace the column in the table
122109
col_index = table.schema.get_field_index(column)
123110
table = table.set_column(col_index, column, new_dict_array)
124111
new_tables.append(table)
125-
tables = new_tables # update tables for next categorical column
112+
tables = new_tables
126113

127114
return tables

plateau/io/iter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def read_dataset_as_metapartitions__iterator(
7979
MetaPartition.concat_metapartitions_arrow
8080
if arrow_mode
8181
else MetaPartition.concat_metapartitions
82-
) # Dirty, refactor later
82+
)
8383
mp = concatenate(
8484
[
8585
mp_inner.load_dataframes(

plateau/io_components/metapartition.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def __init__(
256256
self.schema = schema
257257
self.table_name = table_name
258258
if data is not None and schema is None:
259-
self.schema = make_meta( # handles pa.Table as well
259+
self.schema = make_meta(
260260
data, origin=f"{table_name}/{label}", partition_keys=partition_keys
261261
)
262262

@@ -691,9 +691,7 @@ def load_dataframes(
691691
predicate_pushdown_to_io=predicate_pushdown_to_io,
692692
predicates=filtered_predicates,
693693
date_as_object=dates_as_object,
694-
**(
695-
{"return_pyarrow_table": True} if arrow_mode else {}
696-
), # dirty hack for now
694+
**({"return_pyarrow_table": True} if arrow_mode else {}),
697695
)
698696
LOGGER.debug(
699697
"Loaded dataframe %s in %s seconds.", self.file, time.time() - start
@@ -736,7 +734,6 @@ def load_dataframes(
736734
", ".join(sorted(missing_cols))
737735
)
738736
)
739-
# Really ugly, refactor later!
740737
if arrow_mode and list(df_or_arrow.column_names) != columns:
741738
# Arrow tables are immutable, so we need to create a new table
742739
df_or_arrow = df_or_arrow.select(columns)
@@ -811,7 +808,6 @@ def _reconstruct_index_columns_arrow(
811808
)
812809

813810
# Create an array filled with the repeated key value
814-
# FIXME: remove pdb.set_trace()
815811
if categories and name in categories:
816812
# Use dictionary type (categorical)
817813
dictionary_array = pa.DictionaryArray.from_arrays(
@@ -824,8 +820,7 @@ def _reconstruct_index_columns_arrow(
824820

825821
new_columns.append((name, arrow_value))
826822

827-
# Prepend new index columns
828-
for name, array in reversed(new_columns): # insert in reverse to maintain order
823+
for name, array in reversed(new_columns):
829824
table = table.append_column(name, array)
830825

831826
# move newly added column to front
@@ -1210,7 +1205,7 @@ def partition_on(self, partition_on: str | Sequence[str]):
12101205
partition_on = [partition_on]
12111206
partition_on = self._ensure_compatible_partitioning(partition_on)
12121207

1213-
new_data = self._partition_data(partition_on) # WIP: needs arrow compatibility
1208+
new_data = self._partition_data(partition_on)
12141209

12151210
for label, data in new_data.items():
12161211
tmp_mp = MetaPartition(
@@ -1399,8 +1394,6 @@ def concat_metapartitions_arrow(
13991394

14001395
new_table = pa.concat_tables(data)
14011396

1402-
# TODO: What about align_categories?
1403-
14041397
new_schema = validate_compatible(schema)
14051398

14061399
new_label = MetaPartition._merge_labels(metapartitions, label_merger)

plateau/io_components/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,16 +126,18 @@ def _ensure_compatible_indices(
126126
def group_table_by_partition_keys(table: pa.Table, partition_on: list[str]):
127127
"""Yield tuples of partition keys and pyarrow tables (excluding the partition_on columns) using polars.
128128
129-
Pyarrow's groupby is not really useful for this specific purpose, thus the detour through polars."""
129+
Pyarrow's groupby is not really useful for this specific purpose, thus the detour through polars.
130+
"""
130131

131132
df = pl.from_arrow(table)
132133

133134
groups = df.group_by(partition_on, maintain_order=True)
134135

135136
for key, group in groups:
136-
arrow_table = group.drop(partition_on).to_arrow() # drop partition keys
137+
arrow_table = group.drop(partition_on).to_arrow() # drop partition keys
137138
yield key, arrow_table
138139

140+
139141
def validate_partition_keys(
140142
dataset_uuid,
141143
store,

plateau/serialization/_generic.py

Lines changed: 65 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#!/usr/bin/env python
21
"""This module contains functionality for persisting/serialising DataFrames.
32
43
Available constants
@@ -13,7 +12,11 @@
1312
:meta public:
1413
"""
1514

15+
import pdb
1616
import warnings
17+
from duckdb import arrow
18+
import pyarrow as pa
19+
import pyarrow.compute as pc
1720
from collections.abc import Iterable
1821
from typing import TYPE_CHECKING, TypeVar
1922

@@ -265,6 +268,61 @@ def columns_in_predicates(predicates: PredicatesType) -> set[str]:
265268
return columns
266269

267270

271+
def _filter_df_or_table_from_predicates(
272+
df_or_table: pd.DataFrame | pa.Table,
273+
predicates: PredicatesType | None,
274+
strict_date_types: bool = False,
275+
arrow_mode: bool = False,
276+
) -> pd.DataFrame | pa.Table:
277+
if predicates is None:
278+
return df_or_table
279+
indexer: npt.NDArray[np.bool_] = np.zeros(len(df_or_table), dtype=bool)
280+
for conjunction in predicates:
281+
inner_indexer: npt.NDArray[np.bool_] = np.ones(len(df_or_table), dtype=bool)
282+
for column, op, value in conjunction:
283+
column_name = ensure_unicode_string_type(column)
284+
values = (
285+
df_or_table.column(column_name).to_numpy()
286+
if arrow_mode
287+
else df_or_table[column_name].values
288+
)
289+
filter_array_like(
290+
values,
291+
op,
292+
value,
293+
inner_indexer,
294+
inner_indexer,
295+
strict_date_types=strict_date_types,
296+
column_name=column_name,
297+
)
298+
indexer = inner_indexer | indexer
299+
300+
if not arrow_mode:
301+
return df_or_table[indexer]
302+
303+
table_mask = pa.array(indexer, type=pa.bool_())
304+
return df_or_table.filter(table_mask)
305+
306+
307+
# Casting pyarrow structures to numpy ones might introduce some overhead
308+
# but we do not have to maintain twice the logic for filtering from predicates
309+
def filter_table_from_predicates(table: pa.Table, predicates: PredicatesType):
310+
"""Filter a `pyarrow.Table` based on predicates in disjunctive normal
311+
form.
312+
313+
See Also
314+
--------
315+
* :ref:`predicate_pushdown`
316+
* :ref:`filter_df_from_predicates`
317+
"""
318+
return _filter_df_or_table_from_predicates(
319+
df_or_table=table,
320+
predicates=predicates,
321+
strict_date_types=False,
322+
arrow_mode=True,
323+
)
324+
325+
268326
def filter_df_from_predicates(
269327
df: pd.DataFrame,
270328
predicates: PredicatesType | None,
@@ -288,24 +346,12 @@ def filter_df_from_predicates(
288346
--------
289347
* :ref:`predicate_pushdown`
290348
"""
291-
if predicates is None:
292-
return df
293-
indexer: npt.NDArray[np.bool_] = np.zeros(len(df), dtype=bool)
294-
for conjunction in predicates:
295-
inner_indexer: npt.NDArray[np.bool_] = np.ones(len(df), dtype=bool)
296-
for column, op, value in conjunction:
297-
column_name = ensure_unicode_string_type(column)
298-
filter_array_like(
299-
df[column_name].values,
300-
op,
301-
value,
302-
inner_indexer,
303-
inner_indexer,
304-
strict_date_types=strict_date_types,
305-
column_name=column_name,
306-
)
307-
indexer = inner_indexer | indexer
308-
return df[indexer]
349+
return _filter_df_or_table_from_predicates(
350+
df_or_table=df,
351+
predicates=predicates,
352+
strict_date_types=strict_date_types,
353+
arrow_mode=False,
354+
)
309355

310356

311357
def _handle_categorical_data(array_like, require_ordered):

plateau/serialization/_parquet.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
check_predicates,
2121
filter_df,
2222
filter_df_from_predicates,
23+
filter_table_from_predicates,
2324
)
2425
from ._io_buffer import BlockBuffer
2526
from ._util import ensure_unicode_string_type, schema_metadata_bytes_to_object
@@ -47,7 +48,7 @@ def _empty_table_from_schema(parquet_file: ParquetFile) -> pa.Table:
4748
return schema.empty_table()
4849

4950

50-
def _reset_dictionary_columns(table: pa.Table, exclude=None):
51+
def _reset_dictionary_columns(table: pa.Table, exclude=None) -> pa.Table:
5152
"""We need to ensure that the dtype is exactly as requested, see GH227."""
5253
if exclude is None:
5354
exclude = []
@@ -197,9 +198,7 @@ def _restore_dataframe(
197198
# otherwise full read en block is the better option.
198199
if (not predicate_pushdown_to_io) or (columns is None and predicates is None):
199200
with pa.BufferReader(store.get(key)) as reader:
200-
table = pq.read_pandas(
201-
reader, columns=columns
202-
) # TODO: is this relevant?
201+
table = pq.read_pandas(reader, columns=columns)
203202
else:
204203
if HAVE_BOTO and isinstance(store, BotoStore):
205204
# Parquet and seeks on S3 currently leak connections thus
@@ -281,7 +280,20 @@ def _restore_dataframe(
281280
table = table.cast(schema_metadata_bytes_to_object(table.schema))
282281

283282
if return_pyarrow_table:
284-
return table
283+
table.rename_columns(
284+
[ensure_unicode_string_type(name) for name in table.column_names]
285+
)
286+
287+
if filter_query:
288+
raise ValueError(
289+
"filter_query is not supported when 'return_pyarrow_table' is True (if you use arrow_mode)."
290+
"Hint: please express your filter query as predicates."
291+
)
292+
293+
if predicates:
294+
table = filter_table_from_predicates(table, predicates)
295+
296+
return table if columns is None else table.select(columns)
285297

286298
_coerce = {"coerce_temporal_nanoseconds": True}
287299
df = table.to_pandas(date_as_object=date_as_object, **_coerce)

0 commit comments

Comments
 (0)