[MPS] Get the correct size of the view tensor when copying from cpu to mps (#81730)
Fixes: https://github.com/pytorch/pytorch/issues/81567, https://github.com/pytorch/pytorch/issues/80844
* Get the correct size of the view tensor when copying from cpu to mps
* Use 'computeStorageNbytesContiguous' to get the size just when src is a view
* Add asserts and tests to check for storage_offset
* Add testcase for https://github.com/pytorch/pytorch/issues/80844
* Replace assert_allclose with assertEqual
* Replace TORCH_CHECK with TORCH_INTERNAL_ASSERT
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81730
Approved by: https://github.com/razarmehr, https://github.com/albanD
diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm
index 4b57f62..dd899b1 100644
--- a/aten/src/ATen/native/mps/OperationUtils.mm
+++ b/aten/src/ATen/native/mps/OperationUtils.mm
@@ -192,7 +192,7 @@
{
NSInteger sz_i = (i < sz) ? t.size(i) : 1;
- NSNumber* number = [NSNumber numberWithInt:sz_i];
+ NSNumber* number = [NSNumber numberWithInteger:sz_i];
numbers[i] = number;
}
return [NSArray arrayWithObjects:numbers count:sz_];
@@ -213,7 +213,7 @@
{
NSInteger sz_i = (i < sz) ? sizes[i] : 1;
- NSNumber* number = [NSNumber numberWithInt:sz_i];
+ NSNumber* number = [NSNumber numberWithInteger:sz_i];
numbers[i] = number;
}
return [NSArray arrayWithObjects:numbers count:sz_];
diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm
index 69983ed..3c2ab0d 100644
--- a/aten/src/ATen/native/mps/operations/Copy.mm
+++ b/aten/src/ATen/native/mps/operations/Copy.mm
@@ -113,30 +113,37 @@
src = src_;
}
id<MTLBuffer> sourceBuffer = getMTLBufferStorage(src);
- const size_t src_size = src.nbytes();
- // if there's anything wrong with source, we shouldn't return dst_ silently and must error out.
- TORCH_CHECK(sourceBuffer && src_size > 0);
+ size_t src_total_size = src_.is_view() ? at::detail::computeStorageNbytesContiguous(src.sizes(), src.element_size(), src.storage_offset()) :
+ src.nbytes();
+ size_t size_to_copy = src.nbytes();
// In case of dtype change, first convert src inplace
if (src_.dtype() != dst_.dtype()) {
copy_cast_mps(dst, src, sourceBuffer, sourceBuffer);
+ // Use the element size of dst to calculate the total size after casting
+ size_to_copy = (size_to_copy / src.element_size()) * dst.element_size();
}
+ // If there's anything wrong with source, we shouldn't return dst_ silently and must error out.
+ TORCH_INTERNAL_ASSERT(sourceBuffer && size_to_copy > 0);
+ TORCH_INTERNAL_ASSERT(src_total_size >= storage_byte_offset);
+ TORCH_INTERNAL_ASSERT(dst.nbytes() >= (dst.storage_offset() * dst.element_size()));
+
@autoreleasepool {
MTLResourceOptions options = MTLResourceOptionCPUCacheModeDefault | MTLResourceStorageModeShared;
NSUInteger alignedLength = 0;
void* host_dst = dst.storage().data();
- void* alignedPtr = pageAlignedBlockPtr(host_dst, (NSUInteger)src_size, &alignedLength);
+ void* alignedPtr = pageAlignedBlockPtr(host_dst, (NSUInteger)src_total_size, &alignedLength);
id<MTLBuffer> destBuffer = [device newBufferWithBytesNoCopy:alignedPtr
length:alignedLength
options:options
deallocator:nil];
NSUInteger destOffset = uintptr_t(host_dst) - uintptr_t(alignedPtr);
// 4 bytes alignment required on macos for blits.
- TORCH_CHECK(destOffset % 4 == 0, "Unaligned blit request");
+ TORCH_INTERNAL_ASSERT(destOffset % 4 == 0, "Unaligned blit request");
- stream->copy_and_sync(sourceBuffer, destBuffer, src_size, storage_byte_offset, destOffset, non_blocking);
+ stream->copy_and_sync(sourceBuffer, destBuffer, size_to_copy, storage_byte_offset, destOffset, non_blocking);
[destBuffer release];
}
if (!dst.is_same(dst_)) {
@@ -155,26 +162,33 @@
id<MTLDevice> device = MPSDevice::getInstance()->device();
auto dst_byte_offset = dst_.storage_offset() * dst_.itemsize();
id<MTLBuffer> destBuffer = getMTLBufferStorage(dst_);
+ uint64_t src_total_size = 0;
if (src_.is_view()) {
src = src_.to(dst_.dtype()).expand_as(dst_).contiguous();
+ // Get the actual size of a View (takes into account the storage offset)
+ // For View tensors, the storage offset can be bigger than what's being reported by nbytes
+ src_total_size = at::detail::computeStorageNbytesContiguous(src.sizes(), src.element_size(), src.storage_offset());
} else {
src = src_;
if (src.dtype() != dst_.dtype()) {
// In case of dtype change, perform conversion on source device
src = src.to(dst_.dtype());
}
+ src_total_size = src.nbytes();
}
+ const size_t size_to_copy = src.nbytes();
const void* host_src = src.storage().data();
- uint64_t size = src.nbytes();
+ TORCH_INTERNAL_ASSERT(src_total_size >= (src.storage_offset() * src.element_size()));
+ TORCH_INTERNAL_ASSERT(dst_.nbytes() >= dst_byte_offset);
NSUInteger sourceOffset = 0;
@autoreleasepool {
MTLResourceOptions options = MTLResourceOptionCPUCacheModeDefault | MTLResourceStorageModeShared;
NSUInteger alignedLength = 0;
- void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)size, &alignedLength);
+ void* alignedPtr = pageAlignedBlockPtr(host_src, (NSUInteger)src_total_size, &alignedLength);
id<MTLBuffer> sourceBuffer = [device newBufferWithBytesNoCopy:alignedPtr
length:alignedLength
options:options
@@ -183,7 +197,7 @@
if (src_.is_view() || !src_.is_contiguous())
sourceOffset += src_.storage_offset() * src_.itemsize();
- stream->copy_and_sync(sourceBuffer, destBuffer, size, sourceOffset, dst_byte_offset, non_blocking);
+ stream->copy_and_sync(sourceBuffer, destBuffer, size_to_copy, sourceOffset, dst_byte_offset, non_blocking);
[sourceBuffer release];
}
diff --git a/test/test_mps.py b/test/test_mps.py
index c6e1063..12b5643 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1490,6 +1490,58 @@
z = y.as_strided(size=(32, 3), stride=(1, 0)).to("cpu")
self.assertEqual(x.to("cpu").as_strided(size=(32, 3), stride=(1, 0)), z)
+ def test_type_casting(self):
+ # https://github.com/pytorch/pytorch/issues/81567
+ def helper(data, to_dtype):
+ a_cpu = torch.tensor(data)
+ a_mps = a_cpu.to(torch.device('mps'))
+
+ res_cpu = a_cpu.type(to_dtype)
+ res_mps = a_mps.type(to_dtype)
+ self.assertEqual(res_cpu, res_mps)
+
+ helper([9.0, 3.0, 5.0, 4.0], torch.LongTensor)
+ helper([9.0, 3.0, 5.0, 4.0], torch.FloatTensor)
+ helper([9.0, 3.0, 5.0, 4.0], torch.IntTensor)
+ helper([9.0, 3.0, 5.0, 4.0], torch.ShortTensor)
+ helper([9.0, 3.0, 5.0, 4.0], torch.HalfTensor)
+ helper([9.0, 3.0, 5.0, 4.0], torch.CharTensor)
+ helper([9.0, 3.0, 5.0, 4.0], torch.ByteTensor)
+
+ def test_to_casting(self):
+ # https://github.com/pytorch/pytorch/issues/81567
+ def helper(data, to_dtype):
+ a_cpu = torch.tensor(data)
+ a_mps = a_cpu.to(torch.device('mps'))
+
+ res_cpu = a_cpu.to(to_dtype)
+ res_mps = a_mps.to(to_dtype)
+ self.assertEqual(res_cpu, res_mps)
+
+ helper([9.0, 3.0, 5.0, 4.0], torch.int64)
+ helper([9.0, 3.0, 5.0, 4.0], torch.float)
+ helper([9.0, 3.0, 5.0, 4.0], torch.int32)
+ helper([9.0, 3.0, 5.0, 4.0], torch.short)
+ helper([9.0, 3.0, 5.0, 4.0], torch.half)
+ helper([9.0, 3.0, 5.0, 4.0], torch.int8)
+ helper([9.0, 3.0, 5.0, 4.0], torch.uint8)
+
+ def test_storage_offset_greater_than_src_nbytes(self):
+ # https://github.com/pytorch/pytorch/issues/80844
+ n_tensors = 100
+ n_tensor_elems = 784
+ elems = torch.arange(n_tensors * n_tensor_elems, dtype=torch.float32)
+
+ tensor_list = []
+ for i in range(0, n_tensors - 1):
+ # create a list of contiguous view tensors (view tensor created by the slice op)
+ t = elems[n_tensor_elems * i : n_tensor_elems * (i + 1)]
+ tensor_list.append(t)
+
+ for i in range(0, n_tensors - 1):
+ t = tensor_list[i].view(1, 784)
+ t_mps = t.to("mps")
+ self.assertEqual(t, t_mps.cpu())
class TestLogical(TestCase):
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):