@@ -13,6 +13,9 @@ ESTIMATOR_TEMPLATE_BAZEL_PATH = "//codegen:sklearn_wrapper_template.py_template"
13
13
ESTIMATOR_TEST_TEMPLATE_BAZEL_PATH = (
14
14
"//codegen:transformer_autogen_test_template.py_template"
15
15
)
16
+ SNOWPARK_PANDAS_TEST_TEMPLATE_BAZEL_PATH = (
17
+ "//codegen:snowpark_pandas_autogen_test_template.py_template"
18
+ )
16
19
INIT_TEMPLATE_BAZEL_PATH = "//codegen:init_template.py_template"
17
20
SRC_OUTPUT_PATH = ""
18
21
TEST_OUTPUT_PATH = "tests/integ"
@@ -113,7 +116,7 @@ def autogen_tests_for_estimators(module, module_root_dir, estimator_info_list):
113
116
List of generated build rules for every class in the estimator_info_list
114
117
1. `genrule` with label `generate_test_<estimator-class-name-snakecase>` to auto-generate
115
118
integration test for the estimator's wrapper class.
116
- 2. `py_test` rule with label `test_ <estimator-class-name-snakecase>` to build the auto-generated
119
+ 2. `py_test` rule with label `<estimator-class-name-snakecase>_test ` to build the auto-generated
117
120
test files from the `generate_test_<estimator-class-name-snakecase>` rule.
118
121
"""
119
122
cmd = get_genrule_cmd (
@@ -145,3 +148,42 @@ def autogen_tests_for_estimators(module, module_root_dir, estimator_info_list):
145
148
shard_count = 5 ,
146
149
tags = ["autogen" ],
147
150
)
151
+
152
+ def autogen_snowpark_pandas_tests (module , module_root_dir , snowpark_pandas_estimator_info_list ):
153
+ """Generates `genrules` and `py_test` rules for every snowpark pandas estimator
154
+ List of generated build rules for every class in the snowpark_pandas_estimator_info_list
155
+ 1. `genrule` with label `generate_test_snowpark_pandas_<estimator-class-name-snakecase>` to auto-generate
156
+ integration test for the estimator.
157
+ 2. `py_test` rule with label `estimator-class-name-snakecase>_snowpark_pandas_test` to build the auto-generated
158
+ test files from the `generate_test_snowpark_pandas_<estimator-class-name-snakecase>` rule.
159
+ """
160
+ cmd = get_genrule_cmd (
161
+ gen_mode = "SNOWPARK_PANDAS_TEST" ,
162
+ template_path = SNOWPARK_PANDAS_TEST_TEMPLATE_BAZEL_PATH ,
163
+ module = module ,
164
+ output_path = TEST_OUTPUT_PATH ,
165
+ )
166
+
167
+ for e in snowpark_pandas_estimator_info_list :
168
+ py_genrule (
169
+ name = "generate_test_snowpark_pandas_{}" .format (e .normalized_class_name ),
170
+ outs = ["{}_snowpark_pandas_test.py" .format (e .normalized_class_name )],
171
+ tools = [AUTO_GEN_TOOL_BAZEL_PATH ],
172
+ srcs = [SNOWPARK_PANDAS_TEST_TEMPLATE_BAZEL_PATH ],
173
+ cmd = cmd .format (e .class_name ),
174
+ tags = ["autogen_build" ],
175
+ )
176
+
177
+ py_test (
178
+ name = "{}_snowpark_pandas_test" .format (e .normalized_class_name ),
179
+ srcs = [":generate_test_snowpark_pandas_{}" .format (e .normalized_class_name )],
180
+ deps = [
181
+ "//snowflake/ml/snowpark_pandas:snowpark_pandas_lib" ,
182
+ "//snowflake/ml/utils:connection_params" ,
183
+ ],
184
+ compatible_with_snowpark = False ,
185
+ timeout = "long" ,
186
+ legacy_create_init = 0 ,
187
+ shard_count = 5 ,
188
+ tags = ["snowpark_pandas_autogen" ],
189
+ )
0 commit comments