[MPS] Copy from CPU always add storageOffset (#86958)
Because why wouldn't it?
Fixes https://github.com/pytorch/pytorch/issues/86052
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86958
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm
index 91b4863..2bfee3f 100644
--- a/aten/src/ATen/native/mps/operations/Copy.mm
+++ b/aten/src/ATen/native/mps/operations/Copy.mm
@@ -210,8 +210,7 @@
options:options
deallocator:nil];
sourceOffset = uintptr_t(host_src) - uintptr_t(alignedPtr);
- if (src_.is_view() || !src_.is_contiguous())
- sourceOffset += src_.storage_offset() * src_.itemsize();
+ sourceOffset += src_.storage_offset() * src_.itemsize();
stream->copy_and_sync(sourceBuffer, destBuffer, size_to_copy, sourceOffset, dst_byte_offset, non_blocking);
[sourceBuffer release];
diff --git a/test/test_mps.py b/test/test_mps.py
index a52e6e1..af7460c 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -5697,6 +5697,15 @@
self.assertEqual(expected1, out1)
self.assertEqual(expected2, out2)
+ def test_detached_view_copy(self, device="mps"):
+ # https://github.com/pytorch/pytorch/issues/86052
+ x = torch.arange(2)
+ # .detach() makes y not a view, but contig tensor
+ # with non-zero offset
+ y = x[1].detach()
+ z = y.to(device)
+ self.assertEqual(y, z.cpu())
+
def test_empty_reshape(self, device="mps"):
x = torch.randn(0, 6, device=device)
self.assertEqual((1, 0, 6, 1, 1), x.reshape(1, 0, 6, 1, 1).shape)