Skip to content

Commit 966f84f

Browse files
authored
Merge pull request #970 from scap3yvt/968-feature-add-new-optimizers
Added new optimizers
2 parents d2f8c2f + 882fd51 commit 966f84f

File tree

10 files changed

+600
-27
lines changed

10 files changed

+600
-27
lines changed

GANDLF/optimizers/README.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,16 @@
33
## Adding a new algorithm
44

55
- For an optimizer defined in PyTorch [[ref](https://pytorch.org/docs/stable/optim.html#algorithms)], update the `GANDLF.optimizers.wrap_torch.py` submodule.
6-
- For a custom optimizer, create a new submodule called `GANDLF.optimizers.${awesome_optimizer}.py`. Ensure that it inherits from PyTorch's base optimizer class [[ref](https://pytorch.org/docs/stable/optim.html#base-class)]
6+
- For a custom optimizer, create a new submodule called `GANDLF.optimizers.${awesome_optimizer}.py`.
7+
- For a third-party optimizer (i.e., where the code is available from an external source/repository):
8+
- Add the relevant code under the `GANDLF.optimizers.thirdparty` submodule.
9+
- Add a wrapper which takes in GaNDLF's `parameter` dictionary as input and creates a `torch.optim.Optimizer` object as output.
10+
- Add the wrapper to the `GANDLF.optimizers.thirdparty.__init__.py` so that it can be called from `GANDLF.optimizers.__init__.py`.
11+
- See `GANDLF.optimizers.thirdparty.adopy.py` as an example.
712
- If a new dependency needs to be used, update GaNDLF's [`setup.py`](https://github.com/mlcommons/GaNDLF/blob/master/setup.py) with the new requirement.
813
- Define a new submodule under `GANDLF.optimizers` as `GANDLF.optimizers.wrap_${package_name}.py`.
914
- Ensure that the new algorithm is wrapped in a function which returns an object with the PyTorch optimizer type. Use any of the optimizers in `GANDLF.optimizers.wrap_torch.py` as an example.
1015
- Add the algorithm's identifier to `GANDLF.optimizers.__init__.global_optimizer_dict` with an appropriate key.
1116
- Call the new algorithm from the config using the `optimizer` key.
12-
- [Update the tests!](https://mlcommons.github.io/GaNDLF/extending/#update-tests)https://mlcommons.github.io/GaNDLF/extending/#update-tests
17+
- [If appropriate, please update the tests!](https://mlcommons.github.io/GaNDLF/extending/#update-tests)https://mlcommons.github.io/GaNDLF/extending/#update-tests
18+
- All wrappers should return the type `from torch.optim.optimizer.Optimizer`.

GANDLF/optimizers/__init__.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from .wrap_monai import novograd_wrapper
1717

18-
from .ademamix import ademamix_wrapper
18+
from .thirdparty import ademamix_wrapper, lion_wrapper, adopt_wrapper
1919

2020
global_optimizer_dict = {
2121
"sgd": sgd,
@@ -32,6 +32,8 @@
3232
"novograd": novograd_wrapper,
3333
"nadam": nadam,
3434
"ademamix": ademamix_wrapper,
35+
"lion": lion_wrapper,
36+
"adopt": adopt_wrapper,
3537
}
3638

3739

@@ -49,9 +51,10 @@ def get_optimizer(params):
4951
# Retrieve the optimizer type from the input parameters
5052
optimizer_type = params["optimizer"]["type"]
5153

54+
assert (
55+
optimizer_type in global_optimizer_dict
56+
), f"Optimizer type {optimizer_type} not found"
57+
5258
# Create the optimizer instance using the specified type and input parameters
53-
if optimizer_type in global_optimizer_dict:
54-
optimizer_function = global_optimizer_dict[optimizer_type]
55-
return optimizer_function(params)
56-
else:
57-
raise ValueError("Optimizer type %s not found" % optimizer_type)
59+
optimizer_function = global_optimizer_dict[optimizer_type]
60+
return optimizer_function(params)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .ademamix import ademamix_wrapper
2+
3+
from .lion import lion_wrapper
4+
5+
from .adopt import adopt_wrapper
File renamed without changes.

0 commit comments

Comments
 (0)