[MPS] Fix masked_fill_ in non_contiguous cases (#131957)

fixes #131285

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131957
Approved by: https://github.com/DenisVieriu97
diff --git a/aten/src/ATen/native/mps/operations/Indexing.mm b/aten/src/ATen/native/mps/operations/Indexing.mm
index f9c9b5b..ac6a920 100644
--- a/aten/src/ATen/native/mps/operations/Indexing.mm
+++ b/aten/src/ATen/native/mps/operations/Indexing.mm
@@ -643,6 +643,14 @@
 
   c10::MaybeOwned<Tensor> b_mask = expand_inplace(self, mask, "masked_fill_");
 
+  bool needs_output_copy = false;
+
+  Tensor output;
+  if (needsGather(self)) {
+    output = at::empty(self.sizes(), self.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);
+    needs_output_copy = true;
+  }
+
   struct CachedGraph : public MPSCachedGraph {
     CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
     MPSGraphTensor* inputTensor_ = nil;
@@ -692,8 +700,11 @@
         Placeholder(cachedGraph->inputTensor_, self, /*mpsShape*/ nil, /*gatherTensorData=*/true, inputDataType);
     Placeholder maskPlaceholder =
         Placeholder(cachedGraph->maskTensor_, *b_mask, /*mpsShape*/ nil, /*gatherTensorData=*/true, maskDataType);
-    Placeholder outputPlaceholder =
-        Placeholder(cachedGraph->outputTensor_, self, /*mpsShape*/ nil, /*gatherTensorData=*/false, inputDataType);
+    Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_,
+                                                needs_output_copy ? output : self,
+                                                /*mpsShape*/ nil,
+                                                /*gatherTensorData=*/false,
+                                                inputDataType);
 
     // Create dictionary of inputs and outputs
     NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
@@ -704,6 +715,11 @@
 
     runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
   }
+
+  if (needs_output_copy) {
+    self.copy_(output);
+  }
+
   namedinference::propagate_names_if_nonempty(self, maybe_outnames);
   return self;
 }
diff --git a/test/test_mps.py b/test/test_mps.py
index 90fe6e8..33f1701 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2321,43 +2321,36 @@
         device = "mps"
         dtype = torch.float32
         mask_dtype = torch.bool
+        num_dest = 10
 
-        with warnings.catch_warnings(record=True) as w:
-            warnings.simplefilter("always")
-            num_dest = 10
-            dst = torch.zeros(num_dest, dtype=dtype, device=device)
-            mask = torch.randint(2, (num_dest,), dtype=mask_dtype, device=device)
-            val = random.random()
-            dst2 = torch.zeros(num_dest, dtype=dtype)
-            mask_cpu = mask.to("cpu")
+        dst = torch.zeros(num_dest, dtype=dtype, device=device)
+        mask = torch.randint(2, (num_dest,), dtype=mask_dtype, device=device)
+        val = random.random()
+        dst2 = torch.zeros(num_dest, dtype=dtype)
+        mask_cpu = mask.to("cpu")
 
-            dst.masked_fill_(mask, val)
-            for i in range(num_dest):
-                if mask_cpu[i]:
-                    dst2[i] = val
-            self.assertEqual(dst.to("cpu"), dst2, atol=0, rtol=0)
+        dst.masked_fill_(mask, val)
+        for i in range(num_dest):
+            if mask_cpu[i]:
+                dst2[i] = val
+        self.assertEqual(dst.to("cpu"), dst2, atol=0, rtol=0)
 
-            # test non-contiguous case
-            dst = ((torch.randn(num_dest, num_dest, num_dest) * 10).to(dtype)).permute((2, 0, 1))
-            dst2 = dst.contiguous()
-            if dtype.is_complex:
-                mask = dst.abs() > 0
-            else:
-                mask = dst > 0
-            self.assertFalse(dst.is_contiguous())
-            self.assertTrue(dst2.is_contiguous())
-            dst.masked_fill_(mask.to(mask_dtype), val)
-            dst2.masked_fill_(mask.to(mask_dtype), val)
-            self.assertEqual(dst, dst2, atol=0, rtol=0)
+    def test_masked_fill__non_contiguous(self):
+        shape = (3, 5)
 
-            if mask_dtype == torch.uint8:
-                self.assertEqual(len(w), 3)
+        x_mps = torch.randn(shape, device="mps")
+        x_cpu = x_mps.detach().clone().cpu()
+        mask_mps = torch.zeros(shape, device="mps", dtype=torch.bool)
+        mask_cpu = mask_mps.detach().clone().cpu()
 
-                warn = 'masked_fill_ received a mask with dtype torch.uint8,'
-                for wi in w:
-                    self.assertEqual(str(wi.message)[0:52], str(warn))
-            else:
-                self.assertEqual(len(w), 0)
+        x_mps_strided = x_mps.T
+        x_cpu_strided = x_cpu.T
+
+        x_mps_strided.masked_fill_(mask_mps.T, float("-inf"))
+        x_cpu_strided.masked_fill_(mask_cpu.T, float("-inf"))
+
+        self.assertEqual(x_mps_strided, x_cpu_strided)
+        self.assertFalse((x_mps_strided == float("-inf")).any())
 
     def test_nhwc_operation(self):
         def helper(shape, channels_last=False):