[MPS] Add flip (#80214)
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80214
Approved by: https://github.com/DenisVieriu97, https://github.com/albanD
diff --git a/test/test_mps.py b/test/test_mps.py
index f460434..c119bf2 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -3726,6 +3726,28 @@
helper((2, 8, 4, 5))
+ # Test flip
+ def test_flip(self):
+ def helper(shape, dims):
+ cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
+ x = cpu_x.detach().clone().to('mps')
+
+ flip_result = torch.flip(x, dims=dims)
+ flip_result_cpu = torch.flip(cpu_x, dims=dims)
+
+ self.assertEqual(flip_result, flip_result_cpu)
+
+ helper((2, 8, 4, 5), [0])
+ helper((8, 8, 4, 5), [0, 1])
+ helper((2, 8, 4, 5), (0, 1, 2, 3))
+ helper((2, 3, 3), (-1,))
+ # empty dims
+ helper((2, 8, 4, 5), [])
+ # input.numel() == 1
+ helper((1,), (0,))
+ # input.numel() == 0
+ helper((0,), (0,))
+
# Test index select
def test_index_select(self):
def helper(shape, dim, index, idx_dtype=torch.int32):