[MPS] Fix masked_fill_ in non_contiguous cases (#131957)

fixes #131285

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131957
Approved by: https://github.com/DenisVieriu97
diff --git a/test/test_mps.py b/test/test_mps.py
index 90fe6e8..33f1701 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2321,43 +2321,36 @@
         device = "mps"
         dtype = torch.float32
         mask_dtype = torch.bool
+        num_dest = 10
 
-        with warnings.catch_warnings(record=True) as w:
-            warnings.simplefilter("always")
-            num_dest = 10
-            dst = torch.zeros(num_dest, dtype=dtype, device=device)
-            mask = torch.randint(2, (num_dest,), dtype=mask_dtype, device=device)
-            val = random.random()
-            dst2 = torch.zeros(num_dest, dtype=dtype)
-            mask_cpu = mask.to("cpu")
+        dst = torch.zeros(num_dest, dtype=dtype, device=device)
+        mask = torch.randint(2, (num_dest,), dtype=mask_dtype, device=device)
+        val = random.random()
+        dst2 = torch.zeros(num_dest, dtype=dtype)
+        mask_cpu = mask.to("cpu")
 
-            dst.masked_fill_(mask, val)
-            for i in range(num_dest):
-                if mask_cpu[i]:
-                    dst2[i] = val
-            self.assertEqual(dst.to("cpu"), dst2, atol=0, rtol=0)
+        dst.masked_fill_(mask, val)
+        for i in range(num_dest):
+            if mask_cpu[i]:
+                dst2[i] = val
+        self.assertEqual(dst.to("cpu"), dst2, atol=0, rtol=0)
 
-            # test non-contiguous case
-            dst = ((torch.randn(num_dest, num_dest, num_dest) * 10).to(dtype)).permute((2, 0, 1))
-            dst2 = dst.contiguous()
-            if dtype.is_complex:
-                mask = dst.abs() > 0
-            else:
-                mask = dst > 0
-            self.assertFalse(dst.is_contiguous())
-            self.assertTrue(dst2.is_contiguous())
-            dst.masked_fill_(mask.to(mask_dtype), val)
-            dst2.masked_fill_(mask.to(mask_dtype), val)
-            self.assertEqual(dst, dst2, atol=0, rtol=0)
+    def test_masked_fill__non_contiguous(self):
+        shape = (3, 5)
 
-            if mask_dtype == torch.uint8:
-                self.assertEqual(len(w), 3)
+        x_mps = torch.randn(shape, device="mps")
+        x_cpu = x_mps.detach().clone().cpu()
+        mask_mps = torch.zeros(shape, device="mps", dtype=torch.bool)
+        mask_cpu = mask_mps.detach().clone().cpu()
 
-                warn = 'masked_fill_ received a mask with dtype torch.uint8,'
-                for wi in w:
-                    self.assertEqual(str(wi.message)[0:52], str(warn))
-            else:
-                self.assertEqual(len(w), 0)
+        x_mps_strided = x_mps.T
+        x_cpu_strided = x_cpu.T
+
+        x_mps_strided.masked_fill_(mask_mps.T, float("-inf"))
+        x_cpu_strided.masked_fill_(mask_cpu.T, float("-inf"))
+
+        self.assertEqual(x_mps_strided, x_cpu_strided)
+        self.assertFalse((x_mps_strided == float("-inf")).any())
 
     def test_nhwc_operation(self):
         def helper(shape, channels_last=False):