Skip to content

Commit e0ea84e

Browse files
Support Docker image as objective in the tune API
Signed-off-by: akhilsaivenkata <[email protected]>
1 parent 55e283e commit e0ea84e

File tree

1 file changed

+70
-63
lines changed

1 file changed

+70
-63
lines changed

sdk/python/v1beta1/kubeflow/katib/api/katib_client.py

Lines changed: 70 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,9 @@ def tune(
153153
self,
154154
# TODO (andreyvelich): How to be consistent with other APIs (name) ?
155155
name: str,
156-
objective: Callable,
156+
objective: Union[Callable, str],
157157
parameters: Dict[str, Any],
158-
base_image: str = constants.BASE_IMAGE_TENSORFLOW,
158+
#base_image: str = constants.BASE_IMAGE_TENSORFLOW,
159159
namespace: Optional[str] = None,
160160
env_per_trial: Optional[
161161
Union[Dict[str, str], List[Union[client.V1EnvVar, client.V1EnvFromSource]]]
@@ -283,65 +283,72 @@ def tune(
283283
if max_failed_trial_count is not None:
284284
experiment.spec.max_failed_trial_count = max_failed_trial_count
285285

286-
# Validate objective function.
287-
utils.validate_objective_function(objective)
288-
289-
# Extract objective function implementation.
290-
objective_code = inspect.getsource(objective)
291-
292-
# Objective function might be defined in some indented scope
293-
# (e.g. in another function). We need to dedent the function code.
294-
objective_code = textwrap.dedent(objective_code)
295-
296-
# Iterate over input parameters.
297-
input_params = {}
298-
experiment_params = []
299-
trial_params = []
300-
for p_name, p_value in parameters.items():
301-
# If input parameter value is Katib Experiment parameter sample.
302-
if isinstance(p_value, models.V1beta1ParameterSpec):
303-
# Wrap value for the function input.
304-
input_params[p_name] = f"${{trialParameters.{p_name}}}"
305-
306-
# Add value to the Katib Experiment parameters.
307-
p_value.name = p_name
308-
experiment_params.append(p_value)
309-
310-
# Add value to the Katib Experiment's Trial parameters.
311-
trial_params.append(
312-
models.V1beta1TrialParameterSpec(name=p_name, reference=p_name)
313-
)
314-
else:
315-
# Otherwise, add value to the function input.
316-
input_params[p_name] = p_value
317-
318-
# Wrap objective function to execute it from the file. For example
319-
# def objective(parameters):
320-
# print(f'Parameters are {parameters}')
321-
# objective({'lr': '${trialParameters.lr}', 'epochs': '${trialParameters.epochs}', 'is_dist': False})
322-
objective_code = f"{objective_code}\n{objective.__name__}({input_params})\n"
323-
324-
# Prepare execute script template.
325-
exec_script = textwrap.dedent(
326-
"""
327-
program_path=$(mktemp -d)
328-
read -r -d '' SCRIPT << EOM\n
329-
{objective_code}
330-
EOM
331-
printf "%s" "$SCRIPT" > $program_path/ephemeral_objective.py
332-
python3 -u $program_path/ephemeral_objective.py"""
333-
)
334-
335-
# Add objective code to the execute script.
336-
exec_script = exec_script.format(objective_code=objective_code)
337-
338-
# Install Python packages if that is required.
339-
if packages_to_install is not None:
340-
exec_script = (
341-
utils.get_script_for_python_packages(packages_to_install, pip_index_url)
342-
+ exec_script
286+
# Handle different types of objective input
287+
if callable(objective):
288+
# Validate objective function.
289+
utils.validate_objective_function(objective)
290+
291+
# Extract objective function implementation.
292+
objective_code = inspect.getsource(objective)
293+
294+
# Objective function might be defined in some indented scope
295+
# (e.g. in another function). We need to dedent the function code.
296+
objective_code = textwrap.dedent(objective_code)
297+
298+
# Iterate over input parameters.
299+
input_params = {}
300+
experiment_params = []
301+
trial_params = []
302+
base_image = constants.BASE_IMAGE_TENSORFLOW,
303+
for p_name, p_value in parameters.items():
304+
# If input parameter value is Katib Experiment parameter sample.
305+
if isinstance(p_value, models.V1beta1ParameterSpec):
306+
# Wrap value for the function input.
307+
input_params[p_name] = f"${{trialParameters.{p_name}}}"
308+
309+
# Add value to the Katib Experiment parameters.
310+
p_value.name = p_name
311+
experiment_params.append(p_value)
312+
313+
# Add value to the Katib Experiment's Trial parameters.
314+
trial_params.append(
315+
models.V1beta1TrialParameterSpec(name=p_name, reference=p_name)
316+
)
317+
else:
318+
# Otherwise, add value to the function input.
319+
input_params[p_name] = p_value
320+
321+
# Wrap objective function to execute it from the file. For example
322+
# def objective(parameters):
323+
# print(f'Parameters are {parameters}')
324+
# objective({'lr': '${trialParameters.lr}', 'epochs': '${trialParameters.epochs}', 'is_dist': False})
325+
objective_code = f"{objective_code}\n{objective.__name__}({input_params})\n"
326+
327+
# Prepare execute script template.
328+
exec_script = textwrap.dedent(
329+
"""
330+
program_path=$(mktemp -d)
331+
read -r -d '' SCRIPT << EOM\n
332+
{objective_code}
333+
EOM
334+
printf "%s" "$SCRIPT" > $program_path/ephemeral_objective.py
335+
python3 -u $program_path/ephemeral_objective.py"""
343336
)
344337

338+
# Add objective code to the execute script.
339+
exec_script = exec_script.format(objective_code=objective_code)
340+
341+
# Install Python packages if that is required.
342+
if packages_to_install is not None:
343+
exec_script = (
344+
utils.get_script_for_python_packages(packages_to_install, pip_index_url)
345+
+ exec_script
346+
)
347+
elif isinstance(objective, str):
348+
base_image=objective
349+
else:
350+
raise ValueError("The objective must be a callable function or a docker image.")
351+
345352
if isinstance(resources_per_trial, dict):
346353
if "gpu" in resources_per_trial:
347354
resources_per_trial["nvidia.com/gpu"] = resources_per_trial.pop("gpu")
@@ -384,8 +391,8 @@ def tune(
384391
client.V1Container(
385392
name=constants.DEFAULT_PRIMARY_CONTAINER_NAME,
386393
image=base_image,
387-
command=["bash", "-c"],
388-
args=[exec_script],
394+
command=["bash", "-c"] if callable(objective) else None,
395+
args=[exec_script] if callable(objective) else None,
389396
env=env,
390397
env_from=env_from,
391398
resources=resources_per_trial,
@@ -400,12 +407,12 @@ def tune(
400407
trial_template = models.V1beta1TrialTemplate(
401408
primary_container_name=constants.DEFAULT_PRIMARY_CONTAINER_NAME,
402409
retain=retain_trials,
403-
trial_parameters=trial_params,
410+
trial_parameters=trial_params if callable(objective) else [],
404411
trial_spec=trial_spec,
405412
)
406413

407414
# Add parameters to the Katib Experiment.
408-
experiment.spec.parameters = experiment_params
415+
experiment.spec.parameters = experiment_params if callable(objective) else []
409416

410417
# Add Trial template to the Katib Experiment.
411418
experiment.spec.trial_template = trial_template

0 commit comments

Comments
 (0)