[MPS] Add Conv3D support for MPS (#114183)

Fixes #77818

I saw that PR #99246 was approved, but no one fixed the rebase conflicts, so I am bringing this up again to be merged.
I am leveraging @mattiaspaul work. Quoting the description here:

> * this pull request enables 3D convolutions (forward/backward) for MPS (Apple Silicon) within the same Convolution.mm file as conv2d.
> * does not support channel_last (since pytorch doesn't implement channel_last for 3D tensors)
> * does not support conv3d_transpose and treats depth-separable convolutions not as normal case (there are no MPS kernels available for either of those so far)
> * requires MacOS >=13.2 (Ventura)

Please, let me know if there are any other changes needed and I'll be happy to implement them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114183
Approved by: https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index 6347d2e..078ef90 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2641,8 +2641,38 @@
                 cpu_norm = torch.nn.utils.weight_norm(cpu_conv, dim=dim)
                 norm = torch.nn.utils.weight_norm(conv, dim=dim)
 
-                cpu_out = cpu_conv(cpu_x)
-                out = conv(x)
+                cpu_out = cpu_norm(cpu_x)
+                out = norm(x)
+
+                self.assertEqual(cpu_out, out)
+
+                cpu_grad = torch.randn(cpu_out.shape)
+                grad = cpu_grad.to('mps')
+                cpu_out.backward(gradient=cpu_grad)
+                out.backward(gradient=grad)
+
+                self.assertEqual(cpu_conv.weight_g.grad, conv.weight_g.grad)
+                self.assertEqual(cpu_conv.weight_v.grad, conv.weight_v.grad)
+
+                self.assertEqual(x.grad, cpu_x.grad)
+
+                # conv layer
+            if layer == 'conv3d':
+                cpu_x = torch.randn((3, 5, 5, 4), device='cpu', dtype=dtype, requires_grad=True)
+                x = cpu_x.detach().clone().to('mps').requires_grad_()
+
+                cpu_conv = torch.nn.Conv3d(3, 3, 3, device='cpu')
+                conv = torch.nn.Conv3d(3, 3, 3, device='mps')
+
+                with torch.no_grad():
+                    conv.weight.copy_(cpu_conv.weight)
+                    conv.bias.copy_(cpu_conv.bias)
+
+                cpu_norm = torch.nn.utils.weight_norm(cpu_conv, dim=dim)
+                norm = torch.nn.utils.weight_norm(conv, dim=dim)
+
+                cpu_out = cpu_norm(cpu_x)
+                out = norm(x)
 
                 self.assertEqual(cpu_out, out)
 
@@ -2666,6 +2696,15 @@
         helper(3, layer='conv')
         helper(-1, layer='conv')
 
+        if product_version >= 13.2:
+            # Conv3d is only available from MacOS 13 onwards
+            helper(0, layer='conv3d')
+            helper(1, layer='conv3d')
+            helper(2, layer='conv3d')
+            helper(3, layer='conv3d')
+            helper(4, layer='conv3d')
+            helper(-1, layer='conv3d')
+
     # Test conv2d
     def test_conv2d_unit(self):
         def helper(input_shape, wt_shape,
@@ -8229,6 +8268,17 @@
         # This used to crash with MPSNDArrayConvolutionA14.mm:4352: failed assertion
         y2.sum().backward()
 
+    @unittest.skipIf(product_version < 13.2, "Skipped on macOS 12")
+    def test_conv3d_backward_collision(self):
+        # Conv3D is only available from MacOS 13.2 onwards
+        x = torch.rand(1, 1, 10, 10, 20, device="mps", requires_grad=True)
+        m1 = nn.Conv3d(1, 1, 3, stride=2, padding=1).to("mps")
+        m2 = nn.Conv3d(1, 1, 4, stride=2, padding=1).to("mps")
+        y1, y2 = m1(x), m2(x)
+        self.assertEqual(y1.shape, y2.shape)
+        y1.sum().backward()
+        # This used to crash with MPSNDArrayConvolutionA14.mm:4352: failed assertion
+        y2.sum().backward()
 
     def test_gemm_permute_transpose(self):
         batch_size = 32
@@ -9556,6 +9606,18 @@
             x_gpu = conv_gpu(y_gpu)
             self.assertEqual(x_cpu, x_gpu.cpu(), rtol=1e-03, atol=1e-05)
 
+    @unittest.skipIf(product_version < 13.2, "Skipped on macOS 12")
+    def test_conv3d_single_stride(self):
+        # Conv3d is only available from MacOS 13.2 onwards
+        y_cpu = torch.randn(2, 2, 3, 6)
+        y_gpu = y_cpu.to(device='mps')
+        for stride in range(1, 4):
+            conv_cpu = torch.nn.Conv3d(in_channels=2, out_channels=2, kernel_size=2, stride=stride)
+            conv_gpu = copy.deepcopy(conv_cpu).to(device='mps')
+            x_cpu = conv_cpu(y_cpu)
+            x_gpu = conv_gpu(y_gpu)
+            self.assertEqual(x_cpu, x_gpu.cpu(), rtol=1e-03, atol=1e-05)
+
     def test_grid_sample(self):
         def test(N, C, H, W, mode, padding_mode, align_corners, input_requires_grad):
             def test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners):