[MPS] Add native strided API for MPSNDArray starting with macOS 15 (#128393)

Add support for native strides in MPS starting with macOS Sequoia. This will get rid of the additional gather and scatter operations needed to solve the strides or storage offsets of the tensors.

Summary of changes (starting with macOS 15):
- Add support for **MPS strided API** (strides/storage offsets etc):
   - [initWithBuffer:offset:descriptor:](https://developer.apple.com/documentation/metalperformanceshaders/mpsndarray/4391636-initwithbuffer?language=objc)
   - [arrayViewWithCommandBuffer:descriptor:aliasing:](https://developer.apple.com/documentation/metalperformanceshaders/mpsndarray/3114040-arrayviewwithcommandbuffer?language=objc)
   - [arrayViewWithShape:strides:](https://developer.apple.com/documentation/metalperformanceshaders/mpsndarray/4408694-arrayviewwithshape?language=objc)
   - [reshapeWithCommandBuffer:sourceArray:shape:destinationArray:](https://developer.apple.com/documentation/metalperformanceshaders/mpsndarrayidentity/4438557-reshapewithcommandbuffer?language=objc)
- Add native support for NHWC convolutions (without incurring any extra copy from NCHW -> NHWC -> NCHW).
- Add support for strided output buffers (previously we would create a contiguous buffer

OSes older than macOS 15 will run the old gather/scatter code path to solve strides/storage offsets.

---

Couple performance stats collected from torchbench comparing macOS 15 vs macOS 14:
```
- test_train[functorch_maml_omniglot-mps]: 27% faster
- test_train[timm_vision_transformer-mps]: 12% faster
- test_train[hf_T5-mps]: 9.46% faster
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128393
Approved by: https://github.com/albanD

Co-authored-by: Siddharth Kotapati <[email protected]>
diff --git a/test/test_mps.py b/test/test_mps.py
index 4f0e847..86b181e 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -23,7 +23,7 @@
 from torch.testing._internal import opinfo
 from torch.testing._internal.common_utils import \
     (gradcheck, gradgradcheck, parametrize, run_tests, TestCase, download_file, IS_CI,
-     NoTest, skipIfSlowGradcheckEnv, suppress_warnings, serialTest)
+     NoTest, skipIfSlowGradcheckEnv, suppress_warnings, serialTest, instantiate_parametrized_tests)
 from torch.testing import make_tensor
 from torch.testing._internal.common_dtype import get_all_dtypes, integral_types
 import torch.backends.mps
@@ -728,7 +728,6 @@
         'masked.median': None,
         'matrix_exp': None,
         'mode': None,
-        'nanquantile': None,
         'nanmedian': None,
         'native_dropout_backward': None,
         'normnuc': None,
@@ -758,7 +757,6 @@
         'ormqr': None,
         'pca_lowrank': None,
         'qr': None,
-        'quantile': None,
         'rsub': None,
         'scatter_reduceamax': None,
         'scatter_reduceamin': None,
@@ -922,6 +920,12 @@
             'nn.functional.max_pool2d': [torch.uint8],
         })
 
+    if product_version < 15.0:
+        UNIMPLEMENTED_XFAILLIST.update({
+            'quantile': None,
+            'nanquantile': None,
+        })
+
     UNDEFINED_XFAILLIST = {
         # Top 60 operators
         # topk fails with duplicate indices
@@ -3540,6 +3544,76 @@
         mps_slice4 = mps_x[1, :].to('cpu')
         self.assertEqual(cpu_slice4, mps_slice4)
 
+    @parametrize("torch_type", arg_values=[torch.float16, torch.float32, torch.bfloat16])
+    def test_slice_view_api(self, torch_type: torch.dtype):
+
+        def helper(x_tensor, y_func, z_func, r_func=None):
+            x_mps = x_tensor.detach().clone().to("mps")
+
+            y = y_func(x_tensor)
+            y_mps = y_func(x_mps)
+            self.assertEqual(y, y_mps)
+
+            z = z_func(y)
+            z_mps = z_func(y_mps)
+            self.assertEqual(z, z_mps)
+            self.assertEqual(z.storage_offset(), z_mps.storage_offset())
+
+            if r_func:
+                r = r_func(z)
+                r_mps = r_func(z_mps)
+                self.assertEqual(r, r_mps)
+
+        # Skip bfloat16 before MacOS15
+        if not (product_version < 15.0 and torch_type == torch.bfloat16):
+            # Tests for previously encountered MPS bugs
+            helper(
+                torch.randn(4, 4, dtype=torch_type),
+                lambda x: x[1],
+                lambda y: y.reshape(2, 2),
+                lambda z: z + 1
+            )
+            helper(
+                torch.randn(2, 4, dtype=torch_type),
+                lambda x: x[1],
+                lambda y: y + torch.ones(4, device=y.device)
+            )
+            helper(
+                torch.randn(4, 6, dtype=torch_type),
+                lambda x: x[1],
+                lambda y: y.reshape(3, 2).t(),
+                lambda z: z + 1
+            )
+            helper(
+                torch.arange(4, dtype=torch_type).resize(1, 2, 2),
+                lambda x: x.permute(2, 0, 1),
+                lambda y: y + 1
+            )
+            helper(
+                torch.randn(4, 8, dtype=torch_type),
+                lambda x: x.transpose(0, 1).reshape(-1),
+                lambda y: y[:2],
+                lambda z: z + 1
+            )
+            helper(
+                torch.randn(1, dtype=torch_type),
+                lambda x: x.expand(2, 3),
+                lambda y: y + torch.ones(2, 3, device=y.device)
+            )
+
+    def test_slice_reshape_contiguous(self):
+        x = torch.randn(4, 4)
+        x_mps = x.detach().clone().to("mps")
+
+        y = x[1]
+        y_mps = x_mps[1]
+        self.assertEqual(y, y_mps)
+
+        z = y.reshape(2, 2)
+        z_mps = y_mps.reshape(2, 2)
+        self.assertEqual(z, z_mps)
+        self.assertEqual(z.storage_offset(), z_mps.storage_offset())
+
     def test_scalar_from_slice_unary(self):
         # https://github.com/pytorch/pytorch/issues/82543
         tensor_list = torch.tensor([1.0, 1.2], device="mps")
@@ -3951,6 +4025,11 @@
         y = torch.tensor([0, 1], dtype=torch.bool, device='mps')
         self.assertFalse(torch.equal(x.expand(2, 2), y.expand(2, 2)))
 
+    def test_int_expand(self):
+        x = torch.tensor([[1], [0]], dtype=torch.int8, device='mps')
+        y = torch.tensor([0, 1], dtype=torch.int8, device='mps')
+        self.assertFalse(torch.equal(x.expand(2, 2), y.expand(2, 2)))
+
     # Empty unary op should return tensor of the same size
     def test_empty_neg(self):
         x = torch.tensor([[]], device='mps')
@@ -10343,10 +10422,10 @@
                     res_cpu = conv_cpu(x_cpu)
                     res_mps = conv_mps(x_mps)
                     self.assertEqual(res_cpu, res_mps.cpu(), rtol=1e-03, atol=1e-05)
-
                     res_cpu = res_cpu.sum().backward()
                     res_mps = res_mps.sum().backward()
                     self.assertEqual(res_cpu, res_mps, rtol=2.6e-05, atol=2e-04)
+
                     self.assertEqual(conv_cpu.weight.grad, conv_mps.weight.grad, rtol=2.6e-05, atol=2e-04)
                     self.assertEqual(conv_cpu.bias.grad, conv_mps.bias.grad)
                     self.assertEqual(x_cpu.grad, x_mps.grad)
@@ -10776,12 +10855,12 @@
 
     def test_nonzero_non_diff(self):
         device = "mps"
-        x = torch.randn(10, requires_grad=True)
+        x = torch.randn(10, requires_grad=True, device=device)
         nz = x.nonzero()
         self.assertFalse(nz.requires_grad)
 
     def test_nonzero_multi_threading(self):
-        # Test that MPS does not crash if nonzero called concurrently
+        # Test that MPS doesn't crash if nonzero called concurrently
         # See https://github.com/pytorch/pytorch/issues/100285
         x = torch.rand(3, 3, device="mps")
         t1 = threading.Thread(target=torch.nonzero, args=(x,))
@@ -12160,6 +12239,7 @@
 instantiate_device_type_tests(TestErrorInputs, globals(), allow_mps=True, only_for="mps")
 instantiate_device_type_tests(TestCommon, globals(), allow_mps=True, only_for="mps")
 instantiate_device_type_tests(TestLinalgMPS, globals(), allow_mps=True, only_for="mps")
+instantiate_parametrized_tests(TestMPS)
 
 if __name__ == "__main__":
     run_tests()