diff --git a/toolkit/optimizer.py b/toolkit/optimizer.py index 75d1442d6..270d6316e 100644 --- a/toolkit/optimizer.py +++ b/toolkit/optimizer.py @@ -1,5 +1,5 @@ import torch - +import importlib def get_optimizer( params, @@ -39,7 +39,7 @@ def get_optimizer( # let net be the neural network you want to train # you can choose weight decay value based on your problem, 0 by default optimizer = Prodigy8bit(params, lr=use_lr, eps=1e-6, **optimizer_params) - elif lower_type.startswith("prodigy"): + elif lower_type == "prodigy": from prodigyopt import Prodigy print("Using Prodigy optimizer") @@ -60,19 +60,18 @@ def get_optimizer( from toolkit.optimizers.adam8bit import Adam8bit optimizer = Adam8bit(params, lr=learning_rate, eps=1e-6, decouple=True, **optimizer_params) - elif lower_type.endswith("8bit"): + elif lower_type == "adam8bit": import bitsandbytes - - if lower_type == "adam8bit": - return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params) - if lower_type == "ademamix8bit": - return bitsandbytes.optim.AdEMAMix8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params) - elif lower_type == "adamw8bit": - return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params) - elif lower_type == "lion8bit": - return bitsandbytes.optim.Lion8bit(params, lr=learning_rate, **optimizer_params) - else: - raise ValueError(f'Unknown optimizer type {optimizer_type}') + return bitsandbytes.optim.Adam8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params) + elif lower_type == "ademamix8bit": + import bitsandbytes + return bitsandbytes.optim.AdEMAMix8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params) + elif lower_type == "adamw8bit": + import bitsandbytes + return bitsandbytes.optim.AdamW8bit(params, lr=learning_rate, eps=1e-6, **optimizer_params) + elif lower_type == "lion8bit": + import bitsandbytes + return bitsandbytes.optim.Lion8bit(params, lr=learning_rate, **optimizer_params) elif lower_type == 'adam': optimizer = torch.optim.Adam(params, lr=float(learning_rate), eps=1e-6, **optimizer_params) elif lower_type == 'adamw': @@ -98,5 +97,33 @@ def get_optimizer( from toolkit.optimizers.automagic import Automagic optimizer = Automagic(params, lr=float(learning_rate), **optimizer_params) else: - raise ValueError(f'Unknown optimizer type {optimizer_type}') + # Try to dynamically import a user-defined optimizer + try: + # Split the string into module path and class name + parts = optimizer_type.split(".") + if len(parts) < 2: + raise ValueError(f"Unknown optimizer type {optimizer_type}") + + module_path = ".".join(parts[:-1]) + class_name = parts[-1] + + # Import module dynamically + mod = importlib.import_module(module_path) + + # Get optimizer class from module + opt_class = getattr(mod, class_name) + + # Instantiate optimizer + try: + optimizer = opt_class(params, lr=float(learning_rate), **optimizer_params) + except TypeError: + # In case the optimizer does not take lr or eps the same way + optimizer = opt_class(params, **optimizer_params) + + print(f"Using user-defined optimizer: {optimizer_type}") + return optimizer + + except Exception as e: + raise ValueError(f'Unknown optimizer type. Make sure your optimizer is installed in the virtual environment (venv). {optimizer_type}. ' + f'Failed to import dynamically. Error: {e}') return optimizer