[MPS] Fix strided mse_loss (#125696)

Fixes https://github.com/pytorch/pytorch/issues/124621

Summary of changes:
- In case of non-contiguous input, the output would be non-contiguous too. At the moment it's not supported to save the result to a non-contiguous buffer, thus we need two steps, one to allocate a contiguous buffer and the second one to scatter the result back to the original ouput.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125696
Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index 469eaec..de773fe 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -4603,6 +4603,33 @@
         helper([7, 5, 2, 4, 6], 'sum')
         helper([8, 4, 5, 7, 6], 'mean')
 
+    def test_mse_loss_strided_output(self):
+        # https://github.com/pytorch/pytorch/issues/124621
+        lf = nn.MSELoss(reduction='none')
+        model_cpu = nn.Sequential(
+            nn.Conv1d(3, 3, 1),
+        )
+        model_mps = copy.deepcopy(model_cpu).to("mps")
+
+        x = torch.randn(128, 10, 3)
+        x = x.permute(0, 2, 1)
+
+        x_mps = x.detach().clone().to("mps").permute(0, 2, 1)
+        x_mps = x_mps.permute(0, 2, 1)
+
+        y = model_cpu(x)
+        y_mps = model_mps(x_mps)
+
+        y = y.permute(0, 2, 1)[:, :5, :]
+        y_mps = y_mps.permute(0, 2, 1)[:, :5, :]
+
+        y_hat = torch.randn(128, 5, 3)
+        y_hat_mps = y_hat.detach().clone().to("mps")
+
+        loss = lf(y, y_hat)
+        loss_mps = lf(y_mps, y_hat_mps)
+        self.assertEqual(loss, loss_mps)
+
     # Binary Cross Enropy
     def test_bce_loss_simple(self):
         def helper(shape, reduction):