Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 68 additions & 111 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
TransformWithStateInPySparkRowSerializer,
TransformWithStateInPySparkRowInitStateSerializer,
ArrowStreamAggPandasUDFSerializer,
ArrowStreamAggArrowUDFSerializer,
ArrowBatchUDFSerializer,
ArrowStreamUDTFSerializer,
ArrowStreamArrowUDTFSerializer,
Expand Down Expand Up @@ -1088,26 +1087,6 @@ def wrap_window_agg_pandas_udf(
)


def wrap_window_agg_arrow_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf, udf_index):
window_bound_types_str = runner_conf.get("window_bound_types")
window_bound_type = [t.strip().lower() for t in window_bound_types_str.split(",")][udf_index]
if window_bound_type == "bounded":
return wrap_bounded_window_agg_arrow_udf(
f, args_offsets, kwargs_offsets, return_type, runner_conf
)
elif window_bound_type == "unbounded":
return wrap_unbounded_window_agg_arrow_udf(
f, args_offsets, kwargs_offsets, return_type, runner_conf
)
else:
raise PySparkRuntimeError(
errorClass="INVALID_WINDOW_BOUND_TYPE",
messageParameters={
"window_bound_type": window_bound_type,
},
)


def wrap_unbounded_window_agg_pandas_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf):
func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets)

Expand All @@ -1127,27 +1106,6 @@ def wrapped(*series):
)


def wrap_unbounded_window_agg_arrow_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf):
func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets)

# This is similar to wrap_unbounded_window_agg_pandas_udf, the only difference
# is that this function is for arrow udf.
arrow_return_type = to_arrow_type(
return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types
)

def wrapped(*series):
import pyarrow as pa

result = func(*series)
return pa.repeat(result, len(series[0]))

return (
args_kwargs_offsets,
lambda *a: (wrapped(*a), arrow_return_type),
)


def wrap_bounded_window_agg_pandas_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf):
# args_offsets should have at least 2 for begin_index, end_index.
assert len(args_offsets) >= 2, len(args_offsets)
Expand Down Expand Up @@ -1188,35 +1146,6 @@ def wrapped(begin_index, end_index, *series):
)


def wrap_bounded_window_agg_arrow_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf):
# args_offsets should have at least 2 for begin_index, end_index.
assert len(args_offsets) >= 2, len(args_offsets)
func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets[2:], kwargs_offsets)

arrow_return_type = to_arrow_type(
return_type, timezone="UTC", prefers_large_types=runner_conf.use_large_var_types
)

def wrapped(begin_index, end_index, *series):
import pyarrow as pa

assert isinstance(begin_index, pa.Int32Array), type(begin_index)
assert isinstance(end_index, pa.Int32Array), type(end_index)

result = []
for i in range(len(begin_index)):
offset = begin_index[i].as_py()
length = end_index[i].as_py() - offset
series_slices = [s.slice(offset=offset, length=length) for s in series]
result.append(func(*series_slices))
return pa.array(result)

return (
args_offsets[:2] + args_kwargs_offsets,
lambda *a: (wrapped(*a), arrow_return_type),
)


