You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Since torch 2.8, torch.bmm supports an out_dtype parameter. From the docs "the dtype of the output tensor, Supported only on CUDA and for torch.float32 given torch.float16/torch.bfloat16 input dtypes"