Skip to content

Commit f0326eb

Browse files
Project import generated by Copybara. (#39)
1 parent 192f794 commit f0326eb

File tree

5 files changed

+72
-120
lines changed

5 files changed

+72
-120
lines changed

ci/build_and_run_tests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ fi
117117

118118
# Compare test required dependencies with wheel pkg dependencies and exclude tests if necessary
119119
EXCLUDE_TESTS=$(mktemp "${TEMP_TEST_DIR}/exclude_tests_XXXXX")
120-
if [[ ${MODE} = "continuous_run" ]]; then
120+
if [[ ${MODE} = "continuous_run" || ${MODE} = "release" ]]; then
121121
./ci/get_excluded_tests.sh -f "${EXCLUDE_TESTS}" -m unused -b "${BAZEL}"
122122
elif [[ ${MODE} = "merge_gate" ]]; then
123123
./ci/get_excluded_tests.sh -f "${EXCLUDE_TESTS}" -m all -b "${BAZEL}"

ci/get_excluded_tests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# -f: specify output file path
99
# -m: specify the mode from the following options
1010
# unused: exclude integration tests whose dependency is not part of the wheel package.
11-
# The missing dependency cuold happen when a new operator is being developed,
11+
# The missing dependency could happen when a new operator is being developed,
1212
# but not yet released.
1313
# unaffected: exclude integration tests whose dependency is not part of the affected targets
1414
# compare to the the merge base to main of current revision.

snowflake/ml/_internal/utils/result.py

Lines changed: 46 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,49 +13,54 @@
1313
_RESULT_SIZE_THRESHOLD = 5 * (1024**2) # 5MB
1414

1515

16-
class SnowflakeResult:
16+
# This module handles serialization, uploading, downloading, and deserialization of stored
17+
# procedure results. If the results are too large to be returned from a stored procedure,
18+
# the result will be uploaded. The client can then retrieve and deserialize the result if
19+
# it was uploaded.
20+
21+
22+
def serialize(session: snowpark.Session, result: Any) -> bytes:
23+
"""
24+
Serialize a tuple containing the result (or None) and the result object filepath
25+
if the result was uploaded to a stage (or None).
26+
27+
Args:
28+
session: Snowpark session.
29+
result: Object to be serialized.
30+
31+
Returns:
32+
Cloudpickled string of bytes of the result tuple.
1733
"""
18-
Handles serialization, uploading, downloading, and deserialization of stored procedure results. If the results
19-
are too large to be returned from a stored procedure, the result will be uploaded. The client can then retrieve
20-
and deserialize the result if it was uploaded.
34+
result_object_filepath = None
35+
result_bytes = cloudpickle.dumps(result)
36+
if sys.getsizeof(result_bytes) > _RESULT_SIZE_THRESHOLD:
37+
stage_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
38+
session.sql(f"CREATE TEMPORARY STAGE {stage_name}").collect()
39+
result_object_filepath = f"@{stage_name}/{snowpark_utils.generate_random_alphanumeric()}"
40+
session.file.put_stream(BytesIO(result_bytes), result_object_filepath)
41+
result_object_filepath = f"{result_object_filepath}.gz"
42+
43+
if result_object_filepath is not None:
44+
return cloudpickle.dumps((None, result_object_filepath)) # type: ignore[no-any-return]
45+
46+
return cloudpickle.dumps((result, None)) # type: ignore[no-any-return]
47+
48+
49+
def deserialize(session: snowpark.Session, result_bytes: bytes) -> Any:
2150
"""
51+
Loads and/or deserializes the (maybe uploaded) result.
2252
23-
def __init__(self, session: snowpark.Session, result: Any) -> None:
24-
self.result = result
25-
self.session = session
26-
self.result_object_filepath = None
27-
result_bytes = cloudpickle.dumps(self.result)
28-
if sys.getsizeof(result_bytes) > _RESULT_SIZE_THRESHOLD:
29-
stage_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
30-
session.sql(f"CREATE TEMPORARY STAGE {stage_name}").collect()
31-
result_object_filepath = f"@{stage_name}/{snowpark_utils.generate_random_alphanumeric()}"
32-
session.file.put_stream(BytesIO(result_bytes), result_object_filepath)
33-
self.result_object_filepath = f"{result_object_filepath}.gz"
34-
35-
def serialize(self) -> bytes:
36-
"""
37-
Serialize a tuple containing the result (or None) and the result object filepath
38-
if the result was uploaded to a stage (or None).
39-
40-
Returns:
41-
Cloudpickled string of bytes of the result tuple.
42-
"""
43-
if self.result_object_filepath is not None:
44-
return cloudpickle.dumps((None, self.result_object_filepath)) # type: ignore[no-any-return]
45-
return cloudpickle.dumps((self.result, None)) # type: ignore[no-any-return]
46-
47-
@staticmethod
48-
def load_result_from_filepath(session: snowpark.Session, result_object_filepath: str) -> Any:
49-
"""
50-
Loads and deserializes the uploaded result.
51-
52-
Args:
53-
session: Snowpark session.
54-
result_object_filepath: Stage filepath of the result object returned by serialize method.
55-
56-
Returns:
57-
The original serialized result (any type).
58-
"""
53+
Args:
54+
session: Snowpark session.
55+
result_bytes: String of bytes returned by serialize method.
56+
57+
Returns:
58+
The deserialized result (any type).
59+
"""
60+
result_object, result_object_filepath = cloudpickle.loads(result_bytes)
61+
if result_object_filepath is not None:
5962
result_object_bytes_io = session.file.get_stream(result_object_filepath, decompress=True)
6063
result_bytes = result_object_bytes_io.read()
61-
return cloudpickle.loads(result_bytes)
64+
result_object = cloudpickle.loads(result_bytes)
65+
66+
return result_object

snowflake/ml/modeling/metrics/ranking.py

Lines changed: 12 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def precision_recall_curve(
8181
statement_params = telemetry.get_statement_params(_PROJECT, _SUBPROJECT)
8282
cols = metrics_utils.flatten_cols([y_true_col_name, probas_pred_col_name, sample_weight_col_name])
8383
queries = df[cols].queries["queries"]
84-
pickled_snowflake_result = cloudpickle.dumps(result)
84+
pickled_result_module = cloudpickle.dumps(result)
8585

8686
@F.sproc( # type: ignore[misc]
8787
is_permanent=False,
@@ -109,16 +109,10 @@ def precision_recall_curve_anon_sproc(session: snowpark.Session) -> bytes:
109109
pos_label=pos_label,
110110
sample_weight=sample_weight,
111111
)
112-
result_module = cloudpickle.loads(pickled_snowflake_result)
113-
result_object = result_module.SnowflakeResult(session, (precision, recall, thresholds))
114-
115-
return result_object.serialize() # type: ignore[no-any-return]
116-
117-
sproc_result = precision_recall_curve_anon_sproc(session)
118-
result_object, result_object_filepath = cloudpickle.loads(sproc_result)
119-
if result_object_filepath is not None:
120-
result_object = result.SnowflakeResult.load_result_from_filepath(session, result_object_filepath)
112+
result_module = cloudpickle.loads(pickled_result_module)
113+
return result_module.serialize(session, (precision, recall, thresholds)) # type: ignore[no-any-return]
121114

115+
result_object = result.deserialize(session, precision_recall_curve_anon_sproc(session))
122116
res: Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]] = result_object
123117
return res
124118

