Skip to content

Commit 3e133aa

Browse files
resolving review comments
1 parent e0ea84e commit 3e133aa

File tree

1 file changed

+26
-25
lines changed

1 file changed

+26
-25
lines changed

sdk/python/v1beta1/kubeflow/katib/api/katib_client.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,30 @@ def tune(
283283
if max_failed_trial_count is not None:
284284
experiment.spec.max_failed_trial_count = max_failed_trial_count
285285

286+
# Iterate over input parameters.
287+
input_params = {}
288+
experiment_params = []
289+
trial_params = []
290+
base_image = constants.BASE_IMAGE_TENSORFLOW,
291+
292+
for p_name, p_value in parameters.items():
293+
# If input parameter value is Katib Experiment parameter sample.
294+
if isinstance(p_value, models.V1beta1ParameterSpec):
295+
# Wrap value for the function input.
296+
input_params[p_name] = f"${{trialParameters.{p_name}}}"
297+
298+
# Add value to the Katib Experiment parameters.
299+
p_value.name = p_name
300+
experiment_params.append(p_value)
301+
302+
# Add value to the Katib Experiment's Trial parameters.
303+
trial_params.append(
304+
models.V1beta1TrialParameterSpec(name=p_name, reference=p_name)
305+
)
306+
else:
307+
# Otherwise, add value to the function input.
308+
input_params[p_name] = p_value
309+
286310
# Handle different types of objective input
287311
if callable(objective):
288312
# Validate objective function.
@@ -295,29 +319,6 @@ def tune(
295319
# (e.g. in another function). We need to dedent the function code.
296320
objective_code = textwrap.dedent(objective_code)
297321

298-
# Iterate over input parameters.
299-
input_params = {}
300-
experiment_params = []
301-
trial_params = []
302-
base_image = constants.BASE_IMAGE_TENSORFLOW,
303-
for p_name, p_value in parameters.items():
304-
# If input parameter value is Katib Experiment parameter sample.
305-
if isinstance(p_value, models.V1beta1ParameterSpec):
306-
# Wrap value for the function input.
307-
input_params[p_name] = f"${{trialParameters.{p_name}}}"
308-
309-
# Add value to the Katib Experiment parameters.
310-
p_value.name = p_name
311-
experiment_params.append(p_value)
312-
313-
# Add value to the Katib Experiment's Trial parameters.
314-
trial_params.append(
315-
models.V1beta1TrialParameterSpec(name=p_name, reference=p_name)
316-
)
317-
else:
318-
# Otherwise, add value to the function input.
319-
input_params[p_name] = p_value
320-
321322
# Wrap objective function to execute it from the file. For example
322323
# def objective(parameters):
323324
# print(f'Parameters are {parameters}')
@@ -407,12 +408,12 @@ def tune(
407408
trial_template = models.V1beta1TrialTemplate(
408409
primary_container_name=constants.DEFAULT_PRIMARY_CONTAINER_NAME,
409410
retain=retain_trials,
410-
trial_parameters=trial_params if callable(objective) else [],
411+
trial_parameters=trial_params,
411412
trial_spec=trial_spec,
412413
)
413414

414415
# Add parameters to the Katib Experiment.
415-
experiment.spec.parameters = experiment_params if callable(objective) else []
416+
experiment.spec.parameters = experiment_params
416417

417418
# Add Trial template to the Katib Experiment.
418419
experiment.spec.trial_template = trial_template

0 commit comments

Comments
 (0)