[MPS] Add tensor::index_put op (#85672)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85672
Approved by: https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index df23895..ca875f9 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -6046,9 +6046,10 @@
 
 class TestAdvancedIndexing(TestCase):
     supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8]
+    supported_np_dtypes = [np.float32, np.float16, np.int64, np.int32, np.int16, np.uint8]
 
     # examples from https://www.tutorialspoint.com/numpy/numpy_advanced_indexing.htm
-    def test_indexing_1(self):
+    def test_indexing_get(self):
         def helper(dtype):
             x_cpu = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dtype)
             x_mps = x_cpu.detach().clone().to("mps")
@@ -6107,7 +6108,148 @@
     #         self.assertEqual(res_cpu, res_mps, str(dtype))
     #     [helper(dtype) for dtype in self.supported_dtypes]
 
-    # tests from test_indexing.py
+
+    def test_advanced_indexing_3D_get(self):
+        def helper(x_cpu):
+            x_mps = x_cpu.detach().clone().to("mps")
+            self.assertEqual(x_cpu[[1, 2], 3, :], x_mps[[1, 2], 3, :])
+            self.assertEqual(x_cpu[[0, 2], :, :], x_mps[[0, 2], :, :])
+            self.assertEqual(x_cpu[:, [1, 0], [1]], x_mps[:, [1, 0], [1]])
+
+        x_cpu = torch.tensor([[[0.1, 0.2, 0.3, 0.4],
+                               [0.5, 0.6, 0.7, 0.8],
+                               [0.9, 1.0, 1.1, 1.2],
+                               [1.3, 1.4, 1.5, 1.6]],
+
+                              [[2.0, 2.1, 2.2, 2.3],
+                               [2.4, 2.5, 2.6, 2.7],
+                               [2.8, 2.9, 3.0, 3.1],
+                               [3.2, 3.3, 3.4, 3.5]],
+
+                              [[4.0, 4.1, 4.2, 4.3],
+                               [4.4, 4.5, 4.6, 4.7],
+                               [4.8, 4.9, 5.0, 5.1],
+                               [5.1, 5.2, 5.3, 5.4]]], device="cpu", dtype=torch.float32)
+        helper(x_cpu)
+        for idx in range(len(self.supported_np_dtypes)):
+            # torch.randn / torch.rand don't work with all dtypes
+            # Generate input data for all dtypes on Numpy them move to torch
+            input_t = np.random.random_sample(size=[3, 4, 4]).astype(self.supported_np_dtypes[idx])
+            inputCPU = torch.tensor(input_t, device='cpu', dtype=self.supported_dtypes[idx])
+
+            helper(inputCPU)
+
+    def test_advanced_indexing_3D_put(self):
+        def helper(x_cpu):
+            dtype = x_cpu.dtype
+            x_mps = x_cpu.detach().clone().to("mps")
+
+            out_tensor_cpu = torch.tensor([88, 99], dtype=dtype, device="cpu")
+            out_tensor_cpu_view = out_tensor_cpu[1:]
+
+            out_tensor_mps = torch.tensor([88, 99], dtype=dtype, device="mps")
+            out_tensor_mps_view = out_tensor_mps[1:]
+
+            x_cpu[[1, 2], 3, :] = out_tensor_cpu_view
+            x_mps[[1, 2], 3, :] = out_tensor_mps_view
+            self.assertEqual(x_cpu, x_mps)
+
+            x_cpu[[0, 2], :, :] = out_tensor_cpu_view
+            x_mps[[0, 2], :, :] = out_tensor_mps_view
+            self.assertEqual(x_cpu, x_mps)
+
+            x_cpu[:, [1, 0], [1]] = out_tensor_cpu_view
+            x_mps[:, [1, 0], [1]] = out_tensor_mps_view
+            self.assertEqual(x_cpu, x_mps)
+
+        x_cpu = torch.tensor([[[0.1, 0.2, 0.3, 0.4],
+                               [0.5, 0.6, 0.7, 0.8],
+                               [0.9, 1.0, 1.1, 1.2],
+                               [1.3, 1.4, 1.5, 1.6]],
+
+                              [[2.0, 2.1, 2.2, 2.3],
+                               [2.4, 2.5, 2.6, 2.7],
+                               [2.8, 2.9, 3.0, 3.1],
+                               [3.2, 3.3, 3.4, 3.5]],
+
+                              [[4.0, 4.1, 4.2, 4.3],
+                               [4.4, 4.5, 4.6, 4.7],
+                               [4.8, 4.9, 5.0, 5.1],
+                               [5.1, 5.2, 5.3, 5.4]]], device="cpu", dtype=torch.float32)
+        helper(x_cpu)
+        for idx in range(len(self.supported_np_dtypes)):
+            # torch.randn / torch.rand don't work with all dtypes
+            # Generate input data for all dtypes on Numpy them move to torch
+            input_t = np.random.random_sample(size=[3, 4, 4]).astype(self.supported_np_dtypes[idx])
+            inputCPU = torch.tensor(input_t, device='cpu', dtype=self.supported_dtypes[idx])
+
+            helper(inputCPU)
+
+    def test_index_put_with_view_indices(self):
+        def helper(dtype):
+            target_cpu = torch.zeros([5, 3], device="cpu", dtype=dtype)
+            target_mps = torch.zeros([5, 3], device="mps", dtype=dtype)
+
+            indices_cpu = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64, device="cpu")
+            indices_mps = torch.tensor([[0, 1], [0, 1]], dtype=torch.int64, device="mps")
+
+            value_cpu = torch.ones(indices_cpu.shape[0], device="cpu", dtype=dtype)
+            value_mps = torch.ones(indices_mps.shape[0], device="mps", dtype=dtype)
+
+            target_cpu.index_put_(tuple(indices_cpu.t()), value_cpu, accumulate=True)
+            target_mps.index_put_(tuple(indices_mps.t()), value_mps, accumulate=True)
+
+            self.assertEqual(target_cpu, target_mps)
+
+        [helper(dtype) for dtype in [torch.int32, torch.float]]
+
+    # tests from 'test_indexing.py'
+    def test_advancedindex_big(self, device="mps"):
+        reference = torch.arange(0, 123344, dtype=torch.int, device=device)
+
+        self.assertEqual(reference[[0, 123, 44488, 68807, 123343], ],
+                         torch.tensor([0, 123, 44488, 68807, 123343], dtype=torch.int))
+
+    def test_set_item_to_scalar_tensor(self, device="mps"):
+        m = random.randint(1, 10)
+        n = random.randint(1, 10)
+        z = torch.randn([m, n], device=device)
+        a = 1.0
+        w = torch.tensor(a, requires_grad=True, device=device)
+        z[:, 0] = w
+        z.sum().backward()
+        self.assertEqual(w.grad, m * a)
+
+    def test_single_int(self, device="mps"):
+        v = torch.randn(5, 7, 3, device=device)
+        self.assertEqual(v[4].shape, (7, 3))
+
+    def test_multiple_int(self, device="mps"):
+        v = torch.randn(5, 7, 3, device=device)
+        self.assertEqual(v[4].shape, (7, 3))
+        self.assertEqual(v[4, :, 1].shape, (7,))
+
+    def test_none(self, device="mps"):
+        v = torch.randn(5, 7, 3, device=device)
+        self.assertEqual(v[None].shape, (1, 5, 7, 3))
+        self.assertEqual(v[:, None].shape, (5, 1, 7, 3))
+        self.assertEqual(v[:, None, None].shape, (5, 1, 1, 7, 3))
+        self.assertEqual(v[..., None].shape, (5, 7, 3, 1))
+
+    def test_step(self, device="mps"):
+        v = torch.arange(10, device=device)
+        self.assertEqual(v[::1], v)
+        self.assertEqual(v[::2].tolist(), [0, 2, 4, 6, 8])
+        self.assertEqual(v[::3].tolist(), [0, 3, 6, 9])
+        self.assertEqual(v[::11].tolist(), [0])
+        self.assertEqual(v[1:6:2].tolist(), [1, 3, 5])
+
+    def test_step_assignment(self, device="mps"):
+        v = torch.zeros(4, 4, device=device)
+        v[0, 1::2] = torch.tensor([3., 4.], device=device)
+        self.assertEqual(v[0].tolist(), [0, 3, 0, 4])
+        self.assertEqual(v[1:].sum(), 0)
+
     def test_bool_indices(self, device="mps"):
         v = torch.randn(5, 7, 3, device=device)
         boolIndices = torch.tensor([True, False, True, True, False], dtype=torch.bool, device=device)
