[MPS] Add the floor_divide fixes. (#94488)
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94488
Approved by: https://github.com/razarmehr
diff --git a/test/test_mps.py b/test/test_mps.py
index 608eb3c..53b38ec 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2010,9 +2010,10 @@
# See https://github.com/pytorch/pytorch/issues/84995
def test_div_bugs(self):
for (dtype, mode) in itertools.product(integral_types(), ['trunc', 'floor']):
- x = torch.tensor(list(range(1, 11)), device='mps', dtype=dtype)
- y = torch.div(x, 101, rounding_mode=mode)
- self.assertEqual(y.sum(), 0)
+ if dtype != torch.int64:
+ x = torch.tensor(list(range(1, 11)), device='mps', dtype=dtype)
+ y = torch.div(x, 101, rounding_mode=mode)
+ self.assertEqual(y.sum(), 0)
# See https://github.com/pytorch/pytorch/issues/82663
def test_bool_expand(self):
@@ -4114,27 +4115,28 @@
def test_divmode(self):
def helper(shape, rounding_mode):
for dtype in [torch.float32, torch.float16, torch.int32, torch.int64]:
- cpu_x = None
- cpu_y = None
- if (dtype in [torch.float32, torch.float16]):
- cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
- cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
- else:
- cpu_x = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False)
- cpu_y = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False)
+ if (rounding_mode is not None and "floor" in rounding_mode and dtype == torch.int64) is False:
+ cpu_x = None
+ cpu_y = None
+ if (dtype in [torch.float32, torch.float16]):
+ cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
+ cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
+ else:
+ cpu_x = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False)
+ cpu_y = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False)
- mps_x = cpu_x.detach().clone().to('mps')
- # clamp to avoid division by 0
- mps_y = cpu_y.detach().clone().to('mps')
+ mps_x = cpu_x.detach().clone().to('mps')
+ # clamp to avoid division by 0
+ mps_y = cpu_y.detach().clone().to('mps')
- if (rounding_mode == "floor_divide"):
- result_div_cpu = torch.floor_divide(cpu_x, cpu_y)
- result_div_mps = torch.floor_divide(mps_x, mps_y)
- self.assertEqual(result_div_mps, result_div_cpu)
- else:
- result_div_cpu = torch.div(cpu_x, cpu_y, rounding_mode=rounding_mode)
- result_div_mps = torch.div(mps_x, mps_y, rounding_mode=rounding_mode)
- self.assertEqual(result_div_mps, result_div_cpu)
+ if (rounding_mode == "floor_divide"):
+ result_div_cpu = torch.floor_divide(cpu_x, cpu_y)
+ result_div_mps = torch.floor_divide(mps_x, mps_y)
+ self.assertEqual(result_div_mps, result_div_cpu)
+ else:
+ result_div_cpu = torch.div(cpu_x, cpu_y, rounding_mode=rounding_mode)
+ result_div_mps = torch.div(mps_x, mps_y, rounding_mode=rounding_mode)
+ self.assertEqual(result_div_mps, result_div_cpu)
helper((2, 8, 4, 5), None)
helper((2, 8, 4, 5), "floor")