MPS: fixes (#77462)

- Fix the is_available flag for x86 machines
- Fix the tensor creation for older MacOS platforms
- Addmm fixes for transposition

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77462
Approved by: https://github.com/albanD
diff --git a/test/test_mps.py b/test/test_mps.py
index ac1ddd4..0480426 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -9,6 +9,7 @@
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
+import itertools
 from torch.nn import Parameter
 from torch.testing._internal.common_utils import run_tests, TestCase, download_file, TEST_WITH_UBSAN
 import torch.backends.mps
@@ -77,14 +78,20 @@
                 np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
                 device="mps")
 
-
 class MatmulTest(TestCase):
-    def _helper(self, shape_tensor_1, shape_tensor_2):
-        tensor1_cpu = torch.randn(shape_tensor_1, device="cpu")
-        tensor2_cpu = torch.randn(shape_tensor_2, device="cpu")
+    def _helper(self, shape_tensor_1, shape_tensor_2, expand_tensor_1_shape=None, expand_tensor_2_shape=None):
+        if expand_tensor_1_shape:
+            tensor1_mps = torch.randn(shape_tensor_1, device="mps").expand(expand_tensor_1_shape)
+        else:
+            tensor1_mps = torch.randn(shape_tensor_1, device="mps")
 
-        tensor1_mps = tensor1_cpu.to("mps")
-        tensor2_mps = tensor2_cpu.to("mps")
+        if expand_tensor_2_shape:
+            tensor2_mps = torch.randn(shape_tensor_2, device="mps").expand(expand_tensor_2_shape)
+        else:
+            tensor2_mps = torch.randn(shape_tensor_2, device="mps")
+
+        tensor1_cpu = tensor1_mps.to("cpu")
+        tensor2_cpu = tensor2_mps.to("cpu")
 
         matmul_cpu = torch.matmul(tensor1_cpu, tensor2_cpu)
         matmul_mps = torch.matmul(tensor1_mps, tensor2_mps)
@@ -3929,8 +3936,8 @@
         if beta != 0:
             res3 += (torch.mul(t, beta)).to(numpy_dtype).cpu().numpy()
         res3 = torch.from_numpy(res3).to(dtype)
-        # self.assertEqual(res1, res2)
-        # self.assertEqual(res1, res3)
+        self.assertEqual(res1, res2)
+        self.assertEqual(res1, res3)
 
     def test_addmm(self, device="mps", dtype=torch.float32):
         M = torch.randn(10, 25, device=device).to(dtype)
@@ -3938,29 +3945,23 @@
         m2 = torch.randn(50, 25, device=device).to(dtype)
         self._test_addmm_addmv(torch.addmm, M, m1, m2)
 
-        # # Test 0-strided
-        # M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25)
-        # m1 = torch.randn(10, 1, device=device).to(dtype).expand(10, 50)
-        # m2 = torch.randn(50, 25, device=device).to(dtype)
-        # self._test_addmm_addmv(torch.addmm, M, m1, m2)
-
         # Test beta=0, M=nan
         M = torch.full((10, 25), math.nan, device=device).to(dtype)
         m1 = torch.randn(10, 50, device=device).to(dtype)
         m2 = torch.randn(50, 25, device=device).to(dtype)
         self._test_addmm_addmv(torch.addmm, M, m1, m2, beta=0)
 
-        # # Test transpose
-        # for t1, t2, t3, t4 in itertools.product([True, False], repeat=4):
-        # def maybe_transpose(cond, m):
-        # if not cond:
-        # return m
-        # return m.t().clone(memory_format=torch.contiguous_format).t()
+        # Test transpose
+        for t1, t2, t3, t4 in itertools.product([True, False], repeat=4):
+            def maybe_transpose(cond, m):
+                if not cond:
+                    return m
+                return m.t().clone(memory_format=torch.contiguous_format).t()
 
-        # M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype))
-        # m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype))
-        # m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
-        # self._test_addmm_addmv(torch.addmm, M, m1, m2, transpose_out=t4)
+        M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype))
+        m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype))
+        m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
+        self._test_addmm_addmv(torch.addmm, M, m1, m2, transpose_out=t4)
 
 
 class TestRNNMPS(TestCase):