Skip to content

Commit af3e20b

Browse files
authored
update model selection (#1134)
1 parent 3607791 commit af3e20b

File tree

3 files changed

+14
-11
lines changed

3 files changed

+14
-11
lines changed

nvflare/app_common/app_event_type.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@
1616
class AppEventType(object):
1717
"""Defines application events."""
1818

19-
START_ROUND = "_start_round"
20-
END_ROUND = "_end_round"
21-
2219
BEFORE_AGGREGATION = "_before_aggregation"
2320
END_AGGREGATION = "_end_aggregation"
2421

nvflare/app_common/widgets/intime_model_selector.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,27 +25,32 @@
2525

2626

2727
class IntimeModelSelector(Widget):
28-
def __init__(self, weigh_by_local_iter=False, aggregation_weights=None):
28+
def __init__(
29+
self, weigh_by_local_iter=False, aggregation_weights=None, validation_metric_name=MetaKey.INITIAL_METRICS
30+
):
2931
"""Handler to determine if the model is globally best.
3032
3133
Args:
3234
weigh_by_local_iter (bool, optional): whether the metrics should be weighted by trainer's iteration number.
3335
aggregation_weights (dict, optional): a mapping of client name to float for aggregation. Defaults to None.
36+
validation_metric_name (str, optional): key used to save initial validation metric in the DXO meta properties (defaults to MetaKey.INITIAL_METRICS).
3437
"""
3538
super().__init__()
3639

3740
self.val_metric = self.best_val_metric = -np.inf
3841
self.weigh_by_local_iter = weigh_by_local_iter
39-
self.validation_metric_name = MetaKey.INITIAL_METRICS
42+
self.validation_metric_name = validation_metric_name
4043
self.aggregation_weights = aggregation_weights or {}
4144

42-
self.logger.debug(f"model selection weights control: {aggregation_weights}")
45+
self.logger.info(f"model selection weights control: {aggregation_weights}")
4346
self._reset_stats()
4447

4548
def handle_event(self, event_type: str, fl_ctx: FLContext):
4649
if event_type == EventType.START_RUN:
4750
self._startup(fl_ctx)
48-
elif event_type == EventType.BEFORE_PROCESS_SUBMISSION:
51+
elif event_type == AppEventType.ROUND_STARTED:
52+
self._reset_stats()
53+
elif event_type == AppEventType.BEFORE_CONTRIBUTION_ACCEPT:
4954
self._before_accept(fl_ctx)
5055
elif event_type == AppEventType.BEFORE_AGGREGATION:
5156
self._before_aggregate(fl_ctx)
@@ -84,7 +89,7 @@ def _before_accept(self, fl_ctx: FLContext):
8489
return False # There is no aggregated model at round 0
8590

8691
if contribution_round != current_round:
87-
self.log_debug(
92+
self.log_warning(
8893
fl_ctx,
8994
f"discarding shareable from {client_name} for round: {contribution_round}. Current round is: {current_round}",
9095
)
@@ -105,8 +110,9 @@ def _before_accept(self, fl_ctx: FLContext):
105110
aggregation_weights = self.aggregation_weights.get(client_name, 1.0)
106111
self.log_debug(fl_ctx, f"aggregation weight: {aggregation_weights}")
107112

108-
self.validation_metric_weighted_sum += validation_metric * n_iter * aggregation_weights
109-
self.validation_metric_sum_of_weights += n_iter
113+
weight = n_iter * aggregation_weights
114+
self.validation_metric_weighted_sum += validation_metric * weight
115+
self.validation_metric_sum_of_weights += weight
110116
return True
111117

112118
def _before_aggregate(self, fl_ctx):

tests/unit_test/app_common/widgets/in_time_model_selector_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,6 @@ def test_model_selection(self, initial, received, expected):
8686
fl_ctx = engine.fl_ctx_mgr.new_context()
8787
fl_ctx.set_prop(FLContextKey.PEER_CONTEXT, peer_ctx)
8888

89-
handler.handle_event(EventType.BEFORE_PROCESS_SUBMISSION, fl_ctx)
89+
handler.handle_event(AppEventType.BEFORE_CONTRIBUTION_ACCEPT, fl_ctx)
9090
handler.handle_event(AppEventType.BEFORE_AGGREGATION, fl_ctx)
9191
assert (engine.last_event == AppEventType.GLOBAL_BEST_MODEL_AVAILABLE) == expected

0 commit comments

Comments
 (0)