Skip to content
This repository was archived by the owner on May 7, 2026. It is now read-only.

Commit 0d0ad68

Browse files
committed
Merge remote-tracking branch 'origin/main' into b409390651-progress-bar
2 parents e7ca461 + 56e5033 commit 0d0ad68

21 files changed

Lines changed: 366 additions & 77 deletions

File tree

bigframes/bigquery/_operations/ai.py

Lines changed: 74 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -348,20 +348,20 @@ def if_(
348348
provides optimization such that not all rows are evaluated with the LLM.
349349
350350
**Examples:**
351-
>>> import bigframes.pandas as bpd
352-
>>> import bigframes.bigquery as bbq
353-
>>> bpd.options.display.progress_bar = None
354-
>>> us_state = bpd.Series(["Massachusetts", "Illinois", "Hawaii"])
355-
>>> bbq.ai.if_((us_state, " has a city called Springfield"))
356-
0 True
357-
1 True
358-
2 False
359-
dtype: boolean
360-
361-
>>> us_state[bbq.ai.if_((us_state, " has a city called Springfield"))]
362-
0 Massachusetts
363-
1 Illinois
364-
dtype: string
351+
>>> import bigframes.pandas as bpd
352+
>>> import bigframes.bigquery as bbq
353+
>>> bpd.options.display.progress_bar = None
354+
>>> us_state = bpd.Series(["Massachusetts", "Illinois", "Hawaii"])
355+
>>> bbq.ai.if_((us_state, " has a city called Springfield"))
356+
0 True
357+
1 True
358+
2 False
359+
dtype: boolean
360+
361+
>>> us_state[bbq.ai.if_((us_state, " has a city called Springfield"))]
362+
0 Massachusetts
363+
1 Illinois
364+
dtype: string
365365
366366
Args:
367367
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
@@ -386,6 +386,56 @@ def if_(
386386
return series_list[0]._apply_nary_op(operator, series_list[1:])
387387

388388

389+
@log_adapter.method_logger(custom_base_name="bigquery_ai")
390+
def classify(
391+
input: PROMPT_TYPE,
392+
categories: tuple[str, ...] | list[str],
393+
*,
394+
connection_id: str | None = None,
395+
) -> series.Series:
396+
"""
397+
Classifies a given input into one of the specified categories. It will always return one of the provided categories best fit the prompt input.
398+
399+
**Examples:**
400+
401+
>>> import bigframes.pandas as bpd
402+
>>> import bigframes.bigquery as bbq
403+
>>> bpd.options.display.progress_bar = None
404+
>>> df = bpd.DataFrame({'creature': ['Cat', 'Salmon']})
405+
>>> df['type'] = bbq.ai.classify(df['creature'], ['Mammal', 'Fish'])
406+
>>> df
407+
creature type
408+
0 Cat Mammal
409+
1 Salmon Fish
410+
<BLANKLINE>
411+
[2 rows x 2 columns]
412+
413+
Args:
414+
input (Series | List[str|Series] | Tuple[str|Series, ...]):
415+
A mixture of Series and string literals that specifies the input to send to the model. The Series can be BigFrames Series
416+
or pandas Series.
417+
categories (tuple[str, ...] | list[str]):
418+
Categories to classify the input into.
419+
connection_id (str, optional):
420+
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
421+
If not provided, the connection from the current session will be used.
422+
423+
Returns:
424+
bigframes.series.Series: A new series of strings.
425+
"""
426+
427+
prompt_context, series_list = _separate_context_and_series(input)
428+
assert len(series_list) > 0
429+
430+
operator = ai_ops.AIClassify(
431+
prompt_context=tuple(prompt_context),
432+
categories=tuple(categories),
433+
connection_id=_resolve_connection_id(series_list[0], connection_id),
434+
)
435+
436+
return series_list[0]._apply_nary_op(operator, series_list[1:])
437+
438+
389439
@log_adapter.method_logger(custom_base_name="bigquery_ai")
390440
def score(
391441
prompt: PROMPT_TYPE,
@@ -398,15 +448,16 @@ def score(
398448
rubric with examples in the prompt.
399449
400450
**Examples:**
401-
>>> import bigframes.pandas as bpd
402-
>>> import bigframes.bigquery as bbq
403-
>>> bpd.options.display.progress_bar = None
404-
>>> animal = bpd.Series(["Tiger", "Rabbit", "Blue Whale"])
405-
>>> bbq.ai.score(("Rank the relative weights of ", animal, " on the scale from 1 to 3")) # doctest: +SKIP
406-
0 2.0
407-
1 1.0
408-
2 3.0
409-
dtype: Float64
451+
452+
>>> import bigframes.pandas as bpd
453+
>>> import bigframes.bigquery as bbq
454+
>>> bpd.options.display.progress_bar = None
455+
>>> animal = bpd.Series(["Tiger", "Rabbit", "Blue Whale"])
456+
>>> bbq.ai.score(("Rank the relative weights of ", animal, " on the scale from 1 to 3")) # doctest: +SKIP
457+
0 2.0
458+
1 1.0
459+
2 3.0
460+
dtype: Float64
410461
411462
Args:
412463
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):

bigframes/core/array_value.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import functools
1919
import typing
2020
from typing import Iterable, List, Mapping, Optional, Sequence, Tuple
21-
import warnings
2221

2322
import google.cloud.bigquery
2423
import pandas
@@ -37,7 +36,6 @@
3736
import bigframes.core.tree_properties
3837
from bigframes.core.window_spec import WindowSpec
3938
import bigframes.dtypes
40-
import bigframes.exceptions as bfe
4139
import bigframes.operations as ops
4240
import bigframes.operations.aggregations as agg_ops
4341

@@ -101,12 +99,6 @@ def from_table(
10199
):
102100
if offsets_col and primary_key:
103101
raise ValueError("must set at most one of 'offests', 'primary_key'")
104-
if any(i.field_type == "JSON" for i in table.schema if i.name in schema.names):
105-
msg = bfe.format_message(
106-
"JSON column interpretation as a custom PyArrow extention in `db_dtypes` "
107-
"is a preview feature and subject to change."
108-
)
109-
warnings.warn(msg, bfe.PreviewWarning)
110102
# define data source only for needed columns, this makes row-hashing cheaper
111103
table_def = nodes.GbqTable.from_table(table, columns=schema.names)
112104

bigframes/core/backports.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Helpers for working across versions of different depenencies."""
16+
17+
from typing import List
18+
19+
import pyarrow
20+
21+
22+
def pyarrow_struct_type_fields(struct_type: pyarrow.StructType) -> List[pyarrow.Field]:
23+
"""StructType.fields was added in pyarrow 18.
24+
25+
See: https://arrow.apache.org/docs/18.0/python/generated/pyarrow.StructType.html
26+
"""
27+
28+
if hasattr(struct_type, "fields"):
29+
return struct_type.fields
30+
31+
return [
32+
struct_type.field(field_index) for field_index in range(struct_type.num_fields)
33+
]

bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2039,6 +2039,18 @@ def ai_if(*values: ibis_types.Value, op: ops.AIIf) -> ibis_types.StructValue:
20392039
).to_expr()
20402040

20412041

2042+
@scalar_op_compiler.register_nary_op(ops.AIClassify, pass_op=True)
2043+
def ai_classify(
2044+
*values: ibis_types.Value, op: ops.AIClassify
2045+
) -> ibis_types.StructValue:
2046+
2047+
return ai_ops.AIClassify(
2048+
_construct_prompt(values, op.prompt_context), # type: ignore
2049+
op.categories, # type: ignore
2050+
op.connection_id, # type: ignore
2051+
).to_expr()
2052+
2053+
20422054
@scalar_op_compiler.register_nary_op(ops.AIScore, pass_op=True)
20432055
def ai_score(*values: ibis_types.Value, op: ops.AIScore) -> ibis_types.StructValue:
20442056

bigframes/core/compile/sqlglot/expressions/ai_ops.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,21 @@ def _(*exprs: TypedExpr, op: ops.AIIf) -> sge.Expression:
6161
return sge.func("AI.IF", *args)
6262

6363

64+
@register_nary_op(ops.AIClassify, pass_op=True)
65+
def _(*exprs: TypedExpr, op: ops.AIClassify) -> sge.Expression:
66+
category_literals = [sge.Literal.string(cat) for cat in op.categories]
67+
categories_arg = sge.Kwarg(
68+
this="categories", expression=sge.array(*category_literals)
69+
)
70+
71+
args = [
72+
_construct_prompt(exprs, op.prompt_context, param_name="input"),
73+
categories_arg,
74+
] + _construct_named_args(op)
75+
76+
return sge.func("AI.CLASSIFY", *args)
77+
78+
6479
@register_nary_op(ops.AIScore, pass_op=True)
6580
def _(*exprs: TypedExpr, op: ops.AIScore) -> sge.Expression:
6681
args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op)
@@ -69,7 +84,9 @@ def _(*exprs: TypedExpr, op: ops.AIScore) -> sge.Expression:
6984

7085

7186
def _construct_prompt(
72-
exprs: tuple[TypedExpr, ...], prompt_context: tuple[str | None, ...]
87+
exprs: tuple[TypedExpr, ...],
88+
prompt_context: tuple[str | None, ...],
89+
param_name: str = "prompt",
7390
) -> sge.Kwarg:
7491
prompt: list[str | sge.Expression] = []
7592
column_ref_idx = 0
@@ -80,7 +97,7 @@ def _construct_prompt(
8097
else:
8198
prompt.append(sge.Literal.string(elem))
8299

83-
return sge.Kwarg(this="prompt", expression=sge.Tuple(expressions=prompt))
100+
return sge.Kwarg(this=param_name, expression=sge.Tuple(expressions=prompt))
84101

85102

86103
def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]:

bigframes/core/indexes/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,12 +171,16 @@ def shape(self) -> typing.Tuple[int]:
171171

172172
@property
173173
def dtype(self):
174-
return self._block.index.dtypes[0] if self.nlevels == 1 else np.dtype("O")
174+
dtype = self._block.index.dtypes[0] if self.nlevels == 1 else np.dtype("O")
175+
bigframes.dtypes.warn_on_db_dtypes_json_dtype([dtype])
176+
return dtype
175177

176178
@property
177179
def dtypes(self) -> pandas.Series:
180+
dtypes = self._block.index.dtypes
181+
bigframes.dtypes.warn_on_db_dtypes_json_dtype(dtypes)
178182
return pandas.Series(
179-
data=self._block.index.dtypes,
183+
data=dtypes,
180184
index=typing.cast(typing.Tuple, self._block.index.names),
181185
)
182186

bigframes/dataframe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,9 @@ def at(self) -> indexers.AtDataFrameIndexer:
321321

322322
@property
323323
def dtypes(self) -> pandas.Series:
324-
return pandas.Series(data=self._block.dtypes, index=self._block.column_labels)
324+
dtypes = self._block.dtypes
325+
bigframes.dtypes.warn_on_db_dtypes_json_dtype(dtypes)
326+
return pandas.Series(data=dtypes, index=self._block.column_labels)
325327

326328
@property
327329
def columns(self) -> pandas.Index:

bigframes/dtypes.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import textwrap
2121
import typing
2222
from typing import Any, Dict, List, Literal, Sequence, Union
23+
import warnings
2324

2425
import bigframes_vendored.constants as constants
2526
import db_dtypes # type: ignore
@@ -30,6 +31,9 @@
3031
import pyarrow as pa
3132
import shapely.geometry # type: ignore
3233

34+
import bigframes.core.backports
35+
import bigframes.exceptions
36+
3337
# Type hints for Pandas dtypes supported by BigQuery DataFrame
3438
Dtype = Union[
3539
pd.BooleanDtype,
@@ -62,7 +66,8 @@
6266
# No arrow equivalent
6367
GEO_DTYPE = gpd.array.GeometryDtype()
6468
# JSON
65-
# TODO: switch to pyarrow.json_(pyarrow.string()) when available.
69+
# TODO(https://github.com/pandas-dev/pandas/issues/60958): switch to
70+
# pyarrow.json_(pyarrow.string()) when pandas 3+ and pyarrow 18+ is installed.
6671
JSON_ARROW_TYPE = db_dtypes.JSONArrowType()
6772
JSON_DTYPE = pd.ArrowDtype(JSON_ARROW_TYPE)
6873
OBJ_REF_DTYPE = pd.ArrowDtype(
@@ -368,8 +373,7 @@ def get_struct_fields(type_: ExpressionType) -> dict[str, Dtype]:
368373
assert isinstance(type_.pyarrow_dtype, pa.StructType)
369374
struct_type = type_.pyarrow_dtype
370375
result: dict[str, Dtype] = {}
371-
for field_no in range(struct_type.num_fields):
372-
field = struct_type.field(field_no)
376+
for field in bigframes.core.backports.pyarrow_struct_type_fields(struct_type):
373377
result[field.name] = arrow_dtype_to_bigframes_dtype(field.type)
374378
return result
375379

@@ -547,7 +551,8 @@ def arrow_type_to_literal(
547551
return [arrow_type_to_literal(arrow_type.value_type)]
548552
if pa.types.is_struct(arrow_type):
549553
return {
550-
field.name: arrow_type_to_literal(field.type) for field in arrow_type.fields
554+
field.name: arrow_type_to_literal(field.type)
555+
for field in bigframes.core.backports.pyarrow_struct_type_fields(arrow_type)
551556
}
552557
if pa.types.is_string(arrow_type):
553558
return "string"
@@ -915,3 +920,40 @@ def lcd_type_or_throw(dtype1: Dtype, dtype2: Dtype) -> Dtype:
915920

916921

917922
TIMEDELTA_DESCRIPTION_TAG = "#microseconds"
923+
924+
925+
def contains_db_dtypes_json_arrow_type(type_):
926+
if isinstance(type_, db_dtypes.JSONArrowType):
927+
return True
928+
929+
if isinstance(type_, pa.ListType):
930+
return contains_db_dtypes_json_arrow_type(type_.value_type)
931+
932+
if isinstance(type_, pa.StructType):
933+
return any(
934+
contains_db_dtypes_json_arrow_type(field.type)
935+
for field in bigframes.core.backports.pyarrow_struct_type_fields(type_)
936+
)
937+
return False
938+
939+
940+
def contains_db_dtypes_json_dtype(dtype):
941+
if not isinstance(dtype, pd.ArrowDtype):
942+
return False
943+
944+
return contains_db_dtypes_json_arrow_type(dtype.pyarrow_dtype)
945+
946+
947+
def warn_on_db_dtypes_json_dtype(dtypes):
948+
"""Warn that the JSON dtype is changing.
949+
950+
Note: only call this function if the user is explicitly checking the
951+
dtypes.
952+
"""
953+
if any(contains_db_dtypes_json_dtype(dtype) for dtype in dtypes):
954+
msg = bigframes.exceptions.format_message(
955+
"JSON columns will be represented as pandas.ArrowDtype(pyarrow.json_()) "
956+
"instead of using `db_dtypes` in the future when available in pandas "
957+
"(https://github.com/pandas-dev/pandas/issues/60958) and pyarrow."
958+
)
959+
warnings.warn(msg, bigframes.exceptions.JSONDtypeWarning)

bigframes/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ class FunctionAxisOnePreviewWarning(PreviewWarning):
111111
"""Remote Function and Managed UDF with axis=1 preview."""
112112

113113

114+
class JSONDtypeWarning(PreviewWarning):
115+
"""JSON dtype will be pd.ArrowDtype(pa.json_()) in the future."""
116+
117+
114118
class FunctionConflictTypeHintWarning(UserWarning):
115119
"""Conflicting type hints in a BigFrames function."""
116120

bigframes/operations/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
from bigframes.operations.ai_ops import (
18+
AIClassify,
1819
AIGenerate,
1920
AIGenerateBool,
2021
AIGenerateDouble,
@@ -419,6 +420,7 @@
419420
"geo_y_op",
420421
"GeoStDistanceOp",
421422
# AI ops
423+
"AIClassify",
422424
"AIGenerate",
423425
"AIGenerateBool",
424426
"AIGenerateDouble",

0 commit comments

Comments
 (0)