Skip to content

Avoid unnecessary fallback in _bincount when deterministic mode is enabled on CUDA (PyTorch β‰₯ 2.1)Β #3086

@hyukkyukang

Description

@hyukkyukang

πŸš€ Feature

Improve _bincount utility to avoid unnecessary fallback on CUDA when deterministic mode is enabled and conditions are safe for native torch.bincount use (PyTorch β‰₯ 2.1).

Motivation

The current _bincount implementation in TorchMetrics falls back to a slower and more memory-intensive workaround when torch.are_deterministic_algorithms_enabled() is set to True, regardless of the PyTorch version or backend.

However, since PyTorch v2.1, torch.bincount is allowed in deterministic mode on CUDA as long as:

  • No weights are passed
  • Gradients are not required

Avoiding the fallback in this case would improve performance and reduce memory usage.
This is particularly relevant when running large-scale evaluations on modern GPU systems.

Pitch

Update the _bincount utility logic to:

  • Use native torch.bincount if:
    • x.is_cuda is True
    • torch.__version__ >= 2.1
    • No weights are involved
    • Gradients are not required
  • Only fall back when:
    • x.is_mps or
    • XLA backend is detected or
    • PyTorch version is < 2.1 and deterministic algorithms are enabled

Alternatives

Continue using the current fallback unconditionally under deterministic mode, but this leads to unnecessary compute and memory overhead on newer CUDA-enabled systems.

Additional context

This proposed change aligns with the improvements introduced in PyTorch PR #105244, which enabled deterministic torch.bincount on CUDA under safe conditions starting from v2.1.

A PR will follow shortly to implement this enhancement.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions