[MPS] Fix masked_fill_ in non_contiguous cases (#131957)
fixes #131285
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131957
Approved by: https://github.com/DenisVieriu97
diff --git a/test/test_mps.py b/test/test_mps.py
index 90fe6e8..33f1701 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2321,43 +2321,36 @@
device = "mps"
dtype = torch.float32
mask_dtype = torch.bool
+ num_dest = 10
- with warnings.catch_warnings(record=True) as w:
- warnings.simplefilter("always")
- num_dest = 10
- dst = torch.zeros(num_dest, dtype=dtype, device=device)
- mask = torch.randint(2, (num_dest,), dtype=mask_dtype, device=device)
- val = random.random()
- dst2 = torch.zeros(num_dest, dtype=dtype)
- mask_cpu = mask.to("cpu")
+ dst = torch.zeros(num_dest, dtype=dtype, device=device)
+ mask = torch.randint(2, (num_dest,), dtype=mask_dtype, device=device)
+ val = random.random()
+ dst2 = torch.zeros(num_dest, dtype=dtype)
+ mask_cpu = mask.to("cpu")
- dst.masked_fill_(mask, val)
- for i in range(num_dest):
- if mask_cpu[i]:
- dst2[i] = val
- self.assertEqual(dst.to("cpu"), dst2, atol=0, rtol=0)
+ dst.masked_fill_(mask, val)
+ for i in range(num_dest):
+ if mask_cpu[i]:
+ dst2[i] = val
+ self.assertEqual(dst.to("cpu"), dst2, atol=0, rtol=0)
- # test non-contiguous case
- dst = ((torch.randn(num_dest, num_dest, num_dest) * 10).to(dtype)).permute((2, 0, 1))
- dst2 = dst.contiguous()
- if dtype.is_complex:
- mask = dst.abs() > 0
- else:
- mask = dst > 0
- self.assertFalse(dst.is_contiguous())
- self.assertTrue(dst2.is_contiguous())
- dst.masked_fill_(mask.to(mask_dtype), val)
- dst2.masked_fill_(mask.to(mask_dtype), val)
- self.assertEqual(dst, dst2, atol=0, rtol=0)
+ def test_masked_fill__non_contiguous(self):
+ shape = (3, 5)
- if mask_dtype == torch.uint8:
- self.assertEqual(len(w), 3)
+ x_mps = torch.randn(shape, device="mps")
+ x_cpu = x_mps.detach().clone().cpu()
+ mask_mps = torch.zeros(shape, device="mps", dtype=torch.bool)
+ mask_cpu = mask_mps.detach().clone().cpu()
- warn = 'masked_fill_ received a mask with dtype torch.uint8,'
- for wi in w:
- self.assertEqual(str(wi.message)[0:52], str(warn))
- else:
- self.assertEqual(len(w), 0)
+ x_mps_strided = x_mps.T
+ x_cpu_strided = x_cpu.T
+
+ x_mps_strided.masked_fill_(mask_mps.T, float("-inf"))
+ x_cpu_strided.masked_fill_(mask_cpu.T, float("-inf"))
+
+ self.assertEqual(x_mps_strided, x_cpu_strided)
+ self.assertFalse((x_mps_strided == float("-inf")).any())
def test_nhwc_operation(self):
def helper(shape, channels_last=False):