[MPS] Fix convolution `Source and weight input channels mismatch' crash (#91822)
Fixes crashes in conv input/weight backward passes due to NCHW / NHWC formats.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91822
Approved by: https://github.com/razarmehr
diff --git a/test/test_mps.py b/test/test_mps.py
index 86a6564..a39ffc6 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -7085,17 +7085,33 @@
self.assertEqual(tcpu, tgpu.cpu(), rtol=2.6e-05, atol=2e-04)
def test_conv_backward_1d_channels_last(self):
- # https://github.com/pytorch/pytorch/issues/84511
- conv_cpu = torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3)
- conv_mps = copy.deepcopy(conv_cpu).to(device='mps')
+ def helper(shape, in_channels=1, out_channels=1, kernel_size=3, groups=1):
+ # https://github.com/pytorch/pytorch/issues/84511
+ conv_cpu = torch.nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups)
+ conv_mps = torch.nn.Conv1d(
+ in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups).to("mps")
+ conv_mps.weight.data = conv_cpu.weight.data.detach().clone().to("mps").requires_grad_(True)
+ conv_mps.bias.data = conv_cpu.bias.data.detach().clone().to("mps").requires_grad_(True)
- data = torch.rand(1, 176, 1, dtype=torch.float32)
- x_cpu = data.permute(0, 2, 1).contiguous()
- x_mps = data.permute(0, 2, 1).contiguous().to("mps")
- res_cpu = conv_cpu(x_cpu).sum().backward()
- res_mps = conv_mps(x_mps).sum().backward()
- self.assertEqual(res_cpu, res_mps)
+ data = torch.rand(shape, dtype=torch.float32)
+ x_cpu = data.permute(0, 2, 1).contiguous().requires_grad_(True)
+ x_mps = data.permute(0, 2, 1).detach().clone().to("mps").contiguous().requires_grad_(True)
+ res_cpu = conv_cpu(x_cpu)
+ res_mps = conv_mps(x_mps)
+ self.assertEqual(res_cpu, res_mps)
+ res_cpu = res_cpu.sum().backward()
+ res_mps = res_mps.sum().backward()
+
+ self.assertEqual(conv_cpu.weight.grad, conv_mps.weight.grad, rtol=2.6e-05, atol=2e-04)
+ self.assertEqual(x_cpu.grad, x_mps.grad)
+
+ helper(shape=(1, 176, 1))
+ helper(shape=(2, 12, 1))
+ helper(shape=(3, 176, 1))
+ helper(shape=(4, 376, 1))
+ helper(shape=(1024, 376, 9), in_channels=9, out_channels=1, groups=1)
+ helper(shape=(1024, 376, 9), in_channels=9, out_channels=9, groups=3)
def test_conv1d_contiguous(self):
model_cpu = torch.nn.Conv1d(1, 128, 3)