Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 27 additions & 21 deletions src/ATen/native/xpu/sycl/Indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,21 @@ void index_put_kernel(
}
}

DimVector valsShape(
IntArrayRef self_sizes,
int64_t dims_before,
int64_t dims_indexed,
IntArrayRef replacement_shape) {
auto shape = DimVector(self_sizes);
int64_t end = dims_before + dims_indexed;
shape.erase(shape.begin() + dims_before, shape.begin() + end);
shape.insert(
shape.begin() + dims_before,
replacement_shape.begin(),
replacement_shape.end());
return shape;
}

void index_put_deterministic_kernel(
Tensor& self,
const c10::List<std::optional<Tensor>>& indices,
Expand All @@ -633,30 +648,21 @@ void index_put_deterministic_kernel(
bool self_contiguous = self.is_contiguous();
auto self_ = self_contiguous ? self : self.contiguous();
Tensor linearIndex, src, expandedValue = value;
int64_t nElemBefore, strideBefore, sliceSize;
int64_t nElemBefore, strideBefore, sliceSize, dims_before, dims_indexed;
std::vector<int64_t> inversePerm;
std::tie(
linearIndex, src, nElemBefore, strideBefore, sliceSize, inversePerm) =
makeLinearIndex(self_, indices, !unsafe);
linearIndex,
src,
nElemBefore,
strideBefore,
sliceSize,
inversePerm,
dims_before,
dims_indexed) = makeLinearIndex(self_, indices, !unsafe);
auto vals_shape =
valsShape(src.sizes(), dims_before, dims_indexed, linearIndex.sizes());
int64_t num_indices = linearIndex.numel();

if (expandedValue.numel() < num_indices * nElemBefore * sliceSize) {
auto expanded_size = at::DimVector(expandedValue.sizes());

auto size1 = expandedValue.sizes();
auto size2 = linearIndex.sizes();
if (are_expandable(size1, size2)) {
expanded_size = infer_size_dimvector(size1, size2);
}
if (nElemBefore > 1) {
expanded_size.insert(expanded_size.begin(), nElemBefore);
}
if (sliceSize > 1) {
expanded_size.insert(expanded_size.end(), sliceSize);
}
expandedValue = expandedValue.expand(expanded_size);
}
expandedValue = expandedValue.contiguous();
expandedValue = expandedValue.expand(vals_shape).contiguous();

if (num_indices > 0 && sliceSize > 0) {
const bool permuted = !src.is_contiguous();
Expand Down
47 changes: 36 additions & 11 deletions src/ATen/native/xpu/sycl/IndexingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,8 @@ static std::vector<int64_t> computeLinearStride(const Tensor& tensor) {
return stride;
}

static std::tuple<Tensor, int64_t, int64_t, int64_t> computeLinearIndex(
const Tensor& src,
TensorList indices,
bool check_range) {
static std::tuple<Tensor, int64_t, int64_t, int64_t, int64_t, int64_t>
computeLinearIndex(const Tensor& src, TensorList indices, bool check_range) {
auto strides = computeLinearStride(src);
const auto& device = src.options().device();

Expand All @@ -70,8 +68,10 @@ static std::tuple<Tensor, int64_t, int64_t, int64_t> computeLinearIndex(
// are not being index.
Tensor linearIndex;
int64_t nElemBefore = 1, nElemAfter = 1, strideBefore = 0;
int64_t dims_before = 0, dims_indexed = 0;
for (const auto i : c10::irange(src.dim())) {
if (indices[i].defined()) {
dims_indexed++;
// Cast index to the longType matching src's device
// This allows us to support ie indexing a xpu tensor with a cpu tensor
Tensor index =
Expand All @@ -88,17 +88,30 @@ static std::tuple<Tensor, int64_t, int64_t, int64_t> computeLinearIndex(
} else if (linearIndex.defined()) {
nElemAfter *= src.size(i);
} else {
dims_before++;
nElemBefore *= src.size(i);
}
}

return std::make_tuple(
std::move(linearIndex), nElemBefore, strideBefore, nElemAfter);
std::move(linearIndex),
nElemBefore,
strideBefore,
nElemAfter,
dims_before,
dims_indexed);
}

static std::
tuple<Tensor, Tensor, int64_t, int64_t, int64_t, std::vector<int64_t>>
makeLinearIndex(Tensor self, IOptTensorListRef orig, bool check_range) {
static std::tuple<
Tensor,
Tensor,
int64_t,
int64_t,
int64_t,
std::vector<int64_t>,
int64_t,
int64_t>
makeLinearIndex(Tensor self, IOptTensorListRef orig, bool check_range) {
checkIndexTensorTypes(orig, /*allow_int*/ true);
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more
// LongTensors
Expand All @@ -121,10 +134,22 @@ static std::
std::tie(self, indices, inversePerm) =
transposeToFrontAndInvPerm(self, indices);
}
auto [linearIndex, nElemBefore, strideBefore, nElemAfter] =
computeLinearIndex(self, indices, check_range);
auto
[linearIndex,
nElemBefore,
strideBefore,
nElemAfter,
dims_before,
dims_indexed] = computeLinearIndex(self, indices, check_range);
return std::make_tuple(
linearIndex, self, nElemBefore, strideBefore, nElemAfter, inversePerm);
linearIndex,
self,
nElemBefore,
strideBefore,
nElemAfter,
inversePerm,
dims_before,
dims_indexed);
}

} // namespace at::native::xpu
51 changes: 36 additions & 15 deletions test/xpu/test_indexing_xpu.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Owner(s): ["module: intel"]

from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.common_utils import DeterministicGuard

try:
from xpu_test_utils import XPUPatchForImport
Expand All @@ -13,16 +14,15 @@
from test_indexing import NumpyTests, TestIndexing

torch.Tensor.is_cuda = torch.Tensor.is_xpu

def __test_index_put_accumulate_with_optional_tensors(self, device):
# TODO: replace with a better solution.
# Currently, here using torchscript to put None into indices.
# on C++ it gives indices as a list of 2 optional tensors: first is null and
# the second is a valid tensor.
@torch.jit.script
def __test_index_put_deterministic_with_optional_tensors(self, device):

Check warning on line 17 in test/xpu/test_indexing_xpu.py

View workflow job for this annotation

GitHub Actions / preci-lint-check

FLAKE8 E301

expected 1 blank line, found 0 See https://www.flake8rules.com/rules/E301.html
def func(x, i, v):
idx = [None, i]
x.index_put_(idx, v, accumulate=True)
with DeterministicGuard(True):
x[..., i] = v
return x

def func1(x, i, v):
with DeterministicGuard(True):
x[i] = v
return x

n = 4
Expand All @@ -32,17 +32,38 @@
indices_dev = indices.to(device)
value0d = torch.tensor(10.0)
value1d = torch.tensor([1.0, 2.0])
values2d = torch.randn(n, 1)

for val in (value0d, value1d, values2d):
out_cuda = func(t_dev, indices_dev, val.to(device))
out_cpu = func(t, indices, val)
self.assertEqual(out_cuda.cpu(), out_cpu)

out_cuda = func(t_dev, indices_dev, value0d.xpu())
out_cpu = func(t, indices, value0d)
t = torch.zeros((5, 4))
t_dev = t.to(device)
indices = torch.tensor([1, 4, 3])
indices_dev = indices.to(device)
val = torch.randn(4)
out_cuda = func1(t_dev, indices_dev, val.xpu())
out_cpu = func1(t, indices, val)
self.assertEqual(out_cuda.cpu(), out_cpu)

out_cuda = func(t_dev, indices_dev, value1d.xpu())
out_cpu = func(t, indices, value1d)
t = torch.zeros(2, 3, 4)
ind = torch.tensor([0, 1])
val = torch.randn(6, 2)
with self.assertRaisesRegex(RuntimeError, "shape mismatch"):
func(t, ind, val)

with self.assertRaisesRegex(RuntimeError, "must match"):
func(t.to(device), ind.to(device), val.to(device))

val = torch.randn(2, 3, 1)
out_cuda = func1(t.to(device), ind.to(device), val.to(device))
out_cpu = func1(t, ind, val)
self.assertEqual(out_cuda.cpu(), out_cpu)

TestIndexing.test_index_put_accumulate_with_optional_tensors = (
__test_index_put_accumulate_with_optional_tensors
TestIndexing.test_index_put_deterministic_with_optional_tensors = (
__test_index_put_deterministic_with_optional_tensors
)

instantiate_device_type_tests(NumpyTests, globals(), only_for=("xpu"), allow_xpu=True)
Expand Down
Loading