[MPS] Get the correct size of the view tensor when copying from cpu to mps  (#81730)

Fixes: https://github.com/pytorch/pytorch/issues/81567, https://github.com/pytorch/pytorch/issues/80844
* Get the correct size of the view tensor when copying from cpu to mps

* Use 'computeStorageNbytesContiguous' to get the size just when src is a view

* Add asserts and tests to check for storage_offset

* Add testcase for https://github.com/pytorch/pytorch/issues/80844

* Replace assert_allclose with assertEqual

* Replace TORCH_CHECK with TORCH_INTERNAL_ASSERT
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81730
Approved by: https://github.com/razarmehr, https://github.com/albanD
diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm
index 4b57f62..dd899b1 100644
--- a/aten/src/ATen/native/mps/OperationUtils.mm
+++ b/aten/src/ATen/native/mps/OperationUtils.mm
@@ -192,7 +192,7 @@
   {
     NSInteger sz_i = (i < sz) ? t.size(i) : 1;
 
-    NSNumber* number = [NSNumber numberWithInt:sz_i];
+    NSNumber* number = [NSNumber numberWithInteger:sz_i];
     numbers[i] = number;
   }
   return [NSArray arrayWithObjects:numbers count:sz_];
@@ -213,7 +213,7 @@
   {
     NSInteger sz_i = (i < sz) ? sizes[i] : 1;
 
-    NSNumber* number = [NSNumber numberWithInt:sz_i];
+    NSNumber* number = [NSNumber numberWithInteger:sz_i];
     numbers[i] = number;
   }
   return [NSArray arrayWithObjects:numbers count:sz_];
diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm
index 69983ed..3c2ab0d 100644
--- a/aten/src/ATen/native/mps/operations/Copy.mm
+++ b/aten/src/ATen/native/mps/operations/Copy.mm
@@ -113,30 +113,37 @@
     src = src_;
   }
   id<MTLBuffer> sourceBuffer = getMTLBufferStorage(src);
-  const size_t src_size = src.nbytes();
-  // if there's anything wrong with source, we shouldn't return dst_ silently and must error out.
-  TORCH_CHECK(sourceBuffer && src_size > 0);
+  size_t src_total_size = src_.is_view() ? at::detail::computeStorageNbytesContiguous(src.sizes(), src.element_size(), src.storage_offset()) :
+                                           src.nbytes();
+  size_t size_to_copy = src.nbytes();
 
   // In case of dtype change, first convert src inplace
   if (src_.dtype() != dst_.dtype()) {
     copy_cast_mps(dst, src, sourceBuffer, sourceBuffer);
+    // Use the element size of dst to calculate the total size after casting
+    size_to_copy = (size_to_copy / src.element_size()) * dst.element_size();
   }
 
+  // If there's anything wrong with source, we shouldn't return dst_ silently and must error out.
+  TORCH_INTERNAL_ASSERT(sourceBuffer && size_to_copy > 0);
+  TORCH_INTERNAL_ASSERT(src_total_size >= storage_byte_offset);
+  TORCH_INTERNAL_ASSERT(dst.nbytes() >= (dst.storage_offset() * dst.element_size()));
+
   @autoreleasepool {
     MTLResourceOptions options = MTLResourceOptionCPUCacheModeDefault | MTLResourceStorageModeShared;
     NSUInteger alignedLength = 0;
 
     void* host_dst = dst.storage().data();
-    void* alignedPtr = pageAlignedBlockPtr(host_dst, (NSUInteger)src_size, &alignedLength);
+    void* alignedPtr = pageAlignedBlockPtr(host_dst, (NSUInteger)src_total_size, &alignedLength);
     id<MTLBuffer> destBuffer = [device newBufferWithBytesNoCopy:alignedPtr
                                                          length:alignedLength
                                                         options:options
                                                     deallocator:nil];
      NSUInteger destOffset = uintptr_t(host_dst) - uintptr_t(alignedPtr);
     // 4 bytes alignment required on macos for blits.
-    TORCH_CHECK(destOffset % 4 == 0, "Unaligned blit request");
+    TORCH_INTERNAL_ASSERT(destOffset % 4 == 0, "Unaligned blit request");
 
-    stream->copy_and_sync(sourceBuffer, destBuffer, src_size, storage_byte_offset, destOffset, non_blocking);
+    stream->copy_and_sync(sourceBuffer, destBuffer, size_to_copy, storage_byte_offset, destOffset, non_blocking);
     [destBuffer release];
   }
   if (!dst.is_same(dst_)) {
@@ -155,26 +162,33 @@
   id<MTLDevice> device = MPSDevice::getInstance()->device();
   auto dst_byte_offset = dst_.storage_offset() * dst_.itemsize();
   id<MTLBuffer> destBuffer = getMTLBufferStorage(dst_);
+  uint64_t src_total_size = 0;
 
   if (src_.is_view()) {
     src = src_.to(dst_.dtype()).expand_as(dst_).contiguous();
+    // Get the actual size of a View (takes into account the storage offset)
+    // For View tensors, the storage offset can be bigger than what's being reported by nbytes
+    src_total_size = at::detail::computeStorageNbytesContiguous(src.sizes(), src.element_size(), src.storage_offset());
   } else {
     src = src_;
     if (src.dtype() != dst_.dtype()) {
       // In case of dtype change, perform conversion on source device
       src = src.to(dst_.dtype());
     }
+    src_total_size = src.nbytes();
   }
 
+  const size_t size_to_copy = src.nbytes();
   const void* host_src = src.storage().data();
-  uint64_t size = src.nbytes();
+  TORCH_INTERNAL_ASSERT(src_total_size >= (src.storage_offset() * src.element_size()));
+  TORCH_INTERNAL_ASSERT(dst_.nbytes() >= dst_byte_offset);
 
   NSUInteger sourceOffset = 0;
   @autoreleasepool {
     MTLResourceOptions options = MTLResourceOptionCPUCacheModeDefault | MTLResourceStorageModeShared;
     NSUInteger alignedLength = 0;
 
-    void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)size, &alignedLength);
+    void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)src_total_size, &alignedLength);
     id<MTLBuffer> sourceBuffer = [device newBufferWithBytesNoCopy:alignedPtr
                                           length:alignedLength
                                          options:options
@@ -183,7 +197,7 @@
     if (src_.is_view() || !src_.is_contiguous())
       sourceOffset += src_.storage_offset() * src_.itemsize();
 
-    stream->copy_and_sync(sourceBuffer, destBuffer, size, sourceOffset, dst_byte_offset, non_blocking);
+    stream->copy_and_sync(sourceBuffer, destBuffer, size_to_copy, sourceOffset, dst_byte_offset, non_blocking);
     [sourceBuffer release];
   }
 
