-
Notifications
You must be signed in to change notification settings - Fork 448
Description
π Feature
Improve _bincount utility to avoid unnecessary fallback on CUDA when deterministic mode is enabled and conditions are safe for native torch.bincount
use (PyTorch β₯ 2.1).
Motivation
The current _bincount
implementation in TorchMetrics falls back to a slower and more memory-intensive workaround when torch.are_deterministic_algorithms_enabled()
is set to True
, regardless of the PyTorch version or backend.
However, since PyTorch v2.1, torch.bincount
is allowed in deterministic mode on CUDA as long as:
- No weights are passed
- Gradients are not required
Avoiding the fallback in this case would improve performance and reduce memory usage.
This is particularly relevant when running large-scale evaluations on modern GPU systems.
Pitch
Update the _bincount
utility logic to:
- Use native
torch.bincount
if:x.is_cuda
isTrue
torch.__version__ >= 2.1
- No weights are involved
- Gradients are not required
- Only fall back when:
x.is_mps
or- XLA backend is detected or
- PyTorch version is < 2.1 and deterministic algorithms are enabled
Alternatives
Continue using the current fallback unconditionally under deterministic mode, but this leads to unnecessary compute and memory overhead on newer CUDA-enabled systems.
Additional context
This proposed change aligns with the improvements introduced in PyTorch PR #105244, which enabled deterministic torch.bincount
on CUDA under safe conditions starting from v2.1.
A PR will follow shortly to implement this enhancement.