[MPS] Fix int rounding div crash on M1 (#85016)

Fixes  https://github.com/pytorch/pytorch/issues/84995

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85016
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm
index 810c56f..4a92048 100644
--- a/aten/src/ATen/native/mps/operations/BinaryOps.mm
+++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm
@@ -159,7 +159,11 @@
     MPSGraphTensor* divTensor =  [mpsGraph divisionWithPrimaryTensor:primaryCastTensor
                                                      secondaryTensor:secondaryCastTensor
                                                                 name:nil];
-    if (!rounding_mode.has_value()) {
+    // Rounding is a no-op for integral types, and also a reasonable workaround
+    // For MPSGraph bug on Apple Silicon, that throws `Function floorOp_i64 was not found in the library`
+    // See https://github.com/pytorch/pytorch/issues/84995
+    bool isFloatOutput = ([divTensor dataType] & MPSDataTypeFloatBit) != 0;
+    if (!rounding_mode.has_value() || !isFloatOutput) {
       return divTensor;
     } else if (*rounding_mode == "trunc") {
       return trunc_tensor(mpsGraph, divTensor);
diff --git a/aten/src/ATen/native/mps/operations/UnaryOps.mm b/aten/src/ATen/native/mps/operations/UnaryOps.mm
index 2231a66..97f3d18 100644
--- a/aten/src/ATen/native/mps/operations/UnaryOps.mm
+++ b/aten/src/ATen/native/mps/operations/UnaryOps.mm
@@ -61,6 +61,14 @@
 
 MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor)
 {
+  // Rounding is a no-op for integral types, and also a reasonable workaround
+  // For MPSGraph bug on Apple Silicon, that throws `Function floorOp_i64 was not found in the library`
+  // See https://github.com/pytorch/pytorch/issues/84995
+  bool isFloatInput = ([inputTensor dataType] & MPSDataTypeFloatBit) != 0;
+  if (!isFloatInput) {
+    return inputTensor;
+  }
+
   MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0
                                                    dataType:inputTensor.dataType];
   MPSGraphTensor* predicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor
diff --git a/test/test_mps.py b/test/test_mps.py
index 395f1ef..e036f69 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -23,7 +23,7 @@
      TEST_WITH_UBSAN, dtype_abbrs)
 from torch.testing import make_tensor
 from torch.testing._comparison import TensorLikePair
-from torch.testing._internal.common_dtype import get_all_dtypes
+from torch.testing._internal.common_dtype import get_all_dtypes, integral_types
 import torch.backends.mps
 from torch.distributions import Uniform, Exponential
 from functools import partial
@@ -1578,6 +1578,13 @@
         y_cpu = torch.full((2, 2), 247, device='cpu', dtype=torch.uint8)
         self.assertEqual(y_mps, y_cpu)
 
+    # See https://github.com/pytorch/pytorch/issues/84995
+    def test_div_bugs(self):
+        for (dtype, mode) in itertools.product(integral_types(), ['trunc', 'floor']):
+            x = torch.tensor(list(range(1, 11)), device='mps', dtype=dtype)
+            y = torch.div(x, 101, rounding_mode=mode)
+            self.assertEqual(y.sum(), 0)
+
     # See https://github.com/pytorch/pytorch/issues/82663
     def test_bool_expand(self):
         x = torch.tensor([[1], [0]], dtype=torch.bool, device='mps')