diff --git a/src/ATen/native/xpu/sycl/Indexing.cpp b/src/ATen/native/xpu/sycl/Indexing.cpp index 47a7f76e5..c940dd1a9 100644 --- a/src/ATen/native/xpu/sycl/Indexing.cpp +++ b/src/ATen/native/xpu/sycl/Indexing.cpp @@ -609,6 +609,21 @@ void index_put_kernel( } } +DimVector valsShape( + IntArrayRef self_sizes, + int64_t dims_before, + int64_t dims_indexed, + IntArrayRef replacement_shape) { + auto shape = DimVector(self_sizes); + int64_t end = dims_before + dims_indexed; + shape.erase(shape.begin() + dims_before, shape.begin() + end); + shape.insert( + shape.begin() + dims_before, + replacement_shape.begin(), + replacement_shape.end()); + return shape; +} + void index_put_deterministic_kernel( Tensor& self, const c10::List>& indices, @@ -633,30 +648,21 @@ void index_put_deterministic_kernel( bool self_contiguous = self.is_contiguous(); auto self_ = self_contiguous ? self : self.contiguous(); Tensor linearIndex, src, expandedValue = value; - int64_t nElemBefore, strideBefore, sliceSize; + int64_t nElemBefore, strideBefore, sliceSize, dims_before, dims_indexed; std::vector inversePerm; std::tie( - linearIndex, src, nElemBefore, strideBefore, sliceSize, inversePerm) = - makeLinearIndex(self_, indices, !unsafe); + linearIndex, + src, + nElemBefore, + strideBefore, + sliceSize, + inversePerm, + dims_before, + dims_indexed) = makeLinearIndex(self_, indices, !unsafe); + auto vals_shape = + valsShape(src.sizes(), dims_before, dims_indexed, linearIndex.sizes()); int64_t num_indices = linearIndex.numel(); - - if (expandedValue.numel() < num_indices * nElemBefore * sliceSize) { - auto expanded_size = at::DimVector(expandedValue.sizes()); - - auto size1 = expandedValue.sizes(); - auto size2 = linearIndex.sizes(); - if (are_expandable(size1, size2)) { - expanded_size = infer_size_dimvector(size1, size2); - } - if (nElemBefore > 1) { - expanded_size.insert(expanded_size.begin(), nElemBefore); - } - if (sliceSize > 1) { - expanded_size.insert(expanded_size.end(), sliceSize); - } - expandedValue = expandedValue.expand(expanded_size); - } - expandedValue = expandedValue.contiguous(); + expandedValue = expandedValue.expand(vals_shape).contiguous(); if (num_indices > 0 && sliceSize > 0) { const bool permuted = !src.is_contiguous(); diff --git a/src/ATen/native/xpu/sycl/IndexingUtils.h b/src/ATen/native/xpu/sycl/IndexingUtils.h index 1c6d9c373..e4bab3ff0 100644 --- a/src/ATen/native/xpu/sycl/IndexingUtils.h +++ b/src/ATen/native/xpu/sycl/IndexingUtils.h @@ -57,10 +57,8 @@ static std::vector computeLinearStride(const Tensor& tensor) { return stride; } -static std::tuple computeLinearIndex( - const Tensor& src, - TensorList indices, - bool check_range) { +static std::tuple +computeLinearIndex(const Tensor& src, TensorList indices, bool check_range) { auto strides = computeLinearStride(src); const auto& device = src.options().device(); @@ -70,8 +68,10 @@ static std::tuple computeLinearIndex( // are not being index. Tensor linearIndex; int64_t nElemBefore = 1, nElemAfter = 1, strideBefore = 0; + int64_t dims_before = 0, dims_indexed = 0; for (const auto i : c10::irange(src.dim())) { if (indices[i].defined()) { + dims_indexed++; // Cast index to the longType matching src's device // This allows us to support ie indexing a xpu tensor with a cpu tensor Tensor index = @@ -88,17 +88,30 @@ static std::tuple computeLinearIndex( } else if (linearIndex.defined()) { nElemAfter *= src.size(i); } else { + dims_before++; nElemBefore *= src.size(i); } } return std::make_tuple( - std::move(linearIndex), nElemBefore, strideBefore, nElemAfter); + std::move(linearIndex), + nElemBefore, + strideBefore, + nElemAfter, + dims_before, + dims_indexed); } -static std:: - tuple> - makeLinearIndex(Tensor self, IOptTensorListRef orig, bool check_range) { +static std::tuple< + Tensor, + Tensor, + int64_t, + int64_t, + int64_t, + std::vector, + int64_t, + int64_t> +makeLinearIndex(Tensor self, IOptTensorListRef orig, bool check_range) { checkIndexTensorTypes(orig, /*allow_int*/ true); // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more // LongTensors @@ -121,10 +134,22 @@ static std:: std::tie(self, indices, inversePerm) = transposeToFrontAndInvPerm(self, indices); } - auto [linearIndex, nElemBefore, strideBefore, nElemAfter] = - computeLinearIndex(self, indices, check_range); + auto + [linearIndex, + nElemBefore, + strideBefore, + nElemAfter, + dims_before, + dims_indexed] = computeLinearIndex(self, indices, check_range); return std::make_tuple( - linearIndex, self, nElemBefore, strideBefore, nElemAfter, inversePerm); + linearIndex, + self, + nElemBefore, + strideBefore, + nElemAfter, + inversePerm, + dims_before, + dims_indexed); } } // namespace at::native::xpu diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index da15d41ab..77b875e38 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -1031,8 +1031,6 @@ # https://github.com/intel/torch-xpu-ops/issues/461 "test_index_put_src_datatype_xpu_float8_e5m2", "test_index_put_src_datatype_xpu_float8_e4m3fn", - # https://github.com/intel/torch-xpu-ops/issues/1702 - "test_index_put_deterministic_with_optional_tensors_xpu", ), "nn/test_pooling_xpu.py": None, "nn/test_dropout_xpu.py": None, diff --git a/test/xpu/test_indexing_xpu.py b/test/xpu/test_indexing_xpu.py index 338922db3..cfa9eccc2 100644 --- a/test/xpu/test_indexing_xpu.py +++ b/test/xpu/test_indexing_xpu.py @@ -1,7 +1,7 @@ # Owner(s): ["module: intel"] from torch.testing._internal.common_device_type import instantiate_device_type_tests -from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import DeterministicGuard, run_tests try: from xpu_test_utils import XPUPatchForImport @@ -14,15 +14,15 @@ torch.Tensor.is_cuda = torch.Tensor.is_xpu - def __test_index_put_accumulate_with_optional_tensors(self, device): - # TODO: replace with a better solution. - # Currently, here using torchscript to put None into indices. - # on C++ it gives indices as a list of 2 optional tensors: first is null and - # the second is a valid tensor. - @torch.jit.script + def __test_index_put_deterministic_with_optional_tensors(self, device): def func(x, i, v): - idx = [None, i] - x.index_put_(idx, v, accumulate=True) + with DeterministicGuard(True): + x[..., i] = v + return x + + def func1(x, i, v): + with DeterministicGuard(True): + x[i] = v return x n = 4 @@ -32,17 +32,38 @@ def func(x, i, v): indices_dev = indices.to(device) value0d = torch.tensor(10.0) value1d = torch.tensor([1.0, 2.0]) + values2d = torch.randn(n, 1) - out_cuda = func(t_dev, indices_dev, value0d.xpu()) - out_cpu = func(t, indices, value0d) + for val in (value0d, value1d, values2d): + out_cuda = func(t_dev, indices_dev, val.to(device)) + out_cpu = func(t, indices, val) + self.assertEqual(out_cuda.cpu(), out_cpu) + + t = torch.zeros((5, 4)) + t_dev = t.to(device) + indices = torch.tensor([1, 4, 3]) + indices_dev = indices.to(device) + val = torch.randn(4) + out_cuda = func1(t_dev, indices_dev, val.xpu()) + out_cpu = func1(t, indices, val) self.assertEqual(out_cuda.cpu(), out_cpu) - out_cuda = func(t_dev, indices_dev, value1d.xpu()) - out_cpu = func(t, indices, value1d) + t = torch.zeros(2, 3, 4) + ind = torch.tensor([0, 1]) + val = torch.randn(6, 2) + with self.assertRaisesRegex(RuntimeError, "shape mismatch"): + func(t, ind, val) + + with self.assertRaisesRegex(RuntimeError, "must match"): + func(t.to(device), ind.to(device), val.to(device)) + + val = torch.randn(2, 3, 1) + out_cuda = func1(t.to(device), ind.to(device), val.to(device)) + out_cpu = func1(t, ind, val) self.assertEqual(out_cuda.cpu(), out_cpu) - TestIndexing.test_index_put_accumulate_with_optional_tensors = ( - __test_index_put_accumulate_with_optional_tensors + TestIndexing.test_index_put_deterministic_with_optional_tensors = ( + __test_index_put_deterministic_with_optional_tensors ) instantiate_device_type_tests(NumpyTests, globals(), only_for=("xpu"), allow_xpu=True)