@@ -6123,6 +6265,13 @@
             self.assertEqual(v[boolIndices], torch.tensor([True], dtype=torch.bool, device=device))
             self.assertEqual(len(w), 2)
 
+    def test_bool_indices_accumulate(self, device="mps"):
+        mask = torch.zeros(size=(10, ), dtype=torch.uint8, device=device)
+        mask = mask > 0
+        y = torch.ones(size=(10, 10), device=device)
+        y.index_put_((mask, ), y[mask], accumulate=True)
+        self.assertEqual(y, torch.ones(size=(10, 10), device=device))
+
     def test_multiple_bool_indices(self, device="mps"):
         v = torch.randn(5, 7, 3, device=device)
         # note: these broadcast together and are transposed to the first dim
@@ -6130,14 +6279,6 @@
         mask2 = torch.tensor([1, 1, 1], dtype=torch.bool, device=device)
         self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
 
-    def test_step(self, device="mps"):
-        v = torch.arange(10, device=device)
-        self.assertEqual(v[::1], v)
-        self.assertEqual(v[::2].tolist(), [0, 2, 4, 6, 8])
-        self.assertEqual(v[::3].tolist(), [0, 3, 6, 9])
-        self.assertEqual(v[::11].tolist(), [0])
-        self.assertEqual(v[1:6:2].tolist(), [1, 3, 5])
-
     def test_byte_mask(self, device="mps"):
         v = torch.randn(5, 7, 3, device=device)
         mask = torch.ByteTensor([1, 0, 1, 1, 0]).to(device)
