-
Notifications
You must be signed in to change notification settings - Fork 1
Modified mask weights functionality #5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
to the loss function and removed the ability to specify custom mask weights
wli51
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Please see my comments for minor things.
| self.device = ( | ||
| device if isinstance(device, torch.device) else torch.device(device) | ||
| ) | ||
| self.mask_weights_alpha = mask_weights_alpha |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you are not doing any checks here and i see from your optuna trails that this gets specified by the sampler, it is the safest to cast all weights explicitly to the fp32 dtype (or dtype of your model weights). I think torch automatically casts numpy values to the equivalent tensor precision and in multiplication the lower precision tenor to match the higher precision tensor.
In the event that the mask_weights_alpha happened to be np.float64, multiplying them with your torch.ones(3, dtype=torch.float32) weights would make the weights torch.float64 and can cause downstream errors if it hits any layer/metric/loss that does not like different precisions for input/target/nn module weights. Since you are weighting after the the BCE loss you shouldn't get any precision errors, though if you loss.backward() from a float64 loss I think torch by default computes the gradients in 64bit precision too, which is expensive (twice as slow on data center grade gpus like A100 and H100s and many times more slower on consumer grade cards we have) and not useful if it gets applied back to float32 models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hadn't thought about this, but it is good to know. Do you know how this behaves with torch AMP?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If torch auto-casts the weight alpha/weight to float32 or 16 thats probably fine. However, torch AMP is explicitly documented to not work with flaot64 document page. My guess is that unless some casting is explicitly enforced the float64 intermediates will eventually hit some float32 or float16 module weights and result in an error.
predicted images to semantic segmentations and added ability to filter the original target image in the whole image dataset
dataloader to compute initial semantic mask weights. Also handle the situation where a semantic mask has no pixels by setting that mask's weight to a small number in the weight exponential moving average.
doesn't upcast
Added the ability to compute the exponential moving average of the mask weights in the BCE loss function. Also removed the ability to specify custom mask weights.