[MPS] Fix int rounding div crash on M1 (#85016)
Fixes https://github.com/pytorch/pytorch/issues/84995
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85016
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm
index 810c56f..4a92048 100644
--- a/aten/src/ATen/native/mps/operations/BinaryOps.mm
+++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm
@@ -159,7 +159,11 @@
MPSGraphTensor* divTensor = [mpsGraph divisionWithPrimaryTensor:primaryCastTensor
secondaryTensor:secondaryCastTensor
name:nil];
- if (!rounding_mode.has_value()) {
+ // Rounding is a no-op for integral types, and also a reasonable workaround
+ // For MPSGraph bug on Apple Silicon, that throws `Function floorOp_i64 was not found in the library`
+ // See https://github.com/pytorch/pytorch/issues/84995
+ bool isFloatOutput = ([divTensor dataType] & MPSDataTypeFloatBit) != 0;
+ if (!rounding_mode.has_value() || !isFloatOutput) {
return divTensor;
} else if (*rounding_mode == "trunc") {
return trunc_tensor(mpsGraph, divTensor);
diff --git a/aten/src/ATen/native/mps/operations/UnaryOps.mm b/aten/src/ATen/native/mps/operations/UnaryOps.mm
index 2231a66..97f3d18 100644
--- a/aten/src/ATen/native/mps/operations/UnaryOps.mm
+++ b/aten/src/ATen/native/mps/operations/UnaryOps.mm
@@ -61,6 +61,14 @@
MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor)
{
+ // Rounding is a no-op for integral types, and also a reasonable workaround
+ // For MPSGraph bug on Apple Silicon, that throws `Function floorOp_i64 was not found in the library`
+ // See https://github.com/pytorch/pytorch/issues/84995
+ bool isFloatInput = ([inputTensor dataType] & MPSDataTypeFloatBit) != 0;
+ if (!isFloatInput) {
+ return inputTensor;
+ }
+
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0
dataType:inputTensor.dataType];
MPSGraphTensor* predicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor
diff --git a/test/test_mps.py b/test/test_mps.py
index 395f1ef..e036f69 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -23,7 +23,7 @@
TEST_WITH_UBSAN, dtype_abbrs)
from torch.testing import make_tensor
from torch.testing._comparison import TensorLikePair
-from torch.testing._internal.common_dtype import get_all_dtypes
+from torch.testing._internal.common_dtype import get_all_dtypes, integral_types
import torch.backends.mps
from torch.distributions import Uniform, Exponential
from functools import partial
@@ -1578,6 +1578,13 @@
y_cpu = torch.full((2, 2), 247, device='cpu', dtype=torch.uint8)
self.assertEqual(y_mps, y_cpu)
+ # See https://github.com/pytorch/pytorch/issues/84995
+ def test_div_bugs(self):
+ for (dtype, mode) in itertools.product(integral_types(), ['trunc', 'floor']):
+ x = torch.tensor(list(range(1, 11)), device='mps', dtype=dtype)
+ y = torch.div(x, 101, rounding_mode=mode)
+ self.assertEqual(y.sum(), 0)
+
# See https://github.com/pytorch/pytorch/issues/82663
def test_bool_expand(self):
x = torch.tensor([[1], [0]], dtype=torch.bool, device='mps')