@@ -6149,6 +6290,189 @@
         v = torch.tensor([1.], device=device)
         self.assertEqual(v[v == 0], torch.tensor([], device=device))
 
+    def test_byte_mask_accumulate(self, device="mps"):
+        mask = torch.zeros(size=(10, ), dtype=torch.uint8, device=device)
+        y = torch.ones(size=(10, 10), device=device)
+        with warnings.catch_warnings(record=True) as w:
+            warnings.simplefilter("always")
+            y.index_put_((mask, ), y[mask], accumulate=True)
+            self.assertEqual(y, torch.ones(size=(10, 10), device=device))
+            self.assertEqual(len(w), 2)
+
+    def test_index_put_accumulate_expanded_values(self, device="mps"):
+        t = torch.zeros((5, 2))
+        t_dev = t.to(device)
+        indices = [
+            torch.tensor([0, 1, 2, 3]),
+            torch.tensor([1, ]),
+        ]
+        indices_dev = [i.to(device) for i in indices]
+        values0d = torch.tensor(1.0)
+        values1d = torch.tensor([1.0, ])
+
+        out_mps = t_dev.index_put_(indices_dev, values0d.to(device), accumulate=True)
+        out_cpu = t.index_put_(indices, values0d, accumulate=True)
+        self.assertEqual(out_mps.cpu(), out_cpu)
+
+        out_mps = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True)
+        out_cpu = t.index_put_(indices, values1d, accumulate=True)
+        self.assertEqual(out_mps.cpu(), out_cpu)
+
+        t = torch.zeros(4, 3, 2)
+        t_dev = t.to(device)
+
+        indices = [
+            torch.tensor([0, ]),
+            torch.arange(3)[:, None],
+            torch.arange(2)[None, :],
+        ]
+        indices_dev = [i.to(device) for i in indices]
+        values1d = torch.tensor([-1.0, -2.0])
+        values2d = torch.tensor([[-1.0, -2.0], ])
+
+        out_mps = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True)
+        out_cpu = t.index_put_(indices, values1d, accumulate=True)
+        self.assertEqual(out_mps.cpu(), out_cpu)
+
+        out_mps = t_dev.index_put_(indices_dev, values2d.to(device), accumulate=True)
+        out_cpu = t.index_put_(indices, values2d, accumulate=True)
+        self.assertEqual(out_mps.cpu(), out_cpu)
+
+    def test_index_put_accumulate_non_contiguous(self, device="mps"):
+        t = torch.zeros((5, 2, 2))
+        t_dev = t.to(device)
+        t1 = t_dev[:, 0, :]
+        t2 = t[:, 0, :]
+        self.assertTrue(not t1.is_contiguous())
+        self.assertTrue(not t2.is_contiguous())
+
+        indices = [torch.tensor([0, 1]), ]
+        indices_dev = [i.to(device) for i in indices]
+        value = torch.randn(2, 2)
+        out_mps = t1.index_put_(indices_dev, value.to(device), accumulate=True)
+        out_cpu = t2.index_put_(indices, value, accumulate=True)
+        self.assertTrue(not t1.is_contiguous())
+        self.assertTrue(not t2.is_contiguous())
+
+        self.assertEqual(out_mps.cpu(), out_cpu)
+
+    def test_index_put_accumulate_with_optional_tensors(self, device="mps"):
+        # TODO: replace with a better solution.
+        # Currently, here using torchscript to put None into indices.
+        # on C++ it gives indices as a list of 2 optional tensors: first is null and
+        # the second is a valid tensor.
+        @torch.jit.script
+        def func(x, i, v):
+            idx = [None, i]
+            x.index_put_(idx, v, accumulate=True)
+            return x
+
+        n = 4
+        t = torch.arange(n * 2, dtype=torch.float32).reshape(n, 2)
+        t_dev = t.to(device)
+        indices = torch.tensor([1, 0])
+        indices_dev = indices.to(device)
+        value0d = torch.tensor(10.0)
+        value1d = torch.tensor([1.0, 2.0])
+
+        out_mps = func(t_dev, indices_dev, value0d.to("mps"))
+        out_cpu = func(t, indices, value0d)
+        self.assertEqual(out_mps.cpu(), out_cpu)
+
+        out_mps = func(t_dev, indices_dev, value1d.to("mps"))
+        out_cpu = func(t, indices, value1d)
+        self.assertEqual(out_mps.cpu(), out_cpu)
+
+    def test_index_put_accumulate_duplicate_indices(self, device="mps"):
+        for i in range(1, 128):
+            # generate indices by random walk, this will create indices with
+            # lots of duplicates interleaved with each other
+            delta = torch.empty(i, dtype=torch.float32, device=device).uniform_(-1, 1)
+
+            # cumsum not supported on 'mps', fallback on 'cpu'
+            indices = delta.to("cpu").cumsum(0).long().to("mps")
+
+            # abs for int64 is not supported on mps, fallback on 'cpu' to calculate it
+            input = torch.randn(indices.to("cpu").abs().to("mps").max() + 1, device=device)
+            values = torch.randn(indices.size(0), device=device)
+            output = input.index_put((indices,), values, accumulate=True)
+
+            input_list = input.tolist()
+            indices_list = indices.tolist()
+            values_list = values.tolist()
+            for i, v in zip(indices_list, values_list):
+                input_list[i] += v
+
+            self.assertEqual(output, input_list)
+
+    def test_multiple_byte_mask(self, device="mps"):
+        v = torch.randn(5, 7, 3, device=device)
+        # note: these broadcast together and are transposed to the first dim
+        mask1 = torch.ByteTensor([1, 0, 1, 1, 0]).to(device)
+        mask2 = torch.ByteTensor([1, 1, 1]).to(device)
+        with warnings.catch_warnings(record=True) as w:
+            warnings.simplefilter("always")
+            self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
+            self.assertEqual(len(w), 2)
+
+    def test_byte_mask2d(self, device="mps"):
+        v = torch.randn(5, 7, 3, device=device)
+        c = torch.randn(5, 7, device=device)
+        num_ones = (c > 0).sum()
+        r = v[c > 0]
+        self.assertEqual(r.shape, (num_ones, 3))
+
+    # FIXME: conditional indexing not working
+    # def test_jit_indexing(self, device="mps"):
+    #     def fn1(x):
+    #         x[x < 50] = 1.0
+    #         return x
+
+    #     def fn2(x):
+    #         x[0:50] = 1.0
+    #         return x
+
+    #     scripted_fn1 = torch.jit.script(fn1)
+    #     scripted_fn2 = torch.jit.script(fn2)
+    #     data = torch.arange(100, device=device, dtype=torch.float)
+    #     out = scripted_fn1(data.detach().clone())
+    #     ref = torch.tensor(np.concatenate((np.ones(50), np.arange(50, 100))), device=device, dtype=torch.float)
+    #     self.assertEqual(out, ref)
+    #     out = scripted_fn2(data.detach().clone())
+    #     self.assertEqual(out, ref)
+
+    def test_int_indices(self, device="mps"):
+        v = torch.randn(5, 7, 3, device=device)
+        self.assertEqual(v[[0, 4, 2]].shape, (3, 7, 3))
+        self.assertEqual(v[:, [0, 4, 2]].shape, (5, 3, 3))
+        self.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3))
+
+    def test_index_put_src_datatype(self):
+        def helper(device, dtype):
+            src = torch.ones(3, 2, 4, device=device, dtype=dtype)
+            vals = torch.ones(3, 2, 4, device=device, dtype=dtype)
+            indices = (torch.tensor([0, 2, 1]),)
+            res = src.index_put_(indices, vals, accumulate=True)
+            self.assertEqual(res.shape, src.shape)
+        [helper(device="mps", dtype=dtype) for dtype in [torch.float, torch.int32]]
+
+    def test_index_src_datatype(self):
+        def helper(device, dtype):
+            orig_dtype = dtype
+            if dtype is torch.bool:
+                dtype = torch.uint8
+
+            src = torch.ones(3, 2, 4, device=device, dtype=dtype)
+            if orig_dtype is torch.bool:
+                src = src == 1
+            # test index
+            res = src[[0, 2, 1], :, :]
+            self.assertEqual(res.shape, src.shape)
+            # test index_put, no accum
+            src[[0, 2, 1], :, :] = res
+            self.assertEqual(res.shape, src.shape)
+        [helper(device="mps", dtype=dtype) for dtype in [torch.float, torch.float16, torch.long, torch.bool]]
+
     def test_int_indices2d(self, device="mps"):
         # From the NumPy indexing example
         x = torch.arange(0, 12, device=device).view(4, 3)
