Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit e90e1d8

Browse files
authored
chore: Add a function to traverse BFET and encode type usage (#2390)
Next step is to add the encoded usage to the job_config.label before SQL dispatch. Related bug: b/406578908
1 parent 74150c5 commit e90e1d8

File tree

4 files changed

+230
-19
lines changed

4 files changed

+230
-19
lines changed

bigframes/core/logging/data_types.py

Lines changed: 103 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,115 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
17+
import functools
1518

1619
from bigframes import dtypes
20+
from bigframes.core import agg_expressions, bigframe_node, expression, nodes
21+
from bigframes.core.rewrite import schema_binding
22+
23+
IGNORED_NODES = (
24+
nodes.SelectionNode,
25+
nodes.ReadLocalNode,
26+
nodes.ReadTableNode,
27+
nodes.ConcatNode,
28+
nodes.RandomSampleNode,
29+
nodes.FromRangeNode,
30+
nodes.PromoteOffsetsNode,
31+
nodes.ReversedNode,
32+
nodes.SliceNode,
33+
nodes.ResultNode,
34+
)
35+
36+
37+
def encode_type_refs(root: bigframe_node.BigFrameNode) -> str:
38+
return f"{root.reduce_up(_encode_type_refs_from_node):x}"
39+
40+
41+
def _encode_type_refs_from_node(
42+
node: bigframe_node.BigFrameNode, child_results: tuple[int, ...]
43+
) -> int:
44+
child_result = functools.reduce(lambda x, y: x | y, child_results, 0)
45+
46+
curr_result = 0
47+
if isinstance(node, nodes.FilterNode):
48+
curr_result = _encode_type_refs_from_expr(node.predicate, node.child)
49+
elif isinstance(node, nodes.ProjectionNode):
50+
for assignment in node.assignments:
51+
expr = assignment[0]
52+
if isinstance(expr, (expression.DerefOp)):
53+
# Ignore direct assignments in projection nodes.
54+
continue
55+
curr_result = curr_result | _encode_type_refs_from_expr(
56+
assignment[0], node.child
57+
)
58+
elif isinstance(node, nodes.OrderByNode):
59+
for by in node.by:
60+
curr_result = curr_result | _encode_type_refs_from_expr(
61+
by.scalar_expression, node.child
62+
)
63+
elif isinstance(node, nodes.JoinNode):
64+
for left, right in node.conditions:
65+
curr_result = (
66+
curr_result
67+
| _encode_type_refs_from_expr(left, node.left_child)
68+
| _encode_type_refs_from_expr(right, node.right_child)
69+
)
70+
elif isinstance(node, nodes.InNode):
71+
curr_result = _encode_type_refs_from_expr(node.left_col, node.left_child)
72+
elif isinstance(node, nodes.AggregateNode):
73+
for agg, _ in node.aggregations:
74+
curr_result = curr_result | _encode_type_refs_from_expr(agg, node.child)
75+
elif isinstance(node, nodes.WindowOpNode):
76+
for grouping_key in node.window_spec.grouping_keys:
77+
curr_result = curr_result | _encode_type_refs_from_expr(
78+
grouping_key, node.child
79+
)
80+
for ordering_expr in node.window_spec.ordering:
81+
curr_result = curr_result | _encode_type_refs_from_expr(
82+
ordering_expr.scalar_expression, node.child
83+
)
84+
for col_def in node.agg_exprs:
85+
curr_result = curr_result | _encode_type_refs_from_expr(
86+
col_def.expression, node.child
87+
)
88+
elif isinstance(node, nodes.ExplodeNode):
89+
for col_id in node.column_ids:
90+
curr_result = curr_result | _encode_type_refs_from_expr(col_id, node.child)
91+
elif isinstance(node, IGNORED_NODES):
92+
# Do nothing
93+
pass
94+
else:
95+
# For unseen nodes, do not raise errors as this is the logging path, but
96+
# we should cover those nodes either in the branches above, or place them
97+
# in the IGNORED_NODES collection.
98+
pass
99+
100+
return child_result | curr_result
101+
102+
103+
def _encode_type_refs_from_expr(
104+
expr: expression.Expression, child_node: bigframe_node.BigFrameNode
105+
) -> int:
106+
# TODO(b/409387790): Remove this branch once SQLGlot compiler fully replaces Ibis compiler
107+
if not expr.is_resolved:
108+
if isinstance(expr, agg_expressions.Aggregation):
109+
expr = schema_binding._bind_schema_to_aggregation_expr(expr, child_node)
110+
else:
111+
expr = expression.bind_schema_fields(expr, child_node.field_by_id)
17112

113+
result = _get_dtype_mask(expr.output_type)
114+
for child_expr in expr.children:
115+
result = result | _encode_type_refs_from_expr(child_expr, child_node)
18116

19-
def _add_data_type(existing_types: int, curr_type: dtypes.Dtype) -> int:
20-
return existing_types | _get_dtype_mask(curr_type)
117+
return result
21118

22119

23-
def _get_dtype_mask(dtype: dtypes.Dtype) -> int:
120+
def _get_dtype_mask(dtype: dtypes.Dtype | None) -> int:
121+
if dtype is None:
122+
# If the dtype is not given, ignore
123+
return 0
24124
if dtype == dtypes.INT_DTYPE:
25125
return 1 << 1
26126
if dtype == dtypes.FLOAT_DTYPE:
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Sequence
16+
17+
import pandas as pd
18+
import pyarrow as pa
19+
20+
from bigframes import dtypes
21+
from bigframes.core.logging import data_types
22+
import bigframes.pandas as bpd
23+
24+
25+
def encode_types(inputs: Sequence[dtypes.Dtype]) -> str:
26+
encoded_val = 0
27+
for t in inputs:
28+
encoded_val = encoded_val | data_types._get_dtype_mask(t)
29+
30+
return f"{encoded_val:x}"
31+
32+
33+
def test_get_type_refs_no_op(scalars_df_index):
34+
node = scalars_df_index._block._expr.node
35+
expected_types: list[dtypes.Dtype] = []
36+
37+
assert data_types.encode_type_refs(node) == encode_types(expected_types)
38+
39+
40+
def test_get_type_refs_projection(scalars_df_index):
41+
node = (
42+
scalars_df_index["datetime_col"] - scalars_df_index["datetime_col"]
43+
)._block._expr.node
44+
expected_types = [dtypes.DATETIME_DTYPE, dtypes.TIMEDELTA_DTYPE]
45+
46+
assert data_types.encode_type_refs(node) == encode_types(expected_types)
47+
48+
49+
def test_get_type_refs_filter(scalars_df_index):
50+
node = scalars_df_index[scalars_df_index["int64_col"] > 0]._block._expr.node
51+
expected_types = [dtypes.INT_DTYPE, dtypes.BOOL_DTYPE]
52+
53+
assert data_types.encode_type_refs(node) == encode_types(expected_types)
54+
55+
56+
def test_get_type_refs_order_by(scalars_df_index):
57+
node = scalars_df_index.sort_index()._block._expr.node
58+
expected_types = [dtypes.INT_DTYPE]
59+
60+
assert data_types.encode_type_refs(node) == encode_types(expected_types)
61+
62+
63+
def test_get_type_refs_join(scalars_df_index):
64+
node = (
65+
scalars_df_index[["int64_col"]].merge(
66+
scalars_df_index[["float64_col"]],
67+
left_on="int64_col",
68+
right_on="float64_col",
69+
)
70+
)._block._expr.node
71+
expected_types = [dtypes.INT_DTYPE, dtypes.FLOAT_DTYPE]
72+
73+
assert data_types.encode_type_refs(node) == encode_types(expected_types)
74+
75+
76+
def test_get_type_refs_isin(scalars_df_index):
77+
node = scalars_df_index["string_col"].isin(["a"])._block._expr.node
78+
expected_types = [dtypes.STRING_DTYPE, dtypes.BOOL_DTYPE]
79+
80+
assert data_types.encode_type_refs(node) == encode_types(expected_types)
81+
82+
83+
def test_get_type_refs_agg(scalars_df_index):
84+
node = scalars_df_index[["bool_col", "string_col"]].count()._block._expr.node
85+
expected_types = [
86+
dtypes.INT_DTYPE,
87+
dtypes.BOOL_DTYPE,
88+
dtypes.STRING_DTYPE,
89+
dtypes.FLOAT_DTYPE,
90+
]
91+
92+
assert data_types.encode_type_refs(node) == encode_types(expected_types)
93+
94+
95+
def test_get_type_refs_window(scalars_df_index):
96+
node = (
97+
scalars_df_index[["string_col", "bool_col"]]
98+
.groupby("string_col")
99+
.rolling(window=3)
100+
.count()
101+
._block._expr.node
102+
)
103+
expected_types = [dtypes.STRING_DTYPE, dtypes.BOOL_DTYPE, dtypes.INT_DTYPE]
104+
105+
assert data_types.encode_type_refs(node) == encode_types(expected_types)
106+
107+
108+
def test_get_type_refs_explode():
109+
df = bpd.DataFrame({"A": ["a", "b"], "B": [[1, 2], [3, 4, 5]]})
110+
node = df.explode("B")._block._expr.node
111+
expected_types = [pd.ArrowDtype(pa.list_(pa.int64()))]
112+
113+
assert data_types.encode_type_refs(node) == encode_types(expected_types)

tests/unit/core/logging/test_data_types.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
@pytest.mark.parametrize(
3030
("dtype", "expected_mask"),
3131
[
32+
(None, 0),
3233
(UNKNOWN_TYPE, 1 << 0),
3334
(dtypes.INT_DTYPE, 1 << 1),
3435
(dtypes.FLOAT_DTYPE, 1 << 2),
@@ -51,19 +52,3 @@
5152
)
5253
def test_get_dtype_mask(dtype, expected_mask):
5354
assert data_types._get_dtype_mask(dtype) == expected_mask
54-
55-
56-
def test_add_data_type__type_overlap_no_op():
57-
curr_type = dtypes.STRING_DTYPE
58-
existing_types = data_types._get_dtype_mask(curr_type)
59-
60-
assert data_types._add_data_type(existing_types, curr_type) == existing_types
61-
62-
63-
def test_add_data_type__new_type_updated():
64-
curr_type = dtypes.STRING_DTYPE
65-
existing_types = data_types._get_dtype_mask(dtypes.INT_DTYPE)
66-
67-
assert data_types._add_data_type(
68-
existing_types, curr_type
69-
) == existing_types | data_types._get_dtype_mask(curr_type)

0 commit comments

Comments
 (0)