def wrap_kwargs_support(f, args_offsets, kwargs_offsets):
if len(kwargs_offsets):
keys = list(kwargs_offsets.keys())
Expand Down Expand Up @@ -1431,9 +1360,7 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
func, args_offsets, kwargs_offsets, return_type, runner_conf, udf_index
)
elif eval_type == PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF:
return wrap_window_agg_arrow_udf(
func, args_offsets, kwargs_offsets, return_type, runner_conf, udf_index
)
return func, args_offsets, kwargs_offsets, return_type
elif eval_type == PythonEvalType.SQL_BATCHED_UDF:
return wrap_udf(func, args_offsets, kwargs_offsets, return_type)
else:
Expand Down Expand Up @@ -2647,7 +2574,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, eval_conf):
):
ser = ArrowStreamSerializer(write_start_stream=True, num_dfs=1)
elif eval_type == PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF:
ser = ArrowStreamAggArrowUDFSerializer(safecheck=True, arrow_cast=True)
ser = ArrowStreamSerializer(write_start_stream=True, num_dfs=1)
elif eval_type in (
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_ITER_UDF,
Expand Down Expand Up @@ -2968,6 +2895,72 @@ def func(split_index: int, batches: Iterator[Any]) -> Iterator[pa.RecordBatch]:
# profiling is not supported for UDF
return func, None, ser, ser

if eval_type == PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF:
import pyarrow as pa

window_bound_types_str = runner_conf.get("window_bound_types")
window_bound_types = [t.strip().lower() for t in window_bound_types_str.split(",")]

col_names = ["_%d" % i for i in range(len(udfs))]
return_schema = to_arrow_schema(
StructType([StructField(name, rt) for name, (_, _, _, rt) in zip(col_names, udfs)]),
timezone="UTC",
prefers_large_types=runner_conf.use_large_var_types,
)

def func(split_index: int, batches: Iterator[Any]) -> Iterator[pa.RecordBatch]:
for group_batches in batches:
batch_list = list(group_batches)
if not batch_list:
continue
if hasattr(pa, "concat_batches"):
concatenated = pa.concat_batches(batch_list)
else:
# pyarrow.concat_batches not supported before 19.0.0
# remove this once we drop support for old versions
concatenated = pa.RecordBatch.from_struct_array(
pa.concat_arrays([b.to_struct_array() for b in batch_list])
)
num_rows = concatenated.num_rows

result_arrays = []
for udf_index, (udf_func, args_offsets, kwargs_offsets, _) in enumerate(udfs):
bound_type = window_bound_types[udf_index]
if bound_type == "unbounded":
result = udf_func(
*[concatenated.column(o) for o in args_offsets],
**{k: concatenated.column(v) for k, v in kwargs_offsets.items()},
)
result_arrays.append(pa.repeat(result, num_rows))
elif bound_type == "bounded":
begin_col = concatenated.column(args_offsets[0])
end_col = concatenated.column(args_offsets[1])
results = []
for i in range(num_rows):
offset = begin_col[i].as_py()
length = end_col[i].as_py() - offset
slices = [
concatenated.column(o).slice(offset=offset, length=length)
for o in args_offsets[2:]
]
kw_slices = {
k: concatenated.column(v).slice(offset=offset, length=length)
for k, v in kwargs_offsets.items()
}
results.append(udf_func(*slices, **kw_slices))
result_arrays.append(pa.array(results))
else:
raise PySparkRuntimeError(
errorClass="INVALID_WINDOW_BOUND_TYPE",
messageParameters={"window_bound_type": bound_type},
)

batch = pa.RecordBatch.from_arrays(result_arrays, col_names)
yield ArrowBatchTransformer.enforce_schema(batch, return_schema)

# profiling is not supported for UDF
return func, None, ser, ser

is_scalar_iter = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
is_map_pandas_iter = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF

Expand Down Expand Up @@ -3392,42 +3385,6 @@ def mapper(batch_iter):
)
return f(series_iter)

elif eval_type == PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF:
import pyarrow as pa

# For SQL_WINDOW_AGG_ARROW_UDF,
# convert iterator of batch columns to a concatenated RecordBatch
def mapper(a):
# a is Iterator[Tuple[pa.Array, ...]] - convert to RecordBatch
batches = []
for batch_columns in a:
# batch_columns is Tuple[pa.Array, ...] - convert to RecordBatch
batch = pa.RecordBatch.from_arrays(
batch_columns, names=["_%d" % i for i in range(len(batch_columns))]
)
batches.append(batch)

# Concatenate all batches into one
if hasattr(pa, "concat_batches"):
concatenated_batch = pa.concat_batches(batches)
else:
# pyarrow.concat_batches not supported before 19.0.0
# remove this once we drop support for old versions
concatenated_batch = pa.RecordBatch.from_struct_array(
pa.concat_arrays([b.to_struct_array() for b in batches])
)

# Extract series using offsets (concatenated_batch.columns[o] gives pa.Array)
result = tuple(
f(*[concatenated_batch.columns[o] for o in arg_offsets]) for arg_offsets, f in udfs
)
# In the special case of a single UDF this will return a single result rather
# than a tuple of results; this is the format that the JVM side expects.
if len(result) == 1:
return result[0]
else:
return result

elif eval_type in (
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
Expand Down