Add scatter support for view operations (#79939)

* Add scatter support for view operations; #78074, #78886, #79672
* Update test_slicing_replace_column to properly test different sizes
* Handle in-place changes for binary ops; add new testcase
* Add new view ops testing scatter; add MPSDebugConfig.h config file for debugging purposes
* Merge gatherViewTensor and scatterViewTensor into a generic function
* Add scatter on demand in scatterViewOperation instead of caching it into a generic graph
* Create separate graphs for scatter and gather;
* Create scatter graph at scatter time

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79939
Approved by: https://github.com/razarmehr
diff --git a/test/test_mps.py b/test/test_mps.py
index 0bb7ffc..6264541 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -15,10 +15,12 @@
 import itertools
 from torch._six import inf
 from torch.nn import Parameter
-from torch.testing._internal.common_utils import run_tests, TestCase, download_file, TEST_WITH_UBSAN
+from torch.testing._internal.common_utils import run_tests, TestCase, download_file, TEST_WITH_UBSAN, gradcheck, gradgradcheck
+from torch.testing import make_tensor
 from torch.testing._comparison import TensorLikePair
 import torch.backends.mps
 from torch.distributions import Uniform, Exponential
+from functools import partial
 
 from torch.testing._internal.common_nn import NNTestCase
 import numpy as np
@@ -4782,6 +4784,841 @@
         m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
         self._test_addmm_addmv(torch.addmm, M, m1, m2, transpose_out=t4)
 
+class TestGatherScatter(TestCase):
+    def test_slicing_with_step(self):
+        # Slicing with step
+        # https://github.com/pytorch/pytorch/issues/78886
+        x_mps = torch.zeros(10, dtype=torch.float32, device="mps")
+        x_mps[::2] = 1.0
+
+        x_cpu = torch.zeros(10, dtype=torch.float32, device="mps")
+        x_cpu[::2] = 1.0
+
+        self.assertEqual(x_cpu, x_mps)
+
+    def test_slicing_replace_column(self):
+        # https://github.com/pytorch/pytorch/issues/78074
+        def _helper(tensor_data):
+            x_cpu = torch.tensor(tensor_data)
+            x_mps = x_cpu.to('mps')
+
+            x_cpu[:, 0] = 7
+            x_mps[:, 0] = 7
+
+            self.assertEqual(x_cpu, x_mps)
+
+        _helper([[1, 2, 3], [4, 5, 6]])
+        _helper([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
+        _helper([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
+
+    def test_inplace_scatter(self):
+        # https://github.com/pytorch/pytorch/issues/79672
+        a_mps = torch.ones((2, 2),).to(torch.device("mps"))
+        b_mps = torch.ones((2, 2),).to(torch.device("mps"))
+
+        a_cpu = torch.ones((2, 2),).to(torch.device("cpu"))
+        b_cpu = torch.ones((2, 2),).to(torch.device("cpu"))
+
+        a_mps[:, 0] += b_mps[:, 0]
+        a_cpu[:, 0] += b_cpu[:, 0]
+        self.assertEqual(a_cpu, a_mps)
+
+        a_mps[:, 0] = a_mps[:, 0] + b_mps[:, 0]
+        a_cpu[:, 0] = a_cpu[:, 0] + b_cpu[:, 0]
+        self.assertEqual(a_cpu, a_mps)
+
+class TestViewOpsMPS(TestCase):
+    exact_dtype = True
+
+    def is_view_of(self, base, other):
+        if (not other._is_view() or
+                other is base or
+                other._base is not base or
+                base.device != other.device):
+            return False
+        # Note: only validates storage on native device types
+        # because some accelerators, like XLA, do not expose storage
+        if base.device.type == 'cpu' or base.device.type == 'cuda':
+            if base.storage().data_ptr() != other.storage().data_ptr():
+                return False
+
+        return True
+
+    # Returns true if v1 and v2 are views of the same base
+    def is_view_of_same_base(self, v1, v2):
+        if (not v1._is_view() or v1 is v2):
+            return False
+        return self.is_view_of(v1._base, v2)
+
+    # Performs transpose if contiguous=True, else returns the input tensor as is
+    def _do_transpose(self, x, contiguous=False, dim0=0, dim1=1):
+        if contiguous:
+            return x
+        else:
+            return x.transpose(dim0, dim1)
+
+    def test_diagonal_view(self, device="mps"):
+        t = torch.ones((5, 5), device=device)
+        v = torch.diagonal(t)
+        self.assertTrue(self.is_view_of(t, v))
+
+        v[0] = 0
+        self.assertEqual(t[0, 0], v[0])
+
+        t = torch.ones((3, 3, 3), device="mps")
+        v = torch.diagonal(t, offset=1, dim1=1, dim2=2)
+        self.assertTrue(self.is_view_of(t, v))
+
+        v[0, 0] = 0
+        self.assertEqual(t[0, 0, 1], v[0, 0])
+
+    def test_select_view(self, device="mps") -> None:
+        t = torch.ones((5, 5), device=device)
+        v = t.select(0, 2)
+        self.assertTrue(self.is_view_of(t, v))
+
+        v[0] = 0
+        self.assertEqual(t[2, 0], v[0])
+
+    def test_unbind_view(self, device="mps") -> None:
+        t = torch.zeros((5, 5), device=device)
+        tup = torch.unbind(t)
+
+        for idx, v in enumerate(tup):
+            self.assertTrue(self.is_view_of(t, v))
+
+            v[0] = idx + 1
+            self.assertEqual(t[idx, 0], v[0])
+
+    def test_expand_view(self, device="mps") -> None:
+        t = torch.ones((5, 1), device=device)
+        v = t.expand(5, 5)
+        self.assertTrue(self.is_view_of(t, v))
+
+        v[2, 2] = 0
+        self.assertEqual(t[2, 0], v[2, 2])
+
+    def test_expand_as_view(self, device="mps"):
+        t = torch.ones((5, 1), device=device)
+        e = torch.empty((5, 5), device=device)
+        v = t.expand_as(e)
+        self.assertTrue(self.is_view_of(t, v))
+
+        v[2, 2] = 0
+        self.assertEqual(t[2, 0], v[2, 2])
+
+    def test_narrow_view(self, device="mps"):
+        t = torch.ones((5, 5), device=device)
+        v = torch.narrow(t, 1, 2, 2)
+        self.assertTrue(self.is_view_of(t, v))
+
+        v[0, 0] = 0
+        self.assertEqual(t[0, 2], v[0, 0])
+
+    def test_permute_view(self, device="mps") -> None:
+        t = torch.ones((5, 5), device=device)
+        v = t.permute(1, 0)
+        self.assertTrue(self.is_view_of(t, v))
+
+        v[0, 1] = 0
+        self.assertEqual(t[1, 0], v[0, 1])
+
+    def test_transpose_view(self, device="mps"):
+        for fn in (torch.swapdims, torch.swapaxes, torch.transpose):
+            t = torch.ones((5, 5), device=device)
+            v = fn(t, 0, 1)
+            self.assertTrue(self.is_view_of(t, v))
+
+            v[0, 1] = 0
+            self.assertEqual(t[1, 0], v[0, 1])
+
+    def test_transpose_inplace_view(self, device="mps"):
+        t = torch.ones(5, 5, device=device)
+        v = t.view_as(t)
+        v = v.swapdims_(0, 1)
+        self.assertTrue(self.is_view_of(t, v))
+        v[0, 1] = 0
+        self.assertEqual(t[1, 0], v[0, 1])
+
+        t = torch.ones(5, 5, device=device)
+        v = t.view_as(t)
+        v = v.swapaxes_(0, 1)
+        self.assertTrue(self.is_view_of(t, v))
+        v[0, 1] = 0
+        self.assertEqual(t[1, 0], v[0, 1])
+
+        t = torch.ones(5, 5, device=device)
+        v = t.view_as(t)
+        v = v.transpose_(0, 1)
+        self.assertTrue(self.is_view_of(t, v))
+        v[0, 1] = 0
+        self.assertEqual(t[1, 0], v[0, 1])
+
+    def test_t_view(self, device="mps"):
+        t = torch.ones((5, 5), device=device)
+        v = t.t()
+        self.assertTrue(self.is_view_of(t, v))
+
+        v[0, 1] = 0
+        self.assertEqual(t[1, 0], v[0, 1])
+
+    def test_t_inplace_view(self, device="mps"):
+        t = torch.ones(5, 5, device=device)
+        v = t.view_as(t)
+        v = v.t_()
+        self.assertTrue(self.is_view_of(t, v))
+        v[0, 1] = 0
+        self.assertEqual(t[1, 0], v[0, 1])
+
+    def test_T_view(self, device="mps"):
+        for op in ("T", "H", "mT", "mH"):
+            t = torch.ones((5, 5), device=device)
+            v = getattr(t, op)
+            self.assertTrue(self.is_view_of(t, v))
+
+            v[0, 1] = 0
+            self.assertEqual(t[1, 0], v[0, 1])
+
+    # requires aten::unfold
+    # def test_unfold_view(self, device="mps"):
+    #     t = torch.ones(10, device=device)
+    #     v = t.unfold(0, 3, 2)
+    #     self.assertTrue(self.is_view_of(t, v))
+
+    #     v[1, 0] = 0
+    #     self.assertEqual(t[2], v[1, 0])
+
+    def test_squeeze_view(self, device="mps"):
+        t = torch.ones(5, 1, 5, device=device)
+        v = torch.squeeze(t)
+        self.assertTrue(self.is_view_of(t, v))
+        v[0, 1] = 0
+        self.assertEqual(t, v._base)
+
+    def test_squeeze_inplace_view(self, device="mps"):
+        t = torch.ones(5, 5, device=device)
+        v = t.view_as(t)
+        v = v.squeeze_()
+        self.assertTrue(self.is_view_of(t, v))
+        v[0, 1] = 0
+        self.assertEqual(t, v._base)
+
+    def test_unsqueeze_view(self, device="mps"):
+        t = torch.ones(5, 5, device=device)
+        v = torch.unsqueeze(t, 1)
+        self.assertTrue(self.is_view_of(t, v))
+
+        v[0, 0, 1] = 0
+        self.assertEqual(t[0, 1], v[0, 0, 1])
+
+    def test_unsqueeze_inplace_view(self, device="mps"):
+        t = torch.ones(5, 5, device=device)
+        v = t.view_as(t)
+        v = v.unsqueeze_(1)
+        self.assertTrue(self.is_view_of(t, v))
+        v[0, 0, 1] = 0
+        self.assertEqual(t[0, 1], v[0, 0, 1])
+
+    def test_as_strided_view(self, device="mps"):
+        t = torch.ones(5, 5, device=device)
+        v = torch.as_strided(t, (25,), (1,))
+        self.assertTrue(self.is_view_of(t, v))
+
+        v[6] = 0
+        self.assertEqual(t[1, 1], v[6])
+
+    def test_as_strided_inplace_view(self, device="mps"):
+        t = torch.ones(5, 5, device=device)
+        v = t.view_as(t)
+        v = v.as_strided_((25,), (1,))
+        self.assertTrue(self.is_view_of(t, v))
+        v[6] = 0
+        self.assertEqual(t[1, 1], v[6])
+
+    def test_view_view(self, device="mps"):
+        t = torch.ones(5, 5, device=device)
+        v = t.view(25)
+        self.assertTrue(self.is_view_of(t, v))
+
+        v[6] = 0
+        self.assertEqual(t[1, 1], v[6])
+
+    def test_view_as_view(self, device="mps"):
+        t = torch.ones(5, 5, device=device)
+        e = torch.empty((25,))
+        v = t.view_as(e)
+        self.assertTrue(self.is_view_of(t, v))
+
+        v[6] = 0
+        self.assertEqual(t[1, 1], v[6])
+
+    def test_contiguous_self(self, device="mps"):
+        t = torch.ones(5, 5, device=device)
+        s = t.contiguous()
+        self.assertTrue(s is t)
+
+    def test_contiguous_nonview(self, device="mps"):
+        t = torch.ones(5, 5, device=device)
+        nv = t.t().contiguous()
+        self.assertTrue(not self.is_view_of(t, nv))
+
+        nv[0, 0] = 0
+        self.assertNotEqual(t[0, 0], nv[0, 0])
+
+    def test_reshape_view(self, device="mps"):
+        t = torch.ones(5, 5, device=device)
+        v = torch.reshape(t, (25,))
+        self.assertTrue(self.is_view_of(t, v))
+
+        v[6] = 0
+        self.assertEqual(t[1, 1], v[6])
+
+    def test_reshape_as_view(self, device="mps"):
+        t = torch.ones(5, 5, device=device)
+        e = torch.empty((25,), device=device)
+        v = t.reshape_as(e)
+        self.assertTrue(self.is_view_of(t, v))
+
+        v[6] = 0
+        self.assertEqual(t[1, 1], v[6])
+
+    def test_reshape_nonview(self, device="mps"):
+        t = torch.ones(5, 5, device=device)
+        nv = torch.reshape(t.t(), (25,))
+        self.assertTrue(not self.is_view_of(t, nv))
+
+        nv[6] = 0
+        self.assertNotEqual(t[1, 1], nv[6])
+
+    def test_flatten_view(self, device="mps"):
+        def test_writes_propagate(t, v):
+            idx_t = (0,) * t.ndim
+            idx_v = (0,) * v.ndim
+            v[idx_v] = 0
+            self.assertEqual(t[idx_t], v[idx_v])
+
+        t = torch.ones(1, 2, 3, 4, device=device)
+        v = t.flatten()
+        self.assertTrue(self.is_view_of(t, v))
+        test_writes_propagate(t, v)
+
+        # zero-dimensional tensor
+        t = torch.tensor(1, device=device)
+        v = t.flatten()
+        test_writes_propagate(t, v)
+        self.assertTrue(self.is_view_of(t, v))
+
+        t = torch.ones(1, 2, 3, 4, device=device).transpose(2, 3)
+        v = t.flatten(0, 1)
+        test_writes_propagate(t, v)
+        self.assertTrue(self.is_view_of_same_base(t, v))
+
+        # stride[i] = stride[i + 1] * size[i + 1] is satisfied for 3 groups:
+        t = torch.ones(720, device=device) \
+            .as_strided((2, 3, 2, 3, 5, 4), (6, 2, 15, 5, 1, 0))
+        #               [--1--|---2---|-3-] [--1--|----2---|-3-]
+        v1 = t.flatten(0, 1)
+        v2 = v1.flatten(1, 3)
+        v3 = v2.flatten(2, 2)
+        test_writes_propagate(t, v1)
+        self.assertTrue(self.is_view_of_same_base(t, v1))
+        test_writes_propagate(t, v2)
+        self.assertTrue(self.is_view_of_same_base(t, v2))
+        test_writes_propagate(t, v3)
+        self.assertTrue(self.is_view_of_same_base(t, v3))
+
+    def test_flatten_nonview(self, device="mps"):
+        def assert_is_nonview(t, nv):
+            idx_t = (0,) * t.ndim
+            idx_nv = (0,) * nv.ndim
+            self.assertTrue(not nv._is_view())
+            nv[idx_nv] = 0
+            self.assertNotEqual(t[idx_t], nv[idx_nv])
+        t = torch.ones(2, 3, 2, 3, device=device).transpose(2, 3)
+        nv = t.flatten(1, 3)
+        assert_is_nonview(t, nv)
+
+        t = torch.ones(2, 2, device=device).T
+        nv = t.flatten()
+        assert_is_nonview(t, nv)
+
+        # flatten returns the original object if start_dim=end_dim
+        t = t = torch.ones(2, 2, device=device)
+        nv = t.flatten(1, 1)
+        self.assertTrue(t is nv)
+
+    def test_basic_indexing_slice_view(self, device="mps"):
+        t = torch.ones(5, 5, device=device)
+        v = t[:2, :3]
+        self.assertTrue(self.is_view_of(t, v))
+
+        v[0, 0] = 0
+        self.assertEqual(t[0, 0], v[0, 0])
+
+    def test_basic_indexing_ellipses_view(self, device="mps"):
+        t = torch.ones(5, 5, device=device)
+        v = t[..., :2]
+        self.assertTrue(self.is_view_of(t, v))
+
+        v[0, 0] = 0
+        self.assertEqual(t[0, 0], v[0, 0])
+
+    def test_basic_indexing_newaxis_view(self, device="mps"):
+        t = torch.ones(5, 5, device=device)
+        v = t[None, :2, 3]
+        self.assertTrue(self.is_view_of(t, v))
+
+        v[0, 0] = 0
+        self.assertEqual(t[0, 3], v[0, 0])
+
+    def test_chunk_view(self, device="mps"):
+        t = torch.zeros(3, 3, device=device)
+        l = torch.chunk(t, 3)
+
+        for idx, v in enumerate(l):
+            self.assertTrue(self.is_view_of(t, v))
+
+            v[0, 0] = idx + 1
+            self.assertEqual(t[idx, 0], v[0, 0])
+
+    def test_split_view(self, device="mps"):
+        t = torch.zeros(3, 3, device=device)
+        l = torch.split(t, [1, 1, 1])
+
+        for idx, v in enumerate(l):
+            self.assertTrue(self.is_view_of(t, v))
+
+            v[0, 0] = idx + 1
+            self.assertEqual(t[idx, 0], v[0, 0])
+
+    def test_movedim_view(self, device="mps"):
+        def run_test(device, op):
+            t = torch.zeros(3, 3, device=device)
+            out = op(t)
+
+            self.assertTrue(self.is_view_of(t, out))
+
+            # Randomly change values in output
+            # and verify that original is changed
+            # as well.
+            for _ in range(3):
+                idx_1, idx_2 = random.randint(0, 2), random.randint(0, 2)
+                out[idx_1, idx_2] = random.random()
+                self.assertEqual(t[idx_2, idx_1], out[idx_1, idx_2])
+
+        for fn in [torch.movedim, torch.moveaxis]:
+            op = partial(fn, source=(0, 1), destination=(1, 0))
+            run_test(device, op)
+
+            op = partial(fn, source=0, destination=1)
+            run_test(device, op)
+
+    # Testing that the generated view_copy kernel and its derivative are implemented correctly
+    def test_view_copy(self, device="mps"):
+        a = torch.randn(4, device=device, requires_grad=True)
+        a_ref = a.clone().detach().requires_grad_()
+        a_view = a_ref.view(2, 2)
+        a_view_copy = torch.view_copy(a, (2, 2))
+
+        # view_copy ops don't preserve view relationship
+        self.assertTrue(self.is_view_of(a_ref, a_view))
+        self.assertFalse(self.is_view_of(a, a_view_copy))
+
+        a_view_copy.sum().backward()
+        a_view.sum().backward()
+
+        # forward and backward give the same shape + result
+        self.assertEqual(a_view_copy, a_view)
+        self.assertEqual(a.grad, a_ref.grad)
+
+    def test_view_copy_out(self, device="mps"):
+        a = torch.randn(2, 2, device=device)
+        out = torch.empty(2, device=device)
+
+        torch.diagonal_copy(a, out=out)
+        expected = torch.diagonal_copy(a)
+
+        self.assertEqual(expected, out)
+
+        a = torch.randn(4, device=device)
+        out1 = torch.empty(2, device=device)
+        out2 = torch.empty(2, device=device)
+
+        torch.split_copy(a, 2, out=(out1, out2))
+        expected1, expected2 = torch.split_copy(a, 2)
+
+        self.assertEqual(expected1, out1)
+        self.assertEqual(expected2, out2)
+
+    def test_empty_reshape(self, device="mps"):
+        x = torch.randn(0, 6, device=device)
+        self.assertEqual((1, 0, 6, 1, 1), x.reshape(1, 0, 6, 1, 1).shape)
+        # should be viewable -- i.e. data_ptr is the same.
+        self.assertEqual(x.data_ptr(), x.reshape(1, 0, 6, 1, 1).data_ptr())
+
+        # match NumPy semantics -- don't infer the size of dimension with a degree of freedom
+        self.assertRaises(RuntimeError, lambda: x.reshape(0, -1))
+
+    def test_expand(self, device="mps"):
+        tensor = torch.rand(1, 8, 1, device=device)
+        tensor2 = torch.rand(5, device=device)
+        template = torch.rand(4, 8, 5, device=device)
+        target = template.size()
+        self.assertEqual(tensor.expand_as(template).size(), target)
+        self.assertEqual(tensor.expand(4, 8, 5).size(), target)
+        self.assertEqual(tensor.expand(target).size(), target)
+        self.assertEqual(tensor2.expand_as(template).size(), target)
+        self.assertEqual(tensor2.expand(4, 8, 5).size(), target)
+        self.assertEqual(tensor2.expand(target).size(), target)
+
+        # test double expand
+        self.assertEqual(tensor2.expand(1, 5).expand(2, 2, 5), tensor2.repeat(2, 2, 1))
+
+        # test non-contiguous
+        noncontig = torch.randn(5, 2, 1, 3, device=device)[:, 0]
+        self.assertFalse(noncontig.is_contiguous())
+        self.assertEqual(noncontig.expand(2, 5, 4, 3), noncontig.contiguous().repeat(2, 1, 4, 1))
+
+        # make sure it's compatible with unsqueeze
+        expanded = tensor2.expand(1, 1, 5)
+        unsqueezed = tensor2.unsqueeze(0).unsqueeze(1)
+        self.assertEqual(expanded, unsqueezed)
+        self.assertEqual(expanded.stride(), unsqueezed.stride())
+
+        # test -1 as target size
+        self.assertEqual(tensor.expand(4, -1, 5), tensor.expand(4, 8, 5))
+        self.assertRaises(RuntimeError, lambda: tensor2.expand(-1, -1))
+
+        # test expanding empty to empty
+        self.assertEqual(torch.zeros(0, device=device).expand((0,)), torch.zeros(0, device=device))
+
+    def test_view_empty(self, device="mps"):
+        x = torch.randn(0, 6, device=device)
+        self.assertEqual((1, 0, 6, 1, 1), x.view(1, 0, 6, 1, 1).shape)
+
+    def test_reshape(self, device="mps"):
+        x = torch.randn(3, 3, device=device)
+        self.assertEqual(x.data_ptr(), x.reshape(-1).data_ptr())
+        self.assertEqual(x.data_ptr(), x.reshape(1, 9, 1).data_ptr())
+        self.assertEqual(torch.reshape(x, (9,)), x.reshape(9))
+        self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1))
+
+        y = torch.randn(4, 4, 4, device=device)[:, 0, :]
+        # .data_ptr() on meta tensors is always 0 so they are equal regardless of the reshape
+        if device != "meta":
+            self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr())
+        self.assertEqual(y.contiguous().view(-1), y.reshape(-1))
+        self.assertEqual(y.reshape(2, 2, 4).data_ptr(), y.data_ptr())
+
+        s = torch.randn((), device=device)
+        self.assertEqual(s.data_ptr(), s.reshape(()).data_ptr())
+        self.assertEqual(s.reshape(-1).shape, (1,))
+        self.assertRaises(RuntimeError, lambda: s.reshape(2))
+
+        empty = torch.tensor([], device=device)
+        self.assertEqual(empty, empty.reshape(-1))
+        self.assertEqual(empty, empty.reshape([0]))
+        # TODO: fix these once we have multi-dimensional empty tensors
+        self.assertEqual(empty.reshape([0, 1]).shape, (0, 1))
+        self.assertEqual(empty.reshape([1, -1]).shape, (1, 0))
+        self.assertRaises(RuntimeError, lambda: empty.reshape(1))
+
+        x = torch.randn(3, 3, device=device)
+        self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(9)).data_ptr())
+        self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(1, 9, 1)).data_ptr())
+        self.assertRaises(RuntimeError, lambda: x.reshape_as(torch.rand(10, device=device)))
+
+    def test_narrow(self, device="mps"):
+        x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
+        self.assertEqual(x.narrow(0, 0, 1), torch.tensor([[0, 1, 2]]))
+        self.assertEqual(x.narrow(0, 0, 2), torch.tensor([[0, 1, 2], [3, 4, 5]]))
+        self.assertEqual(x.narrow(0, 1, 1), torch.tensor([[3, 4, 5]]))
+        self.assertEqual(x.narrow(0, -1, 1), torch.tensor([[6, 7, 8]]))
+        self.assertEqual(x.narrow(0, -2, 2), torch.tensor([[3, 4, 5], [6, 7, 8]]))
+        self.assertEqual(x.narrow(0, -3, 3), torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]))
+        self.assertEqual(x.narrow(-1, -1, 1), torch.tensor([[2], [5], [8]]))
+        self.assertEqual(x.narrow(-2, -1, 1), torch.tensor([[6, 7, 8]]))
+
+    def test_narrow_tensor(self, device="mps"):
+        x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
+        self.assertEqual(x.narrow(0, torch.tensor(0), 1), torch.tensor([[0, 1, 2]]))
+        with self.assertRaises(Exception):
+            x.narrow(0, torch.tensor(0.), 1)
+        with self.assertRaises(Exception):
+            x.narrow(0, torch.tensor([0]), 1)
+        with self.assertRaises(Exception):
+            x.narrow(0, torch.tensor([0, 1]), 1)
+
+    def test_t(self, device="mps"):
+        # Test 0D tensors
+        x = torch.randn(())
+        self.assertEqual(x, x.t())
+        x = x.to_sparse()
+        self.assertEqual(x, x.t())
+
+        # Test 1D tensors
+        x = torch.arange(4)
+        self.assertEqual(x, x.t())
+        x = x.to_sparse()
+        self.assertEqual(x, x.t())
+
+        # Test 2D tensors
+        x = torch.rand((2, 2))
+        self.assertEqual(x.t(), x.transpose(0, 1))
+        x = x.to_sparse()
+        self.assertEqual(x.t(), x.transpose(0, 1))
+
+        # Test 3D tensor
+        x = torch.rand((2, 2, 2))
+        with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 dimensions, but self is 3D'):
+            x.t()
+        x = x.to_sparse()
+        with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 sparse and 0 dense dimensions'):
+            x.t()
+
+    def test_split(self, device="mps"):
+        tensor = torch.rand(7, 4)
+        split_size = 3
+        dim = 0
+        target_sizes = ([3, 4], [3, 4], [1, 4])
+        splits = tensor.split(split_size, dim)
+        start = 0
+        for target_size, split in zip(target_sizes, splits):
+            self.assertEqual(split.size(), target_size)
+            self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0)
+            start = start + target_size[dim]
+
+        # Variable sections split
+        tensor = torch.randn(20, 10)
+        dim = 0
+        split_sizes = [5, 5, 10]
+        target_sizes = ([[5, 10], [5, 10], [10, 10]])
+        splits = tensor.split(split_sizes, dim)
+        start = 0
+        for target_size, split in zip(target_sizes, splits):
+            self.assertEqual(split.size(), target_size)
+            self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0)
+            start = start + target_size[dim]
+
+        split_sizes = [2, 2, 6]
+        target_sizes = ([20, 2], [20, 2], [20, 6])
+        dim = 1
+        splits = tensor.split(split_sizes, dim)
+        start = 0
+        for target_size, split in zip(target_sizes, splits):
+            self.assertEqual(split.size(), target_size)
+            self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0)
+            start = start + target_size[dim]
+
+    def test_chunk(self, device="mps"):
+        tensor = torch.rand(4, 7)
+        num_chunks = 3
+        dim = 1
+        target_sizes = ([4, 3], [4, 3], [4, 1])
+        splits = tensor.chunk(num_chunks, dim)
+        start = 0
+        for target_size, split in zip(target_sizes, splits):
+            self.assertEqual(split.size(), target_size)
+            self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split,
+                             atol=0, rtol=0)
+            start = start + target_size[dim]
+
+        # Invalid chunk sizes
+        error_regex = 'chunk expects.*greater than 0'
+        with self.assertRaisesRegex(RuntimeError, error_regex):
+            tensor.chunk(0)
+        with self.assertRaisesRegex(RuntimeError, error_regex):
+            tensor.chunk(-2)
+
+    def test_unsqueeze(self, device="mps") -> None:
+        x = torch.randn(2, 3, 4)
+        y = x.unsqueeze(1)
+        self.assertEqual(y, x.view(2, 1, 3, 4))
+        y = x.clone().unsqueeze_(2)
+        self.assertEqual(y, x.view(2, 3, 1, 4))
+
+        x = x[:, 1]
+        self.assertFalse(x.is_contiguous())
+        y = x.unsqueeze(1)
+        self.assertEqual(y, x.contiguous().view(2, 1, 4))
+        y = x.clone().unsqueeze_(2)
+        self.assertEqual(y, x.contiguous().view(2, 4, 1))
+
+    # unit test for special case transposed copy (see ATen/native/Copy.cpp for details)
+    def test_big_transpose(self, device="mps"):
+        t = torch.rand(456, 789, device=device)
+        t1 = t.t().contiguous()
+        t2 = torch.from_numpy(t.cpu().numpy().transpose())
+        self.assertEqual(t1, t2)
+
+    def test_T(self, device="mps"):
+        a = torch.randn(2, 3, 4, device=device)
+        t1 = a.T
+        t2 = a.permute(2, 1, 0)
+        self.assertEqual(t2, t1)
+        b = torch.randn(10, device=device)
+        self.assertEqual(b, b.T)
+        scalar = torch.tensor(5, device=device)
+        self.assertEqual(scalar, scalar.T)
+
+    def test_transposes(self, device="mps", dtype=torch.float32):
+        for op in ("T", "H", "mT", "mH", "adjoint"):
+            shapes = ((), (2, 3), (2, 3, 4)) if op[0] == "m" or op == "adjoint" else ((), (2, 3),)
+            for shape in shapes:
+                a = make_tensor(shape, device=device, dtype=dtype)
+                t1 = getattr(a, op)
+                if op == "adjoint":
+                    t1 = t1()
+                t2 = a
+                if a.ndim != 0:
+                    t2 = t2.transpose(-2, -1)
+                if op[-1] == "H" or op == "adjoint":
+                    t2 = t2.conj()
+                self.assertEqual(t2, t1)
+
+    def test_transposes_errors(self, device="mps", dtype=torch.float32):
+        for op in ("H", "mT", "mH", "adjoint"):
+            shapes = ((2,), (2, 3, 4)) if op == "H" else ((2,),)
+            for shape in shapes:
+                a = make_tensor(shape, device=device, dtype=dtype)
+                with self.assertRaisesRegex(RuntimeError, "only supported on matrices"):
+                    t1 = getattr(a, op)
+                    if op == "adjoint":
+                        t1 = t1()
+
+    def test_python_types(self, device="mps"):
+        a1 = torch.randn((1, 2), device=device, dtype=torch.float32)
+        a2 = torch.randn((1, 2), device=device, dtype=torch.float32)
+        self.assertEqual(a1.dtype, a2.dtype)
+
+        b1 = torch.arange(10, 20, dtype=torch.int64, device=device)
+        b2 = torch.arange(10, 20, dtype=int, device=device)
+        self.assertEqual(b1.dtype, b2.dtype)
+
+        c1 = torch.tensor([True, False], dtype=torch.bool, device=device)
+        c2 = torch.tensor([True, False], dtype=bool, device=device)
+        self.assertEqual(c1.dtype, c2.dtype)
+
+    # TODO: is resize best put in test_view_ops?
+    def test_resize_as_preserves_strides(self, device="mps"):
+        x = torch.empty(2, 3).t()
+        old_strides = x.stride()
+        x.resize_as_(x)
+        self.assertEqual(x.stride(), old_strides)
+
+    def test_memory_format_resize_as(self, device="mps"):
+        def test_helper(shape, memory_format, device="mps"):
+            xc = torch.randn(shape, device=device).contiguous(memory_format=memory_format)
+            flat = torch.randn(xc.numel(), device=device)
+            flat.resize_as_(xc, memory_format=torch.preserve_format)
+            self.assertTrue(flat.is_contiguous(memory_format=memory_format))
+
+        test_helper((10, 3, 32, 32), torch.channels_last, device="mps")
+        test_helper((3, 10, 3, 32, 32), torch.channels_last_3d, device="mps")
+
+    def test_memory_format_resize_(self, device="mps"):
+        def test_helper(shape, numel, memory_format, device="mps"):
+            flat = torch.randn(numel, device=device)
+            flat.resize_(shape, memory_format=memory_format)
+            self.assertTrue(flat.is_contiguous(memory_format=memory_format))
+
+        test_helper((10, 3, 32, 32), 10 * 3 * 32 * 32, torch.channels_last, device="mps")
+        test_helper((3, 10, 3, 32, 32), 3 * 10 * 3 * 32 * 32, torch.channels_last_3d, device="mps")
+
+    # TODO: OpInfo this
+    def _test_atleast(self, device, torch_fn):
+        # 0-dim
+        s = torch.tensor(0.5, dtype=torch.double, requires_grad=True)
+
+        gradcheck(lambda x: torch_fn(x), s)
+        gradgradcheck(lambda x: torch_fn(x), s)
+
+        # 1-dim
+        a = torch.rand(4, dtype=torch.double, requires_grad=True)
+
+        gradcheck(lambda x: torch_fn(x), a)
+        gradgradcheck(lambda x: torch_fn(x), a)
+
+        # 2,3,4-dim
+        b = torch.rand(4, 3, dtype=torch.double, requires_grad=True)
+        c = torch.rand(4, 3, 2, dtype=torch.double, requires_grad=True)
+        d = torch.rand(4, 3, 2, 1, dtype=torch.double, requires_grad=True)
+
+        input_tuple = (s, a, b, c, d)
+        gradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple)
+        gradgradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple)
+
+    def test_atleast_gradient(self, device="mps"):
+        self._test_atleast(device, torch.atleast_1d)
+        self._test_atleast(device, torch.atleast_2d)
+        self._test_atleast(device, torch.atleast_3d)
+
+    def test_view(self, device="mps"):
+        tensor = torch.rand(15, device=device)
+        template = torch.rand(3, 5, device=device)
+        empty = torch.empty(0, device=device)
+        target = template.size()
+        self.assertEqual(tensor.view_as(template).size(), target)
+        self.assertEqual(tensor.view(3, 5).size(), target)
+        self.assertEqual(tensor.view(torch.Size([3, 5])).size(), target)
+        self.assertEqual(tensor.view(-1, 5).size(), target)
+        self.assertEqual(tensor.view(3, -1).size(), target)
+        tensor_view = tensor.view(5, 3)
+        tensor_view.fill_(random.uniform(0, 1))
+        self.assertEqual(empty.view_as(empty), empty)
+        self.assertEqual(empty.view(0), empty)
+        self.assertEqual(empty.view(0, 3, 0, 1).size(), torch.Size([0, 3, 0, 1]))
+        self.assertEqual(empty.view(0, 3, 0, 1).view(0), empty)
+
+        # test size inference with empty tensors
+        self.assertEqual(empty.view(-1).size(), torch.Size([0]))
+        self.assertEqual(empty.view(10, 3, -1).size(), torch.Size([10, 3, 0]))
+
+        with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"):
+            empty.view(-1, 0)
+
+        with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"):
+            empty.view(3, 0, -1, 0)
+
+        self.assertRaises(RuntimeError, lambda: tensor.view(15, 0))
+        self.assertRaises(RuntimeError, lambda: tensor.view(7, -1))
+        self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1))
+
+    # RuntimeError: Invalid device for storage: mps
+    # def test_contiguous(self, device="mps"):
+    #     x = torch.randn(1, 16, 5, 5, device=device)
+    #     self.assertTrue(x.is_contiguous())
+    #     stride = list(x.stride())
+    #     stride[0] = 20
+    #     # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1
+    #     x.set_(x.storage(), 0, x.size(), stride)
+    #     self.assertTrue(x.is_contiguous())
+
+    def test_resize_all_dtypes_and_devices(self, device="mps"):
+        shape = (2, 2)
+        for dt in (torch.half, torch.bfloat16, torch.bool):
+            x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
+            x.resize_(shape)
+            self.assertEqual(shape, x.shape)
+
+    def test_resize_as_all_dtypes_and_devices(self, device="mps"):
+        for dt in (torch.half, torch.bfloat16, torch.bool):
+            x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
+            y = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dt, device=device)
+            x.resize_as_(y)
+            self.assertEqual(y.shape, x.shape)
+
+    def test_resize_overflow(self, device="mps"):
+        x = torch.empty((), dtype=torch.float64)
+        with self.assertRaisesRegex(RuntimeError, 'Storage size calculation overflowed'):
+            x.resize_([2, 4, 2**29, 2**29])
+        with self.assertRaisesRegex(RuntimeError, 'overflow'):
+            x.resize_([8, 8, 2**29, 2**29])
+
+    def test_view_all_dtypes_and_devices(self, device="mps"):
+        for dt in (torch.float, torch.bool):
+            x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
+            self.assertEqual(x.view(6).shape, [6])
 
 class TestRNNMPS(TestCase):
     def test_lstm_1(self, device="mps", dtype=torch.float32):