MPS: Eye op (#78408)
This can be used as a reference PR was to add Op in MPS backend.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78408
Approved by: https://github.com/albanD
diff --git a/test/test_mps.py b/test/test_mps.py
index f9778b4..f4e9568 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -3560,6 +3560,29 @@
helper((2, 8, 4, 5), diag=-2)
helper((2, 8, 4, 5), diag=-3)
+ # test eye
+ def test_eye(self):
+ def helper(n, m, dtype):
+ cpu_result = None
+ result = None
+
+ if(n == m):
+ cpu_result = torch.eye(n, dtype=dtype, device='cpu')
+ result = torch.eye(n, dtype=dtype, device='mps')
+ else:
+ cpu_result = torch.eye(n, m, device='cpu')
+ result = torch.eye(n, m, device='mps')
+
+ self.assertEqual(result, cpu_result)
+
+ for dtype in [torch.float32, torch.int32, torch.int64]:
+ helper(2, 2, dtype)
+ helper(2, 3, dtype)
+ helper(0, 2, dtype)
+ helper(0, 0, dtype)
+ helper(3, 8, dtype)
+ helper(8, 3, dtype)
+
# Test diag
def test_diag(self):
def helper(shape, diag=0):
@@ -4119,9 +4142,11 @@
self.assertTrue(False, "There was a warning when importing torch.")
def _get_not_implemented_op(self):
- # This can be changed once we actually implement `torch.eye`
+ # This can be changed once we actually implement `torch.bincount`
# Should return fn, args, kwargs, string_version
- return torch.eye, (2,), {"device": "mps"}, "torch.eye(2, device='mps')"
+ return (torch.bincount,
+ (torch.tensor([4, 3, 6, 3, 4], device='mps')), {},
+ "torch.bincount(torch.tensor([4, 3, 6, 3, 4], device='mps'))")
def test_error_on_not_implemented(self):
fn, args, kwargs, _ = self._get_not_implemented_op()