MPS: fixes (#77462)
- Fix the is_available flag for x86 machines
- Fix the tensor creation for older MacOS platforms
- Addmm fixes for transposition
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77462
Approved by: https://github.com/albanD
diff --git a/test/test_mps.py b/test/test_mps.py
index ac1ddd4..0480426 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -9,6 +9,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
+import itertools
from torch.nn import Parameter
from torch.testing._internal.common_utils import run_tests, TestCase, download_file, TEST_WITH_UBSAN
import torch.backends.mps
@@ -77,14 +78,20 @@
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
device="mps")
-
class MatmulTest(TestCase):
- def _helper(self, shape_tensor_1, shape_tensor_2):
- tensor1_cpu = torch.randn(shape_tensor_1, device="cpu")
- tensor2_cpu = torch.randn(shape_tensor_2, device="cpu")
+ def _helper(self, shape_tensor_1, shape_tensor_2, expand_tensor_1_shape=None, expand_tensor_2_shape=None):
+ if expand_tensor_1_shape:
+ tensor1_mps = torch.randn(shape_tensor_1, device="mps").expand(expand_tensor_1_shape)
+ else:
+ tensor1_mps = torch.randn(shape_tensor_1, device="mps")
- tensor1_mps = tensor1_cpu.to("mps")
- tensor2_mps = tensor2_cpu.to("mps")
+ if expand_tensor_2_shape:
+ tensor2_mps = torch.randn(shape_tensor_2, device="mps").expand(expand_tensor_2_shape)
+ else:
+ tensor2_mps = torch.randn(shape_tensor_2, device="mps")
+
+ tensor1_cpu = tensor1_mps.to("cpu")
+ tensor2_cpu = tensor2_mps.to("cpu")
matmul_cpu = torch.matmul(tensor1_cpu, tensor2_cpu)
matmul_mps = torch.matmul(tensor1_mps, tensor2_mps)
@@ -3929,8 +3936,8 @@
if beta != 0:
res3 += (torch.mul(t, beta)).to(numpy_dtype).cpu().numpy()
res3 = torch.from_numpy(res3).to(dtype)
- # self.assertEqual(res1, res2)
- # self.assertEqual(res1, res3)
+ self.assertEqual(res1, res2)
+ self.assertEqual(res1, res3)
def test_addmm(self, device="mps", dtype=torch.float32):
M = torch.randn(10, 25, device=device).to(dtype)
@@ -3938,29 +3945,23 @@
m2 = torch.randn(50, 25, device=device).to(dtype)
self._test_addmm_addmv(torch.addmm, M, m1, m2)
- # # Test 0-strided
- # M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25)
- # m1 = torch.randn(10, 1, device=device).to(dtype).expand(10, 50)
- # m2 = torch.randn(50, 25, device=device).to(dtype)
- # self._test_addmm_addmv(torch.addmm, M, m1, m2)
-
# Test beta=0, M=nan
M = torch.full((10, 25), math.nan, device=device).to(dtype)
m1 = torch.randn(10, 50, device=device).to(dtype)
m2 = torch.randn(50, 25, device=device).to(dtype)
self._test_addmm_addmv(torch.addmm, M, m1, m2, beta=0)
- # # Test transpose
- # for t1, t2, t3, t4 in itertools.product([True, False], repeat=4):
- # def maybe_transpose(cond, m):
- # if not cond:
- # return m
- # return m.t().clone(memory_format=torch.contiguous_format).t()
+ # Test transpose
+ for t1, t2, t3, t4 in itertools.product([True, False], repeat=4):
+ def maybe_transpose(cond, m):
+ if not cond:
+ return m
+ return m.t().clone(memory_format=torch.contiguous_format).t()
- # M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype))
- # m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype))
- # m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
- # self._test_addmm_addmv(torch.addmm, M, m1, m2, transpose_out=t4)
+ M = maybe_transpose(t1, torch.randn(10, 25, device=device).to(dtype))
+ m1 = maybe_transpose(t2, torch.randn(10, 50, device=device).to(dtype))
+ m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
+ self._test_addmm_addmv(torch.addmm, M, m1, m2, transpose_out=t4)
class TestRNNMPS(TestCase):