@@ -965,7 +965,7 @@ def float_parser(row):
965965 )
966966
967967
968- def test_managed_function_df_where (session , dataset_id , scalars_dfs ):
968+ def test_managed_function_df_where_mask (session , dataset_id , scalars_dfs ):
969969 try :
970970
971971 # The return type has to be bool type for callable where condition.
@@ -987,15 +987,15 @@ def is_sum_positive(a, b):
987987 pd_int64_df = scalars_pandas_df [int64_cols ]
988988 pd_int64_df_filtered = pd_int64_df .dropna ()
989989
990- # Use callable condition in dataframe.where method.
990+ # Test callable condition in dataframe.where method.
991991 bf_result = bf_int64_df_filtered .where (is_sum_positive_mf ).to_pandas ()
992992 # Pandas doesn't support such case, use following as workaround.
993993 pd_result = pd_int64_df_filtered .where (pd_int64_df_filtered .sum (axis = 1 ) > 0 )
994994
995995 # Ignore any dtype difference.
996996 pandas .testing .assert_frame_equal (bf_result , pd_result , check_dtype = False )
997997
998- # Make sure the read_gbq_function path works for this function .
998+ # Make sure the read_gbq_function path works for dataframe.where method .
999999 is_sum_positive_ref = session .read_gbq_function (
10001000 function_name = is_sum_positive_mf .bigframes_bigquery_function
10011001 )
@@ -1012,14 +1012,27 @@ def is_sum_positive(a, b):
10121012 bf_result_gbq , pd_result_gbq , check_dtype = False
10131013 )
10141014
1015+ # Test callable condition in dataframe.mask method.
1016+ bf_result_gbq = bf_int64_df_filtered .mask (
1017+ is_sum_positive_ref , - bf_int64_df_filtered
1018+ ).to_pandas ()
1019+ pd_result_gbq = pd_int64_df_filtered .mask (
1020+ pd_int64_df_filtered .sum (axis = 1 ) > 0 , - pd_int64_df_filtered
1021+ )
1022+
1023+ # Ignore any dtype difference.
1024+ pandas .testing .assert_frame_equal (
1025+ bf_result_gbq , pd_result_gbq , check_dtype = False
1026+ )
1027+
10151028 finally :
10161029 # Clean up the gcp assets created for the managed function.
10171030 cleanup_function_assets (
10181031 is_sum_positive_mf , session .bqclient , ignore_failures = False
10191032 )
10201033
10211034
1022- def test_managed_function_df_where_series (session , dataset_id , scalars_dfs ):
1035+ def test_managed_function_df_where_mask_series (session , dataset_id , scalars_dfs ):
10231036 try :
10241037
10251038 # The return type has to be bool type for callable where condition.
@@ -1041,14 +1054,14 @@ def is_sum_positive_series(s):
10411054 pd_int64_df = scalars_pandas_df [int64_cols ]
10421055 pd_int64_df_filtered = pd_int64_df .dropna ()
10431056
1044- # Use callable condition in dataframe.where method.
1057+ # Test callable condition in dataframe.where method.
10451058 bf_result = bf_int64_df_filtered .where (is_sum_positive_series ).to_pandas ()
10461059 pd_result = pd_int64_df_filtered .where (is_sum_positive_series )
10471060
10481061 # Ignore any dtype difference.
10491062 pandas .testing .assert_frame_equal (bf_result , pd_result , check_dtype = False )
10501063
1051- # Make sure the read_gbq_function path works for this function .
1064+ # Make sure the read_gbq_function path works for dataframe.where method .
10521065 is_sum_positive_series_ref = session .read_gbq_function (
10531066 function_name = is_sum_positive_series_mf .bigframes_bigquery_function ,
10541067 is_row_processor = True ,
@@ -1070,6 +1083,19 @@ def func_for_other(x):
10701083 bf_result_gbq , pd_result_gbq , check_dtype = False
10711084 )
10721085
1086+ # Test callable condition in dataframe.mask method.
1087+ bf_result_gbq = bf_int64_df_filtered .mask (
1088+ is_sum_positive_series_ref , func_for_other
1089+ ).to_pandas ()
1090+ pd_result_gbq = pd_int64_df_filtered .mask (
1091+ is_sum_positive_series , func_for_other
1092+ )
1093+
1094+ # Ignore any dtype difference.
1095+ pandas .testing .assert_frame_equal (
1096+ bf_result_gbq , pd_result_gbq , check_dtype = False
1097+ )
1098+
10731099 finally :
10741100 # Clean up the gcp assets created for the managed function.
10751101 cleanup_function_assets (
@@ -1121,3 +1147,31 @@ def _is_positive(s):
11211147 finally :
11221148 # Clean up the gcp assets created for the managed function.
11231149 cleanup_function_assets (is_positive_mf , session .bqclient , ignore_failures = False )
1150+
1151+
1152+ def test_managed_function_series_apply_args (session , dataset_id , scalars_dfs ):
1153+ try :
1154+
1155+ with pytest .warns (bfe .PreviewWarning , match = "udf is in preview." ):
1156+
1157+ @session .udf (dataset = dataset_id , name = prefixer .create_prefix ())
1158+ def foo_list (x : int , y0 : float , y1 : bytes , y2 : bool ) -> list [str ]:
1159+ return [str (x ), str (y0 ), str (y1 ), str (y2 )]
1160+
1161+ scalars_df , scalars_pandas_df = scalars_dfs
1162+
1163+ bf_result = (
1164+ scalars_df ["int64_too" ]
1165+ .apply (foo_list , args = (12.34 , b"hello world" , False ))
1166+ .to_pandas ()
1167+ )
1168+ pd_result = scalars_pandas_df ["int64_too" ].apply (
1169+ foo_list , args = (12.34 , b"hello world" , False )
1170+ )
1171+
1172+ # Ignore any dtype difference.
1173+ pandas .testing .assert_series_equal (bf_result , pd_result , check_dtype = False )
1174+
1175+ finally :
1176+ # Clean up the gcp assets created for the managed function.
1177+ cleanup_function_assets (foo_list , session .bqclient , ignore_failures = False )
0 commit comments