Skip to content

Commit 5b88e43

Browse files
Patch just once (#2416)
1 parent eca20b9 commit 5b88e43

File tree

1 file changed

+8
-7
lines changed
  • nvflare/app_opt/lightning

1 file changed

+8
-7
lines changed

nvflare/app_opt/lightning/api.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,16 +75,17 @@ def __init__(self):
7575
self.__fl_meta__ = {"CUSTOM_VAR": "VALUE_OF_THE_VAR"}
7676
7777
"""
78-
fl_callback = FLCallback(rank=trainer.global_rank, load_state_dict_strict=load_state_dict_strict)
7978
callbacks = trainer.callbacks
80-
if isinstance(callbacks, list):
79+
if isinstance(callbacks, Callback):
80+
callbacks = [callbacks]
81+
elif not isinstance(callbacks, list):
82+
callbacks = []
83+
84+
if not any(isinstance(cb, FLCallback) for cb in callbacks):
85+
fl_callback = FLCallback(rank=trainer.global_rank, load_state_dict_strict=load_state_dict_strict)
8186
callbacks.append(fl_callback)
82-
elif isinstance(callbacks, Callback):
83-
callbacks = [callbacks, fl_callback]
84-
else:
85-
callbacks = [fl_callback]
8687

87-
if restore_state:
88+
if restore_state and not any(isinstance(cb, RestoreState) for cb in callbacks):
8889
callbacks.append(RestoreState())
8990

9091
trainer.callbacks = callbacks

0 commit comments

Comments
 (0)