-
Notifications
You must be signed in to change notification settings - Fork 11
Refactor and generalize loss.py
#635
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @ppegolo! I think it looks great, have some comments for discussion inline, one thing which is missing for approval though is tests
src/metatrain/utils/custom_loss.py
Outdated
# Initialize the base loss weight on the first call | ||
if not self.scheduler.initialized: | ||
self.sliding_weight = self.scheduler.initialize(self.base, targets) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could go to the constructor (__init__
) so it isn't called every time compute
is used..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's called there because we have access to the targets only at compute, not at the time the loss is initialized in the trainer. Maybe there's a way to redesign it, but we've implemented it this way to have minimal change with respect to the current implementation of Trainer
src/metatrain/utils/custom_loss.py
Outdated
self.base = base_loss | ||
self.scheduler = scheduler | ||
self.target = base_loss.target | ||
self.reduction = base_loss.reduction |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're always composing / delegatin then it might be nicer to store the object and use it, e.g.
self._base_loss = base_loss
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just in case we forget to delegate anything else later.
src/metatrain/utils/custom_loss.py
Outdated
""" | ||
Metaclass to auto-register :py:class:`LossInterface` subclasses. | ||
|
||
Maintains a mapping from ``registry_name`` to the subclass type. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think perhaps, if we're only discovering internal plugins, then an enum
might be clearer (single source of truth). So something like:
class CoreLoss(Enum):
MSE = TensorMapMSELoss
...
If / when we want to allow third party loss functions (i.e. in a separate pip install
-able package) then we'd add in a Registry
then which discovers plugins off of entry_points(group='metatrain.losses')
and provides a unified view of all the loss functions present?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this is a minor design nit, I think we can easily do that later too... mostly I'm thinking an explicit interface for internal losses is clearer (and in keeping with the "Zen of Python") rather than auto-discovery here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
enum or even just a (factory) function somewhere doing
def get_loss(name, hypers):
if name == "mse":
return TensorMapMSELoss(**hypers)
...
We can then extend the factor however we want down the line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comments! If we need to have a register at some point when we'll allow for external loss functions, is it so bad to have it already now, to avoid having to re-implement it then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure we'll want to use a register for external loss function. The goal is not so much to have user A use the loss function provided by user B in some other python package, but rather allow user A to directly provide a python script with a loss. This can be done without any kind of global registry
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW regarding external package entrypoint registration there's also this (draft) spec. https://scientific-python.org/specs/spec-0002/
src/metatrain/utils/custom_loss.py
Outdated
# Use explicit registry_name if given, else snake_case from class name | ||
key = getattr(cls, "registry_name", None) | ||
if key is None: | ||
key = "".join( | ||
f"_{c.lower()}" if c.isupper() else c for c in name | ||
).lstrip("_") | ||
# only register the very first class under each key | ||
mcs._registry.setdefault(key, cls) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Use explicit registry_name if given, else snake_case from class name | |
key = getattr(cls, "registry_name", None) | |
if key is None: | |
key = "".join( | |
f"_{c.lower()}" if c.isupper() else c for c in name | |
).lstrip("_") | |
# only register the very first class under each key | |
mcs._registry.setdefault(key, cls) | |
if name != "LossInterface" and issubclass(cls, LossInterface): | |
if "registry_name" not in attrs: | |
raise TypeError(f"Class '{name}' must define a 'registry_name' class attribute") |
Indentation might be off. It is good to have a fallback in general, but not if we only have a set number of loss functions in metatrain
, since we can easily enforce a fixed name for each new loss (esp. with the enum
suggestion).
If we end up with third party loss support, then this would make much more sense :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Luthaf what's your opinion on this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that names for the losses should be explicit, but I don't think we need to go through through a registration mechanism. For now, doing something like
def get_loss(name: str):
if name == "mse":
return MSELoss(...)
elif name == "mae":
return MAELoss(...)
elif name == "whatever":
return WhateverLoss(...)
else:
raise ValueError(f"unknown loss function {name}")
The main advantage of a registry would be to allow loss defined outside of metatrain, but this is not something we need yet, so let's apply YAGNI
@@ -211,5 +171,35 @@ | |||
"uniqueItems": true | |||
} | |||
}, | |||
"additionalProperties": false | |||
"additionalProperties": false, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has to go in any of the models that use the new loss function, right. Maybe it makes sense that we write a general json schema for the loss and only "import" those in the architecture ones.
I will figure out if there is a clever way to do it.
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
This comment was marked as resolved.
This comment was marked as resolved.
Sorry, something went wrong.
Fixes #629
Refactoring of the loss modules as discussed in #629 (with @SanggyuChong).
<target_name>
as top level section in theloss
field in the input fileMSELoss
,L1Loss
,HuberLoss
to a common interfaceContributor (creator of pull-request) checklist
Reviewer checklist
📚 Documentation preview 📚: https://metatrain--635.org.readthedocs.build/en/635/