[MPS] Fix unary ops over sparse-mapped tensors (#100765)
If input tensor is backed by a sparse view, create a dense copy before running unary op, otherwise op will be applied against the wrong elements.
Introduce `is_dense_in_storage` that returns true if tensor/view are mapped to a dense area in the tensor storage.
Add unit test to validate the fix.
Fixes https://github.com/pytorch/pytorch/issues/98074
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100765
Approved by: https://github.com/albanD
diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h
index cd323f4..c7d8754 100644
--- a/aten/src/ATen/native/mps/OperationUtils.h
+++ b/aten/src/ATen/native/mps/OperationUtils.h
@@ -313,6 +313,18 @@
", downcasting to a smaller data type (int32/float32). Native support for int64 has been added in macOS 13.3."); \
}
+/**
+ * Returns distance from lowest to highest element offset in given tensor.
+ */
+size_t compute_storage_numel_distance(const at::Tensor& t);
+
+/**
+ * Checks whether tensor is mapped to a contiguous area in the storage.
+ */
+inline bool is_dense_in_storage(const at::Tensor& t) {
+ return compute_storage_numel_distance(t) == static_cast<size_t>(t.numel());
+}
+
} // namespace mps
} // namespace native
} // namespace at
diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm
index e91c311..001b92b 100644
--- a/aten/src/ATen/native/mps/OperationUtils.mm
+++ b/aten/src/ATen/native/mps/OperationUtils.mm
@@ -13,6 +13,21 @@
namespace at::native::mps {
+/**
+ * Computes distance from lowest to highest element offset in given tensor.
+ */
+size_t compute_storage_numel_distance(const at::Tensor& t) {
+ size_t rc = 1;
+ if (t.numel() == 0) {
+ return 0;
+ }
+ for (const auto i : c10::irange(t.dim())) {
+ assert(t.size(i) > 0);
+ rc += (t.size(i) - 1) * t.stride(i);
+ }
+ return rc;
+}
+
void runMPSGraph(MPSStream* mpsStream, MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results) {
mpsStream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_ADAPTIVE);
}
diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm
index 79a6577..c4f95e2 100644
--- a/aten/src/ATen/native/mps/operations/Copy.mm
+++ b/aten/src/ATen/native/mps/operations/Copy.mm
@@ -20,25 +20,6 @@
return (void*)alignedAddress;
}
-/**
- * Computes number of elements one needs to transfer to preserve all the elements
- */
-size_t compute_strided_size(const at::Tensor& t) {
- size_t rc = 1;
- if (t.numel() == 0) {
- return 0;
- }
- for (const auto i : c10::irange(t.dim())) {
- assert(t.size(i) > 0);
- rc += (t.size(i) - 1) * t.stride(i);
- }
- return rc;
-}
-
-bool is_strided_contiguous(const at::Tensor& t) {
- return compute_strided_size(t) == static_cast<size_t>(t.numel());
-}
-
// Copy sourceBuffer into destBuffer, casting sourceBuffer to src.scalar_type().
// The shapes and dtypes are taken from dst and src, but their storage pointers are not used.
void copy_cast_mps(at::Tensor& dst,
@@ -177,7 +158,7 @@
const size_t size_to_copy = src.nbytes();
const void* host_src = static_cast<const char*>(src.storage().data()) + src_byte_offset;
- TORCH_INTERNAL_ASSERT(src.dtype() == dst.dtype() && src.strides() == dst.strides() && is_strided_contiguous(src));
+ TORCH_INTERNAL_ASSERT(src.dtype() == dst.dtype() && src.strides() == dst.strides() && is_dense_in_storage(src));
@autoreleasepool {
MTLResourceOptions options = MTLResourceCPUCacheModeDefaultCache | MTLResourceStorageModeShared;
@@ -204,12 +185,12 @@
// Typecast to dst_ if needed and expand, which is a no-op
Tensor src = (src_.dtype() != dst_.dtype() ? src_.to(dst_.dtype()) : src_).expand_as(dst_);
- // If src is not contiguously strided it must be cloned
+ // If src is not densely mapped in storage it must be cloned
// It does not mean that tensor is contiguous, but rather
// that it could be represented as 1d view
- if (!is_strided_contiguous(src)) {
+ if (!is_dense_in_storage(src)) {
src = src.clone();
- TORCH_INTERNAL_ASSERT(is_strided_contiguous(src));
+ TORCH_INTERNAL_ASSERT(is_dense_in_storage(src));
}
Tensor dst = dst_;
bool needs_copy = false;
diff --git a/aten/src/ATen/native/mps/operations/UnaryOps.mm b/aten/src/ATen/native/mps/operations/UnaryOps.mm
index d2f0105..2bbd238 100644
--- a/aten/src/ATen/native/mps/operations/UnaryOps.mm
+++ b/aten/src/ATen/native/mps/operations/UnaryOps.mm
@@ -1,5 +1,6 @@
// Copyright © 2022 Apple Inc.
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
+#include <ATen/native/mps/Copy.h>
#include <ATen/native/mps/MPSGraphVenturaOps.h>
#include <ATen/native/mps/OperationUtils.h>
@@ -80,13 +81,25 @@
newCachedGraph->outputTensor_ = unaryBlock(mpsGraph, castTensor);
});
+ // If self is densely mapped in storage, create a dense output-like representation
+ at::Tensor self_;
+ if (!is_dense_in_storage(self)) {
+ self_ = at::empty_like(output);
+ mps::mps_copy_(self_, self, false);
+ } else {
+ self_ = self;
+ }
+
bool gatherTensorData = true;
+ // NS: This check is wrong and needs to be fixed, as it would produce wrong results for transposed outputs
+ // See https://github.com/pytorch/pytorch/issues/100764
+
if (!output.is_contiguous() || output.is_view()) {
gatherTensorData = false;
}
- Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, /*mpsShape=*/nullptr, gatherTensorData);
- Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output, /*mpsShape=*/nullptr, false);
+ auto selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self_, /*mpsShape=*/nullptr, gatherTensorData);
+ auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output, /*mpsShape=*/nullptr, false);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =
@{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
diff --git a/test/test_mps.py b/test/test_mps.py
index 7ae5e21..a31cc50 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -6620,6 +6620,13 @@
helper((2, 8, 4, 5))
+ def test_neg_strided_input(self):
+ # See https://github.com/pytorch/pytorch/issues/98074#issuecomment-1496088337
+ x = torch.arange(18.0, device='mps').reshape(2, 3, 3)
+ y = x.permute(1, 0, 2)[..., 1]
+ z = y + y.neg()
+ self.assertEqual(z.abs().max().item(), 0.0)
+
# Test index add
def test_index_add(self):
def helper(shape, dim, index, source_shape, alpha, x_dtype=torch.float32, idx_dtype=torch.int32):