@@ -216,11 +216,13 @@ def evaluate(self, metrics):
216216 error = "ERR" in metrics .values () or "ERR" in reference .values ()
217217 not_found = "N/A" in metrics .values () or "N/A" in reference .values ()
218218 if error or not_found :
219- return ERROR_METRIC
219+ return ( ERROR_METRIC , "-" , "-" )
220220 ppa = self .get_ppa (metrics )
221221 gamma = ppa / 10
222222 score = ppa * (self .step_ / 100 ) ** (- 1 ) + (gamma * metrics ["num_drc" ])
223- return score
223+ effective_clk_period = metrics ["clk_period" ] - metrics ["worst_slack" ]
224+ num_drc = metrics ["num_drc" ]
225+ return (score , effective_clk_period , num_drc )
224226
225227
226228def parse_arguments ():
@@ -464,32 +466,34 @@ def parse_arguments():
464466 return args
465467
466468
467- def set_algorithm (experiment_name , config ):
469+ def set_algorithm (
470+ algorithm_name , experiment_name , best_params , seed , perturbation , jobs , config
471+ ):
468472 """
469473 Configure search algorithm.
470474 """
471475 # Pre-set seed if user sets seed to 0
472- if args . seed == 0 :
476+ if seed == 0 :
473477 print (
474478 "Warning: you have chosen not to set a seed. Do you wish to continue? (y/n)"
475479 )
476480 if input ().lower () != "y" :
477481 sys .exit (0 )
478- args . seed = None
482+ seed = None
479483 else :
480- torch .manual_seed (args . seed )
481- np .random .seed (args . seed )
482- random .seed (args . seed )
484+ torch .manual_seed (seed )
485+ np .random .seed (seed )
486+ random .seed (seed )
483487
484- if args . algorithm == "hyperopt" :
488+ if algorithm_name == "hyperopt" :
485489 algorithm = HyperOptSearch (
486490 points_to_evaluate = best_params ,
487- random_state_seed = args . seed ,
491+ random_state_seed = seed ,
488492 )
489- elif args . algorithm == "ax" :
493+ elif algorithm_name == "ax" :
490494 ax_client = AxClient (
491495 enforce_sequential_optimization = False ,
492- random_seed = args . seed ,
496+ random_seed = seed ,
493497 )
494498 AxClientMetric = namedtuple ("AxClientMetric" , "minimize" )
495499 ax_client .create_experiment (
@@ -498,25 +502,25 @@ def set_algorithm(experiment_name, config):
498502 objectives = {METRIC : AxClientMetric (minimize = True )},
499503 )
500504 algorithm = AxSearch (ax_client = ax_client , points_to_evaluate = best_params )
501- elif args . algorithm == "optuna" :
502- algorithm = OptunaSearch (points_to_evaluate = best_params , seed = args . seed )
503- elif args . algorithm == "pbt" :
504- print ("Warning: PBT does not support seed values. args. seed will be ignored." )
505+ elif algorithm_name == "optuna" :
506+ algorithm = OptunaSearch (points_to_evaluate = best_params , seed = seed )
507+ elif algorithm_name == "pbt" :
508+ print ("Warning: PBT does not support seed values. seed will be ignored." )
505509 algorithm = PopulationBasedTraining (
506510 time_attr = "training_iteration" ,
507- perturbation_interval = args . perturbation ,
511+ perturbation_interval = perturbation ,
508512 hyperparam_mutations = config ,
509513 synch = True ,
510514 )
511- elif args . algorithm == "random" :
515+ elif algorithm_name == "random" :
512516 algorithm = BasicVariantGenerator (
513- max_concurrent = args . jobs ,
514- random_state = args . seed ,
517+ max_concurrent = jobs ,
518+ random_state = seed ,
515519 )
516520
517521 # A wrapper algorithm for limiting the number of concurrent trials.
518- if args . algorithm not in ["random" , "pbt" ]:
519- algorithm = ConcurrencyLimiter (algorithm , max_concurrent = args . jobs )
522+ if algorithm_name not in ["random" , "pbt" ]:
523+ algorithm = ConcurrencyLimiter (algorithm , max_concurrent = jobs )
520524
521525 return algorithm
522526
@@ -607,7 +611,15 @@ def main():
607611
608612 if args .mode == "tune" :
609613 best_params = set_best_params (args .platform , args .design )
610- search_algo = set_algorithm (args .experiment , config_dict )
614+ search_algo = set_algorithm (
615+ args .algorithm ,
616+ args .experiment ,
617+ best_params ,
618+ args .seed ,
619+ args .perturbation ,
620+ args .jobs ,
621+ config_dict ,
622+ )
611623 TrainClass = set_training_class (args .eval )
612624 # PPAImprov requires a reference file to compute training scores.
613625 if args .eval == "ppa-improv" :
0 commit comments