2525
2626
2727class IntimeModelSelector (Widget ):
28- def __init__ (self , weigh_by_local_iter = False , aggregation_weights = None ):
28+ def __init__ (
29+ self , weigh_by_local_iter = False , aggregation_weights = None , validation_metric_name = MetaKey .INITIAL_METRICS
30+ ):
2931 """Handler to determine if the model is globally best.
3032
3133 Args:
3234 weigh_by_local_iter (bool, optional): whether the metrics should be weighted by trainer's iteration number.
3335 aggregation_weights (dict, optional): a mapping of client name to float for aggregation. Defaults to None.
36+ validation_metric_name (str, optional): key used to save initial validation metric in the DXO meta properties (defaults to MetaKey.INITIAL_METRICS).
3437 """
3538 super ().__init__ ()
3639
3740 self .val_metric = self .best_val_metric = - np .inf
3841 self .weigh_by_local_iter = weigh_by_local_iter
39- self .validation_metric_name = MetaKey . INITIAL_METRICS
42+ self .validation_metric_name = validation_metric_name
4043 self .aggregation_weights = aggregation_weights or {}
4144
42- self .logger .debug (f"model selection weights control: { aggregation_weights } " )
45+ self .logger .info (f"model selection weights control: { aggregation_weights } " )
4346 self ._reset_stats ()
4447
4548 def handle_event (self , event_type : str , fl_ctx : FLContext ):
4649 if event_type == EventType .START_RUN :
4750 self ._startup (fl_ctx )
48- elif event_type == EventType .BEFORE_PROCESS_SUBMISSION :
51+ elif event_type == AppEventType .ROUND_STARTED :
52+ self ._reset_stats ()
53+ elif event_type == AppEventType .BEFORE_CONTRIBUTION_ACCEPT :
4954 self ._before_accept (fl_ctx )
5055 elif event_type == AppEventType .BEFORE_AGGREGATION :
5156 self ._before_aggregate (fl_ctx )
@@ -84,7 +89,7 @@ def _before_accept(self, fl_ctx: FLContext):
8489 return False # There is no aggregated model at round 0
8590
8691 if contribution_round != current_round :
87- self .log_debug (
92+ self .log_warning (
8893 fl_ctx ,
8994 f"discarding shareable from { client_name } for round: { contribution_round } . Current round is: { current_round } " ,
9095 )
@@ -105,8 +110,9 @@ def _before_accept(self, fl_ctx: FLContext):
105110 aggregation_weights = self .aggregation_weights .get (client_name , 1.0 )
106111 self .log_debug (fl_ctx , f"aggregation weight: { aggregation_weights } " )
107112
108- self .validation_metric_weighted_sum += validation_metric * n_iter * aggregation_weights
109- self .validation_metric_sum_of_weights += n_iter
113+ weight = n_iter * aggregation_weights
114+ self .validation_metric_weighted_sum += validation_metric * weight
115+ self .validation_metric_sum_of_weights += weight
110116 return True
111117
112118 def _before_aggregate (self , fl_ctx ):
0 commit comments