[MPS] Fix `use_metal_mm` condition (#118830)

One should not only look at stride size, but on dimensions as well, as strides of `torch.rand(65536, 1)` are `(1, 1)`

Extend test to account for this situation

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118830
Approved by: https://github.com/huydhn
diff --git a/test/test_mps.py b/test/test_mps.py
index b0e64b3..186518b 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -6878,13 +6878,19 @@
         # error: subRange.start (24576) is not less than length of dimension[0] (16384)
         # See https://github.com/pytorch/pytorch/issues/116769#issuecomment-1888302095
         self.assertNotEqual(torch.mm(x, y[:, 16384:32768]).abs().max().item(), 0.0)
-        # And below used to produce incorrect results
-        m, n, k = 1024, 1, 32769
-        x = torch.rand(m, n, device="mps")
-        y = torch.rand(n, k, device="mps")
-        z = torch.mm(x, y).to("cpu")
-        z_cpu = torch.mm(x.to("cpu"), y.to("cpu"))
-        self.assertEqual(z, z_cpu)
+
+        def compare_mm(m, n, k):
+            x = torch.rand(m, n, device="mps")
+            y = torch.rand(n, k, device="mps")
+            z = torch.mm(x, y).cpu()
+            z_cpu = torch.mm(x.cpu(), y.cpu())
+            self.assertEqual(z, z_cpu)
+
+        # Used to produce incorrect results with MPS on M1 running MacOS 14.3, but correct with Metal
+        compare_mm(1024, 1, 32769)
+        # one more time, but with dimensions inverted
+        # see https://github.com/pytorch/pytorch/issues/116769#issuecomment-1920066984
+        compare_mm(32769, 1, 1025)
 
     # Test flip
     def test_flip(self):