@@ -283,6 +283,30 @@ def tune(
283283 if max_failed_trial_count is not None :
284284 experiment .spec .max_failed_trial_count = max_failed_trial_count
285285
286+ # Iterate over input parameters.
287+ input_params = {}
288+ experiment_params = []
289+ trial_params = []
290+ base_image = constants .BASE_IMAGE_TENSORFLOW ,
291+
292+ for p_name , p_value in parameters .items ():
293+ # If input parameter value is Katib Experiment parameter sample.
294+ if isinstance (p_value , models .V1beta1ParameterSpec ):
295+ # Wrap value for the function input.
296+ input_params [p_name ] = f"${{trialParameters.{ p_name } }}"
297+
298+ # Add value to the Katib Experiment parameters.
299+ p_value .name = p_name
300+ experiment_params .append (p_value )
301+
302+ # Add value to the Katib Experiment's Trial parameters.
303+ trial_params .append (
304+ models .V1beta1TrialParameterSpec (name = p_name , reference = p_name )
305+ )
306+ else :
307+ # Otherwise, add value to the function input.
308+ input_params [p_name ] = p_value
309+
286310 # Handle different types of objective input
287311 if callable (objective ):
288312 # Validate objective function.
@@ -295,29 +319,6 @@ def tune(
295319 # (e.g. in another function). We need to dedent the function code.
296320 objective_code = textwrap .dedent (objective_code )
297321
298- # Iterate over input parameters.
299- input_params = {}
300- experiment_params = []
301- trial_params = []
302- base_image = constants .BASE_IMAGE_TENSORFLOW ,
303- for p_name , p_value in parameters .items ():
304- # If input parameter value is Katib Experiment parameter sample.
305- if isinstance (p_value , models .V1beta1ParameterSpec ):
306- # Wrap value for the function input.
307- input_params [p_name ] = f"${{trialParameters.{ p_name } }}"
308-
309- # Add value to the Katib Experiment parameters.
310- p_value .name = p_name
311- experiment_params .append (p_value )
312-
313- # Add value to the Katib Experiment's Trial parameters.
314- trial_params .append (
315- models .V1beta1TrialParameterSpec (name = p_name , reference = p_name )
316- )
317- else :
318- # Otherwise, add value to the function input.
319- input_params [p_name ] = p_value
320-
321322 # Wrap objective function to execute it from the file. For example
322323 # def objective(parameters):
323324 # print(f'Parameters are {parameters}')
@@ -407,12 +408,12 @@ def tune(
407408 trial_template = models .V1beta1TrialTemplate (
408409 primary_container_name = constants .DEFAULT_PRIMARY_CONTAINER_NAME ,
409410 retain = retain_trials ,
410- trial_parameters = trial_params if callable ( objective ) else [] ,
411+ trial_parameters = trial_params ,
411412 trial_spec = trial_spec ,
412413 )
413414
414415 # Add parameters to the Katib Experiment.
415- experiment .spec .parameters = experiment_params if callable ( objective ) else []
416+ experiment .spec .parameters = experiment_params
416417
417418 # Add Trial template to the Katib Experiment.
418419 experiment .spec .trial_template = trial_template
0 commit comments