[MPS] Handle casting for div operation (#84742)
* Handle casting for div operation
* Update divmode test to test for rounding mode in div
cc. @lhoenig
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84742
Approved by: https://github.com/razarmehr
diff --git a/test/test_mps.py b/test/test_mps.py
index 448859b..354862f 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -3208,11 +3208,18 @@
def test_divmode(self):
def helper(shape, rounding_mode):
- for dtype in [torch.float32]:
- cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
+ for dtype in [torch.float32, torch.float16, torch.int32, torch.int64]:
+ cpu_x = None
+ cpu_y = None
+ if(dtype in [torch.float32, torch.float16]):
+ cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
+ cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
+ else:
+ cpu_x = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False)
+ cpu_y = torch.randint(-10, 0, shape, device='cpu', dtype=dtype, requires_grad=False)
+
mps_x = cpu_x.detach().clone().to('mps')
# clamp to avoid division by 0
- cpu_y = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=False)
mps_y = cpu_y.detach().clone().to('mps')
result_div_cpu = torch.div(cpu_x, cpu_y, rounding_mode=rounding_mode)