[MPS] Fix `torch.mm` correctness for large matrices (#117549)
Currently `matrixMultiplicationWithPrimaryTensor:secondaryTensor:` returns incorrect results if one of the matrix dimensions is greater than 32K
Solve it by providing a very naive matrix multiplication metal shader and call it if stride size is greater than 32768 elements, as slicing inside the MPSGraph doesn't work either, since `-sliceTensor:starts:ends:strides:` somehow affects matmul as well, if tiling is done as follows:
```objc
NSMutableArray<MPSGraphTensor*>* rows = [NSMutableArray new];
for (int64_t i = 0; i < M; i += tile_size) {
const auto i_end = std::min(i + tile_size, M);
NSMutableArray<MPSGraphTensor*>* row_chunks = [NSMutableArray new];
for (int64_t j = 0; j < K; j += tile_size) {
const auto j_end = std::min(j + tile_size, K);
MPSGraphTensor* tile = nil;
for (int64_t k = 0; k < N; k += tile_size) {
const auto k_end = std::min(k + tile_size, N);
auto selfChunk = [graph sliceTensor:selfTensor
starts:@[ @(i), @(k) ]
ends:@[ @(i_end), @(k_end) ]
strides:@[ @(1), @(1) ]
name:nil];
auto otherChunk = [graph sliceTensor:otherTensor
starts:@[ @(k), @(j) ]
ends:@[ @(k_end), @(j_end) ]
strides:@[ @(1), @(1) ]
name:nil];
auto chunkMM = [graph matrixMultiplicationWithPrimaryTensor:selfChunk secondaryTensor:otherChunk name:nil];
tile = tile ? [graph additionWithPrimaryTensor:tile secondaryTensor:chunkMM name:nil] : chunkMM;
}
[row_chunks addObject:tile];
}
auto row = row_chunks.count > 1 ? [graph concatTensors:row_chunks dimension:1 name:nil] : row_chunks.firstObject;
[rows addObject:row];
}
return rows.count > 1 ? [graph concatTensors:rows dimension:0 name:nil] : rows.firstObject;
```
One can always use metal MM by defining `PYTORCH_MPS_PREFER_METAL` environment variable
Fixes https://github.com/pytorch/pytorch/issues/116769
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117549
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/operations/CrossKernel.mm b/aten/src/ATen/native/mps/operations/CrossKernel.mm
index 1e04a76..afabf04 100644
--- a/aten/src/ATen/native/mps/operations/CrossKernel.mm
+++ b/aten/src/ATen/native/mps/operations/CrossKernel.mm
@@ -9,6 +9,7 @@
namespace {
static const char* METAL_CROSS = R"CROSS_METAL(
+#include <metal_array>
#include <metal_stdlib>
using namespace metal;
diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm
index 8aad3ad..66813cf 100644
--- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm
+++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm
@@ -22,12 +22,119 @@
namespace at::native {
namespace mps {
+namespace {
+static const char* METAL_LINALG = R"MATMUL_METAL(
+#include <metal_array>
-enum LinearAlgebraOpType { ADDBMM_OP_TYPE, BADDBMM_OP_TYPE };
+using namespace metal;
+template<typename T>
+T dot_product(constant T *v1, constant T* v2, ulong2 strides, uint32_t size) {
+ T rc = 0.0;
+ for (uint32_t i = 0; i < size; ++i) {
+ rc += v1[i * strides.x] * v2[i * strides.y];
+ }
+ return rc;
+}
-static std::tuple<MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*> do_mm(MPSGraph* graph,
- const Tensor& self,
- const Tensor& other) {
+template<typename T>
+kernel void naive_matmul(
+ constant T * mat1Data [[buffer(0)]],
+ constant T * mat2Data [[buffer(1)]],
+ device T * outputData [[buffer(2)]],
+ constant array<ulong2, 3> & strides [[buffer(3)]],
+ constant uint3 & sizes [[buffer(4)]],
+ uint thread_index [[thread_position_in_grid]]) {
+ uint y = thread_index / sizes.x;
+ uint x = thread_index % sizes.x;
+ if (x >= sizes.x || y >= sizes.z) {
+ return;
+ }
+ auto rc = dot_product(mat1Data + x * strides[0].x,
+ mat2Data + y * strides[1].y,
+ ulong2(strides[0].y, strides[1].x),
+ sizes.y);
+ outputData[x * strides[2].x + y * strides[2].y] = rc;
+}
+
+#define INSTANTIATE_NAIVE_MM(DTYPE) \
+template \
+[[host_name("naive_matmul_" #DTYPE)]] \
+kernel void naive_matmul<DTYPE>( \
+ constant DTYPE * mat1Data [[buffer(0)]], \
+ constant DTYPE * mat2Data [[buffer(1)]], \
+ device DTYPE * outputData [[buffer(2)]], \
+ constant array<ulong2, 3> & strides [[buffer(3)]], \
+ constant uint3 & sizes [[buffer(4)]], \
+ uint thread_index [[thread_position_in_grid]])
+
+INSTANTIATE_NAIVE_MM(float);
+INSTANTIATE_NAIVE_MM(half);
+)MATMUL_METAL";
+
+id<MTLLibrary> compileLinalgOpLibrary(id<MTLDevice> device) {
+ static id<MTLLibrary> linalgLibrary = nil;
+ if (linalgLibrary) {
+ return linalgLibrary;
+ }
+
+ NSError* error = nil;
+ MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
+ [options setLanguageVersion:MTLLanguageVersion2_3];
+ linalgLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_LINALG encoding:NSASCIIStringEncoding]
+ options:options
+ error:&error];
+ TORCH_CHECK(linalgLibrary, "Failed to create metal linalg library, error: ", [[error description] UTF8String]);
+ return linalgLibrary;
+}
+
+id<MTLComputePipelineState> matmulPipelineState(id<MTLDevice> device, ScalarType scalar_type) {
+ std::string kernel = "naive_matmul_" + mps::scalarToMetalTypeString(scalar_type);
+ static std::unordered_map<std::string, id<MTLComputePipelineState>> psoCache;
+ id<MTLComputePipelineState> pso = psoCache[kernel];
+ if (pso) {
+ return pso;
+ }
+
+ NSError* error = nil;
+ id<MTLLibrary> linalgLib = compileLinalgOpLibrary(device);
+ id<MTLFunction> matmulFunc = [linalgLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]];
+ TORCH_CHECK(matmulFunc, "Failed to create function state object for: ", kernel);
+ pso = [device newComputePipelineStateWithFunction:matmulFunc error:&error];
+ TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
+
+ psoCache[kernel] = pso;
+ return pso;
+}
+
+Tensor& do_metal_mm(const Tensor& self, const Tensor& other, Tensor& output) {
+ auto stream = getCurrentMPSStream();
+ auto device = MPSDevice::getInstance()->device();
+ auto matmulPSO = matmulPipelineState(device, output.scalar_type());
+ dispatch_sync_with_rethrow(stream->queue(), ^() {
+ @autoreleasepool {
+ getMPSProfiler().beginProfileKernel(matmulPSO, "naive_matmul", {self, other});
+ auto computeEncoder = stream->commandEncoder();
+ [computeEncoder setComputePipelineState:matmulPSO];
+ std::array<uint32_t, 3> sizes = {static_cast<uint32_t>(self.size(0)),
+ static_cast<uint32_t>(self.size(1)),
+ static_cast<uint32_t>(output.size(1))};
+ std::array<int64_t, 6> strides = {
+ self.stride(0), self.stride(1), other.stride(0), other.stride(1), output.stride(0), output.stride(1)};
+ mtl_setBuffer(computeEncoder, self, 0);
+ mtl_setBuffer(computeEncoder, other, 1);
+ mtl_setBuffer(computeEncoder, output, 2);
+ [computeEncoder setBytes:strides.data() length:sizeof(uint64_t) * strides.size() atIndex:3];
+ [computeEncoder setBytes:sizes.data() length:sizeof(uint32_t) * sizes.size() atIndex:4];
+ mtl_dispatch1DJob(computeEncoder, matmulPSO, output.numel());
+ getMPSProfiler().endProfileKernel(matmulPSO);
+ }
+ });
+ return output;
+}
+
+std::tuple<MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*> do_mm(MPSGraph* graph,
+ const Tensor& self,
+ const Tensor& other) {
if (self.numel() == 0 || other.numel() == 0) {
auto output = [graph constantWithScalar:0.0
shape:getMPSShape({self.size(0), other.size(1)})
@@ -40,6 +147,15 @@
return {selfTensor, otherTensor, output};
}
+bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output) {
+ static bool always_use_metal = std::getenv("PYTORCH_MPS_PREFER_METAL") != nullptr;
+ constexpr auto max_stride_size = 32768;
+ return always_use_metal || self.stride(0) > max_stride_size || self.stride(1) > max_stride_size ||
+ other.stride(0) > max_stride_size || other.stride(1) > max_stride_size;
+}
+
+} // anonymous namespace
+
static Tensor& mm_out_mps_impl(const Tensor& self, const Tensor& other, Tensor& output) {
using namespace mps;
using CachedGraph = MPSBinaryCachedGraph;
@@ -58,6 +174,14 @@
return output;
}
+ // MPS matmul returns silently incorrect results if one of the matrix dimentions is greater than 2**15
+ // And crashes if its a view of matrix with dimentions larger than 2**15
+ // See https://github.com/pytorch/pytorch/issues/116769#issuecomment-1888302095
+ // In such cases, fallback to navie but accurate metal shader
+ if (use_metal_mm(self, other, output)) {
+ return do_metal_mm(self, other, output);
+ }
+
@autoreleasepool {
string key = "mm_out_mps_impl" + getTensorsStringKey({self, other});
@@ -85,6 +209,8 @@
return output;
}
+enum LinearAlgebraOpType { ADDBMM_OP_TYPE, BADDBMM_OP_TYPE };
+
static Tensor& addbmm_or_baddbmm_out_mps_impl(const Tensor& input,
const Tensor& batch1,
const Tensor& batch2,
diff --git a/test/test_mps.py b/test/test_mps.py
index 741adc1..62a72e4 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -6869,6 +6869,22 @@
gc.collect()
torch.mps.empty_cache()
+ def test_mm_large(self):
+ """ Test that MM works for matrices with index larger than 32K """
+ x = torch.rand(10, 1, device="mps")
+ y = torch.rand(1, 32769, device="mps")
+ # This used to crash with:
+ # error: subRange.start (24576) is not less than length of dimension[0] (16384)
+ # See https://github.com/pytorch/pytorch/issues/116769#issuecomment-1888302095
+ self.assertNotEqual(torch.mm(x, y[:, 16384:32768]).abs().max().item(), 0.0)
+ # And below used to produce incorrect results
+ m, n, k = 1024, 1, 32769
+ x = torch.rand(m, n, device="mps")
+ y = torch.rand(n, k, device="mps")
+ z = torch.mm(x, y).to("cpu")
+ z_cpu = torch.mm(x.to("cpu"), y.to("cpu"))
+ self.assertEqual(z, z_cpu)
+
# Test flip
def test_flip(self):
def helper(shape, dims):