[MPS] Fix `use_metal_mm` condition (#118830)
One should not only look at stride size, but on dimensions as well, as strides of `torch.rand(65536, 1)` are `(1, 1)`
Extend test to account for this situation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118830
Approved by: https://github.com/huydhn
diff --git a/test/test_mps.py b/test/test_mps.py
index b0e64b3..186518b 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -6878,13 +6878,19 @@
# error: subRange.start (24576) is not less than length of dimension[0] (16384)
# See https://github.com/pytorch/pytorch/issues/116769#issuecomment-1888302095
self.assertNotEqual(torch.mm(x, y[:, 16384:32768]).abs().max().item(), 0.0)
- # And below used to produce incorrect results
- m, n, k = 1024, 1, 32769
- x = torch.rand(m, n, device="mps")
- y = torch.rand(n, k, device="mps")
- z = torch.mm(x, y).to("cpu")
- z_cpu = torch.mm(x.to("cpu"), y.to("cpu"))
- self.assertEqual(z, z_cpu)
+
+ def compare_mm(m, n, k):
+ x = torch.rand(m, n, device="mps")
+ y = torch.rand(n, k, device="mps")
+ z = torch.mm(x, y).cpu()
+ z_cpu = torch.mm(x.cpu(), y.cpu())
+ self.assertEqual(z, z_cpu)
+
+ # Used to produce incorrect results with MPS on M1 running MacOS 14.3, but correct with Metal
+ compare_mm(1024, 1, 32769)
+ # one more time, but with dimensions inverted
+ # see https://github.com/pytorch/pytorch/issues/116769#issuecomment-1920066984
+ compare_mm(32769, 1, 1025)
# Test flip
def test_flip(self):