[MPS] Fix convolution `Source and weight input channels mismatch' crash (#91822)

Fixes crashes in conv input/weight backward passes due to NCHW / NHWC formats.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91822
Approved by: https://github.com/razarmehr
diff --git a/test/test_mps.py b/test/test_mps.py
index 86a6564..a39ffc6 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -7085,17 +7085,33 @@
         self.assertEqual(tcpu, tgpu.cpu(), rtol=2.6e-05, atol=2e-04)
 
     def test_conv_backward_1d_channels_last(self):
-        # https://github.com/pytorch/pytorch/issues/84511
-        conv_cpu = torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3)
-        conv_mps = copy.deepcopy(conv_cpu).to(device='mps')
+        def helper(shape, in_channels=1, out_channels=1, kernel_size=3, groups=1):
+            # https://github.com/pytorch/pytorch/issues/84511
+            conv_cpu = torch.nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups)
+            conv_mps = torch.nn.Conv1d(
+                in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).to("mps")
+            conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_(True)
+            conv_mps.bias.data = conv_cpu.bias.data.detach().clone().to("mps").requires_grad_(True)
 
-        data = torch.rand(1, 176, 1, dtype=torch.float32)
-        x_cpu = data.permute(0, 2, 1).contiguous()
-        x_mps = data.permute(0, 2, 1).contiguous().to("mps")
-        res_cpu = conv_cpu(x_cpu).sum().backward()
-        res_mps = conv_mps(x_mps).sum().backward()
 
-        self.assertEqual(res_cpu, res_mps)
+            data = torch.rand(shape, dtype=torch.float32)
+            x_cpu = data.permute(0, 2, 1).contiguous().requires_grad_(True)
+            x_mps = data.permute(0, 2, 1).detach().clone().to("mps").contiguous().requires_grad_(True)
+            res_cpu = conv_cpu(x_cpu)
+            res_mps = conv_mps(x_mps)
+            self.assertEqual(res_cpu, res_mps)
+            res_cpu = res_cpu.sum().backward()
+            res_mps = res_mps.sum().backward()
+
+            self.assertEqual(conv_cpu.weight.grad, conv_mps.weight.grad, rtol=2.6e-05, atol=2e-04)
+            self.assertEqual(x_cpu.grad, x_mps.grad)
+
+        helper(shape=(1, 176, 1))
+        helper(shape=(2, 12, 1))
+        helper(shape=(3, 176, 1))
+        helper(shape=(4, 376, 1))
+        helper(shape=(1024, 376, 9), in_channels=9, out_channels=1, groups=1)
+        helper(shape=(1024, 376, 9), in_channels=9, out_channels=9, groups=3)
 
     def test_conv1d_contiguous(self):
         model_cpu = torch.nn.Conv1d(1, 128, 3)