[MPS] Fix int rounding div crash on M1 (#85016)
Fixes https://github.com/pytorch/pytorch/issues/84995
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85016
Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index 395f1ef..e036f69 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -23,7 +23,7 @@
TEST_WITH_UBSAN, dtype_abbrs)
from torch.testing import make_tensor
from torch.testing._comparison import TensorLikePair
-from torch.testing._internal.common_dtype import get_all_dtypes
+from torch.testing._internal.common_dtype import get_all_dtypes, integral_types
import torch.backends.mps
from torch.distributions import Uniform, Exponential
from functools import partial
@@ -1578,6 +1578,13 @@
y_cpu = torch.full((2, 2), 247, device='cpu', dtype=torch.uint8)
self.assertEqual(y_mps, y_cpu)
+ # See https://github.com/pytorch/pytorch/issues/84995
+ def test_div_bugs(self):
+ for (dtype, mode) in itertools.product(integral_types(), ['trunc', 'floor']):
+ x = torch.tensor(list(range(1, 11)), device='mps', dtype=dtype)
+ y = torch.div(x, 101, rounding_mode=mode)
+ self.assertEqual(y.sum(), 0)
+
# See https://github.com/pytorch/pytorch/issues/82663
def test_bool_expand(self):
x = torch.tensor([[1], [0]], dtype=torch.bool, device='mps')