[MPS] Revamp copy_to_mps_ implementation (#86956)
Tensor's view in linear storage is represented by the following parameters: `.shape`, `.stride()` and `.storage_offset()`.
Only tensors that are representable as 1d-views can be copied from host to device (and vice versa) using single [`copy(from:sourceOffset:to:destinationOffset:size:)`](https://developer.apple.com/documentation/metal/mtlblitcommandencoder/1400767-copyfrombuffer?language=objc) call.
Modify `copy_to_mps_` function to do the following steps:
- Cast `src` tensor to dst data type if needed
- Expand `src` tensor to `dst` tensor shape
- Clone `src` tensor if it is not stride contiguous (i.e. can not be represented by `src.view(src.numel())`)
- Create an empty tensor if `dst` is not stride-contiguous or if its strides are different then potentially cloned `src` strides
- Do 1d copy for `src` to (potentiall temp) `dst`
- Finally do re-striding/copy on MPS if needed
Add test to cover cases where stide-contiguous permuted tensor is copied to MPS, non-stride-contiguous tensor is copied to MPS and if permuted CPU tensor is copied to differently permuted MPS tensor
Fixes https://github.com/pytorch/pytorch/issues/86954
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86956
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm
index 2bfee3f..99183d2 100644
--- a/aten/src/ATen/native/mps/operations/Copy.mm
+++ b/aten/src/ATen/native/mps/operations/Copy.mm
@@ -33,6 +33,25 @@
return (void*)alignedAddress;
}
+/**
+ * Computes number of elements one needs to transfer to preserve all the elements
+ */
+size_t compute_strided_size(const at::Tensor& t) {
+ size_t rc = 1;
+ if (t.numel() == 0) {
+ return 0;
+ }
+ for(const auto i: c10::irange(t.dim())) {
+ assert(t.size(i) > 0);
+ rc += (t.size(i) - 1) * t.stride(i);
+ }
+ return rc;
+}
+
+bool is_strided_contiguous(const at::Tensor& t) {
+ return compute_strided_size(t) == t.numel();
+}
+
// Copy sourceBuffer into destBuffer, casting sourceBuffer to src.scalar_type().
// The shapes and dtypes are taken from dst and src, but their storage pointers are not used.
void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,
@@ -168,55 +187,60 @@
return dst_;
}
-static at::Tensor& copy_to_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking)
+// Copies tensor from cpu to mps backed by identical strided-contiguous data
+static void copy_to_mps_stride_contig(at::Tensor& dst, const at::Tensor& src, bool non_blocking)
{
MPSStream* stream = getCurrentMPSStream();
- Tensor src;
-
id<MTLDevice> device = MPSDevice::getInstance()->device();
- auto dst_byte_offset = dst_.storage_offset() * dst_.itemsize();
- id<MTLBuffer> destBuffer = getMTLBufferStorage(dst_);
- uint64_t src_total_size = 0;
-
- // This is weird, but sometimes this function can be called
- // with contiguous destination and non-contiguous source
- if (src_.is_view() || dst_.is_contiguous() != src_.is_contiguous()) {
- src = src_.to(dst_.dtype()).expand_as(dst_).contiguous();
- // Get the actual size of a View (takes into account the storage offset)
- // For View tensors, the storage offset can be bigger than what's being reported by nbytes
- src_total_size = at::detail::computeStorageNbytesContiguous(src.sizes(), src.element_size(), src.storage_offset());
- } else {
- TORCH_INTERNAL_ASSERT(src_.strides() == dst_.strides());
- src = src_;
- if (src.dtype() != dst_.dtype()) {
- // In case of dtype change, perform conversion on source device
- src = src.to(dst_.dtype());
- }
- src_total_size = src.nbytes();
- }
-
+ auto dst_byte_offset = dst.storage_offset() * dst.itemsize();
+ auto src_byte_offset = src.storage_offset() * src.itemsize();
+ id<MTLBuffer> destBuffer = getMTLBufferStorage(dst);
const size_t size_to_copy = src.nbytes();
- const void* host_src = src.storage().data();
- TORCH_INTERNAL_ASSERT(src_total_size >= (src.storage_offset() * src.element_size()));
+ const void* host_src = static_cast<char *>(src.storage().data()) + src_byte_offset;
- NSUInteger sourceOffset = 0;
+ TORCH_INTERNAL_ASSERT(src.dtype() == dst.dtype() && src.strides() == dst.strides() && is_strided_contiguous(src));
+
@autoreleasepool {
MTLResourceOptions options = MTLResourceOptionCPUCacheModeDefault | MTLResourceStorageModeShared;
NSUInteger alignedLength = 0;
+ NSUInteger sourceOffset = 0;
- void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)src_total_size, &alignedLength);
+ void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)size_to_copy, &alignedLength);
id<MTLBuffer> sourceBuffer = [device newBufferWithBytesNoCopy:alignedPtr
length:alignedLength
options:options
deallocator:nil];
sourceOffset = uintptr_t(host_src) - uintptr_t(alignedPtr);
- sourceOffset += src_.storage_offset() * src_.itemsize();
stream->copy_and_sync(sourceBuffer, destBuffer, size_to_copy, sourceOffset, dst_byte_offset, non_blocking);
[sourceBuffer release];
}
+}
- return dst_;
+static at::Tensor& copy_to_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking)
+{
+ // Typecast to dst_ if needed and expand, which is a no-op
+ Tensor src = (src_.dtype() != dst_.dtype() ? src_.to(dst_.dtype()) : src_).expand_as(dst_);
+
+ // If src is not contiguously strided it must be cloned
+ // It does not mean that tensor is contiguous, but rather
+ // that it could be represented as 1d view
+ if (!is_strided_contiguous(src)) {
+ src = src.clone();
+ TORCH_INTERNAL_ASSERT(is_strided_contiguous(src));
+ }
+ Tensor dst = dst_;
+ bool needs_copy = false;
+ // If src and dst_ strides do not match, it means that
+ // either dst_ is not representable as 1d view or its stride order is different
+ // in that case create an empty storage like src, copy it to device and then do
+ // reshaping on the device
+ if (src.strides() != dst_.strides()) {
+ needs_copy = true;
+ dst = at::empty_like(src, at::device(at::kMPS));
+ }
+ copy_to_mps_stride_contig(dst, src, non_blocking && !needs_copy);
+ return needs_copy? dst_.copy_(dst) : dst_;
}
void copy_blit_mps(void* dst, const void* src, size_t size) {
diff --git a/test/test_mps.py b/test/test_mps.py
index 9702239..9e83139 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1596,9 +1596,9 @@
tensor_list.append(t)
for i in range(0, n_tensors - 1):
- t = tensor_list[i].view(1, 784)
+ t = tensor_list[i].view(1, n_tensor_elems)
t_mps = t.to("mps")
- self.assertEqual(t, t_mps.cpu())
+ self.assertEqual(t, t_mps.cpu(), f"i={i}")
# See https://github.com/pytorch/pytorch/issues/82427
# and https://github.com/pytorch/pytorch/issues/83692
@@ -1649,6 +1649,27 @@
t_mps = torch.tensor(a, device="mps")
self.assertEqual(t_cpu, t_mps.to("cpu"))
+ # See https://github.com/pytorch/pytorch/issues/86954
+ def test_copy_non_contiguous(self):
+ x = torch.arange(27).reshape(3, 3, 3).permute(2, 0, 1)
+ self.assertFalse(x.is_contiguous())
+ y = x.to('mps')
+ self.assertFalse(y.is_contiguous())
+ self.assertEqual(x, y.to('cpu'))
+
+ x = torch.arange(4**3).reshape(4, 4, 4).permute((2, 0, 1))[1:, ::2]
+ y = x.to('mps')
+ self.assertEqual(x, y.to('cpu'))
+
+ x = torch.full((4, 4, 4, 4), 13, device="cpu")
+ y = torch.full((4, 4, 4, 4), 13, device="mps")
+ z = torch.arange(4**4).reshape(4, 4, 4, 4).permute(3, 2, 0, 1)[1::, ::2]
+ x.permute(3, 2, 1, 0)[1::, ::2] = z
+ # As y is on MPS and z on CPU, this dispatches to a copy operator
+ y.permute(3, 2, 1, 0)[1::, ::2] = z
+ self.assertEqual(x, y.to('cpu'))
+
+
class TestLogical(TestCase):
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):