[MPS] Fix torch.full for uint8 (#83697)
By creating uint32 tensor and then downcasting it to uint8
Workaround https://github.com/pytorch/pytorch/issues/83692
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83697
Approved by: https://github.com/albanD
diff --git a/test/test_mps.py b/test/test_mps.py
index 32aa2c1..d1403fc 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1545,9 +1545,14 @@
self.assertEqual(t, t_mps.cpu())
# See https://github.com/pytorch/pytorch/issues/82427
- # Test should not crash
- def test_bool_full(self):
+ # and https://github.com/pytorch/pytorch/issues/83692
+ def test_full_bugs(self):
+ # Test should not crash
x = torch.full((3, 3), True, device='mps')
+ # torch.full should work for uint8
+ y_mps = torch.full((2, 2), 247, device='mps', dtype=torch.uint8)
+ y_cpu = torch.full((2, 2), 247, device='cpu', dtype=torch.uint8)
+ self.assertEqual(y_mps, y_cpu)
# See https://github.com/pytorch/pytorch/issues/82663
def test_bool_expand(self):