[MPS] Disallow reshape in slice (#95905)

Disallow reshapes for arrayViews.
Current code allows a base shape of `[2, 4, 256]` to be sliced into `[4, 1, 256]` (view's shape) - which is not possible. Slicing a smaller dimension into a bigger one will always error out.

Fixes https://github.com/pytorch/pytorch/issues/95883
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95905
Approved by: https://github.com/razarmehr, https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index 9877612..9949e9c 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1954,6 +1954,78 @@
         x_cpu = x_cpu + 2
         self.assertEqual(x, x_cpu)
 
+    def test_reshape_storage_offset(self):
+        # https://github.com/pytorch/pytorch/issues/95883
+        B = 4
+        T = 1
+
+        lin_cpu = nn.Linear(10, 256)
+        lin_mps = nn.Linear(10, 256, device="mps")
+
+        # Use the same weights and bias as the ones from the cpu
+        lin_mps.weight.data = lin_cpu.weight.data.detach().clone().to("mps").requires_grad_()
+        lin_mps.bias.data = lin_cpu.bias.data.detach().clone().to("mps").requires_grad_()
+
+        x_mps = torch.rand([B, T, 10], device="mps", requires_grad=True)
+        x_cpu = x_mps.detach().clone().cpu().requires_grad_()
+        x_mps = lin_mps(x_mps)
+        x_cpu = lin_cpu(x_cpu)
+
+        self.assertEqual(x_mps.shape, (B, T, 256))
+        self.assertEqual(x_cpu.shape, (B, T, 256))
+
+        cls_token_mps = torch.rand([1, 256], device="mps", requires_grad=True).repeat(B, 1, 1)
+        cls_token_cpu = cls_token_mps.detach().clone().cpu()
+        x_mps = torch.cat([cls_token_mps, x_mps], dim=1)
+        x_cpu = torch.cat([cls_token_cpu, x_cpu], dim=1)
+
+        x_mps = x_mps.transpose(0, 1)
+        x_cpu = x_cpu.transpose(0, 1)
+
+        target_mps = torch.rand_like(x_mps)
+        target_cpu = target_mps.detach().clone().cpu()
+        loss_mps = F.mse_loss(x_mps, target_mps)
+        loss_cpu = F.mse_loss(x_cpu, target_cpu)
+        self.assertEqual(loss_mps, loss_cpu)
+
+        loss_mps.backward()
+        loss_cpu.backward()
+        self.assertEqual(x_mps.grad, x_cpu.grad)
+
+    def test_stack(self):
+        # https://github.com/pytorch/pytorch/issues/87856
+        x_cpu = torch.tensor([[1, 2]])
+        x_mps = x_cpu.detach().clone().to("mps")
+
+        y_cpu = torch.stack((x_cpu[:, :1], x_cpu[:, -1:]), dim=-1)
+        y_mps = torch.stack((x_mps[:, :1], x_mps[:, -1:]), dim=-1)
+
+        self.assertEqual(y_cpu, y_mps)
+
+        t_mps = torch.tensor([1, 2, 3, 4], device="mps")
+        t_cpu = t_mps.detach().cpu().detach()
+
+        x_mps = t_mps[2:]
+        y_mps = t_mps[:2]
+
+        x_cpu = t_cpu[2:]
+        y_cpu = t_cpu[:2]
+
+        res_mps = torch.stack((y_mps, x_mps), dim=-1)
+        res_cpu = torch.stack((y_cpu, x_cpu), dim=-1)
+
+        self.assertEqual(res_mps, res_cpu)
+
+    def test_unsafe_chunk(self):
+        # https://github.com/pytorch/pytorch/issues/91065
+        a = torch.rand(5, dtype=torch.float32, device="cpu")
+        ret = a.unsafe_chunk(4, 0)
+        y = ret[0] * ret[2]
+        a_mps = a.to("mps")
+        ret_mps = a_mps.unsafe_chunk(4, 0)
+        y_mps = ret_mps[0] * ret_mps[2]
+        self.assertEqual(y, y_mps)
+
     def test_slice_casting(self):
         # generate random binary numbers
         cpu_in = torch.bernoulli(torch.empty(1, 1, 128, 128).uniform_(0, 1)).to(torch.uint8)