@@ -6164,6 +6488,20 @@
         result = x[rows[:, None], columns]
         self.assertEqual(result.tolist(), [[0, 2], [9, 11]])
 
+    def test_empty_index(self, device="mps"):
+        x = torch.arange(0, 12, device=device).view(4, 3)
+        idx = torch.tensor([], dtype=torch.long, device=device)
+        self.assertEqual(x[idx].numel(), 0)
+
+        # empty assignment should have no effect but not throw an exception
+        y = x.clone()
+        y[idx] = -1
+        self.assertEqual(x, y)
+
+        mask = torch.zeros(4, 3, device=device).bool()
+        y[mask] = -1
+        self.assertEqual(x, y)
+
     def test_empty_ndim_index(self, device="mps"):
         x = torch.randn(5, device=device)
         self.assertEqual(torch.empty(0, 2, device=device), x[torch.empty(0, 2, dtype=torch.int64, device=device)])
@@ -6182,6 +6520,15 @@
         x = torch.randn(5, device=device)
         self.assertRaises(IndexError, lambda: x[torch.empty(0, 2, dtype=torch.uint8, device=device)])
 
+    def test_empty_slice(self, device="mps"):
+        x = torch.randn(2, 3, 4, 5, device=device)
+        y = x[:, :, :, 1]
+        z = y[:, 1:1, :]
+        self.assertEqual((2, 0, 4), z.shape)
+        # this isn't technically necessary, but matches NumPy stride calculations.
+        self.assertEqual((60, 20, 5), z.stride())
+        self.assertTrue(z.is_contiguous())
+
     def test_index_getitem_copy_bools_slices(self, device="mps"):
         true = torch.tensor(1, dtype=torch.uint8, device=device)
         false = torch.tensor(0, dtype=torch.uint8, device=device)