@@ -223,7 +217,7 @@ class scores must correspond to the order of ``labels``,
223217
statement_params = telemetry.get_statement_params(_PROJECT, _SUBPROJECT)
224218
cols = metrics_utils.flatten_cols([y_true_col_names, y_score_col_names, sample_weight_col_name])
225219
queries = df[cols].queries["queries"]
226-
pickled_snowflake_result = cloudpickle.dumps(result)
220+
pickled_result_module = cloudpickle.dumps(result)
227221

228222
@F.sproc( # type: ignore[misc]
229223
is_permanent=False,
@@ -254,16 +248,10 @@ def roc_auc_score_anon_sproc(session: snowpark.Session) -> bytes:
254248
multi_class=multi_class,
255249
labels=labels,
256250
)
257-
result_module = cloudpickle.loads(pickled_snowflake_result)
258-
result_object = result_module.SnowflakeResult(session, auc)
259-
260-
return result_object.serialize() # type: ignore[no-any-return]
261-
262-
sproc_result = roc_auc_score_anon_sproc(session)
263-
result_object, result_object_filepath = cloudpickle.loads(sproc_result)
264-
if result_object_filepath is not None:
265-
result_object = result.SnowflakeResult.load_result_from_filepath(session, result_object_filepath)
251+
result_module = cloudpickle.loads(pickled_result_module)
252+
return result_module.serialize(session, auc) # type: ignore[no-any-return]
266253

