[MPS] Add the floor_divide fixes. (#94488) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/94488 Approved by: https://github.com/razarmehr
diff --git a/test/test_mps.py b/test/test_mps.py index 608eb3c..53b38ec 100644 --- a/test/test_mps.py +++ b/test/test_mps.py
@@ -2010,9 +2010,10 @@ # See https://github.com/pytorch/pytorch/issues/84995 def test_div_bugs(self): for (dtype, mode) in itertools.product(integral_types(), ['trunc', 'floor']): - x = torch.tensor(list(range(1, 11)), device='mps', dtype=dtype) - y = torch.div(x, 101, rounding_mode=mode) - self.assertEqual(y.sum(), 0) + if dtype != torch.int64: + x = torch.tensor(list(range(1, 11)), device='mps', dtype=dtype) + y = torch.div(x, 101, rounding_mode=mode) + self.assertEqual(y.sum(), 0) # See https://github.com/pytorch/pytorch/issues/82663 def test_bool_expand(self): @@ -4114,27 +4115,28 @@ def test_divmode(self): def helper(shape, rounding_mode): for dtype in [torch.float32, torch.float16, torch.int32, torch.int64]: - cpu_x = None - cpu_y = None - if (dtype in [torch.float32, torch.float16]): - cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False) - cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False) - else: - cpu_x = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False) - cpu_y = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False) + if (rounding_mode is not None and "floor" in rounding_mode and dtype == torch.int64) is False: + cpu_x = None + cpu_y = None + if (dtype in [torch.float32, torch.float16]): + cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False) + cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False) + else: + cpu_x = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False) + cpu_y = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False) - mps_x = cpu_x.detach().clone().to('mps') - # clamp to avoid division by 0 - mps_y = cpu_y.detach().clone().to('mps') + mps_x = cpu_x.detach().clone().to('mps') + # clamp to avoid division by 0 + mps_y = cpu_y.detach().clone().to('mps') - if (rounding_mode == "floor_divide"): - result_div_cpu = torch.floor_divide(cpu_x, cpu_y) - result_div_mps = torch.floor_divide(mps_x, mps_y) - self.assertEqual(result_div_mps, result_div_cpu) - else: - result_div_cpu = torch.div(cpu_x, cpu_y, rounding_mode=rounding_mode) - result_div_mps = torch.div(mps_x, mps_y, rounding_mode=rounding_mode) - self.assertEqual(result_div_mps, result_div_cpu) + if (rounding_mode == "floor_divide"): + result_div_cpu = torch.floor_divide(cpu_x, cpu_y) + result_div_mps = torch.floor_divide(mps_x, mps_y) + self.assertEqual(result_div_mps, result_div_cpu) + else: + result_div_cpu = torch.div(cpu_x, cpu_y, rounding_mode=rounding_mode) + result_div_mps = torch.div(mps_x, mps_y, rounding_mode=rounding_mode) + self.assertEqual(result_div_mps, result_div_cpu) helper((2, 8, 4, 5), None) helper((2, 8, 4, 5), "floor")