@@ -6196,6 +6543,33 @@
             self.assertEqual(a.data_ptr(), a[None].data_ptr())
             self.assertEqual(a.data_ptr(), a[...].data_ptr())
 
+    def test_index_setitem_bools_slices(self, device="mps"):
+        true = torch.tensor(1, dtype=torch.uint8, device=device)
+        false = torch.tensor(0, dtype=torch.uint8, device=device)
+
+        tensors = [torch.randn(2, 3, device=device), torch.tensor(3, device=device)]
+
+        for a in tensors:
+            # prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s
+            # (some of these ops already prefix a 1 to the size)
+            neg_ones = torch.ones_like(a) * -1
+            neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0)
+            a[True] = neg_ones_expanded
+            self.assertEqual(a, neg_ones)
+            a[False] = 5
+            self.assertEqual(a, neg_ones)
+            a[true] = neg_ones_expanded * 2
+            self.assertEqual(a, neg_ones * 2)
+            a[false] = 5
+            self.assertEqual(a, neg_ones * 2)
+            a[None] = neg_ones_expanded * 3
+            self.assertEqual(a, neg_ones * 3)
+            a[...] = neg_ones_expanded * 4
+            self.assertEqual(a, neg_ones * 4)
+            if a.dim() == 0:
+                with self.assertRaises(IndexError):
+                    a[:] = neg_ones_expanded * 5
+
     def test_index_scalar_with_bool_mask(self, device="mps"):
         a = torch.tensor(1, device=device)
         uintMask = torch.tensor(True, dtype=torch.uint8, device=device)
