MPS: Add amax and amin Ops with tests (#79682)
* Add amax and amin with tests
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79682
Approved by: https://github.com/albanD
diff --git a/test/test_mps.py b/test/test_mps.py
index 679564d..d5475da 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2758,6 +2758,50 @@
helper((4, 5, 6, 7))
+ # Test forward amax
+ def test_amax(self):
+ def helper(shape, dim, keepdim):
+ cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
+ x = cpu_x.detach().clone().to('mps').requires_grad_()
+
+ result = torch.amax(x, dim=dim, keepdim=keepdim)
+ result_cpu = torch.amax(cpu_x, dim=dim, keepdim=keepdim)
+
+ cpu_grad = torch.randn(result_cpu.shape)
+ grad = cpu_grad.to('mps')
+
+ result_cpu.backward(gradient=cpu_grad)
+ result.backward(gradient=grad)
+
+ self.assertEqual(result, result_cpu)
+ self.assertEqual(x.grad, cpu_x.grad)
+
+ for dim in ([], [0], [0, 1], [2, 3]):
+ for keepdim in [False, True]:
+ helper((2, 8, 4, 5), dim, keepdim)
+
+ # Test forward amin
+ def test_amin(self):
+ def helper(shape, dim, keepdim):
+ cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
+ x = cpu_x.detach().clone().to('mps').requires_grad_()
+
+ result = torch.amin(x, dim=dim, keepdim=keepdim)
+ result_cpu = torch.amin(cpu_x, dim=dim, keepdim=keepdim)
+
+ cpu_grad = torch.randn(result_cpu.shape)
+ grad = cpu_grad.to('mps')
+
+ result_cpu.backward(gradient=cpu_grad)
+ result.backward(gradient=grad)
+
+ self.assertEqual(result, result_cpu)
+ self.assertEqual(x.grad, cpu_x.grad)
+
+ for dim in ([], [0], [0, 1], [2, 3]):
+ for keepdim in [False, True]:
+ helper((2, 8, 4, 5), dim, keepdim)
+
# Test minimum and maximum
def test_minimum_maximum(self):
def helper(n, c, h, w):