[MPS] Add tensor::index_put op (#85672)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85672
Approved by: https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index df23895..ca875f9 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -6046,9 +6046,10 @@
class TestAdvancedIndexing(TestCase):
supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8]
+ supported_np_dtypes = [np.float32, np.float16, np.int64, np.int32, np.int16, np.uint8]
# examples from https://www.tutorialspoint.com/numpy/numpy_advanced_indexing.htm
- def test_indexing_1(self):
+ def test_indexing_get(self):
def helper(dtype):
x_cpu = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dtype)
x_mps = x_cpu.detach().clone().to("mps")
@@ -6107,7 +6108,148 @@
# self.assertEqual(res_cpu, res_mps, str(dtype))
# [helper(dtype) for dtype in self.supported_dtypes]
- # tests from test_indexing.py
+
+ def test_advanced_indexing_3D_get(self):
+ def helper(x_cpu):
+ x_mps = x_cpu.detach().clone().to("mps")
+ self.assertEqual(x_cpu[[1, 2], 3, :], x_mps[[1, 2], 3, :])
+ self.assertEqual(x_cpu[[0, 2], :, :], x_mps[[0, 2], :, :])
+ self.assertEqual(x_cpu[:, [1, 0], [1]], x_mps[:, [1, 0], [1]])
+
+ x_cpu = torch.tensor([[[0.1, 0.2, 0.3, 0.4],
+ [0.5, 0.6, 0.7, 0.8],
+ [0.9, 1.0, 1.1, 1.2],
+ [1.3, 1.4, 1.5, 1.6]],
+
+ [[2.0, 2.1, 2.2, 2.3],
+ [2.4, 2.5, 2.6, 2.7],
+ [2.8, 2.9, 3.0, 3.1],
+ [3.2, 3.3, 3.4, 3.5]],
+
+ [[4.0, 4.1, 4.2, 4.3],
+ [4.4, 4.5, 4.6, 4.7],
+ [4.8, 4.9, 5.0, 5.1],
+ [5.1, 5.2, 5.3, 5.4]]], device="cpu", dtype=torch.float32)
+ helper(x_cpu)
+ for idx in range(len(self.supported_np_dtypes)):
+ # torch.randn / torch.rand don't work with all dtypes
+ # Generate input data for all dtypes on Numpy them move to torch
+ input_t = np.random.random_sample(size=[3, 4, 4]).astype(self.supported_np_dtypes[idx])
+ inputCPU = torch.tensor(input_t, device='cpu', dtype=self.supported_dtypes[idx])
+
+ helper(inputCPU)
+
+ def test_advanced_indexing_3D_put(self):
+ def helper(x_cpu):
+ dtype = x_cpu.dtype
+ x_mps = x_cpu.detach().clone().to("mps")
+
+ out_tensor_cpu = torch.tensor([88, 99], dtype=dtype, device="cpu")
+ out_tensor_cpu_view = out_tensor_cpu[1:]
+
+ out_tensor_mps = torch.tensor([88, 99], dtype=dtype, device="mps")
+ out_tensor_mps_view = out_tensor_mps[1:]
+
+ x_cpu[[1, 2], 3, :] = out_tensor_cpu_view
+ x_mps[[1, 2], 3, :] = out_tensor_mps_view
+ self.assertEqual(x_cpu, x_mps)
+
+ x_cpu[[0, 2], :, :] = out_tensor_cpu_view
+ x_mps[[0, 2], :, :] = out_tensor_mps_view
+ self.assertEqual(x_cpu, x_mps)
+
+ x_cpu[:, [1, 0], [1]] = out_tensor_cpu_view
+ x_mps[:, [1, 0], [1]] = out_tensor_mps_view
+ self.assertEqual(x_cpu, x_mps)
+
+ x_cpu = torch.tensor([[[0.1, 0.2, 0.3, 0.4],
+ [0.5, 0.6, 0.7, 0.8],
+ [0.9, 1.0, 1.1, 1.2],
+ [1.3, 1.4, 1.5, 1.6]],
+
+ [[2.0, 2.1, 2.2, 2.3],
+ [2.4, 2.5, 2.6, 2.7],
+ [2.8, 2.9, 3.0, 3.1],
+ [3.2, 3.3, 3.4, 3.5]],
+
+ [[4.0, 4.1, 4.2, 4.3],
+ [4.4, 4.5, 4.6, 4.7],
+ [4.8, 4.9, 5.0, 5.1],
+ [5.1, 5.2, 5.3, 5.4]]], device="cpu", dtype=torch.float32)
+ helper(x_cpu)
+ for idx in range(len(self.supported_np_dtypes)):
+ # torch.randn / torch.rand don't work with all dtypes
+ # Generate input data for all dtypes on Numpy them move to torch
+ input_t = np.random.random_sample(size=[3, 4, 4]).astype(self.supported_np_dtypes[idx])
+ inputCPU = torch.tensor(input_t, device='cpu', dtype=self.supported_dtypes[idx])
+
+ helper(inputCPU)
+
+ def test_index_put_with_view_indices(self):
+ def helper(dtype):
+ target_cpu = torch.zeros([5, 3], device="cpu", dtype=dtype)
+ target_mps = torch.zeros([5, 3], device="mps", dtype=dtype)
+
+ indices_cpu = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64, device="cpu")
+ indices_mps = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64, device="mps")
+
+ value_cpu = torch.ones(indices_cpu.shape[0], device="cpu", dtype=dtype)
+ value_mps = torch.ones(indices_mps.shape[0], device="mps", dtype=dtype)
+
+ target_cpu.index_put_(tuple(indices_cpu.t()), value_cpu, accumulate=True)
+ target_mps.index_put_(tuple(indices_mps.t()), value_mps, accumulate=True)
+
+ self.assertEqual(target_cpu, target_mps)
+
+ [helper(dtype) for dtype in [torch.int32, torch.float]]
+
+ # tests from 'test_indexing.py'
+ def test_advancedindex_big(self, device="mps"):
+ reference = torch.arange(0, 123344, dtype=torch.int, device=device)
+
+ self.assertEqual(reference[[0, 123, 44488, 68807, 123343], ],
+ torch.tensor([0, 123, 44488, 68807, 123343], dtype=torch.int))
+
+ def test_set_item_to_scalar_tensor(self, device="mps"):
+ m = random.randint(1, 10)
+ n = random.randint(1, 10)
+ z = torch.randn([m, n], device=device)
+ a = 1.0
+ w = torch.tensor(a, requires_grad=True, device=device)
+ z[:, 0] = w
+ z.sum().backward()
+ self.assertEqual(w.grad, m * a)
+
+ def test_single_int(self, device="mps"):
+ v = torch.randn(5, 7, 3, device=device)
+ self.assertEqual(v[4].shape, (7, 3))
+
+ def test_multiple_int(self, device="mps"):
+ v = torch.randn(5, 7, 3, device=device)
+ self.assertEqual(v[4].shape, (7, 3))
+ self.assertEqual(v[4, :, 1].shape, (7,))
+
+ def test_none(self, device="mps"):
+ v = torch.randn(5, 7, 3, device=device)
+ self.assertEqual(v[None].shape, (1, 5, 7, 3))
+ self.assertEqual(v[:, None].shape, (5, 1, 7, 3))
+ self.assertEqual(v[:, None, None].shape, (5, 1, 1, 7, 3))
+ self.assertEqual(v[..., None].shape, (5, 7, 3, 1))
+
+ def test_step(self, device="mps"):
+ v = torch.arange(10, device=device)
+ self.assertEqual(v[::1], v)
+ self.assertEqual(v[::2].tolist(), [0, 2, 4, 6, 8])
+ self.assertEqual(v[::3].tolist(), [0, 3, 6, 9])
+ self.assertEqual(v[::11].tolist(), [0])
+ self.assertEqual(v[1:6:2].tolist(), [1, 3, 5])
+
+ def test_step_assignment(self, device="mps"):
+ v = torch.zeros(4, 4, device=device)
+ v[0, 1::2] = torch.tensor([3., 4.], device=device)
+ self.assertEqual(v[0].tolist(), [0, 3, 0, 4])
+ self.assertEqual(v[1:].sum(), 0)
+
def test_bool_indices(self, device="mps"):
v = torch.randn(5, 7, 3, device=device)
boolIndices = torch.tensor([True, False, True, True, False], dtype=torch.bool, device=device)
@@ -6123,6 +6265,13 @@
self.assertEqual(v[boolIndices], torch.tensor([True], dtype=torch.bool, device=device))
self.assertEqual(len(w), 2)
+ def test_bool_indices_accumulate(self, device="mps"):
+ mask = torch.zeros(size=(10, ), dtype=torch.uint8, device=device)
+ mask = mask > 0
+ y = torch.ones(size=(10, 10), device=device)
+ y.index_put_((mask, ), y[mask], accumulate=True)
+ self.assertEqual(y, torch.ones(size=(10, 10), device=device))
+
def test_multiple_bool_indices(self, device="mps"):
v = torch.randn(5, 7, 3, device=device)
# note: these broadcast together and are transposed to the first dim
@@ -6130,14 +6279,6 @@
mask2 = torch.tensor([1, 1, 1], dtype=torch.bool, device=device)
self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
- def test_step(self, device="mps"):
- v = torch.arange(10, device=device)
- self.assertEqual(v[::1], v)
- self.assertEqual(v[::2].tolist(), [0, 2, 4, 6, 8])
- self.assertEqual(v[::3].tolist(), [0, 3, 6, 9])
- self.assertEqual(v[::11].tolist(), [0])
- self.assertEqual(v[1:6:2].tolist(), [1, 3, 5])
-
def test_byte_mask(self, device="mps"):
v = torch.randn(5, 7, 3, device=device)
mask = torch.ByteTensor([1, 0, 1, 1, 0]).to(device)
@@ -6149,6 +6290,189 @@
v = torch.tensor([1.], device=device)
self.assertEqual(v[v == 0], torch.tensor([], device=device))
+ def test_byte_mask_accumulate(self, device="mps"):
+ mask = torch.zeros(size=(10, ), dtype=torch.uint8, device=device)
+ y = torch.ones(size=(10, 10), device=device)
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter("always")
+ y.index_put_((mask, ), y[mask], accumulate=True)
+ self.assertEqual(y, torch.ones(size=(10, 10), device=device))
+ self.assertEqual(len(w), 2)
+
+ def test_index_put_accumulate_expanded_values(self, device="mps"):
+ t = torch.zeros((5, 2))
+ t_dev = t.to(device)
+ indices = [
+ torch.tensor([0, 1, 2, 3]),
+ torch.tensor([1, ]),
+ ]
+ indices_dev = [i.to(device) for i in indices]
+ values0d = torch.tensor(1.0)
+ values1d = torch.tensor([1.0, ])
+
+ out_mps = t_dev.index_put_(indices_dev, values0d.to(device), accumulate=True)
+ out_cpu = t.index_put_(indices, values0d, accumulate=True)
+ self.assertEqual(out_mps.cpu(), out_cpu)
+
+ out_mps = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True)
+ out_cpu = t.index_put_(indices, values1d, accumulate=True)
+ self.assertEqual(out_mps.cpu(), out_cpu)
+
+ t = torch.zeros(4, 3, 2)
+ t_dev = t.to(device)
+
+ indices = [
+ torch.tensor([0, ]),
+ torch.arange(3)[:, None],
+ torch.arange(2)[None, :],
+ ]
+ indices_dev = [i.to(device) for i in indices]
+ values1d = torch.tensor([-1.0, -2.0])
+ values2d = torch.tensor([[-1.0, -2.0], ])
+
+ out_mps = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True)
+ out_cpu = t.index_put_(indices, values1d, accumulate=True)
+ self.assertEqual(out_mps.cpu(), out_cpu)
+
+ out_mps = t_dev.index_put_(indices_dev, values2d.to(device), accumulate=True)
+ out_cpu = t.index_put_(indices, values2d, accumulate=True)
+ self.assertEqual(out_mps.cpu(), out_cpu)
+
+ def test_index_put_accumulate_non_contiguous(self, device="mps"):
+ t = torch.zeros((5, 2, 2))
+ t_dev = t.to(device)
+ t1 = t_dev[:, 0, :]
+ t2 = t[:, 0, :]
+ self.assertTrue(not t1.is_contiguous())
+ self.assertTrue(not t2.is_contiguous())
+
+ indices = [torch.tensor([0, 1]), ]
+ indices_dev = [i.to(device) for i in indices]
+ value = torch.randn(2, 2)
+ out_mps = t1.index_put_(indices_dev, value.to(device), accumulate=True)
+ out_cpu = t2.index_put_(indices, value, accumulate=True)
+ self.assertTrue(not t1.is_contiguous())
+ self.assertTrue(not t2.is_contiguous())
+
+ self.assertEqual(out_mps.cpu(), out_cpu)
+
+ def test_index_put_accumulate_with_optional_tensors(self, device="mps"):
+ # TODO: replace with a better solution.
+ # Currently, here using torchscript to put None into indices.
+ # on C++ it gives indices as a list of 2 optional tensors: first is null and
+ # the second is a valid tensor.
+ @torch.jit.script
+ def func(x, i, v):
+ idx = [None, i]
+ x.index_put_(idx, v, accumulate=True)
+ return x
+
+ n = 4
+ t = torch.arange(n * 2, dtype=torch.float32).reshape(n, 2)
+ t_dev = t.to(device)
+ indices = torch.tensor([1, 0])
+ indices_dev = indices.to(device)
+ value0d = torch.tensor(10.0)
+ value1d = torch.tensor([1.0, 2.0])
+
+ out_mps = func(t_dev, indices_dev, value0d.to("mps"))
+ out_cpu = func(t, indices, value0d)
+ self.assertEqual(out_mps.cpu(), out_cpu)
+
+ out_mps = func(t_dev, indices_dev, value1d.to("mps"))
+ out_cpu = func(t, indices, value1d)
+ self.assertEqual(out_mps.cpu(), out_cpu)
+
+ def test_index_put_accumulate_duplicate_indices(self, device="mps"):
+ for i in range(1, 128):
+ # generate indices by random walk, this will create indices with
+ # lots of duplicates interleaved with each other
+ delta = torch.empty(i, dtype=torch.float32, device=device).uniform_(-1, 1)
+
+ # cumsum not supported on 'mps', fallback on 'cpu'
+ indices = delta.to("cpu").cumsum(0).long().to("mps")
+
+ # abs for int64 is not supported on mps, fallback on 'cpu' to calculate it
+ input = torch.randn(indices.to("cpu").abs().to("mps").max() + 1, device=device)
+ values = torch.randn(indices.size(0), device=device)
+ output = input.index_put((indices,), values, accumulate=True)
+
+ input_list = input.tolist()
+ indices_list = indices.tolist()
+ values_list = values.tolist()
+ for i, v in zip(indices_list, values_list):
+ input_list[i] += v
+
+ self.assertEqual(output, input_list)
+
+ def test_multiple_byte_mask(self, device="mps"):
+ v = torch.randn(5, 7, 3, device=device)
+ # note: these broadcast together and are transposed to the first dim
+ mask1 = torch.ByteTensor([1, 0, 1, 1, 0]).to(device)
+ mask2 = torch.ByteTensor([1, 1, 1]).to(device)
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter("always")
+ self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
+ self.assertEqual(len(w), 2)
+
+ def test_byte_mask2d(self, device="mps"):
+ v = torch.randn(5, 7, 3, device=device)
+ c = torch.randn(5, 7, device=device)
+ num_ones = (c > 0).sum()
+ r = v[c > 0]
+ self.assertEqual(r.shape, (num_ones, 3))
+
+ # FIXME: conditional indexing not working
+ # def test_jit_indexing(self, device="mps"):
+ # def fn1(x):
+ # x[x < 50] = 1.0
+ # return x
+
+ # def fn2(x):
+ # x[0:50] = 1.0
+ # return x
+
+ # scripted_fn1 = torch.jit.script(fn1)
+ # scripted_fn2 = torch.jit.script(fn2)
+ # data = torch.arange(100, device=device, dtype=torch.float)
+ # out = scripted_fn1(data.detach().clone())
+ # ref = torch.tensor(np.concatenate((np.ones(50), np.arange(50, 100))), device=device, dtype=torch.float)
+ # self.assertEqual(out, ref)
+ # out = scripted_fn2(data.detach().clone())
+ # self.assertEqual(out, ref)
+
+ def test_int_indices(self, device="mps"):
+ v = torch.randn(5, 7, 3, device=device)
+ self.assertEqual(v[[0, 4, 2]].shape, (3, 7, 3))
+ self.assertEqual(v[:, [0, 4, 2]].shape, (5, 3, 3))
+ self.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3))
+
+ def test_index_put_src_datatype(self):
+ def helper(device, dtype):
+ src = torch.ones(3, 2, 4, device=device, dtype=dtype)
+ vals = torch.ones(3, 2, 4, device=device, dtype=dtype)
+ indices = (torch.tensor([0, 2, 1]),)
+ res = src.index_put_(indices, vals, accumulate=True)
+ self.assertEqual(res.shape, src.shape)
+ [helper(device="mps", dtype=dtype) for dtype in [torch.float, torch.int32]]
+
+ def test_index_src_datatype(self):
+ def helper(device, dtype):
+ orig_dtype = dtype
+ if dtype is torch.bool:
+ dtype = torch.uint8
+
+ src = torch.ones(3, 2, 4, device=device, dtype=dtype)
+ if orig_dtype is torch.bool:
+ src = src == 1
+ # test index
+ res = src[[0, 2, 1], :, :]
+ self.assertEqual(res.shape, src.shape)
+ # test index_put, no accum
+ src[[0, 2, 1], :, :] = res
+ self.assertEqual(res.shape, src.shape)
+ [helper(device="mps", dtype=dtype) for dtype in [torch.float, torch.float16, torch.long, torch.bool]]
+
def test_int_indices2d(self, device="mps"):
# From the NumPy indexing example
x = torch.arange(0, 12, device=device).view(4, 3)
@@ -6164,6 +6488,20 @@
result = x[rows[:, None], columns]
self.assertEqual(result.tolist(), [[0, 2], [9, 11]])
+ def test_empty_index(self, device="mps"):
+ x = torch.arange(0, 12, device=device).view(4, 3)
+ idx = torch.tensor([], dtype=torch.long, device=device)
+ self.assertEqual(x[idx].numel(), 0)
+
+ # empty assignment should have no effect but not throw an exception
+ y = x.clone()
+ y[idx] = -1
+ self.assertEqual(x, y)
+
+ mask = torch.zeros(4, 3, device=device).bool()
+ y[mask] = -1
+ self.assertEqual(x, y)
+
def test_empty_ndim_index(self, device="mps"):
x = torch.randn(5, device=device)
self.assertEqual(torch.empty(0, 2, device=device), x[torch.empty(0, 2, dtype=torch.int64, device=device)])
@@ -6182,6 +6520,15 @@
x = torch.randn(5, device=device)
self.assertRaises(IndexError, lambda: x[torch.empty(0, 2, dtype=torch.uint8, device=device)])
+ def test_empty_slice(self, device="mps"):
+ x = torch.randn(2, 3, 4, 5, device=device)
+ y = x[:, :, :, 1]
+ z = y[:, 1:1, :]
+ self.assertEqual((2, 0, 4), z.shape)
+ # this isn't technically necessary, but matches NumPy stride calculations.
+ self.assertEqual((60, 20, 5), z.stride())
+ self.assertTrue(z.is_contiguous())
+
def test_index_getitem_copy_bools_slices(self, device="mps"):
true = torch.tensor(1, dtype=torch.uint8, device=device)
false = torch.tensor(0, dtype=torch.uint8, device=device)
@@ -6196,6 +6543,33 @@
self.assertEqual(a.data_ptr(), a[None].data_ptr())
self.assertEqual(a.data_ptr(), a[...].data_ptr())
+ def test_index_setitem_bools_slices(self, device="mps"):
+ true = torch.tensor(1, dtype=torch.uint8, device=device)
+ false = torch.tensor(0, dtype=torch.uint8, device=device)
+
+ tensors = [torch.randn(2, 3, device=device), torch.tensor(3, device=device)]
+
+ for a in tensors:
+ # prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s
+ # (some of these ops already prefix a 1 to the size)
+ neg_ones = torch.ones_like(a) * -1
+ neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0)
+ a[True] = neg_ones_expanded
+ self.assertEqual(a, neg_ones)
+ a[False] = 5
+ self.assertEqual(a, neg_ones)
+ a[true] = neg_ones_expanded * 2
+ self.assertEqual(a, neg_ones * 2)
+ a[false] = 5
+ self.assertEqual(a, neg_ones * 2)
+ a[None] = neg_ones_expanded * 3
+ self.assertEqual(a, neg_ones * 3)
+ a[...] = neg_ones_expanded * 4
+ self.assertEqual(a, neg_ones * 4)
+ if a.dim() == 0:
+ with self.assertRaises(IndexError):
+ a[:] = neg_ones_expanded * 5
+
def test_index_scalar_with_bool_mask(self, device="mps"):
a = torch.tensor(1, device=device)
uintMask = torch.tensor(True, dtype=torch.uint8, device=device)
@@ -6207,6 +6581,17 @@
self.assertEqual(a[uintMask], a[boolMask])
self.assertEqual(a[uintMask].dtype, a[boolMask].dtype)
+ def test_setitem_expansion_error(self, device="mps"):
+ true = torch.tensor(True, device=device)
+ a = torch.randn(2, 3, device=device)
+ # check prefix with non-1s doesn't work
+ a_expanded = a.expand(torch.Size([5, 1]) + a.size())
+ # NumPy: ValueError
+ with self.assertRaises(RuntimeError):
+ a[True] = a_expanded
+ with self.assertRaises(RuntimeError):
+ a[true] = a_expanded
+
def test_getitem_scalars(self, device="mps"):
zero = torch.tensor(0, dtype=torch.int64, device=device)
one = torch.tensor(1, dtype=torch.int64, device=device)
@@ -6231,6 +6616,69 @@
r[zero]
self.assertEqual(r, r[...])
+ def test_setitem_scalars(self, device="mps"):
+ zero = torch.tensor(0, dtype=torch.int64)
+
+ # non-scalar indexed with scalars
+ a = torch.randn(2, 3, device=device)
+ a_set_with_number = a.clone()
+ a_set_with_scalar = a.clone()
+ b = torch.randn(3, device=device)
+
+ a_set_with_number[0] = b
+ a_set_with_scalar[zero] = b
+ self.assertEqual(a_set_with_number, a_set_with_scalar)
+ a[1, zero] = 7.7
+ self.assertEqual(7.7, a[1, 0])
+
+ # scalar indexed with scalars
+ r = torch.randn((), device=device)
+ with self.assertRaises(IndexError):
+ r[:] = 8.8
+ with self.assertRaises(IndexError):
+ r[zero] = 8.8
+ r[...] = 9.9
+ self.assertEqual(9.9, r)
+
+ def test_basic_advanced_combined(self, device="mps"):
+ # From the NumPy indexing example
+ x = torch.arange(0, 12, device=device).view(4, 3)
+ self.assertEqual(x[1:2, 1:3], x[1:2, [1, 2]])
+ self.assertEqual(x[1:2, 1:3].tolist(), [[4, 5]])
+
+ # Check that it is a copy
+ unmodified = x.clone()
+ x[1:2, [1, 2]].zero_()
+ self.assertEqual(x, unmodified)
+
+ # But assignment should modify the original
+ unmodified = x.clone()
+ x[1:2, [1, 2]] = 0
+ self.assertNotEqual(x, unmodified)
+
+ def test_int_assignment(self, device="mps"):
+ x = torch.arange(0, 4, device=device).view(2, 2)
+ x[1] = 5
+ self.assertEqual(x.tolist(), [[0, 1], [5, 5]])
+
+ x = torch.arange(0, 4, device=device).view(2, 2)
+ x[1] = torch.arange(5, 7, device=device)
+ self.assertEqual(x.tolist(), [[0, 1], [5, 6]])
+
+ def test_byte_tensor_assignment(self, device="mps"):
+ x = torch.arange(0., 16, device=device).view(4, 4)
+ b = torch.ByteTensor([True, False, True, False]).to(device)
+ value = torch.tensor([3., 4., 5., 6.], device=device)
+
+ with warnings.catch_warnings(record=True) as w:
+ x[b] = value
+ self.assertEqual(len(w), 1)
+
+ self.assertEqual(x[0], value)
+ self.assertEqual(x[1], torch.arange(4., 8, device=device))
+ self.assertEqual(x[2], value)
+ self.assertEqual(x[3], torch.arange(12., 16, device=device))
+
def test_variable_slicing(self, device="mps"):
x = torch.arange(0, 16, device=device).view(4, 4)
indices = torch.IntTensor([0, 1]).to(device)
@@ -6250,6 +6698,36 @@
x = torch.arange(0, 16, device=device).view(4, 4)
self.assertRaisesRegex(TypeError, 'slice indices', lambda: x["0":"1"])
+ def test_out_of_bound_index(self, device="mps"):
+ x = torch.arange(0, 100, device=device).view(2, 5, 10)
+ self.assertRaisesRegex(IndexError, 'index 5 is out of bounds for dimension 1 with size 5', lambda: x[0, 5])
+ self.assertRaisesRegex(IndexError, 'index 4 is out of bounds for dimension 0 with size 2', lambda: x[4, 5])
+ self.assertRaisesRegex(IndexError, 'index 15 is out of bounds for dimension 2 with size 10',
+ lambda: x[0, 1, 15])
+ self.assertRaisesRegex(IndexError, 'index 12 is out of bounds for dimension 2 with size 10',
+ lambda: x[:, :, 12])
+
+ def test_zero_dim_index(self, device="mps"):
+ x = torch.tensor(10, device=device)
+ self.assertEqual(x, x.item())
+
+ def runner():
+ print(x[0])
+ return x[0]
+
+ self.assertRaisesRegex(IndexError, 'invalid index', runner)
+
+ def test_cpu_indices(self, device="mps"):
+ idx = torch.tensor([0, 1])
+ b = torch.zeros(2, device=device)
+ x = torch.ones(10, device=device)
+ x[idx] = b # index_put_
+ ref = torch.ones(10, device=device)
+ ref[:2] = 0
+ self.assertEqual(x, ref, atol=0, rtol=0)
+ out = x[idx] # index
+ self.assertEqual(out, torch.zeros(2, device=device), atol=0, rtol=0)
+
class TestRNNMPS(TestCase):
def test_lstm_1(self, device="mps", dtype=torch.float32):