File tree Expand file tree Collapse file tree 1 file changed +8
-7
lines changed
nvflare/app_opt/lightning Expand file tree Collapse file tree 1 file changed +8
-7
lines changed Original file line number Diff line number Diff line change @@ -75,16 +75,17 @@ def __init__(self):
7575 self.__fl_meta__ = {"CUSTOM_VAR": "VALUE_OF_THE_VAR"}
7676
7777 """
78- fl_callback = FLCallback (rank = trainer .global_rank , load_state_dict_strict = load_state_dict_strict )
7978 callbacks = trainer .callbacks
80- if isinstance (callbacks , list ):
79+ if isinstance (callbacks , Callback ):
80+ callbacks = [callbacks ]
81+ elif not isinstance (callbacks , list ):
82+ callbacks = []
83+
84+ if not any (isinstance (cb , FLCallback ) for cb in callbacks ):
85+ fl_callback = FLCallback (rank = trainer .global_rank , load_state_dict_strict = load_state_dict_strict )
8186 callbacks .append (fl_callback )
82- elif isinstance (callbacks , Callback ):
83- callbacks = [callbacks , fl_callback ]
84- else :
85- callbacks = [fl_callback ]
8687
87- if restore_state :
88+ if restore_state and not any ( isinstance ( cb , RestoreState ) for cb in callbacks ) :
8889 callbacks .append (RestoreState ())
8990
9091 trainer .callbacks = callbacks
You can’t perform that action at this time.
0 commit comments