Convert MPS Tensor data using MPSGraph API (#78092)
Fixes #78091
If you are already working on this, simply disregard this or take what may be helpful. This is my attempt at MPS-native Tensor datatype conversion. It works for everything tested ~~but is currently only implemented for MPS-to-MPS copy, not MPS-to-X or X-to-MPS, but the same approach could easily be used~~.
Before:
```python
In [5]: pt.full((40,), -10.3, device="mps")
Out[5]:
tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0')
In [6]: pt.full((40,), -10.3, device="mps").int()
Out[6]:
tensor([-1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
-1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
-1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
-1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
-1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
-1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
-1054552883, -1054552883, -1054552883, -1054552883, -1054552883,
-1054552883, -1054552883, -1054552883, -1054552883, -1054552883],
device='mps:0', dtype=torch.int32)
In [7]: pt.full((40,), -10.3, device="mps").int().float()
Out[7]:
tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0')
In [8]: pt.full((40,), -10.3, device="mps").int().float().bool()
Out[8]:
tensor([ True, False, False, True, True, False, False, True, True, False,
False, True, True, False, False, True, True, False, False, True,
True, False, False, True, True, False, False, True, True, False,
False, True, True, False, False, True, True, False, False, True],
device='mps:0')
```
After:
```python
In [3]: pt.full((40,), -10.3, device="mps")
Out[3]:
tensor([-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000, -10.3000,
-10.3000, -10.3000, -10.3000, -10.3000, -10.3000], device='mps:0')
In [4]: pt.full((40,), -10.3, device="mps").int()
Out[4]:
tensor([-10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
-10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10,
-10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10],
device='mps:0', dtype=torch.int32)
In [5]: pt.full((40,), -10.3, device="mps").int().float()
Out[5]:
tensor([-10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
-10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
-10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10., -10.,
-10., -10., -10., -10.], device='mps:0')
In [6]: pt.full((40,), -10.3, device="mps").int().float().bool()
Out[6]:
tensor([True, True, True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True, True, True,
True, True, True, True], device='mps:0')
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78092
Approved by: https://github.com/kulinseth, https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index 8ed2efc..e845550 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1275,6 +1275,31 @@
self.assertEqual(p.grad, torch.zeros_like(p.grad))
self.assertEqual(inp.grad, torch.zeros_like(inp))
+ # Test dtype casting, with and without simultaneous device change
+ def test_to(self):
+ values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
+ cpu_x = torch.tensor(values, device='cpu')
+ mps_x = torch.tensor(values, device='mps')
+
+ self.assertEqual(cpu_x.int(), mps_x.int().cpu())
+ self.assertEqual(cpu_x.bool(), mps_x.bool().cpu())
+ self.assertEqual(cpu_x.float(), mps_x.float().cpu())
+
+ self.assertEqual(torch.tensor(1.3, device='mps').int().cpu(),
+ torch.tensor(1, dtype=torch.int32))
+ self.assertEqual(torch.tensor(0.0, device='mps').bool().cpu(), torch.tensor(False))
+ self.assertEqual(torch.tensor(0.1, device='mps').bool().cpu(), torch.tensor(True))
+ self.assertEqual(torch.tensor(0.1, device='mps').bool().int().cpu(),
+ torch.tensor(1, dtype=torch.int32))
+ self.assertEqual(torch.tensor(0.1, device='mps').bool().int().float().cpu(),
+ torch.tensor(1.0))
+ self.assertEqual(torch.tensor(4.25, device='mps').to('cpu', torch.int),
+ torch.tensor(4, dtype=torch.int32))
+ self.assertEqual(torch.tensor(4.25, device='cpu').to('mps', torch.int).cpu(),
+ torch.tensor(4, dtype=torch.int32))
+ self.assertEqual(torch.tensor(-8.34, device='cpu').to('mps', torch.int),
+ torch.tensor(-8.34, device='cpu').to('mps').to(torch.int))
+
class TestSmoothL1Loss(TestCase):