diff --git a/test/test_mps.py b/test/test_mps.py
index c6e1063..12b5643 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1490,6 +1490,58 @@
         z = y.as_strided(size=(32, 3), stride=(1, 0)).to("cpu")
         self.assertEqual(x.to("cpu").as_strided(size=(32, 3), stride=(1, 0)), z)
 
+    def test_type_casting(self):
+        # https://github.com/pytorch/pytorch/issues/81567
+        def helper(data, to_dtype):
+            a_cpu = torch.tensor(data)
+            a_mps = a_cpu.to(torch.device('mps'))
+
+            res_cpu = a_cpu.type(to_dtype)
+            res_mps = a_mps.type(to_dtype)
+            self.assertEqual(res_cpu, res_mps)
+
+        helper([9.0, 3.0, 5.0, 4.0], torch.LongTensor)
+        helper([9.0, 3.0, 5.0, 4.0], torch.FloatTensor)
+        helper([9.0, 3.0, 5.0, 4.0], torch.IntTensor)
+        helper([9.0, 3.0, 5.0, 4.0], torch.ShortTensor)
+        helper([9.0, 3.0, 5.0, 4.0], torch.HalfTensor)
+        helper([9.0, 3.0, 5.0, 4.0], torch.CharTensor)
+        helper([9.0, 3.0, 5.0, 4.0], torch.ByteTensor)
+
+    def test_to_casting(self):
+        # https://github.com/pytorch/pytorch/issues/81567
+        def helper(data, to_dtype):
+            a_cpu = torch.tensor(data)
+            a_mps = a_cpu.to(torch.device('mps'))
+
+            res_cpu = a_cpu.to(to_dtype)
+            res_mps = a_mps.to(to_dtype)
+            self.assertEqual(res_cpu, res_mps)
+
+        helper([9.0, 3.0, 5.0, 4.0], torch.int64)
+        helper([9.0, 3.0, 5.0, 4.0], torch.float)
+        helper([9.0, 3.0, 5.0, 4.0], torch.int32)
+        helper([9.0, 3.0, 5.0, 4.0], torch.short)
+        helper([9.0, 3.0, 5.0, 4.0], torch.half)
+        helper([9.0, 3.0, 5.0, 4.0], torch.int8)
+        helper([9.0, 3.0, 5.0, 4.0], torch.uint8)
+
+    def test_storage_offset_greater_than_src_nbytes(self):
+        # https://github.com/pytorch/pytorch/issues/80844
+        n_tensors = 100
+        n_tensor_elems = 784
+        elems = torch.arange(n_tensors * n_tensor_elems, dtype=torch.float32)
+
+        tensor_list = []
+        for i in range(0, n_tensors - 1):
+            # create a list of contiguous view tensors (view tensor created by the slice op)
+            t = elems[n_tensor_elems * i : n_tensor_elems * (i + 1)]
+            tensor_list.append(t)
+
+        for i in range(0, n_tensors - 1):
+            t = tensor_list[i].view(1, 784)
+            t_mps = t.to("mps")
+            self.assertEqual(t, t_mps.cpu())
 
 class TestLogical(TestCase):
     def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):