MPS: Add amax and amin Ops with tests  (#79682)

* Add amax and amin with tests

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79682
Approved by: https://github.com/albanD
diff --git a/test/test_mps.py b/test/test_mps.py
index 679564d..d5475da 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2758,6 +2758,50 @@
 
         helper((4, 5, 6, 7))
 
+    # Test forward amax
+    def test_amax(self):
+        def helper(shape, dim, keepdim):
+            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
+            x = cpu_x.detach().clone().to('mps').requires_grad_()
+
+            result = torch.amax(x, dim=dim, keepdim=keepdim)
+            result_cpu = torch.amax(cpu_x, dim=dim, keepdim=keepdim)
+
+            cpu_grad = torch.randn(result_cpu.shape)
+            grad = cpu_grad.to('mps')
+
+            result_cpu.backward(gradient=cpu_grad)
+            result.backward(gradient=grad)
+
+            self.assertEqual(result, result_cpu)
+            self.assertEqual(x.grad, cpu_x.grad)
+
+        for dim in ([], [0], [0, 1], [2, 3]):
+            for keepdim in [False, True]:
+                helper((2, 8, 4, 5), dim, keepdim)
+
+    # Test forward amin
+    def test_amin(self):
+        def helper(shape, dim, keepdim):
+            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
+            x = cpu_x.detach().clone().to('mps').requires_grad_()
+
+            result = torch.amin(x, dim=dim, keepdim=keepdim)
+            result_cpu = torch.amin(cpu_x, dim=dim, keepdim=keepdim)
+
+            cpu_grad = torch.randn(result_cpu.shape)
+            grad = cpu_grad.to('mps')
+
+            result_cpu.backward(gradient=cpu_grad)
+            result.backward(gradient=grad)
+
+            self.assertEqual(result, result_cpu)
+            self.assertEqual(x.grad, cpu_x.grad)
+
+        for dim in ([], [0], [0, 1], [2, 3]):
+            for keepdim in [False, True]:
+                helper((2, 8, 4, 5), dim, keepdim)
+
     # Test minimum and maximum
     def test_minimum_maximum(self):
         def helper(n, c, h, w):