254+
result_object = result.deserialize(session, roc_auc_score_anon_sproc(session))
267255
auc: Union[float, npt.NDArray[np.float_]] = result_object
268256
return auc
269257

@@ -320,7 +308,7 @@ def roc_curve(
320308
statement_params = telemetry.get_statement_params(_PROJECT, _SUBPROJECT)
321309
cols = metrics_utils.flatten_cols([y_true_col_name, y_score_col_name, sample_weight_col_name])
322310
queries = df[cols].queries["queries"]
323-
pickled_snowflake_result = cloudpickle.dumps(result)
311+
pickled_result_module = cloudpickle.dumps(result)
324312

325313
@F.sproc( # type: ignore[misc]
326314
is_permanent=False,
@@ -350,16 +338,10 @@ def roc_curve_anon_sproc(session: snowpark.Session) -> bytes:
350338
drop_intermediate=drop_intermediate,
351339
)
352340

353-
result_module = cloudpickle.loads(pickled_snowflake_result)
354-
result_object = result_module.SnowflakeResult(session, (fpr, tpr, thresholds))
355-
356-
return result_object.serialize() # type: ignore[no-any-return]
357-
358-
sproc_result = roc_curve_anon_sproc(session)
359-
result_object, result_object_filepath = cloudpickle.loads(sproc_result)
360-
if result_object_filepath is not None:
361-
result_object = result.SnowflakeResult.load_result_from_filepath(session, result_object_filepath)
341+
result_module = cloudpickle.loads(pickled_result_module)
342+
return result_module.serialize(session, (fpr, tpr, thresholds)) # type: ignore[no-any-return]
362343

344+
result_object = result.deserialize(session, roc_curve_anon_sproc(session))
363345
res: Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]] = result_object
364346

365347
return res

snowflake/ml/modeling/metrics/regression.py

Lines changed: 12 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,9 @@ def d2_absolute_error_score_anon_sproc(session: snowpark.Session) -> bytes:
9999
multioutput=multioutput,
100100
)
101101
result_module = cloudpickle.loads(pickled_snowflake_result)
102-
result_object = result_module.SnowflakeResult(session, score)
103-
104-
return result_object.serialize() # type: ignore[no-any-return]
105-
106-
sproc_result = d2_absolute_error_score_anon_sproc(session)
107-
result_object, result_object_filepath = cloudpickle.loads(sproc_result)
108-
if result_object_filepath is not None:
109-
result_object = result.SnowflakeResult.load_result_from_filepath(session, result_object_filepath)
102+
return result_module.serialize(session, score) # type: ignore[no-any-return]
110103

104+
result_object = result.deserialize(session, d2_absolute_error_score_anon_sproc(session))
111105
score: Union[float, npt.NDArray[np.float_]] = result_object
112106
return score
113107

@@ -192,14 +186,9 @@ def d2_pinball_score_anon_sproc(session: snowpark.Session) -> bytes:
192186
multioutput=multioutput,
193187
)
194188
result_module = cloudpickle.loads(pickled_result_module)
195-
result_object = result_module.SnowflakeResult(session, score)
196-
197-
return result_object.serialize() # type: ignore[no-any-return]
189+
return result_module.serialize(session, score) # type: ignore[no-any-return]
198190

