Skip to content

flip: improve error handling and error messages. #9550

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2497,6 +2497,20 @@ def test_div_raises_error_on_invalid_rounding_mode(self):
"'trunc', 'floor', or be left unspecified.")
self.assertEqual(str(e), expected_error)

def test_flip_raises_error_on_duplicated_dims(self):
a = torch.rand(2, 2, 2, 2, device=torch_xla.device())
dims = [0, 0, 0, 1, 2, 3, -1]
dims_suggestion = [0, 1, 2, 3]

try:
torch.flip(a, dims=dims)
except RuntimeError as e:
expected_error = (
"flip(): expected each dimension to appear at most once. Found "
"dimensions: 0 (3 times), 3 (2 times). Consider changing dims "
f"from {dims} to {dims_suggestion}.")
self.assertEqual(str(e), expected_error)


class MNISTComparator(nn.Module):

Expand Down
6 changes: 4 additions & 2 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1804,8 +1804,10 @@ at::Tensor& XLANativeFunctions::fill_(at::Tensor& self,
at::Tensor XLANativeFunctions::flip(const at::Tensor& self,
at::IntArrayRef dims) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
return bridge::AtenFromXlaTensor(tensor_methods::flip(
GetValueOrThrow(bridge::GetXlaTensor(self)), XlaHelpers::I64List(dims)));
auto xself = GetValueOrThrow(bridge::GetXlaTensor(self));
auto output =
GetValueOrThrow(tensor_methods::flip(xself, XlaHelpers::I64List(dims)));
return bridge::AtenFromXlaTensor(std::move(output));
}

at::Tensor XLANativeFunctions::floor_divide(const at::Tensor& self,
Expand Down
75 changes: 69 additions & 6 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <algorithm>
#include <functional>
#include <iterator>

#include "absl/log/absl_check.h"
#include "absl/strings/str_cat.h"
Expand Down Expand Up @@ -345,6 +346,69 @@ XLATensorPtr DispatchComparisonOp(c10::Symbol kind, const XLATensorPtr& input,
return XLATensor::Create(node, input->GetDevice(), at::ScalarType::Bool);
}

// Checks that the canonical dimensions out of the given dimensions are unique
// for the `flip` operation.
//
// This function fails if any canonical dimension appears more than once.
// Notice that its error message is specialized for the `flip` operation.
//
// @param rank Input rank
// @param dims (Error Message) `flip` operation original `dims` argument
// @param canonical_dims (Error Message) Canonical dimensions extracted from
// the `dims` argument
absl::Status CheckFlipDimensionsAreUnique(
int64_t rank, absl::Span<const int64_t> dims,
absl::Span<const int64_t> canonical_dims) {
// Counter that maps each given dimension to the number of times it has
// appeared.
std::vector<int64_t> count(rank, 0);

// Count the number of times each dimension appears.
for (auto dim : canonical_dims) {
count[dim] += 1;
}

bool any_dimension_appears_more_than_once = std::any_of(
count.begin(), count.end(), [](const auto n) { return n > 1; });

if (any_dimension_appears_more_than_once) {
// Suggestion for the value of dims that wouldn't raise an error.
std::vector<int64_t> dims_suggestion;
// Each "bad" dimension is represented as a string of the form:
//
// <dimension> (<count> times)
//
// To be later joined with commas.
std::vector<std::string> bad_count_str;

// Iterates each dimension, populating both `dims_suggestion` and
// `bad_count_str`.
for (int64_t i : c10::irange(rank)) {
// Dimension does not appear. Do nothing.
if (count[i] == 0) {
continue;
}

// Dimension appears in `dims`. Add it to the suggestion list.
dims_suggestion.push_back(i);

// Dimension appears more than once. Add it to the "bad" list.
if (count[i] > 1) {
bad_count_str.push_back(absl::StrCat(i, " (", count[i], " times)"));
}
}

return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
"flip(): expected each dimension to appear at most once. Found "
"dimensions: ",
absl::StrJoin(bad_count_str, /* sep= */ ", "),
". Consider changing dims from [", absl::StrJoin(dims, /* sep= */ ", "),
"] to [", absl::StrJoin(dims_suggestion, /* sep= */ ", "), "].")));
}

return absl::OkStatus();
}

} // namespace

//////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -1680,12 +1744,11 @@ void fill_(XLATensorPtr& input, const at::Scalar& value) {
input->SetInPlaceIrValue(std::move(constant));
}

XLATensorPtr flip(const XLATensorPtr& input, absl::Span<const int64_t> dims) {
auto dimensions = torch::lazy::GetCanonicalDimensionIndices(
torch_xla::runtime::util::ToVector<int64_t>(dims),
input->shape().get().dimensions_size());
std::set<int64_t> unique_dims(dimensions.begin(), dimensions.end());
XLA_CHECK_EQ(unique_dims.size(), dimensions.size());
absl::StatusOr<absl_nonnull XLATensorPtr> flip(const XLATensorPtr& input,
absl::Span<const int64_t> dims) {
auto rank = input->shape().get().dimensions_size();
auto dimensions = torch::lazy::GetCanonicalDimensionIndices(dims, rank);
XLA_RETURN_IF_ERROR(CheckFlipDimensionsAreUnique(rank, dims, dimensions));
return input->CreateFrom(
torch_xla::MakeNode<Flip>(input->GetIrValue(), dimensions));
}
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,8 @@ void eye_out(XLATensorPtr& out, int64_t lines, int64_t cols);
void fill_(XLATensorPtr& input, const at::Scalar& value);

// Flips (reverses) the values in the dimensions of the input tensor.
XLATensorPtr flip(const XLATensorPtr& input, absl::Span<const int64_t> dims);
absl::StatusOr<absl_nonnull XLATensorPtr> flip(const XLATensorPtr& input,
absl::Span<const int64_t> dims);

XLATensorPtr fmod(
const XLATensorPtr& input, const XLATensorPtr& other,
Expand Down
Loading