[MPS] Fix `torch.full` for boolean types (#82575)
By creating int8 tensor and casting it to bool later
Workaround for MPSGraph deficiency reported in https://github.com/pytorch/pytorch/issues/82427
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82575
Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index c95165b..e05b055 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1543,6 +1543,12 @@
t_mps = t.to("mps")
self.assertEqual(t, t_mps.cpu())
+ # See https://github.com/pytorch/pytorch/issues/82427
+ # Test should not crash
+ def test_bool_full(self):
+ x = torch.full((3, 3), True, device='mps')
+
+
class TestLogical(TestCase):
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad)