Skip to content

Conversation

@MattsonCam
Copy link
Member

Added the ability to compute the exponential moving average of the mask weights in the BCE loss function. Also removed the ability to specify custom mask weights.

Cameron Mattson added 2 commits November 17, 2025 13:37
to the loss function and removed the ability to
specify custom mask weights
@MattsonCam MattsonCam requested a review from wli51 November 17, 2025 21:10
Copy link

@wli51 wli51 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Please see my comments for minor things.

self.device = (
device if isinstance(device, torch.device) else torch.device(device)
)
self.mask_weights_alpha = mask_weights_alpha
Copy link

@wli51 wli51 Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you are not doing any checks here and i see from your optuna trails that this gets specified by the sampler, it is the safest to cast all weights explicitly to the fp32 dtype (or dtype of your model weights). I think torch automatically casts numpy values to the equivalent tensor precision and in multiplication the lower precision tenor to match the higher precision tensor.

In the event that the mask_weights_alpha happened to be np.float64, multiplying them with your torch.ones(3, dtype=torch.float32) weights would make the weights torch.float64 and can cause downstream errors if it hits any layer/metric/loss that does not like different precisions for input/target/nn module weights. Since you are weighting after the the BCE loss you shouldn't get any precision errors, though if you loss.backward() from a float64 loss I think torch by default computes the gradients in 64bit precision too, which is expensive (twice as slow on data center grade gpus like A100 and H100s and many times more slower on consumer grade cards we have) and not useful if it gets applied back to float32 models.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hadn't thought about this, but it is good to know. Do you know how this behaves with torch AMP?

Copy link

@wli51 wli51 Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If torch auto-casts the weight alpha/weight to float32 or 16 thats probably fine. However, torch AMP is explicitly documented to not work with flaot64 document page. My guess is that unless some casting is explicitly enforced the float64 intermediates will eventually hit some float32 or float16 module weights and result in an error.

Cameron Mattson added 5 commits November 21, 2025 11:41
predicted images to semantic segmentations and
added ability to filter the original target image
in the whole image dataset
dataloader to compute initial semantic mask weights.
Also handle the situation where a semantic mask has
no pixels by setting that mask's weight to a small number
in the weight exponential moving average.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants