[MPS] Revamp copy_to_mps_ implementation (#86956)

Tensor's view in linear storage is represented by the following parameters: `.shape`, `.stride()` and `.storage_offset()`.

Only tensors that are representable as 1d-views can be copied from host to device (and vice versa) using single  [`copy(from:sourceOffset:to:destinationOffset:size:)`](https://developer.apple.com/documentation/metal/mtlblitcommandencoder/1400767-copyfrombuffer?language=objc) call.

Modify `copy_to_mps_` function to do the following steps:
- Cast `src` tensor to dst data type if needed
- Expand `src` tensor to `dst` tensor shape
- Clone `src` tensor if it is not stride contiguous (i.e. can not be represented by `src.view(src.numel())`)
- Create an empty tensor if `dst` is not stride-contiguous or if its strides are different then potentially cloned `src` strides
- Do 1d copy for `src` to (potentiall temp) `dst`
- Finally do re-striding/copy on MPS if needed

Add test to cover cases where stide-contiguous permuted tensor is copied to MPS, non-stride-contiguous tensor is copied to MPS and if permuted CPU tensor is copied to differently permuted MPS tensor

Fixes https://github.com/pytorch/pytorch/issues/86954

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86956
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm
index 2bfee3f..99183d2 100644
--- a/aten/src/ATen/native/mps/operations/Copy.mm
+++ b/aten/src/ATen/native/mps/operations/Copy.mm
@@ -33,6 +33,25 @@
   return (void*)alignedAddress;
 }
 
+/**
+ * Computes number of elements one needs to transfer to preserve all the elements
+ */
+size_t compute_strided_size(const at::Tensor& t) {
+   size_t rc = 1;
+   if (t.numel() == 0) {
+       return 0;
+   }
+   for(const auto i: c10::irange(t.dim())) {
+     assert(t.size(i) > 0);
+     rc += (t.size(i) - 1) * t.stride(i);
+   }
+   return rc;
+}
+
+bool is_strided_contiguous(const at::Tensor& t) {
+  return compute_strided_size(t) == t.numel();
+}
+
 // Copy sourceBuffer into destBuffer, casting sourceBuffer to src.scalar_type().
 // The shapes and dtypes are taken from dst and src, but their storage pointers are not used.
 void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,
@@ -168,55 +187,60 @@
   return dst_;
 }
 
-static at::Tensor& copy_to_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking)
+// Copies tensor from cpu to mps backed by identical strided-contiguous data
+static void copy_to_mps_stride_contig(at::Tensor& dst, const at::Tensor& src, bool non_blocking)
 {
   MPSStream* stream = getCurrentMPSStream();
-  Tensor src;
-
   id<MTLDevice> device = MPSDevice::getInstance()->device();
-  auto dst_byte_offset = dst_.storage_offset() * dst_.itemsize();
-  id<MTLBuffer> destBuffer = getMTLBufferStorage(dst_);
-  uint64_t src_total_size = 0;
-
-  // This is weird, but sometimes this function can be called
-  //  with contiguous destination and non-contiguous source
-  if (src_.is_view() || dst_.is_contiguous() != src_.is_contiguous()) {
-    src = src_.to(dst_.dtype()).expand_as(dst_).contiguous();
-    // Get the actual size of a View (takes into account the storage offset)
-    // For View tensors, the storage offset can be bigger than what's being reported by nbytes
-    src_total_size = at::detail::computeStorageNbytesContiguous(src.sizes(), src.element_size(), src.storage_offset());
-  } else {
-    TORCH_INTERNAL_ASSERT(src_.strides() == dst_.strides());
-    src = src_;
-    if (src.dtype() != dst_.dtype()) {
-      // In case of dtype change, perform conversion on source device
-      src = src.to(dst_.dtype());
-    }
-    src_total_size = src.nbytes();
-  }
-
+  auto dst_byte_offset = dst.storage_offset() * dst.itemsize();
+  auto src_byte_offset = src.storage_offset() * src.itemsize();
+  id<MTLBuffer> destBuffer = getMTLBufferStorage(dst);
   const size_t size_to_copy = src.nbytes();
-  const void* host_src = src.storage().data();
-  TORCH_INTERNAL_ASSERT(src_total_size >= (src.storage_offset() * src.element_size()));
+  const void* host_src = static_cast<char *>(src.storage().data()) + src_byte_offset;
 
-  NSUInteger sourceOffset = 0;
+  TORCH_INTERNAL_ASSERT(src.dtype() == dst.dtype() && src.strides() == dst.strides() && is_strided_contiguous(src));
+
   @autoreleasepool {
     MTLResourceOptions options = MTLResourceOptionCPUCacheModeDefault | MTLResourceStorageModeShared;
     NSUInteger alignedLength = 0;
+    NSUInteger sourceOffset = 0;
 
-    void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)src_total_size, &alignedLength);
+    void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)size_to_copy, &alignedLength);
     id<MTLBuffer> sourceBuffer = [device newBufferWithBytesNoCopy:alignedPtr
                                           length:alignedLength
                                          options:options
                                      deallocator:nil];
     sourceOffset = uintptr_t(host_src) - uintptr_t(alignedPtr);
