@@ -153,9 +153,9 @@ def tune(
153153 self ,
154154 # TODO (andreyvelich): How to be consistent with other APIs (name) ?
155155 name : str ,
156- objective : Callable ,
156+ objective : Union [ Callable , str ] ,
157157 parameters : Dict [str , Any ],
158- base_image : str = constants .BASE_IMAGE_TENSORFLOW ,
158+ # base_image: str = constants.BASE_IMAGE_TENSORFLOW,
159159 namespace : Optional [str ] = None ,
160160 env_per_trial : Optional [
161161 Union [Dict [str , str ], List [Union [client .V1EnvVar , client .V1EnvFromSource ]]]
@@ -283,65 +283,72 @@ def tune(
283283 if max_failed_trial_count is not None :
284284 experiment .spec .max_failed_trial_count = max_failed_trial_count
285285
286- # Validate objective function.
287- utils .validate_objective_function (objective )
288-
289- # Extract objective function implementation.
290- objective_code = inspect .getsource (objective )
291-
292- # Objective function might be defined in some indented scope
293- # (e.g. in another function). We need to dedent the function code.
294- objective_code = textwrap .dedent (objective_code )
295-
296- # Iterate over input parameters.
297- input_params = {}
298- experiment_params = []
299- trial_params = []
300- for p_name , p_value in parameters .items ():
301- # If input parameter value is Katib Experiment parameter sample.
302- if isinstance (p_value , models .V1beta1ParameterSpec ):
303- # Wrap value for the function input.
304- input_params [p_name ] = f"${{trialParameters.{ p_name } }}"
305-
306- # Add value to the Katib Experiment parameters.
307- p_value .name = p_name
308- experiment_params .append (p_value )
309-
310- # Add value to the Katib Experiment's Trial parameters.
311- trial_params .append (
312- models .V1beta1TrialParameterSpec (name = p_name , reference = p_name )
313- )
314- else :
315- # Otherwise, add value to the function input.
316- input_params [p_name ] = p_value
317-
318- # Wrap objective function to execute it from the file. For example
319- # def objective(parameters):
320- # print(f'Parameters are {parameters}')
321- # objective({'lr': '${trialParameters.lr}', 'epochs': '${trialParameters.epochs}', 'is_dist': False})
322- objective_code = f"{ objective_code } \n { objective .__name__ } ({ input_params } )\n "
323-
324- # Prepare execute script template.
325- exec_script = textwrap .dedent (
326- """
327- program_path=$(mktemp -d)
328- read -r -d '' SCRIPT << EOM\n
329- {objective_code}
330- EOM
331- printf "%s" "$SCRIPT" > $program_path/ephemeral_objective.py
332- python3 -u $program_path/ephemeral_objective.py"""
333- )
334-
335- # Add objective code to the execute script.
336- exec_script = exec_script .format (objective_code = objective_code )
337-
338- # Install Python packages if that is required.
339- if packages_to_install is not None :
340- exec_script = (
341- utils .get_script_for_python_packages (packages_to_install , pip_index_url )
342- + exec_script
286+ # Handle different types of objective input
287+ if callable (objective ):
288+ # Validate objective function.
289+ utils .validate_objective_function (objective )
290+
291+ # Extract objective function implementation.
292+ objective_code = inspect .getsource (objective )
293+
294+ # Objective function might be defined in some indented scope
295+ # (e.g. in another function). We need to dedent the function code.
296+ objective_code = textwrap .dedent (objective_code )
297+
298+ # Iterate over input parameters.
299+ input_params = {}
300+ experiment_params = []
301+ trial_params = []
302+ base_image = constants .BASE_IMAGE_TENSORFLOW ,
303+ for p_name , p_value in parameters .items ():
304+ # If input parameter value is Katib Experiment parameter sample.
305+ if isinstance (p_value , models .V1beta1ParameterSpec ):
306+ # Wrap value for the function input.
307+ input_params [p_name ] = f"${{trialParameters.{ p_name } }}"
308+
309+ # Add value to the Katib Experiment parameters.
310+ p_value .name = p_name
311+ experiment_params .append (p_value )
312+
313+ # Add value to the Katib Experiment's Trial parameters.
314+ trial_params .append (
315+ models .V1beta1TrialParameterSpec (name = p_name , reference = p_name )
316+ )
317+ else :
318+ # Otherwise, add value to the function input.
319+ input_params [p_name ] = p_value
320+
321+ # Wrap objective function to execute it from the file. For example
322+ # def objective(parameters):
323+ # print(f'Parameters are {parameters}')
324+ # objective({'lr': '${trialParameters.lr}', 'epochs': '${trialParameters.epochs}', 'is_dist': False})
325+ objective_code = f"{ objective_code } \n { objective .__name__ } ({ input_params } )\n "
326+
327+ # Prepare execute script template.
328+ exec_script = textwrap .dedent (
329+ """
330+ program_path=$(mktemp -d)
331+ read -r -d '' SCRIPT << EOM\n
332+ {objective_code}
333+ EOM
334+ printf "%s" "$SCRIPT" > $program_path/ephemeral_objective.py
335+ python3 -u $program_path/ephemeral_objective.py"""
343336 )
344337
338+ # Add objective code to the execute script.
339+ exec_script = exec_script .format (objective_code = objective_code )
340+
341+ # Install Python packages if that is required.
342+ if packages_to_install is not None :
343+ exec_script = (
344+ utils .get_script_for_python_packages (packages_to_install , pip_index_url )
345+ + exec_script
346+ )
347+ elif isinstance (objective , str ):
348+ base_image = objective
349+ else :
350+ raise ValueError ("The objective must be a callable function or a docker image." )
351+
345352 if isinstance (resources_per_trial , dict ):
346353 if "gpu" in resources_per_trial :
347354 resources_per_trial ["nvidia.com/gpu" ] = resources_per_trial .pop ("gpu" )
@@ -384,8 +391,8 @@ def tune(
384391 client .V1Container (
385392 name = constants .DEFAULT_PRIMARY_CONTAINER_NAME ,
386393 image = base_image ,
387- command = ["bash" , "-c" ],
388- args = [exec_script ],
394+ command = ["bash" , "-c" ] if callable ( objective ) else None ,
395+ args = [exec_script ] if callable ( objective ) else None ,
389396 env = env ,
390397 env_from = env_from ,
391398 resources = resources_per_trial ,
@@ -400,12 +407,12 @@ def tune(
400407 trial_template = models .V1beta1TrialTemplate (
401408 primary_container_name = constants .DEFAULT_PRIMARY_CONTAINER_NAME ,
402409 retain = retain_trials ,
403- trial_parameters = trial_params ,
410+ trial_parameters = trial_params if callable ( objective ) else [] ,
404411 trial_spec = trial_spec ,
405412 )
406413
407414 # Add parameters to the Katib Experiment.
408- experiment .spec .parameters = experiment_params
415+ experiment .spec .parameters = experiment_params if callable ( objective ) else []
409416
410417 # Add Trial template to the Katib Experiment.
411418 experiment .spec .trial_template = trial_template
0 commit comments