diff --git a/test/test_operations.py b/test/test_operations.py index 4c0395ff286..cb790a07414 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2497,6 +2497,20 @@ def test_div_raises_error_on_invalid_rounding_mode(self): "'trunc', 'floor', or be left unspecified.") self.assertEqual(str(e), expected_error) + def test_flip_raises_error_on_duplicated_dims(self): + a = torch.rand(2, 2, 2, 2, device=torch_xla.device()) + dims = [0, 0, 0, 1, 2, 3, -1] + dims_suggestion = [0, 1, 2, 3] + + try: + torch.flip(a, dims=dims) + except RuntimeError as e: + expected_error = ( + "flip(): expected each dimension to appear at most once. Found " + "dimensions: 0 (3 times), 3 (2 times). Consider changing dims " + f"from {dims} to {dims_suggestion}.") + self.assertEqual(str(e), expected_error) + class MNISTComparator(nn.Module): diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 4d5286c2b04..7d8ba352059 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1804,8 +1804,10 @@ at::Tensor& XLANativeFunctions::fill_(at::Tensor& self, at::Tensor XLANativeFunctions::flip(const at::Tensor& self, at::IntArrayRef dims) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor(tensor_methods::flip( - GetValueOrThrow(bridge::GetXlaTensor(self)), XlaHelpers::I64List(dims))); + auto xself = GetValueOrThrow(bridge::GetXlaTensor(self)); + auto output = + GetValueOrThrow(tensor_methods::flip(xself, XlaHelpers::I64List(dims))); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::floor_divide(const at::Tensor& self, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index ffeff7bab88..4a749d50ac7 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -8,6 +8,7 @@ #include #include +#include #include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" @@ -345,6 +346,69 @@ XLATensorPtr DispatchComparisonOp(c10::Symbol kind, const XLATensorPtr& input, return XLATensor::Create(node, input->GetDevice(), at::ScalarType::Bool); } +// Checks that the canonical dimensions out of the given dimensions are unique +// for the `flip` operation. +// +// This function fails if any canonical dimension appears more than once. +// Notice that its error message is specialized for the `flip` operation. +// +// @param rank Input rank +// @param dims (Error Message) `flip` operation original `dims` argument +// @param canonical_dims (Error Message) Canonical dimensions extracted from +// the `dims` argument +absl::Status CheckFlipDimensionsAreUnique( + int64_t rank, absl::Span dims, + absl::Span canonical_dims) { + // Counter that maps each given dimension to the number of times it has + // appeared. + std::vector count(rank, 0); + + // Count the number of times each dimension appears. + for (auto dim : canonical_dims) { + count[dim] += 1; + } + + bool any_dimension_appears_more_than_once = std::any_of( + count.begin(), count.end(), [](const auto n) { return n > 1; }); + + if (any_dimension_appears_more_than_once) { + // Suggestion for the value of dims that wouldn't raise an error. + std::vector dims_suggestion; + // Each "bad" dimension is represented as a string of the form: + // + // ( times) + // + // To be later joined with commas. + std::vector bad_count_str; + + // Iterates each dimension, populating both `dims_suggestion` and + // `bad_count_str`. + for (int64_t i : c10::irange(rank)) { + // Dimension does not appear. Do nothing. + if (count[i] == 0) { + continue; + } + + // Dimension appears in `dims`. Add it to the suggestion list. + dims_suggestion.push_back(i); + + // Dimension appears more than once. Add it to the "bad" list. + if (count[i] > 1) { + bad_count_str.push_back(absl::StrCat(i, " (", count[i], " times)")); + } + } + + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + "flip(): expected each dimension to appear at most once. Found " + "dimensions: ", + absl::StrJoin(bad_count_str, /* sep= */ ", "), + ". Consider changing dims from [", absl::StrJoin(dims, /* sep= */ ", "), + "] to [", absl::StrJoin(dims_suggestion, /* sep= */ ", "), "]."))); + } + + return absl::OkStatus(); +} + } // namespace ////////////////////////////////////////////////////////////////////////////// @@ -1680,12 +1744,11 @@ void fill_(XLATensorPtr& input, const at::Scalar& value) { input->SetInPlaceIrValue(std::move(constant)); } -XLATensorPtr flip(const XLATensorPtr& input, absl::Span dims) { - auto dimensions = torch::lazy::GetCanonicalDimensionIndices( - torch_xla::runtime::util::ToVector(dims), - input->shape().get().dimensions_size()); - std::set unique_dims(dimensions.begin(), dimensions.end()); - XLA_CHECK_EQ(unique_dims.size(), dimensions.size()); +absl::StatusOr flip(const XLATensorPtr& input, + absl::Span dims) { + auto rank = input->shape().get().dimensions_size(); + auto dimensions = torch::lazy::GetCanonicalDimensionIndices(dims, rank); + XLA_RETURN_IF_ERROR(CheckFlipDimensionsAreUnique(rank, dims, dimensions)); return input->CreateFrom( torch_xla::MakeNode(input->GetIrValue(), dimensions)); } diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 395768fc867..fb7eae93f8d 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -450,7 +450,8 @@ void eye_out(XLATensorPtr& out, int64_t lines, int64_t cols); void fill_(XLATensorPtr& input, const at::Scalar& value); // Flips (reverses) the values in the dimensions of the input tensor. -XLATensorPtr flip(const XLATensorPtr& input, absl::Span dims); +absl::StatusOr flip(const XLATensorPtr& input, + absl::Span dims); XLATensorPtr fmod( const XLATensorPtr& input, const XLATensorPtr& other,