@@ -6207,6 +6581,17 @@
         self.assertEqual(a[uintMask], a[boolMask])
         self.assertEqual(a[uintMask].dtype, a[boolMask].dtype)
 
+    def test_setitem_expansion_error(self, device="mps"):
+        true = torch.tensor(True, device=device)
+        a = torch.randn(2, 3, device=device)
+        # check prefix with  non-1s doesn't work
+        a_expanded = a.expand(torch.Size([5, 1]) + a.size())
+        # NumPy: ValueError
+        with self.assertRaises(RuntimeError):
+            a[True] = a_expanded
+        with self.assertRaises(RuntimeError):
+            a[true] = a_expanded
+
     def test_getitem_scalars(self, device="mps"):
         zero = torch.tensor(0, dtype=torch.int64, device=device)
         one = torch.tensor(1, dtype=torch.int64, device=device)
@@ -6231,6 +6616,69 @@
             r[zero]
         self.assertEqual(r, r[...])
 
+    def test_setitem_scalars(self, device="mps"):
+        zero = torch.tensor(0, dtype=torch.int64)
+
+        # non-scalar indexed with scalars
+        a = torch.randn(2, 3, device=device)
+        a_set_with_number = a.clone()
+        a_set_with_scalar = a.clone()
+        b = torch.randn(3, device=device)
+
+        a_set_with_number[0] = b
+        a_set_with_scalar[zero] = b
+        self.assertEqual(a_set_with_number, a_set_with_scalar)
+        a[1, zero] = 7.7
+        self.assertEqual(7.7, a[1, 0])
+
+        # scalar indexed with scalars
+        r = torch.randn((), device=device)
+        with self.assertRaises(IndexError):
+            r[:] = 8.8
+        with self.assertRaises(IndexError):
+            r[zero] = 8.8
+        r[...] = 9.9
+        self.assertEqual(9.9, r)
+
+    def test_basic_advanced_combined(self, device="mps"):
+        # From the NumPy indexing example
+        x = torch.arange(0, 12, device=device).view(4, 3)
+        self.assertEqual(x[1:2, 1:3], x[1:2, [1, 2]])
+        self.assertEqual(x[1:2, 1:3].tolist(), [[4, 5]])
+
+        # Check that it is a copy
+        unmodified = x.clone()
+        x[1:2, [1, 2]].zero_()
+        self.assertEqual(x, unmodified)
+
+        # But assignment should modify the original
+        unmodified = x.clone()
+        x[1:2, [1, 2]] = 0
+        self.assertNotEqual(x, unmodified)
+
+    def test_int_assignment(self, device="mps"):
+        x = torch.arange(0, 4, device=device).view(2, 2)
+        x[1] = 5
+        self.assertEqual(x.tolist(), [[0, 1], [5, 5]])
+
+        x = torch.arange(0, 4, device=device).view(2, 2)
+        x[1] = torch.arange(5, 7, device=device)
+        self.assertEqual(x.tolist(), [[0, 1], [5, 6]])
+
+    def test_byte_tensor_assignment(self, device="mps"):
+        x = torch.arange(0., 16, device=device).view(4, 4)
+        b = torch.ByteTensor([True, False, True, False]).to(device)
+        value = torch.tensor([3., 4., 5., 6.], device=device)
+
+        with warnings.catch_warnings(record=True) as w:
+            x[b] = value
+            self.assertEqual(len(w), 1)
+
+        self.assertEqual(x[0], value)
+        self.assertEqual(x[1], torch.arange(4., 8, device=device))
+        self.assertEqual(x[2], value)
+        self.assertEqual(x[3], torch.arange(12., 16, device=device))
+
     def test_variable_slicing(self, device="mps"):
         x = torch.arange(0, 16, device=device).view(4, 4)
         indices = torch.IntTensor([0, 1]).to(device)
