From 0181a494e79fea44fe0cd58bf8c036b237176aa7 Mon Sep 17 00:00:00 2001 From: Marc Glisse Date: Mon, 13 Apr 2020 09:21:47 +0200 Subject: [PATCH] Handle PyTorchTensor[ndarray] Workaround for https://github.com/pytorch/pytorch/issues/34452 --- eagerpy/tensor/pytorch.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/eagerpy/tensor/pytorch.py b/eagerpy/tensor/pytorch.py index 17f082f..97b8015 100644 --- a/eagerpy/tensor/pytorch.py +++ b/eagerpy/tensor/pytorch.py @@ -554,8 +554,11 @@ def __ge__(self: TensorType, other: TensorOrScalar) -> TensorType: def __getitem__(self: TensorType, index: Any) -> TensorType: if isinstance(index, tuple): index = tuple(x.raw if isinstance(x, Tensor) else x for x in index) - elif isinstance(index, Tensor): - index = index.raw + else: + if isinstance(index, Tensor): + index = index.raw + if isinstance(index, np.ndarray): + index = torch.as_tensor(index) return type(self)(self.raw[index]) def take_along_axis(self: TensorType, index: TensorType, axis: int) -> TensorType: