[MPS] Add argmin (#80828)

This PR

1. adds argmin
2. refactors `reduction_type` in `ReduceOps.mm` with enum.

Co-authored by Kulin Seth <[email protected]>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80828
Approved by: https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index f8614ea..dc857a3 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2187,9 +2187,14 @@
 
         helper((2, 3, 4, 5))
 
-    # Test forward argmax
-    def test_argmax(self):
-        def helper(n, c, h, w, dtype=torch.float32):
+    # Test forward argmin argmax
+    def test_argmin_argmax(self):
+        def helper(n, c, h, w, reduction_type, dtype=torch.float32):
+            if reduction_type == "max":
+                arg_reduction_fn = torch.argmax
+            else:
+                arg_reduction_fn = torch.argmin
+
             cpu_x = None
             x = None
             if(dtype not in [torch.float32, torch.bool]):
@@ -2202,46 +2207,50 @@
                 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True)
                 x = cpu_x.detach().clone().to('mps').requires_grad_()
 
-            y = torch.argmax(x)
-            ref_y = torch.argmax(cpu_x)
+            y = arg_reduction_fn(x)
+            ref_y = arg_reduction_fn(cpu_x)
             self.assertEqual(y, ref_y)
 
-            y_0 = torch.argmax(x, dim=0)
-            refy_0 = torch.argmax(cpu_x, dim=0)
+            y_0 = arg_reduction_fn(x, dim=0)
+            refy_0 = arg_reduction_fn(cpu_x, dim=0)
             self.assertEqual(y_0, refy_0)
 
-            y_0dim = torch.argmax(x, dim=0, keepdim=True)
-            refy_0dim = torch.argmax(cpu_x, dim=0, keepdim=True)
+            y_0dim = arg_reduction_fn(x, dim=0, keepdim=True)
+            refy_0dim = arg_reduction_fn(cpu_x, dim=0, keepdim=True)
             self.assertEqual(y_0dim, refy_0dim)
 
-            y_1 = torch.argmax(x, dim=1)
-            refy_1 = torch.argmax(cpu_x, dim=1)
+            y_1 = arg_reduction_fn(x, dim=1)
+            refy_1 = arg_reduction_fn(cpu_x, dim=1)
             self.assertEqual(y_1, refy_1)
 
-            y_1dim = torch.argmax(x, dim=1, keepdim=True)
-            refy_1dim = torch.argmax(cpu_x, dim=1, keepdim=True)
+            y_1dim = arg_reduction_fn(x, dim=1, keepdim=True)
+            refy_1dim = arg_reduction_fn(cpu_x, dim=1, keepdim=True)
             self.assertEqual(y_1dim, refy_1dim)
 
-            y_2 = torch.argmax(x, dim=2)
-            refy_2 = torch.argmax(cpu_x, dim=2)
+            y_2 = arg_reduction_fn(x, dim=2)
+            refy_2 = arg_reduction_fn(cpu_x, dim=2)
             self.assertEqual(y_2, refy_2)
 
-            y_2dim = torch.argmax(x, dim=2, keepdim=True)
-            refy_2dim = torch.argmax(cpu_x, dim=2, keepdim=True)
+            y_2dim = arg_reduction_fn(x, dim=2, keepdim=True)
+            refy_2dim = arg_reduction_fn(cpu_x, dim=2, keepdim=True)
             self.assertEqual(y_2dim, refy_2dim)
 
-            y_3 = torch.argmax(x, dim=3)
-            refy_3 = torch.argmax(cpu_x, dim=3)
+            y_3 = arg_reduction_fn(x, dim=3)
+            refy_3 = arg_reduction_fn(cpu_x, dim=3)
             self.assertEqual(y_3, refy_3)
 
-            y_3dim = torch.argmax(x, dim=3, keepdim=True)
-            refy_3dim = torch.argmax(cpu_x, dim=3, keepdim=True)
+            y_3dim = arg_reduction_fn(x, dim=3, keepdim=True)
+            refy_3dim = arg_reduction_fn(cpu_x, dim=3, keepdim=True)
             self.assertEqual(y_3dim, refy_3dim)
 
-        helper(2, 8, 4, 4, torch.float32)
-        helper(2, 8, 4, 4, torch.int32)
-        helper(2, 8, 4, 4, torch.float16)
-        helper(2, 8, 4, 4, torch.int64)
+        helper(2, 8, 4, 4, "max", torch.float32)
+        helper(2, 8, 4, 4, "max", torch.int32)
+        helper(2, 8, 4, 4, "max", torch.float16)
+        helper(2, 8, 4, 4, "max", torch.int64)
+        helper(2, 8, 4, 4, "min", torch.float32)
+        helper(2, 8, 4, 4, "min", torch.int32)
+        helper(2, 8, 4, 4, "min", torch.float16)
+        helper(2, 8, 4, 4, "min", torch.int64)
 
     # Test forward max
     # Note - don't test grad now