Skip to content

Refactor and generalize loss.py #635

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 42 commits into
base: main
Choose a base branch
from
Open

Refactor and generalize loss.py #635

wants to merge 42 commits into from

Conversation

ppegolo
Copy link
Contributor

@ppegolo ppegolo commented Jun 23, 2025

Fixes #629

Refactoring of the loss modules as discussed in #629 (with @SanggyuChong).

  • dynamic lookup and registration as suggested by @HaoZeke
  • moved <target_name> as top level section in the loss field in the input file
  • grouped "torch-like" losses such as MSELoss, L1Loss, HuberLoss to a common interface

Contributor (creator of pull-request) checklist

  • Tests updated (for new features and bugfixes)?
  • Documentation updated (for new features)?
  • Issue referenced (for PRs that solve an issue)?

Reviewer checklist

  • CHANGELOG updated with public API or any other important changes?

📚 Documentation preview 📚: https://metatrain--635.org.readthedocs.build/en/635/

@ppegolo ppegolo marked this pull request as ready for review July 2, 2025 15:48
@ppegolo ppegolo requested a review from frostedoyster as a code owner July 2, 2025 15:48
@ppegolo ppegolo requested a review from HaoZeke July 3, 2025 07:26
Copy link

@HaoZeke HaoZeke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @ppegolo! I think it looks great, have some comments for discussion inline, one thing which is missing for approval though is tests

Comment on lines 214 to 216
# Initialize the base loss weight on the first call
if not self.scheduler.initialized:
self.sliding_weight = self.scheduler.initialize(self.base, targets)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could go to the constructor (__init__) so it isn't called every time compute is used..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's called there because we have access to the targets only at compute, not at the time the loss is initialized in the trainer. Maybe there's a way to redesign it, but we've implemented it this way to have minimal change with respect to the current implementation of Trainer

Comment on lines 201 to 204
self.base = base_loss
self.scheduler = scheduler
self.target = base_loss.target
self.reduction = base_loss.reduction
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're always composing / delegatin then it might be nicer to store the object and use it, e.g.
self._base_loss = base_loss

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just in case we forget to delegate anything else later.

"""
Metaclass to auto-register :py:class:`LossInterface` subclasses.

Maintains a mapping from ``registry_name`` to the subclass type.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think perhaps, if we're only discovering internal plugins, then an enum might be clearer (single source of truth). So something like:

class CoreLoss(Enum):
  MSE = TensorMapMSELoss
...

If / when we want to allow third party loss functions (i.e. in a separate pip install-able package) then we'd add in a Registry then which discovers plugins off of entry_points(group='metatrain.losses') and provides a unified view of all the loss functions present?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this is a minor design nit, I think we can easily do that later too... mostly I'm thinking an explicit interface for internal losses is clearer (and in keeping with the "Zen of Python") rather than auto-discovery here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enum or even just a (factory) function somewhere doing

def get_loss(name, hypers):
    if name == "mse":
        return TensorMapMSELoss(**hypers)
    ...

We can then extend the factor however we want down the line

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comments! If we need to have a register at some point when we'll allow for external loss functions, is it so bad to have it already now, to avoid having to re-implement it then?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we'll want to use a register for external loss function. The goal is not so much to have user A use the loss function provided by user B in some other python package, but rather allow user A to directly provide a python script with a loss. This can be done without any kind of global registry

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW regarding external package entrypoint registration there's also this (draft) spec. https://scientific-python.org/specs/spec-0002/

Comment on lines 25 to 32
# Use explicit registry_name if given, else snake_case from class name
key = getattr(cls, "registry_name", None)
if key is None:
key = "".join(
f"_{c.lower()}" if c.isupper() else c for c in name
).lstrip("_")
# only register the very first class under each key
mcs._registry.setdefault(key, cls)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Use explicit registry_name if given, else snake_case from class name
key = getattr(cls, "registry_name", None)
if key is None:
key = "".join(
f"_{c.lower()}" if c.isupper() else c for c in name
).lstrip("_")
# only register the very first class under each key
mcs._registry.setdefault(key, cls)
if name != "LossInterface" and issubclass(cls, LossInterface):
if "registry_name" not in attrs:
raise TypeError(f"Class '{name}' must define a 'registry_name' class attribute")

Indentation might be off. It is good to have a fallback in general, but not if we only have a set number of loss functions in metatrain, since we can easily enforce a fixed name for each new loss (esp. with the enum suggestion).

If we end up with third party loss support, then this would make much more sense :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Luthaf what's your opinion on this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that names for the losses should be explicit, but I don't think we need to go through through a registration mechanism. For now, doing something like

def get_loss(name: str):
    if name == "mse":
        return MSELoss(...)
    elif name == "mae":
        return MAELoss(...)
    elif name == "whatever":
        return WhateverLoss(...)
    else:
        raise ValueError(f"unknown loss function {name}")

The main advantage of a registry would be to allow loss defined outside of metatrain, but this is not something we need yet, so let's apply YAGNI

@@ -211,5 +171,35 @@
"uniqueItems": true
}
},
"additionalProperties": false
"additionalProperties": false,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has to go in any of the models that use the new loss function, right. Maybe it makes sense that we write a general json schema for the loss and only "import" those in the architecture ones.

I will figure out if there is a clever way to do it.

This comment was marked as resolved.

This comment was marked as resolved.

@jwa7 jwa7 requested a review from abmazitov as a code owner July 8, 2025 09:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Customize the loss when training with gradient descent
6 participants