diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 74d49f6aae0e..27533a82ea66 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -67,7 +67,6 @@ TransformWithStateInPySparkRowSerializer, TransformWithStateInPySparkRowInitStateSerializer, ArrowStreamAggPandasUDFSerializer, - ArrowStreamAggArrowUDFSerializer, ArrowBatchUDFSerializer, ArrowStreamUDTFSerializer, ArrowStreamArrowUDTFSerializer, @@ -1088,26 +1087,6 @@ def wrap_window_agg_pandas_udf( ) -def wrap_window_agg_arrow_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf, udf_index): - window_bound_types_str = runner_conf.get("window_bound_types") - window_bound_type = [t.strip().lower() for t in window_bound_types_str.split(",")][udf_index] - if window_bound_type == "bounded": - return wrap_bounded_window_agg_arrow_udf( - f, args_offsets, kwargs_offsets, return_type, runner_conf - ) - elif window_bound_type == "unbounded": - return wrap_unbounded_window_agg_arrow_udf( - f, args_offsets, kwargs_offsets, return_type, runner_conf - ) - else: - raise PySparkRuntimeError( - errorClass="INVALID_WINDOW_BOUND_TYPE", - messageParameters={ - "window_bound_type": window_bound_type, - }, - ) - - def wrap_unbounded_window_agg_pandas_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets) @@ -1127,27 +1106,6 @@ def wrapped(*series): ) -def wrap_unbounded_window_agg_arrow_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): - func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets) - - # This is similar to wrap_unbounded_window_agg_pandas_udf, the only difference - # is that this function is for arrow udf. - arrow_return_type = to_arrow_type( - return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types - ) - - def wrapped(*series): - import pyarrow as pa - - result = func(*series) - return pa.repeat(result, len(series[0])) - - return ( - args_kwargs_offsets, - lambda *a: (wrapped(*a), arrow_return_type), - ) - - def wrap_bounded_window_agg_pandas_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): # args_offsets should have at least 2 for begin_index, end_index. assert len(args_offsets) >= 2, len(args_offsets) @@ -1188,35 +1146,6 @@ def wrapped(begin_index, end_index, *series): ) -def wrap_bounded_window_agg_arrow_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf): - # args_offsets should have at least 2 for begin_index, end_index. - assert len(args_offsets) >= 2, len(args_offsets) - func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets[2:], kwargs_offsets) - - arrow_return_type = to_arrow_type( - return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types - ) - - def wrapped(begin_index, end_index, *series): - import pyarrow as pa - - assert isinstance(begin_index, pa.Int32Array), type(begin_index) - assert isinstance(end_index, pa.Int32Array), type(end_index) - - result = [] - for i in range(len(begin_index)): - offset = begin_index[i].as_py() - length = end_index[i].as_py() - offset - series_slices = [s.slice(offset=offset, length=length) for s in series] - result.append(func(*series_slices)) - return pa.array(result) - - return ( - args_offsets[:2] + args_kwargs_offsets, - lambda *a: (wrapped(*a), arrow_return_type), - ) - - def wrap_kwargs_support(f, args_offsets, kwargs_offsets): if len(kwargs_offsets): keys = list(kwargs_offsets.keys()) @@ -1431,9 +1360,7 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): func, args_offsets, kwargs_offsets, return_type, runner_conf, udf_index ) elif eval_type == PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF: - return wrap_window_agg_arrow_udf( - func, args_offsets, kwargs_offsets, return_type, runner_conf, udf_index - ) + return func, args_offsets, kwargs_offsets, return_type elif eval_type == PythonEvalType.SQL_BATCHED_UDF: return wrap_udf(func, args_offsets, kwargs_offsets, return_type) else: @@ -2647,7 +2574,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, eval_conf): ): ser = ArrowStreamSerializer(write_start_stream=True, num_dfs=1) elif eval_type == PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF: - ser = ArrowStreamAggArrowUDFSerializer(safecheck=True, arrow_cast=True) + ser = ArrowStreamSerializer(write_start_stream=True, num_dfs=1) elif eval_type in ( PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF, @@ -2968,6 +2895,72 @@ def func(split_index: int, batches: Iterator[Any]) -> Iterator[pa.RecordBatch]: # profiling is not supported for UDF return func, None, ser, ser + if eval_type == PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF: + import pyarrow as pa + + window_bound_types_str = runner_conf.get("window_bound_types") + window_bound_types = [t.strip().lower() for t in window_bound_types_str.split(",")] + + col_names = ["_%d" % i for i in range(len(udfs))] + return_schema = to_arrow_schema( + StructType([StructField(name, rt) for name, (_, _, _, rt) in zip(col_names, udfs)]), + timezone="UTC", + prefers_large_types=runner_conf.use_large_var_types, + ) + + def func(split_index: int, batches: Iterator[Any]) -> Iterator[pa.RecordBatch]: + for group_batches in batches: + batch_list = list(group_batches) + if not batch_list: + continue + if hasattr(pa, "concat_batches"): + concatenated = pa.concat_batches(batch_list) + else: + # pyarrow.concat_batches not supported before 19.0.0 + # remove this once we drop support for old versions + concatenated = pa.RecordBatch.from_struct_array( + pa.concat_arrays([b.to_struct_array() for b in batch_list]) + ) + num_rows = concatenated.num_rows + + result_arrays = [] + for udf_index, (udf_func, args_offsets, kwargs_offsets, _) in enumerate(udfs): + bound_type = window_bound_types[udf_index] + if bound_type == "unbounded": + result = udf_func( + *[concatenated.column(o) for o in args_offsets], + **{k: concatenated.column(v) for k, v in kwargs_offsets.items()}, + ) + result_arrays.append(pa.repeat(result, num_rows)) + elif bound_type == "bounded": + begin_col = concatenated.column(args_offsets[0]) + end_col = concatenated.column(args_offsets[1]) + results = [] + for i in range(num_rows): + offset = begin_col[i].as_py() + length = end_col[i].as_py() - offset + slices = [ + concatenated.column(o).slice(offset=offset, length=length) + for o in args_offsets[2:] + ] + kw_slices = { + k: concatenated.column(v).slice(offset=offset, length=length) + for k, v in kwargs_offsets.items() + } + results.append(udf_func(*slices, **kw_slices)) + result_arrays.append(pa.array(results)) + else: + raise PySparkRuntimeError( + errorClass="INVALID_WINDOW_BOUND_TYPE", + messageParameters={"window_bound_type": bound_type}, + ) + + batch = pa.RecordBatch.from_arrays(result_arrays, col_names) + yield ArrowBatchTransformer.enforce_schema(batch, return_schema) + + # profiling is not supported for UDF + return func, None, ser, ser + is_scalar_iter = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF is_map_pandas_iter = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF @@ -3392,42 +3385,6 @@ def mapper(batch_iter): ) return f(series_iter) - elif eval_type == PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF: - import pyarrow as pa - - # For SQL_WINDOW_AGG_ARROW_UDF, - # convert iterator of batch columns to a concatenated RecordBatch - def mapper(a): - # a is Iterator[Tuple[pa.Array, ...]] - convert to RecordBatch - batches = [] - for batch_columns in a: - # batch_columns is Tuple[pa.Array, ...] - convert to RecordBatch - batch = pa.RecordBatch.from_arrays( - batch_columns, names=["_%d" % i for i in range(len(batch_columns))] - ) - batches.append(batch) - - # Concatenate all batches into one - if hasattr(pa, "concat_batches"): - concatenated_batch = pa.concat_batches(batches) - else: - # pyarrow.concat_batches not supported before 19.0.0 - # remove this once we drop support for old versions - concatenated_batch = pa.RecordBatch.from_struct_array( - pa.concat_arrays([b.to_struct_array() for b in batches]) - ) - - # Extract series using offsets (concatenated_batch.columns[o] gives pa.Array) - result = tuple( - f(*[concatenated_batch.columns[o] for o in arg_offsets]) for arg_offsets, f in udfs - ) - # In the special case of a single UDF this will return a single result rather - # than a tuple of results; this is the format that the JVM side expects. - if len(result) == 1: - return result[0] - else: - return result - elif eval_type in ( PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,