[MPS] Add floor_divide() op and its test case (#91126)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91126
Approved by: https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index 200889f..748b7f9 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -3401,13 +3401,19 @@
# clamp to avoid division by 0
mps_y = cpu_y.detach().clone().to('mps')
- result_div_cpu = torch.div(cpu_x, cpu_y, rounding_mode=rounding_mode)
- result_div_mps = torch.div(mps_x, mps_y, rounding_mode=rounding_mode)
- self.assertEqual(result_div_mps, result_div_cpu)
+ if (rounding_mode == "floor_divide"):
+ result_div_cpu = torch.floor_divide(cpu_x, cpu_y)
+ result_div_mps = torch.floor_divide(mps_x, mps_y)
+ self.assertEqual(result_div_mps, result_div_cpu)
+ else:
+ result_div_cpu = torch.div(cpu_x, cpu_y, rounding_mode=rounding_mode)
+ result_div_mps = torch.div(mps_x, mps_y, rounding_mode=rounding_mode)
+ self.assertEqual(result_div_mps, result_div_cpu)
helper((2, 8, 4, 5), None)
helper((2, 8, 4, 5), "floor")
helper((2, 8, 4, 5), "trunc")
+ helper((2, 8, 4, 5), "floor_divide")
def test_rounding(self):
def helper(shape):
@@ -7450,6 +7456,7 @@
'flipud': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'float': ['f32'],
'floor': ['f32', 'f16', 'i16', 'i32', 'i64'],
+ 'floor_divide': ['f32', 'f16'],
'frac': ['f16', 'f32'],
'gradient': ['f16', 'f32', 'i16'],
'half': ['f16'],