[MPS] Fix `torch.mm` correctness for large matrices (#117549)

Currently `matrixMultiplicationWithPrimaryTensor:secondaryTensor:` returns incorrect results if one of the matrix dimensions is greater than 32K
Solve it by providing a very naive matrix multiplication metal shader and call it if stride size is greater than 32768 elements, as slicing inside the MPSGraph doesn't work either, since `-sliceTensor:starts:ends:strides:` somehow affects matmul as well, if tiling is done as follows:
```objc
  NSMutableArray<MPSGraphTensor*>* rows = [NSMutableArray new];
  for (int64_t i = 0; i < M; i += tile_size) {
    const auto i_end = std::min(i + tile_size, M);
    NSMutableArray<MPSGraphTensor*>* row_chunks = [NSMutableArray new];
    for (int64_t j = 0; j < K; j += tile_size) {
      const auto j_end = std::min(j + tile_size, K);
      MPSGraphTensor* tile = nil;
      for (int64_t k = 0; k < N; k += tile_size) {
        const auto k_end = std::min(k + tile_size, N);
        auto selfChunk = [graph sliceTensor:selfTensor
                                     starts:@[ @(i), @(k) ]
                                       ends:@[ @(i_end), @(k_end) ]
                                    strides:@[ @(1), @(1) ]
                                       name:nil];
        auto otherChunk = [graph sliceTensor:otherTensor
                                      starts:@[ @(k), @(j) ]
                                        ends:@[ @(k_end), @(j_end) ]
                                     strides:@[ @(1), @(1) ]
                                        name:nil];
        auto chunkMM = [graph matrixMultiplicationWithPrimaryTensor:selfChunk secondaryTensor:otherChunk name:nil];

        tile = tile ? [graph additionWithPrimaryTensor:tile secondaryTensor:chunkMM name:nil] : chunkMM;
      }
      [row_chunks addObject:tile];
    }
    auto row = row_chunks.count > 1 ? [graph concatTensors:row_chunks dimension:1 name:nil] : row_chunks.firstObject;
    [rows addObject:row];
  }
  return rows.count > 1 ? [graph concatTensors:rows dimension:0 name:nil] : rows.firstObject;
```

One can always use metal MM by defining `PYTORCH_MPS_PREFER_METAL` environment variable
Fixes https://github.com/pytorch/pytorch/issues/116769
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117549
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/operations/CrossKernel.mm b/aten/src/ATen/native/mps/operations/CrossKernel.mm
index 1e04a76..afabf04 100644
--- a/aten/src/ATen/native/mps/operations/CrossKernel.mm
+++ b/aten/src/ATen/native/mps/operations/CrossKernel.mm
@@ -9,6 +9,7 @@
 namespace {
 
 static const char* METAL_CROSS = R"CROSS_METAL(
+#include <metal_array>
 
 #include <metal_stdlib>
 using namespace metal;
diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm
index 8aad3ad..66813cf 100644
--- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm
+++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm
@@ -22,12 +22,119 @@
 
 namespace at::native {
 namespace mps {
+namespace {
+static const char* METAL_LINALG = R"MATMUL_METAL(
+#include <metal_array>
 
-enum LinearAlgebraOpType { ADDBMM_OP_TYPE, BADDBMM_OP_TYPE };
+using namespace metal;
+template<typename T>
+T dot_product(constant T *v1, constant T* v2, ulong2 strides, uint32_t size) {
+  T rc = 0.0;
+  for (uint32_t i = 0; i < size; ++i) {
+    rc += v1[i * strides.x] * v2[i * strides.y];
+  }
+  return rc;
+}
 
-static std::tuple<MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*> do_mm(MPSGraph* graph,
-                                                                           const Tensor& self,
-                                                                           const Tensor& other) {
+template<typename T>
+kernel void naive_matmul(
+    constant T                 * mat1Data      [[buffer(0)]],
+    constant T                 * mat2Data      [[buffer(1)]],
+    device   T                 * outputData    [[buffer(2)]],
+    constant array<ulong2, 3>  & strides       [[buffer(3)]],
+    constant uint3             & sizes         [[buffer(4)]],
+    uint                         thread_index [[thread_position_in_grid]]) {
+    uint y = thread_index / sizes.x;
+    uint x = thread_index % sizes.x;
+    if (x >= sizes.x || y >= sizes.z) {
+        return;
+    }
+    auto rc = dot_product(mat1Data + x * strides[0].x,
+                          mat2Data + y * strides[1].y,
+                          ulong2(strides[0].y, strides[1].x),
+                          sizes.y);
+    outputData[x * strides[2].x + y * strides[2].y] = rc;
+}
+
+#define INSTANTIATE_NAIVE_MM(DTYPE)                                        \
+template                                                                   \
+[[host_name("naive_matmul_" #DTYPE)]]                                      \
+kernel void naive_matmul<DTYPE>(                                           \
+    constant DTYPE             * mat1Data      [[buffer(0)]],              \
+    constant DTYPE             * mat2Data      [[buffer(1)]],              \
+    device   DTYPE             * outputData    [[buffer(2)]],              \
+    constant array<ulong2, 3>  & strides       [[buffer(3)]],              \
+    constant uint3             & sizes         [[buffer(4)]],              \
+    uint                         thread_index [[thread_position_in_grid]])
+
+INSTANTIATE_NAIVE_MM(float);
+INSTANTIATE_NAIVE_MM(half);
+)MATMUL_METAL";
+
+id<MTLLibrary> compileLinalgOpLibrary(id<MTLDevice> device) {
+  static id<MTLLibrary> linalgLibrary = nil;
+  if (linalgLibrary) {
+    return linalgLibrary;
+  }
+
+  NSError* error = nil;
+  MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
+  [options setLanguageVersion:MTLLanguageVersion2_3];
+  linalgLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_LINALG encoding:NSASCIIStringEncoding]
+                                       options:options
+                                         error:&error];
+  TORCH_CHECK(linalgLibrary, "Failed to create metal linalg library, error: ", [[error description] UTF8String]);
+  return linalgLibrary;
+}
+
+id<MTLComputePipelineState> matmulPipelineState(id<MTLDevice> device, ScalarType scalar_type) {
+  std::string kernel = "naive_matmul_" + mps::scalarToMetalTypeString(scalar_type);
+  static std::unordered_map<std::string, id<MTLComputePipelineState>> psoCache;
+  id<MTLComputePipelineState> pso = psoCache[kernel];
+  if (pso) {
+    return pso;
+  }
+
+  NSError* error = nil;
+  id<MTLLibrary> linalgLib = compileLinalgOpLibrary(device);
+  id<MTLFunction> matmulFunc = [linalgLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]];
+  TORCH_CHECK(matmulFunc, "Failed to create function state object for: ", kernel);
+  pso = [device newComputePipelineStateWithFunction:matmulFunc error:&error];
+  TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
+
+  psoCache[kernel] = pso;
+  return pso;
+}
+
+Tensor& do_metal_mm(const Tensor& self, const Tensor& other, Tensor& output) {
+  auto stream = getCurrentMPSStream();
+  auto device = MPSDevice::getInstance()->device();
+  auto matmulPSO = matmulPipelineState(device, output.scalar_type());
+  dispatch_sync_with_rethrow(stream->queue(), ^() {
+    @autoreleasepool {
+      getMPSProfiler().beginProfileKernel(matmulPSO, "naive_matmul", {self, other});
+      auto computeEncoder = stream->commandEncoder();
+      [computeEncoder setComputePipelineState:matmulPSO];
+      std::array<uint32_t, 3> sizes = {static_cast<uint32_t>(self.size(0)),
+                                       static_cast<uint32_t>(self.size(1)),
+                                       static_cast<uint32_t>(output.size(1))};
+      std::array<int64_t, 6> strides = {
+          self.stride(0), self.stride(1), other.stride(0), other.stride(1), output.stride(0), output.stride(1)};
+      mtl_setBuffer(computeEncoder, self, 0);
+      mtl_setBuffer(computeEncoder, other, 1);
+      mtl_setBuffer(computeEncoder, output, 2);
+      [computeEncoder setBytes:strides.data() length:sizeof(uint64_t) * strides.size() atIndex:3];
+      [computeEncoder setBytes:sizes.data() length:sizeof(uint32_t) * sizes.size() atIndex:4];
+      mtl_dispatch1DJob(computeEncoder, matmulPSO, output.numel());
+      getMPSProfiler().endProfileKernel(matmulPSO);
+    }
+  });
+  return output;
+}
+
+std::tuple<MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*> do_mm(MPSGraph* graph,
+                                                                    const Tensor& self,
+                                                                    const Tensor& other) {
   if (self.numel() == 0 || other.numel() == 0) {
     auto output = [graph constantWithScalar:0.0
                                       shape:getMPSShape({self.size(0), other.size(1)})
@@ -40,6 +147,15 @@
   return {selfTensor, otherTensor, output};
 }
 
+bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output) {
+  static bool always_use_metal = std::getenv("PYTORCH_MPS_PREFER_METAL") != nullptr;
+  constexpr auto max_stride_size = 32768;
+  return always_use_metal || self.stride(0) > max_stride_size || self.stride(1) > max_stride_size ||
+      other.stride(0) > max_stride_size || other.stride(1) > max_stride_size;
+}
+
+} // anonymous namespace
+
 static Tensor& mm_out_mps_impl(const Tensor& self, const Tensor& other, Tensor& output) {
   using namespace mps;
   using CachedGraph = MPSBinaryCachedGraph;
@@ -58,6 +174,14 @@
     return output;
   }
 
+  // MPS matmul returns silently incorrect results if one of the matrix dimentions is greater than 2**15
+  // And crashes if its a view of matrix with dimentions larger than 2**15
+  // See https://github.com/pytorch/pytorch/issues/116769#issuecomment-1888302095
+  // In such cases, fallback to navie but accurate metal shader
+  if (use_metal_mm(self, other, output)) {
+    return do_metal_mm(self, other, output);
+  }
+
   @autoreleasepool {
     string key = "mm_out_mps_impl" + getTensorsStringKey({self, other});
 
@@ -85,6 +209,8 @@
   return output;
 }
 
+enum LinearAlgebraOpType { ADDBMM_OP_TYPE, BADDBMM_OP_TYPE };
+
 static Tensor& addbmm_or_baddbmm_out_mps_impl(const Tensor& input,
                                               const Tensor& batch1,
                                               const Tensor& batch2,
diff --git a/test/test_mps.py b/test/test_mps.py
index 741adc1..62a72e4 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -6869,6 +6869,22 @@
         gc.collect()
         torch.mps.empty_cache()
 
+    def test_mm_large(self):
+        """ Test that MM works for matrices with index larger than 32K """
+        x = torch.rand(10, 1, device="mps")
+        y = torch.rand(1, 32769, device="mps")
+        # This used to crash with:
+        # error: subRange.start (24576) is not less than length of dimension[0] (16384)
+        # See https://github.com/pytorch/pytorch/issues/116769#issuecomment-1888302095
+        self.assertNotEqual(torch.mm(x, y[:, 16384:32768]).abs().max().item(), 0.0)
+        # And below used to produce incorrect results
+        m, n, k = 1024, 1, 32769
+        x = torch.rand(m, n, device="mps")
+        y = torch.rand(n, k, device="mps")
+        z = torch.mm(x, y).to("cpu")
+        z_cpu = torch.mm(x.to("cpu"), y.to("cpu"))
+        self.assertEqual(z, z_cpu)
+
     # Test flip
     def test_flip(self):
         def helper(shape, dims):