[MPS] Fix `use_metal_mm` condition (#118830)
One should not only look at stride size, but on dimensions as well, as strides of `torch.rand(65536, 1)` are `(1, 1)`
Extend test to account for this situation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118830
Approved by: https://github.com/huydhn
diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm
index 66813cf..53ce8d2 100644
--- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm
+++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm
@@ -151,7 +151,8 @@
static bool always_use_metal = std::getenv("PYTORCH_MPS_PREFER_METAL") != nullptr;
constexpr auto max_stride_size = 32768;
return always_use_metal || self.stride(0) > max_stride_size || self.stride(1) > max_stride_size ||
- other.stride(0) > max_stride_size || other.stride(1) > max_stride_size;
+ self.size(0) > max_stride_size || self.size(1) > max_stride_size || other.stride(0) > max_stride_size ||
+ other.stride(1) > max_stride_size || other.size(0) > max_stride_size || other.size(1) > max_stride_size;
}
} // anonymous namespace
@@ -174,10 +175,10 @@
return output;
}
- // MPS matmul returns silently incorrect results if one of the matrix dimentions is greater than 2**15
- // And crashes if its a view of matrix with dimentions larger than 2**15
+ // MPS matmul returns silently incorrect results if one of the matrix dimensions is greater than 2**15
+ // And crashes if its a view of matrix with dimensions larger than 2**15
// See https://github.com/pytorch/pytorch/issues/116769#issuecomment-1888302095
- // In such cases, fallback to navie but accurate metal shader
+ // In such cases, fallback to naive but accurate metal shader
if (use_metal_mm(self, other, output)) {
return do_metal_mm(self, other, output);
}
diff --git a/test/test_mps.py b/test/test_mps.py
index b0e64b3..186518b 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -6878,13 +6878,19 @@
# error: subRange.start (24576) is not less than length of dimension[0] (16384)
# See https://github.com/pytorch/pytorch/issues/116769#issuecomment-1888302095
self.assertNotEqual(torch.mm(x, y[:, 16384:32768]).abs().max().item(), 0.0)
- # And below used to produce incorrect results
- m, n, k = 1024, 1, 32769
- x = torch.rand(m, n, device="mps")
- y = torch.rand(n, k, device="mps")
- z = torch.mm(x, y).to("cpu")
- z_cpu = torch.mm(x.to("cpu"), y.to("cpu"))
- self.assertEqual(z, z_cpu)
+
+ def compare_mm(m, n, k):
+ x = torch.rand(m, n, device="mps")
+ y = torch.rand(n, k, device="mps")
+ z = torch.mm(x, y).cpu()
+ z_cpu = torch.mm(x.cpu(), y.cpu())
+ self.assertEqual(z, z_cpu)
+
+ # Used to produce incorrect results with MPS on M1 running MacOS 14.3, but correct with Metal
+ compare_mm(1024, 1, 32769)
+ # one more time, but with dimensions inverted
+ # see https://github.com/pytorch/pytorch/issues/116769#issuecomment-1920066984
+ compare_mm(32769, 1, 1025)
# Test flip
def test_flip(self):