[MPS] Fix `torch.mm` correctness for large matrices (#117549)
Currently `matrixMultiplicationWithPrimaryTensor:secondaryTensor:` returns incorrect results if one of the matrix dimensions is greater than 32K
Solve it by providing a very naive matrix multiplication metal shader and call it if stride size is greater than 32768 elements, as slicing inside the MPSGraph doesn't work either, since `-sliceTensor:starts:ends:strides:` somehow affects matmul as well, if tiling is done as follows:
```objc
NSMutableArray<MPSGraphTensor*>* rows = [NSMutableArray new];
for (int64_t i = 0; i < M; i += tile_size) {
const auto i_end = std::min(i + tile_size, M);
NSMutableArray<MPSGraphTensor*>* row_chunks = [NSMutableArray new];
for (int64_t j = 0; j < K; j += tile_size) {
const auto j_end = std::min(j + tile_size, K);
MPSGraphTensor* tile = nil;
for (int64_t k = 0; k < N; k += tile_size) {
const auto k_end = std::min(k + tile_size, N);
auto selfChunk = [graph sliceTensor:selfTensor
starts:@[ @(i), @(k) ]
ends:@[ @(i_end), @(k_end) ]
strides:@[ @(1), @(1) ]
name:nil];
auto otherChunk = [graph sliceTensor:otherTensor
starts:@[ @(k), @(j) ]
ends:@[ @(k_end), @(j_end) ]
strides:@[ @(1), @(1) ]
name:nil];
auto chunkMM = [graph matrixMultiplicationWithPrimaryTensor:selfChunk secondaryTensor:otherChunk name:nil];
tile = tile ? [graph additionWithPrimaryTensor:tile secondaryTensor:chunkMM name:nil] : chunkMM;
}
[row_chunks addObject:tile];
}
auto row = row_chunks.count > 1 ? [graph concatTensors:row_chunks dimension:1 name:nil] : row_chunks.firstObject;
[rows addObject:row];
}
return rows.count > 1 ? [graph concatTensors:rows dimension:0 name:nil] : rows.firstObject;
```
One can always use metal MM by defining `PYTORCH_MPS_PREFER_METAL` environment variable
Fixes https://github.com/pytorch/pytorch/issues/116769
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117549
Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index 741adc1..62a72e4 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -6869,6 +6869,22 @@
gc.collect()
torch.mps.empty_cache()
+ def test_mm_large(self):
+ """ Test that MM works for matrices with index larger than 32K """
+ x = torch.rand(10, 1, device="mps")
+ y = torch.rand(1, 32769, device="mps")
+ # This used to crash with:
+ # error: subRange.start (24576) is not less than length of dimension[0] (16384)
+ # See https://github.com/pytorch/pytorch/issues/116769#issuecomment-1888302095
+ self.assertNotEqual(torch.mm(x, y[:, 16384:32768]).abs().max().item(), 0.0)
+ # And below used to produce incorrect results
+ m, n, k = 1024, 1, 32769
+ x = torch.rand(m, n, device="mps")
+ y = torch.rand(n, k, device="mps")
+ z = torch.mm(x, y).to("cpu")
+ z_cpu = torch.mm(x.to("cpu"), y.to("cpu"))
+ self.assertEqual(z, z_cpu)
+
# Test flip
def test_flip(self):
def helper(shape, dims):