199-
sproc_result = d2_pinball_score_anon_sproc(session)
200-
result_object, result_object_filepath = cloudpickle.loads(sproc_result)
201-
if result_object_filepath is not None:
202-
result_object = result.SnowflakeResult.load_result_from_filepath(session, result_object_filepath)
191+
result_object = result.deserialize(session, d2_pinball_score_anon_sproc(session))
203192

204193
score: Union[float, npt.NDArray[np.float_]] = result_object
205194
return score
@@ -301,15 +290,9 @@ def explained_variance_score_anon_sproc(session: snowpark.Session) -> bytes:
301290
force_finite=force_finite,
302291
)
303292
result_module = cloudpickle.loads(pickled_result_module)
304-
result_object = result_module.SnowflakeResult(session, score)
305-
306-
return result_object.serialize() # type: ignore[no-any-return]
307-
308-
sproc_result = explained_variance_score_anon_sproc(session)
309-
result_object, result_object_filepath = cloudpickle.loads(sproc_result)
310-
if result_object_filepath is not None:
311-
result_object = result.SnowflakeResult.load_result_from_filepath(session, result_object_filepath)
293+
return result_module.serialize(session, score) # type: ignore[no-any-return]
312294

295+
result_object = result.deserialize(session, explained_variance_score_anon_sproc(session))
313296
score: Union[float, npt.NDArray[np.float_]] = result_object
314297
return score
315298

@@ -389,15 +372,9 @@ def mean_absolute_error_anon_sproc(session: snowpark.Session) -> bytes:
389372
)
390373

391374
result_module = cloudpickle.loads(pickled_result_module)
392-
result_object = result_module.SnowflakeResult(session, loss)
393-
394-
return result_object.serialize() # type: ignore[no-any-return]
395-
396-
sproc_result = mean_absolute_error_anon_sproc(session)
397-
result_object, result_object_filepath = cloudpickle.loads(sproc_result)
398-
if result_object_filepath is not None:
399-
result_object = result.SnowflakeResult.load_result_from_filepath(session, result_object_filepath)
375+
return result_module.serialize(session, loss) # type: ignore[no-any-return]
400376

377+
result_object = result.deserialize(session, mean_absolute_error_anon_sproc(session))
401378
loss: Union[float, npt.NDArray[np.float_]] = result_object
402379
return loss
403380

@@ -485,15 +462,9 @@ def mean_absolute_percentage_error_anon_sproc(session: snowpark.Session) -> byte
485462
multioutput=multioutput,
486463
)
487464
result_module = cloudpickle.loads(pickled_result_module)
488-
result_object = result_module.SnowflakeResult(session, loss)
489-
490-
return result_object.serialize() # type: ignore[no-any-return]
491-
492-
sproc_result = mean_absolute_percentage_error_anon_sproc(session)
493-
result_object, result_object_filepath = cloudpickle.loads(sproc_result)
494-
if result_object_filepath is not None:
495-
result_object = result.SnowflakeResult.load_result_from_filepath(session, result_object_filepath)
465+
return result_module.serialize(session, loss) # type: ignore[no-any-return]
496466

467+
result_object = result.deserialize(session, mean_absolute_percentage_error_anon_sproc(session))
497468
loss: Union[float, npt.NDArray[np.float_]] = result_object
498469
return loss
499470

@@ -571,15 +542,9 @@ def mean_squared_error_anon_sproc(session: snowpark.Session) -> bytes:
571542
squared=squared,
572543
)
573544
result_module = cloudpickle.loads(pickled_result_module)
574-
result_object = result_module.SnowflakeResult(session, loss)
575-
576-
return result_object.serialize() # type: ignore[no-any-return]
577-
578-
sproc_result = mean_squared_error_anon_sproc(session)
579-
result_object, result_object_filepath = cloudpickle.loads(sproc_result)
580-
if result_object_filepath is not None:
581-
result_object = result.SnowflakeResult.load_result_from_filepath(session, result_object_filepath)
545+
return result_module.serialize(session, loss) # type: ignore[no-any-return]
582546

547+
result_object = result.deserialize(session, mean_squared_error_anon_sproc(session))
583548
loss: Union[float, npt.NDArray[np.float_]] = result_object
584549
return loss
585550

0 commit comments

Comments
 (0)