[MPS] Add Inverse op. (#90428)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90428
Approved by: https://github.com/DenisVieriu97, https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index 1d17760..200889f 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -4729,6 +4729,21 @@
helper((2, 8, 4, 5), diag=-2)
helper((2, 8, 4, 5), diag=-3)
+ # Test inverse
+ def test_inverse(self):
+ def helper(n):
+ cpu_input = torch.randn(n, n, device='cpu')
+ mps_input = cpu_input.to('mps')
+
+ cpu_result = torch.linalg.inv(cpu_input)
+ mps_result = torch.linalg.inv(mps_input)
+ self.assertEqual(cpu_result, mps_result)
+
+ helper(2)
+ helper(6)
+ helper(3)
+ helper(8)
+
# Test tril
def test_tril(self):
def helper(shape, diag=0):
@@ -7796,6 +7811,7 @@
'diag_embed': [torch.uint8],
'diagonal_scatter': [torch.uint8],
'index_add': None,
+ 'linalg.inv': ['f32'],
'log1p': None,
'long': None,
'nn.functional.avg_pool1d': [torch.int64],
@@ -7814,7 +7830,6 @@
'slice_scatter': [torch.uint8],
'square': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8], # moved from section below
-
# ALLOW_LIST doesn't know about variants
'nn.functional.padconstant': None,