[MPS] Fix bug when value is of complex (#111937)
When the value of `fill` is of complex, this line `value.toDouble() == 0.0` will error out saying that converting complex to double will cause overflow. So we should firstly handle the complex value and then enter this condition.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111937
Approved by: https://github.com/malfet
ghstack dependencies: #111885
diff --git a/test/test_mps.py b/test/test_mps.py
index 17954c8..b359a4c 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1407,19 +1407,18 @@
def test_fill(self):
- def helper(val, shape):
- tensor = torch.zeros(shape, device='mps')
+ def helper(val, shape, dtype):
+ tensor = torch.zeros(shape, device='mps', dtype=dtype)
tensor_mps = tensor.fill_(val)
- tensor_mps = torch.tanh(tensor_mps)
- tensor_0 = torch.zeros(shape, device='cpu')
+ tensor_0 = torch.zeros(shape, device='cpu', dtype=dtype)
tensor_cpu = tensor_0.fill_(val)
- tensor_cpu = torch.tanh(tensor_cpu)
self.assertEqual(tensor_mps, tensor_cpu)
- helper(0, [1024])
- helper(0.2, [2, 3])
+ helper(0, [1024], torch.float32)
+ helper(0.2, [2, 3], torch.float32)
+ helper(0.2 + 0.5j, [2, 3], torch.complex64)
def test_fill_storage_offset(self):
shape = [2, 10]