Skip to content

Commit 59c589a

Browse files
authored
[2.6] Fix lightning api for existing NeMo examples. (#3518)
Fixes # . ### Description Fix lightning api for existing NeMo examples. - Fixes the error in NeMo container when `load_state_dict()` returns None type in Lightning API by wrapping the code into a try/except block. ``` 2025-05-29 19:54:41,781 - SubprocessLauncher - INFO - File "/workspace/code/nvflare/app_opt/lightning/api.py", line 201, in _receive_and_update_model 2025-05-29 19:54:41,782 - SubprocessLauncher - INFO - missing_keys, unexpected_keys = pl_module.load_state_dict( 2025-05-29 19:54:41,782 - SubprocessLauncher - INFO - TypeError: cannot unpack non-iterable NoneType object ``` ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Quick tests passed locally by running `./runtest.sh`. - [ ] In-line docstrings updated. - [ ] Documentation updated.
1 parent bc8af47 commit 59c589a

File tree

4 files changed

+18
-12
lines changed

4 files changed

+18
-12
lines changed

integration/nemo/examples/peft/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ docker run --runtime=nvidia -it --rm --shm-size=16g -p 8888:8888 -p 6006:6006 --
2626

2727
Next, install NVFlare.
2828
```
29-
pip install nvflare~=2.5.0rc
29+
pip install "nvflare>2.6"
3030
```
3131

3232
## Examples

integration/nemo/examples/prompt_learning/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ docker run --runtime=nvidia -it --rm --shm-size=16g -p 8888:8888 -p 6006:6006 --
2626

2727
For easy experimentation with NeMo, install NVFlare and mount the code inside the [nemo_nvflare](./nemo_nvflare) folder.
2828
```
29-
pip install nvflare~=2.5.0rc
29+
pip install "nvflare>2.6"
3030
pip install protobuf==3.20
3131
export PYTHONPATH=${PYTHONPATH}:/workspace
3232
```

integration/nemo/examples/supervised_fine_tuning/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ docker run --runtime=nvidia -it --rm --shm-size=16g -p 8888:8888 -p 6006:6006 --
2525

2626
For easy experimentation with NeMo, install NVFlare and mount the code inside the [nemo_nvflare](./nemo_nvflare) folder.
2727
```
28-
pip install nvflare~=2.5.0rc
28+
pip install "nvflare>2.6"
2929
export PYTHONPATH=${PYTHONPATH}:/workspace
3030
```
3131

nvflare/app_opt/lightning/api.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -198,15 +198,21 @@ def _receive_and_update_model(self, trainer, pl_module):
198198
model = self._receive_model(trainer)
199199
if model:
200200
if model.params:
201-
missing_keys, unexpected_keys = pl_module.load_state_dict(
202-
model.params, strict=self._load_state_dict_strict
203-
)
204-
if len(missing_keys) > 0:
205-
self.logger.warning(f"There were missing keys when loading the global state_dict: {missing_keys}")
206-
if len(unexpected_keys) > 0:
207-
self.logger.warning(
208-
f"There were unexpected keys when loading the global state_dict: {unexpected_keys}"
209-
)
201+
try:
202+
result = pl_module.load_state_dict(model.params, strict=self._load_state_dict_strict)
203+
if result is not None:
204+
missing_keys, unexpected_keys = result
205+
if len(missing_keys) > 0:
206+
self.logger.warning(
207+
f"There were missing keys when loading the global state_dict: {missing_keys}"
208+
)
209+
if len(unexpected_keys) > 0:
210+
self.logger.warning(
211+
f"There were unexpected keys when loading the global state_dict: {unexpected_keys}"
212+
)
213+
except Exception as e:
214+
self.logger.error(f"Failed to load state dict: {str(e)}")
215+
raise RuntimeError(f"Failed to load model state dict: {str(e)}")
210216
if model.current_round is not None:
211217
self.current_round = model.current_round
212218

0 commit comments

Comments
 (0)