@@ -6250,6 +6698,36 @@
         x = torch.arange(0, 16, device=device).view(4, 4)
         self.assertRaisesRegex(TypeError, 'slice indices', lambda: x["0":"1"])
 
+    def test_out_of_bound_index(self, device="mps"):
+        x = torch.arange(0, 100, device=device).view(2, 5, 10)
+        self.assertRaisesRegex(IndexError, 'index 5 is out of bounds for dimension 1 with size 5', lambda: x[0, 5])
+        self.assertRaisesRegex(IndexError, 'index 4 is out of bounds for dimension 0 with size 2', lambda: x[4, 5])
+        self.assertRaisesRegex(IndexError, 'index 15 is out of bounds for dimension 2 with size 10',
+                               lambda: x[0, 1, 15])
+        self.assertRaisesRegex(IndexError, 'index 12 is out of bounds for dimension 2 with size 10',
+                               lambda: x[:, :, 12])
+
+    def test_zero_dim_index(self, device="mps"):
+        x = torch.tensor(10, device=device)
+        self.assertEqual(x, x.item())
+
+        def runner():
+            print(x[0])
+            return x[0]
+
+        self.assertRaisesRegex(IndexError, 'invalid index', runner)
+
+    def test_cpu_indices(self, device="mps"):
+        idx = torch.tensor([0, 1])
+        b = torch.zeros(2, device=device)
+        x = torch.ones(10, device=device)
+        x[idx] = b  # index_put_
+        ref = torch.ones(10, device=device)
+        ref[:2] = 0
+        self.assertEqual(x, ref, atol=0, rtol=0)
+        out = x[idx]  # index
+        self.assertEqual(out, torch.zeros(2, device=device), atol=0, rtol=0)
+
 class TestRNNMPS(TestCase):
     def test_lstm_1(self, device="mps", dtype=torch.float32):