[MPS] Fix masked_fill_ in non_contiguous cases (#131957)
fixes #131285
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131957
Approved by: https://github.com/DenisVieriu97
diff --git a/aten/src/ATen/native/mps/operations/Indexing.mm b/aten/src/ATen/native/mps/operations/Indexing.mm
index f9c9b5b..ac6a920 100644
--- a/aten/src/ATen/native/mps/operations/Indexing.mm
+++ b/aten/src/ATen/native/mps/operations/Indexing.mm
@@ -643,6 +643,14 @@
c10::MaybeOwned<Tensor> b_mask = expand_inplace(self, mask, "masked_fill_");
+ bool needs_output_copy = false;
+
+ Tensor output;
+ if (needsGather(self)) {
+ output = at::empty(self.sizes(), self.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);
+ needs_output_copy = true;
+ }
+
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
@@ -692,8 +700,11 @@
Placeholder(cachedGraph->inputTensor_, self, /*mpsShape*/ nil, /*gatherTensorData=*/true, inputDataType);
Placeholder maskPlaceholder =
Placeholder(cachedGraph->maskTensor_, *b_mask, /*mpsShape*/ nil, /*gatherTensorData=*/true, maskDataType);
- Placeholder outputPlaceholder =
- Placeholder(cachedGraph->outputTensor_, self, /*mpsShape*/ nil, /*gatherTensorData=*/false, inputDataType);
+ Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_,
+ needs_output_copy ? output : self,
+ /*mpsShape*/ nil,
+ /*gatherTensorData=*/false,
+ inputDataType);
// Create dictionary of inputs and outputs
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
@@ -704,6 +715,11 @@
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
}
+
+ if (needs_output_copy) {
+ self.copy_(output);
+ }
+
namedinference::propagate_names_if_nonempty(self, maybe_outnames);
return self;
}
diff --git a/test/test_mps.py b/test/test_mps.py
index 90fe6e8..33f1701 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2321,43 +2321,36 @@
device = "mps"
dtype = torch.float32
mask_dtype = torch.bool
+ num_dest = 10
- with warnings.catch_warnings(record=True) as w:
- warnings.simplefilter("always")
- num_dest = 10
- dst = torch.zeros(num_dest, dtype=dtype, device=device)
- mask = torch.randint(2, (num_dest,), dtype=mask_dtype, device=device)
- val = random.random()
- dst2 = torch.zeros(num_dest, dtype=dtype)
- mask_cpu = mask.to("cpu")
+ dst = torch.zeros(num_dest, dtype=dtype, device=device)
+ mask = torch.randint(2, (num_dest,), dtype=mask_dtype, device=device)
+ val = random.random()
+ dst2 = torch.zeros(num_dest, dtype=dtype)
+ mask_cpu = mask.to("cpu")
- dst.masked_fill_(mask, val)
- for i in range(num_dest):
- if mask_cpu[i]:
- dst2[i] = val
- self.assertEqual(dst.to("cpu"), dst2, atol=0, rtol=0)
+ dst.masked_fill_(mask, val)
+ for i in range(num_dest):
+ if mask_cpu[i]:
+ dst2[i] = val
+ self.assertEqual(dst.to("cpu"), dst2, atol=0, rtol=0)
- # test non-contiguous case
- dst = ((torch.randn(num_dest, num_dest, num_dest) * 10).to(dtype)).permute((2, 0, 1))
- dst2 = dst.contiguous()
- if dtype.is_complex:
- mask = dst.abs() > 0
- else:
- mask = dst > 0
- self.assertFalse(dst.is_contiguous())
- self.assertTrue(dst2.is_contiguous())
- dst.masked_fill_(mask.to(mask_dtype), val)
- dst2.masked_fill_(mask.to(mask_dtype), val)
- self.assertEqual(dst, dst2, atol=0, rtol=0)
+ def test_masked_fill__non_contiguous(self):
+ shape = (3, 5)
- if mask_dtype == torch.uint8:
- self.assertEqual(len(w), 3)
+ x_mps = torch.randn(shape, device="mps")
+ x_cpu = x_mps.detach().clone().cpu()
+ mask_mps = torch.zeros(shape, device="mps", dtype=torch.bool)
+ mask_cpu = mask_mps.detach().clone().cpu()
- warn = 'masked_fill_ received a mask with dtype torch.uint8,'
- for wi in w:
- self.assertEqual(str(wi.message)[0:52], str(warn))
- else:
- self.assertEqual(len(w), 0)
+ x_mps_strided = x_mps.T
+ x_cpu_strided = x_cpu.T
+
+ x_mps_strided.masked_fill_(mask_mps.T, float("-inf"))
+ x_cpu_strided.masked_fill_(mask_cpu.T, float("-inf"))
+
+ self.assertEqual(x_mps_strided, x_cpu_strided)
+ self.assertFalse((x_mps_strided == float("-inf")).any())
def test_nhwc_operation(self):
def helper(shape, channels_last=False):