-    sourceOffset += src_.storage_offset() * src_.itemsize();
 
     stream->copy_and_sync(sourceBuffer, destBuffer, size_to_copy, sourceOffset, dst_byte_offset, non_blocking);
     [sourceBuffer release];
   }
+}
 
-  return dst_;
+static at::Tensor& copy_to_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking)
+{
+  // Typecast to dst_ if needed and expand, which is a no-op
+  Tensor src = (src_.dtype() != dst_.dtype() ? src_.to(dst_.dtype()) : src_).expand_as(dst_);
+
+  // If src is not contiguously strided it must be cloned
+  // It does not mean that tensor is contiguous, but rather
+  // that it could be represented as 1d view
+  if (!is_strided_contiguous(src)) {
+    src = src.clone();
+    TORCH_INTERNAL_ASSERT(is_strided_contiguous(src));
+  }
+  Tensor dst = dst_;
+  bool needs_copy = false;
+  // If src and dst_ strides do not match, it means that
+  // either dst_ is not representable as 1d view or its stride order is different
+  // in that case create an empty storage like src, copy it to device and then do
+  // reshaping on the device
+  if (src.strides() != dst_.strides()) {
+    needs_copy = true;
+    dst = at::empty_like(src, at::device(at::kMPS));
+  }
+  copy_to_mps_stride_contig(dst, src, non_blocking && !needs_copy);
+  return needs_copy? dst_.copy_(dst) : dst_;
 }
 
 void copy_blit_mps(void* dst, const void* src, size_t size) {
diff --git a/test/test_mps.py b/test/test_mps.py
index 9702239..9e83139 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1596,9 +1596,9 @@
             tensor_list.append(t)
 
         for i in range(0, n_tensors - 1):
-            t = tensor_list[i].view(1, 784)
+            t = tensor_list[i].view(1, n_tensor_elems)
             t_mps = t.to("mps")
-            self.assertEqual(t, t_mps.cpu())
+            self.assertEqual(t, t_mps.cpu(), f"i={i}")
 
     # See https://github.com/pytorch/pytorch/issues/82427
     # and https://github.com/pytorch/pytorch/issues/83692
@@ -1649,6 +1649,27 @@
         t_mps = torch.tensor(a, device="mps")
         self.assertEqual(t_cpu, t_mps.to("cpu"))
 
+    # See https://github.com/pytorch/pytorch/issues/86954
+    def test_copy_non_contiguous(self):
+        x = torch.arange(27).reshape(3, 3, 3).permute(2, 0, 1)
+        self.assertFalse(x.is_contiguous())
+        y = x.to('mps')
+        self.assertFalse(y.is_contiguous())
+        self.assertEqual(x, y.to('cpu'))
+
+        x = torch.arange(4**3).reshape(4, 4, 4).permute((2, 0, 1))[1:, ::2]
+        y = x.to('mps')
+        self.assertEqual(x, y.to('cpu'))
+
+        x = torch.full((4, 4, 4, 4), 13, device="cpu")
+        y = torch.full((4, 4, 4, 4), 13, device="mps")
+        z = torch.arange(4**4).reshape(4, 4, 4, 4).permute(3, 2, 0, 1)[1::, ::2]
+        x.permute(3, 2, 1, 0)[1::, ::2] = z
+        # As y is on MPS and z on CPU, this dispatches to a copy operator
+        y.permute(3, 2, 1, 0)[1::, ::2] = z
+        self.assertEqual(x, y.to('cpu'))
+
+
 
 class TestLogical(TestCase):
     def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):