[MPS] Fix overflow in cumsum when dtype is bool (#125318)
`cumsum` and `cumprod` was (is?) buggy for MPS: https://github.com/pytorch/pytorch/blob/c8d2a55273757c90989fde7c6f05e957aba9a238/aten/src/ATen/native/mps/operations/UnaryOps.mm#L435-L436
A workaround casts the input to int32 prior to performing the op to prevent overflow for certain numeric types.
It turns out this issue also affects boolean types:
```python
import torch
print(torch.ones(128, dtype=torch.bool, device="mps").cumsum(0)[-1])
# tensor(-128, device='mps:0')
```
In this PR I'm adding logic to also cast bool dtypes to int32 prior to `cumsum` and `cumprod`, although output is guaranteed not to overflow for the latter with bools. I'm also adding a test to prevent regressions.
Fixes #96614 #106112 #109166
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125318
Approved by: https://github.com/malfet
diff --git a/aten/src/ATen/native/mps/operations/UnaryOps.mm b/aten/src/ATen/native/mps/operations/UnaryOps.mm
index c76a60e..7709c79 100644
--- a/aten/src/ATen/native/mps/operations/UnaryOps.mm
+++ b/aten/src/ATen/native/mps/operations/UnaryOps.mm
@@ -434,7 +434,7 @@
// issue #103810551: cumsum / cumprod are broken for int8, int16 and as chances for overflow are pretty high, cast to
// int32 fixed in macOS 13.3
- bool castInputData = (isIntegralType(input.scalar_type(), false) && input.scalar_type() != ScalarType::Int &&
+ bool castInputData = (isIntegralType(input.scalar_type(), true) && input.scalar_type() != ScalarType::Int &&
input.scalar_type() != ScalarType::Long);
TORCH_CHECK(macOS13_3_plus || input.scalar_type() != ScalarType::Long,
diff --git a/test/test_mps.py b/test/test_mps.py
index 9582831..38fea5b 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -156,11 +156,6 @@
# On the backward pass for `sort` both are used (values and indices), thus resulting in a issmatch between CPU and MPS.
# Running `msort` with stable `sort` passes.
'msort': [torch.float16],
-
- # See https://github.com/pytorch/pytorch/issues/106112 for more information
- 'cumprod': [torch.float32, torch.float16],
- # See https://github.com/pytorch/pytorch/issues/109166 for more information
- 'masked.cumprod': [torch.float16],
}
SKIPLIST_GRAD = {
@@ -4273,6 +4268,13 @@
self.assertEqual(e_string, "MPS does not support cumsum_out_mps op with int64 input." +
" Support has been added in macOS 13.3")
+ def test_cumsum_bool(self):
+ a = torch.ones(2**16, dtype=torch.bool)
+ t_cpu = a.cumsum(0)
+ t_mps = a.to("mps").cumsum(0)
+
+ self.assertEqual(t_cpu, t_mps)
+
def test_cumsum_minus_one_axis(self):
def helper(dtype):
# Test with axis -1