[MPS] Fix conv1d backwards crash for channels last case (#85283)
Fixes pytorch#84511
Use the same logic as in the forward pass for the backward pass (in case of channels last, manually set the shape to NHWC)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85283
Approved by: https://github.com/malfet, https://github.com/razarmehr
diff --git a/test/test_mps.py b/test/test_mps.py
index e036f69..ccb3a29 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -6005,6 +6005,19 @@
self.assertEqual(tcpu, tgpu.cpu(), rtol=2.6e-05, atol=2e-04)
+ def test_conv_backward_1d_channels_last(self):
+ # https://github.com/pytorch/pytorch/issues/84511
+ conv_cpu = torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3)
+ conv_mps = copy.deepcopy(conv_cpu).to(device='mps')
+
+ data = torch.rand(1, 176, 1, dtype=torch.float32)
+ x_cpu = data.permute(0, 2, 1).contiguous()
+ x_mps = data.permute(0, 2, 1).contiguous().to("mps")
+ res_cpu = conv_cpu(x_cpu).sum().backward()
+ res_mps = conv_mps(x_mps).sum().backward()
+
+ self.assertEqual(res_cpu, res_mps)
+
def test_conv1d_contiguous(self):
model_cpu = torch.nn.Conv1d(1, 128, 3)
a_cpu = torch.ones(128, 1, 176)