[MPS] Fix strided mse_loss (#125696) Fixes https://github.com/pytorch/pytorch/issues/124621 Summary of changes: - In case of non-contiguous input, the output would be non-contiguous too. At the moment it's not supported to save the result to a non-contiguous buffer, thus we need two steps, one to allocate a contiguous buffer and the second one to scatter the result back to the original ouput. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125696 Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py index 469eaec..de773fe 100644 --- a/test/test_mps.py +++ b/test/test_mps.py
@@ -4603,6 +4603,33 @@ helper([7, 5, 2, 4, 6], 'sum') helper([8, 4, 5, 7, 6], 'mean') + def test_mse_loss_strided_output(self): + # https://github.com/pytorch/pytorch/issues/124621 + lf = nn.MSELoss(reduction='none') + model_cpu = nn.Sequential( + nn.Conv1d(3, 3, 1), + ) + model_mps = copy.deepcopy(model_cpu).to("mps") + + x = torch.randn(128, 10, 3) + x = x.permute(0, 2, 1) + + x_mps = x.detach().clone().to("mps").permute(0, 2, 1) + x_mps = x_mps.permute(0, 2, 1) + + y = model_cpu(x) + y_mps = model_mps(x_mps) + + y = y.permute(0, 2, 1)[:, :5, :] + y_mps = y_mps.permute(0, 2, 1)[:, :5, :] + + y_hat = torch.randn(128, 5, 3) + y_hat_mps = y_hat.detach().clone().to("mps") + + loss = lf(y, y_hat) + loss_mps = lf(y_mps, y_hat_mps) + self.assertEqual(loss, loss_mps) + # Binary Cross Enropy def test_bce_loss_simple(self): def helper(shape, reduction):