[MPS] Handle casting for div operation (#84742) * Handle casting for div operation * Update divmode test to test for rounding mode in div cc. @lhoenig Pull Request resolved: https://github.com/pytorch/pytorch/pull/84742 Approved by: https://github.com/razarmehr
diff --git a/test/test_mps.py b/test/test_mps.py index 448859b..354862f 100644 --- a/test/test_mps.py +++ b/test/test_mps.py
@@ -3208,11 +3208,18 @@ def test_divmode(self): def helper(shape, rounding_mode): - for dtype in [torch.float32]: - cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False) + for dtype in [torch.float32, torch.float16, torch.int32, torch.int64]: + cpu_x = None + cpu_y = None + if(dtype in [torch.float32, torch.float16]): + cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False) + cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False) + else: + cpu_x = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False) + cpu_y = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False) + mps_x = cpu_x.detach().clone().to('mps') # clamp to avoid division by 0 - cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False) mps_y = cpu_y.detach().clone().to('mps') result_div_cpu = torch.div(cpu_x, cpu_y, rounding_mode=rounding_mode)