[MPS] Revamp copy_to_mps_ implementation (#86956)
Tensor's view in linear storage is represented by the following parameters: `.shape`, `.stride()` and `.storage_offset()`.
Only tensors that are representable as 1d-views can be copied from host to device (and vice versa) using single [`copy(from:sourceOffset:to:destinationOffset:size:)`](https://developer.apple.com/documentation/metal/mtlblitcommandencoder/1400767-copyfrombuffer?language=objc) call.
Modify `copy_to_mps_` function to do the following steps:
- Cast `src` tensor to dst data type if needed
- Expand `src` tensor to `dst` tensor shape
- Clone `src` tensor if it is not stride contiguous (i.e. can not be represented by `src.view(src.numel())`)
- Create an empty tensor if `dst` is not stride-contiguous or if its strides are different then potentially cloned `src` strides
- Do 1d copy for `src` to (potentiall temp) `dst`
- Finally do re-striding/copy on MPS if needed
Add test to cover cases where stide-contiguous permuted tensor is copied to MPS, non-stride-contiguous tensor is copied to MPS and if permuted CPU tensor is copied to differently permuted MPS tensor
Fixes https://github.com/pytorch/pytorch/issues/86954
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86956
Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index 9702239..9e83139 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1596,9 +1596,9 @@
tensor_list.append(t)
for i in range(0, n_tensors - 1):
- t = tensor_list[i].view(1, 784)
+ t = tensor_list[i].view(1, n_tensor_elems)
t_mps = t.to("mps")
- self.assertEqual(t, t_mps.cpu())
+ self.assertEqual(t, t_mps.cpu(), f"i={i}")
# See https://github.com/pytorch/pytorch/issues/82427
# and https://github.com/pytorch/pytorch/issues/83692
@@ -1649,6 +1649,27 @@
t_mps = torch.tensor(a, device="mps")
self.assertEqual(t_cpu, t_mps.to("cpu"))
+ # See https://github.com/pytorch/pytorch/issues/86954
+ def test_copy_non_contiguous(self):
+ x = torch.arange(27).reshape(3, 3, 3).permute(2, 0, 1)
+ self.assertFalse(x.is_contiguous())
+ y = x.to('mps')
+ self.assertFalse(y.is_contiguous())
+ self.assertEqual(x, y.to('cpu'))
+
+ x = torch.arange(4**3).reshape(4, 4, 4).permute((2, 0, 1))[1:, ::2]
+ y = x.to('mps')
+ self.assertEqual(x, y.to('cpu'))
+
+ x = torch.full((4, 4, 4, 4), 13, device="cpu")
+ y = torch.full((4, 4, 4, 4), 13, device="mps")
+ z = torch.arange(4**4).reshape(4, 4, 4, 4).permute(3, 2, 0, 1)[1::, ::2]
+ x.permute(3, 2, 1, 0)[1::, ::2] = z
+ # As y is on MPS and z on CPU, this dispatches to a copy operator
+ y.permute(3, 2, 1, 0)[1::, ::2] = z
+ self.assertEqual(x, y.to('cpu'))
+
+
class TestLogical(TestCase):
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):