MPS: Binary cast fix by proper type promotion and remove spurious copy warning (#79185)
Fixes #78019, #78020
Fixes https://github.com/pytorch/pytorch/pull/79185
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79185
Approved by: https://github.com/albanD, https://github.com/razarmehr
diff --git a/test/test_mps.py b/test/test_mps.py
index ebf56a9..8058f29 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -3988,6 +3988,16 @@
helper(0.1)
helper(0.2)
+ def test_types_binary_op(self):
+ # Float * Bool
+ cpu_x = torch.arange(5, dtype=torch.float32, device="cpu") * torch.tensor([True, False, True, False, True], device="cpu")
+ mps_x = torch.arange(5, dtype=torch.float32, device="mps") * torch.tensor([True, False, True, False, True], device="mps")
+ self.assertEqual(cpu_x, mps_x)
+ # Float * Int64
+ cpu_y = torch.arange(5, dtype=torch.float32, device="cpu") * torch.tensor([1, 0, 1, 0, 1], device="cpu")
+ mps_y = torch.arange(5, dtype=torch.float32, device="mps") * torch.tensor([1, 0, 1, 0, 1], device="mps")
+ self.assertEqual(cpu_y, mps_y)
+
def test_unary_ops(self):
def helper(shape, op):
for dtypef in [torch.float32]: