88from eval_protocol .pytest .types import (
99 Dataset ,
1010 DatasetPathParam ,
11+ EvaluationInputParam ,
1112 EvaluationTestMode ,
1213 InputMessagesParam ,
13- InputParam ,
1414 ModelParam ,
15+ RolloutInputParam ,
1516 RolloutProcessor ,
1617 RolloutProcessorConfig ,
1718 TestFunction ,
@@ -32,8 +33,9 @@ def evaluation_test(
3233 input_messages : Optional [List [InputMessagesParam ]] = None ,
3334 input_dataset : Optional [List [DatasetPathParam ]] = None ,
3435 dataset_adapter : Optional [Callable [[List [Dict [str , Any ]]], Dataset ]] = lambda x : x ,
35- input_params : Optional [List [InputParam ]] = None ,
36+ rollout_input_params : Optional [List [RolloutInputParam ]] = None ,
3637 rollout_processor : RolloutProcessor = default_no_op_rollout_processor ,
38+ evaluation_test_kwargs : Optional [List [EvaluationInputParam ]] = None ,
3739 aggregation_method : AggregationMethod = "mean" ,
3840 threshold_of_success : Optional [float ] = None ,
3941 num_runs : int = 1 ,
@@ -56,8 +58,9 @@ def evaluation_test(
5658 to a list of EvaluationRows if you have a custom dataset format.
5759 dataset_adapter: Function to convert the input dataset to a list of
5860 EvaluationRows. This is useful if you have a custom dataset format.
59- input_params : Generation parameters for the model .
61+ rollout_input_params : Generation parameters for the rollout .
6062 rollout_processor: Function used to perform the rollout.
63+ evaluation_test_kwargs: Kwargs for the evaluation function.
6164 aggregation_method: How to aggregate scores across rows.
6265 threshold_of_success: If set, fail the test if the aggregated score is
6366 below this threshold.
@@ -104,12 +107,19 @@ def execute_with_params(
104107 test_func : TestFunction ,
105108 row : EvaluationRow | None = None ,
106109 input_dataset : List [EvaluationRow ] | None = None ,
110+ evaluation_test_kwargs : Optional [EvaluationInputParam ] = None ,
107111 ):
108112 kwargs = {}
109113 if input_dataset is not None :
110114 kwargs ["rows" ] = input_dataset
111115 if row is not None :
112116 kwargs ["row" ] = row
117+ if evaluation_test_kwargs is not None :
118+ if "row" in evaluation_test_kwargs :
119+ raise ValueError ("'row' is a reserved parameter for the evaluation function" )
120+ if "rows" in evaluation_test_kwargs :
121+ raise ValueError ("'rows' is a reserved parameter for the evaluation function" )
122+ kwargs .update (evaluation_test_kwargs )
113123 return execute_function (test_func , ** kwargs )
114124
115125 # Calculate all possible combinations of parameters
@@ -118,21 +128,23 @@ def generate_combinations():
118128
119129 # Handle optional parameters with defaults
120130 datasets : List [Optional [DatasetPathParam ]] = input_dataset if input_dataset is not None else [None ] # type: ignore
121- params : List [Optional [InputParam ]] = input_params if input_params is not None else [None ] # type: ignore
131+ params : List [Optional [RolloutInputParam ]] = rollout_input_params if rollout_input_params is not None else [None ] # type: ignore
122132 messages : List [Optional [InputMessagesParam ]] = input_messages if input_messages is not None else [None ] # type: ignore
133+ kwargs : List [Optional [EvaluationInputParam ]] = evaluation_test_kwargs if evaluation_test_kwargs is not None else [None ] # type: ignore
123134
124135 # Generate all combinations
125136 for m in model :
126137 for ds in datasets :
127138 for ip in params :
128139 for im in messages :
129- # Skip combinations that don't make sense
130- # If we have a dataset, we should have params for rollout
131- if ds is not None and ip is None :
132- continue
133- # If we have messages but no dataset, that's fine
134- # If we have no dataset and no messages, that's also fine
135- combinations .append ((m , ds , ip , im ))
140+ for etk in kwargs :
141+ # Skip combinations that don't make sense
142+ # If we have a dataset, we should have params for rollout
143+ if ds is not None and ip is None :
144+ continue
145+ # If we have messages but no dataset, that's fine
146+ # If we have no dataset and no messages, that's also fine
147+ combinations .append ((m , ds , ip , im , etk ))
136148
137149 return combinations
138150
@@ -141,27 +153,31 @@ def generate_combinations():
141153 # Create parameter tuples for pytest.mark.parametrize
142154 param_tuples = []
143155 for combo in combinations :
144- model_name , dataset , params , messages = combo
156+ model_name , dataset , params , messages , etk = combo
145157 param_tuple = [model_name ]
146158 if input_dataset is not None :
147159 param_tuple .append (dataset )
148- if input_params is not None :
160+ if rollout_input_params is not None :
149161 param_tuple .append (params )
150162 if input_messages is not None :
151163 param_tuple .append (messages )
164+ if evaluation_test_kwargs is not None :
165+ param_tuple .append (etk )
152166 param_tuples .append (tuple (param_tuple ))
153167
154168 # For batch mode, use the original parameter names
155169 test_param_names = ["model" ]
156170 if input_dataset is not None :
157171 test_param_names .append ("dataset_path" )
158- if input_params is not None :
172+ if rollout_input_params is not None :
159173 test_param_names .append ("input_params" )
160174 if input_messages is not None :
161175 test_param_names .append ("input_messages" )
176+ if evaluation_test_kwargs is not None :
177+ test_param_names .append ("evaluation_test_kwargs" )
162178
163179 # Create wrapper function with exact signature that pytest expects
164- def create_wrapper_with_signature ():
180+ def create_wrapper_with_signature () -> Callable :
165181 # Create the function body that will be used
166182 def wrapper_body (** kwargs ):
167183 model_name = kwargs ["model" ]
@@ -193,6 +209,7 @@ def wrapper_body(**kwargs):
193209 result = execute_with_params (
194210 test_func ,
195211 row = row ,
212+ evaluation_test_kwargs = kwargs .get ("evaluation_test_kwargs" ) or {},
196213 )
197214 if result is None or not isinstance (result , EvaluationRow ):
198215 raise ValueError (
@@ -204,6 +221,7 @@ def wrapper_body(**kwargs):
204221 results = execute_with_params (
205222 test_func ,
206223 input_dataset = input_dataset ,
224+ evaluation_test_kwargs = kwargs .get ("evaluation_test_kwargs" ) or {},
207225 )
208226 if results is None :
209227 raise ValueError (
@@ -234,6 +252,7 @@ def wrapper_body(**kwargs):
234252
235253 wrapper = create_wrapper_with_signature ()
236254 wrapper = pytest .mark .parametrize (test_param_names , param_tuples )(wrapper )
255+ wrapper .original_evaluation_test_func = test_func
237256
238257 return wrapper
239258
0 commit comments