Skip to content

Commit c1c8a44

Browse files
Fix hello numpy (#3702)
Fix issue with hello numpy. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Quick tests passed locally by running `./runtest.sh`. - [ ] In-line docstrings updated. - [ ] Documentation updated. --------- Co-authored-by: Chester Chen <[email protected]>
1 parent b6ccae6 commit c1c8a44

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

nvflare/app_common/np/recipes/fedavg.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from nvflare.apis.dxo import DataKind
2121
from nvflare.app_common.abstract.aggregator import Aggregator
2222
from nvflare.app_common.aggregators import InTimeAccumulateWeightedAggregator
23+
from nvflare.app_common.np.np_model_persistor import NPModelPersistor
2324
from nvflare.app_common.shareablegenerators import FullModelShareableGenerator
2425
from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather
2526
from nvflare.client.config import ExchangeFormat, TransferType
@@ -157,21 +158,25 @@ def __init__(
157158
shareable_generator_id = job.to_server(shareable_generator, id="shareable_generator")
158159
aggregator_id = job.to_server(self.aggregator, id="aggregator")
159160

161+
# Handle initial model if provided
162+
persistor_id = ""
163+
if self.initial_model is not None:
164+
# Add persistor and initial model directly
165+
persistor_id = job.to_server(NPModelPersistor(), id="persistor")
166+
job.to(self.initial_model, "server")
167+
160168
controller = ScatterAndGather(
161169
min_clients=self.min_clients,
162170
num_rounds=self.num_rounds,
163171
wait_time_after_min_received=0,
164172
aggregator_id=aggregator_id,
165-
persistor_id=job.comp_ids["persistor_id"] if self.initial_model is not None else "",
173+
persistor_id=persistor_id,
166174
shareable_generator_id=shareable_generator_id,
175+
allow_empty_global_weights=True, # Allow empty weights if no initial model
167176
)
168177
# Send the controller to the server
169178
job.to_server(controller)
170179

171-
# Send initial model to server if provided
172-
if self.initial_model is not None:
173-
job.to(self.initial_model, "server")
174-
175180
# Add clients with NUMPY framework
176181
executor = ScriptRunner(
177182
script=self.train_script,

0 commit comments

Comments
 (0)