[MPS] Add argmin (#80828)

This PR

1. adds argmin
2. refactors `reduction_type` in `ReduceOps.mm` with enum.

Co-authored by Kulin Seth <[email protected]>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80828
Approved by: https://github.com/malfet
diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm
index 0bcb2ea..bb1b058 100644
--- a/aten/src/ATen/native/mps/operations/ReduceOps.mm
+++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm
@@ -13,13 +13,23 @@
 namespace at {
 namespace native {
 
-using namespace std;
-
 enum StdVarType {
   STANDARD_VARIANCE,
   STANDARD_DEVIATION
 };
 
+enum MPSReductionType {
+  MAX,
+  MIN,
+  AMAX,
+  AMIN,
+  SUM,
+  PROD,
+  MEAN,
+  COUNT_NONZERO
+};
+
+
 void set_apparent_shapes(NSMutableArray<NSNumber*> * &apparent_out_shape,
                          NSMutableArray<NSNumber*> * &apparent_in_shape,
                          int64_t num_reduce_dims,
@@ -131,8 +141,8 @@
     bool keepdim,
     c10::optional<ScalarType> dtype,
     const Tensor& output_t,
-    string reduction_type,
-    string func_name) {
+    MPSReductionType reduction_type,
+    const std::string& func_name) {
 
   IntArrayRef input_shape = input_t.sizes();
 
@@ -196,19 +206,19 @@
 
           MPSGraphTensor* castOutputTensor = nil;
 
-          if(reduction_type == "sum") {
+          if(reduction_type == MPSReductionType::SUM) {
             castOutputTensor = [mpsGraph reductionSumWithTensor:castInputTensor
                                                            axes:axes
                                                            name:nil];
-          } else if(reduction_type == "prod") {
+          } else if(reduction_type == MPSReductionType::PROD) {
             castOutputTensor = [mpsGraph reductionProductWithTensor:castInputTensor
                                                                axes:axes
                                                                name:nil];
-          } else if(reduction_type == "mean") {
+          } else if(reduction_type == MPSReductionType::MEAN) {
             castOutputTensor = [mpsGraph meanOfTensor:inputTensor
                                                  axes:axes
                                                  name:nil];
-          } else if(reduction_type == "count_nonzero") {
+          } else if(reduction_type == MPSReductionType::COUNT_NONZERO) {
             MPSGraphTensor* zeros = [mpsGraph constantWithScalar:0
                                                         dataType:castInputTensor.dataType];
 
@@ -220,11 +230,11 @@
                                                            axes:axes
                                                            name:nil];
           }
-          else if(reduction_type == "amax") {
+          else if(reduction_type == MPSReductionType::AMAX) {
             castOutputTensor = [mpsGraph reductionMaximumWithTensor:inputTensor
                                                                axes:axes
                                                                name:nil];
-          } else if(reduction_type == "amin") {
+          } else if(reduction_type == MPSReductionType::AMIN) {
             castOutputTensor = [mpsGraph reductionMinimumWithTensor:inputTensor
                                                                axes:axes
                                                                name:nil];
@@ -273,7 +283,7 @@
     c10::optional<ScalarType> dtype,
     const Tensor& output_t) {
 
-    reduction_out_mps(input_t, dim, keepdim, dtype, output_t, "sum", "sum_out_mps");
+    reduction_out_mps(input_t, dim, keepdim, dtype, output_t, MPSReductionType::SUM, "sum_out_mps");
 }
 
 TORCH_IMPL_FUNC(prod_out_mps)
@@ -285,7 +295,7 @@
 
     int64_t dims[1] = {dim};
 
-    reduction_out_mps(input_t, IntArrayRef(dims, 1), keepdim, dtype, output_t, "prod", "prod_out_mps");
+    reduction_out_mps(input_t, IntArrayRef(dims, 1), keepdim, dtype, output_t, MPSReductionType::PROD, "prod_out_mps");
 }
 
 // Taken from ReduceOps.cpp
@@ -309,7 +319,7 @@
     bool keepdim,
     const Tensor& output_t) {
 
-    reduction_out_mps(input_t, dim, keepdim, c10::nullopt, output_t, "amax", "amax_out_mps");
+    reduction_out_mps(input_t, dim, keepdim, c10::nullopt, output_t, MPSReductionType::AMAX, "amax_out_mps");
 }
 
 TORCH_IMPL_FUNC(amin_out_mps)
@@ -318,7 +328,7 @@
     bool keepdim,
     const Tensor& output_t) {
 
-    reduction_out_mps(input_t, dim, keepdim, c10::nullopt, output_t, "amin", "amin_out_mps");
+    reduction_out_mps(input_t, dim, keepdim, c10::nullopt, output_t, MPSReductionType::AMIN, "amin_out_mps");
 }
 
 Tensor prod_mps(const Tensor &self, c10::optional<ScalarType> opt_dtype) {
@@ -338,7 +348,7 @@
                       c10::nullopt,
                       c10::nullopt);
 
-  reduction_out_mps(self, IntArrayRef(dims, num_dims), false, opt_dtype, const_cast<Tensor&>(output_t), "prod", "prod_mps");
+  reduction_out_mps(self, IntArrayRef(dims, num_dims), false, opt_dtype, const_cast<Tensor&>(output_t), MPSReductionType::PROD, "prod_mps");
 
   return output_t;
 }
@@ -365,7 +375,7 @@
                       c10::nullopt,
                       c10::nullopt);
 
-  reduction_out_mps(self, dims, false, self.scalar_type(), const_cast<Tensor&>(output_t), "count_nonzero", "count_nonzero_mps");
+  reduction_out_mps(self, dims, false, self.scalar_type(), const_cast<Tensor&>(output_t), MPSReductionType::COUNT_NONZERO, "count_nonzero_mps");
 
   free(raw_output_shape);
 
@@ -379,135 +389,7 @@
     c10::optional<ScalarType> dtype,
     const Tensor& output_t) {
 
-    reduction_out_mps(input_t, dim, keepdim, dtype, output_t, "mean", "mean_out_mps");
-}
-
-TORCH_IMPL_FUNC(argmax_out_mps)
-   (const Tensor& input_t,
-    c10::optional<int64_t> dim,
-    bool keepdim,
-    const Tensor& output_t) {
-
-    namespace native_mps = at::native::mps;
-
-    // Derive from MPSCachedGraph
-    struct CachedGraph : public native_mps::MPSCachedGraph
-    {
-      CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
-      MPSGraphTensor *inputTensor_ = nil;
-      MPSGraphTensor *outputTensor_ = nil;
-    };
-
-    native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance();
-
-    int64_t dim_;
-
-    if (dim.has_value()) {
-        dim_ = maybe_wrap_dim(dim.value(), input_t.dim());
-        native::zero_numel_check_dims(input_t, dim_, "argmax()");
-    } else {
-        TORCH_CHECK_INDEX(
-        input_t.numel() != 0,
-        "argmax()", ": Expected reduction dim to be specified for input.numel() == 0.");
-        // Since input will be flattened, take argmax along 0'th dimension
-        dim_ = 0;
-    }
-
-    // Calculate the output shape according to keepdim=True
-    // If there is no dim argument, the input shape is flattened
-    IntArrayRef input_shape = input_t.sizes();
-    int64_t num_input_dims = input_shape.size();
-    NSMutableArray<NSNumber*> *apparent_in_shape = nil;
-    NSMutableArray<NSNumber*> *apparent_out_shape = nil;
-
-    if(dim.has_value()) {
-        apparent_out_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
-        for(int i = 0; i < num_input_dims; i++) {
-            if(dim_ == i)
-                apparent_out_shape[i] = @1;
-            else
-                apparent_out_shape[i] = [NSNumber numberWithInt:input_shape[i]];
-        }
-    }
-    else {
-        apparent_in_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:1];
-        int64_t num_in_elements = 1;
-        for(int i = 0; i < num_input_dims; i++) {
-            num_in_elements *= input_shape[i];
-        }
-        apparent_in_shape[0] = [NSNumber numberWithInt:num_in_elements];
-
-        apparent_out_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:1];
-        apparent_out_shape[0] = @1;
-    }
-
-    if (output_t.numel() == 0) {
-        return;
-    }
-
-    auto stream = at::mps::getCurrentMPSStream();
-
-    @autoreleasepool {
-        string key = "argmax_out_mps:" + to_string(dim_) + ":" + native_mps::getMPSTypeString(input_t.scalar_type());
-        CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
-
-        if(!cachedGraph) {
-          native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () {
-
-            CachedGraph *newCachedGraph = nil;
-
-            @autoreleasepool {
-              MPSGraph* mpsGraph = native_mps::make_mps_graph();
-              newCachedGraph = new CachedGraph(mpsGraph);
-
-              MPSGraphTensor* inputTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(input_t.scalar_type()));
-
-              MPSGraphTensor* castInputTensor = nil;
-
-              if(input_t.scalar_type() != ScalarType::Float &&
-                 input_t.scalar_type() != ScalarType::Int   &&
-                 input_t.scalar_type() != ScalarType::Half)
-                castInputTensor =  [mpsGraph castTensor:inputTensor
-                                                 toType:MPSDataTypeFloat32
-                                                   name:@"castInputTensor"];
-              else
-                castInputTensor = inputTensor;
-
-              MPSGraphTensor* argmaxOutTensor = [mpsGraph reductionArgMaximumWithTensor:castInputTensor
-                                                                                   axis:(NSInteger)dim_
-                                                                                   name:@"argmax_out"];
-              MPSGraphTensor* outputTensor = [mpsGraph castTensor:argmaxOutTensor
-                                                           toType:MPSDataTypeInt64
-                                                             name:@"cast_out"];
-
-              newCachedGraph->inputTensor_ = inputTensor;
-              newCachedGraph->outputTensor_ = outputTensor;
-            }
-            return newCachedGraph;
-          });
-          cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
-        }
-
-        native_mps::Placeholder inputPlaceholder = native_mps::Placeholder();
-        if(apparent_in_shape)
-            inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t, apparent_in_shape);
-        else
-            inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t);
-
-        auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output_t, apparent_out_shape);
-
-        NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
-          inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
-        };
-
-        NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
-          outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
-        };
-
-        native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
-
-    }
-
+    reduction_out_mps(input_t, dim, keepdim, dtype, output_t, MPSReductionType::MEAN, "mean_out_mps");
 }
 
 TORCH_IMPL_FUNC(norm_out_mps)
@@ -1316,8 +1198,8 @@
 
 Tensor min_max_mps
   (const Tensor& input_t,
-   string reduction_type,
-   string func_name) {
+   MPSReductionType reduction_type,
+   const std::string& func_name) {
 
   namespace native_mps = at::native::mps;
 
@@ -1365,11 +1247,11 @@
 
           MPSGraphTensor* outputTensor = nil;
 
-          if(reduction_type == "max")
+          if(reduction_type == MPSReductionType::MAX)
             outputTensor = [mpsGraph reductionMaximumWithTensor:inputTensor
                                                            axes:@[@0]
                                                            name:nil];
-          else if(reduction_type == "min")
+          else if(reduction_type == MPSReductionType::MIN)
             outputTensor = [mpsGraph reductionMinimumWithTensor:inputTensor
                                                            axes:@[@0]
                                                            name:nil];
@@ -1403,13 +1285,13 @@
 // Max entire tensor into scalar result
 Tensor max_mps(const Tensor& input_t) {
 
-  return min_max_mps(input_t, "max", "max_mps");
+  return min_max_mps(input_t, MPSReductionType::MAX, "max_mps");
 }
 
 // Min entire tensor into scalar result
 Tensor min_mps(const Tensor& input_t) {
 
-  return min_max_mps(input_t, "min", "min_mps");
+  return min_max_mps(input_t, MPSReductionType::MIN, "min_mps");
 }
 
 void min_max_out_mps
@@ -1418,8 +1300,8 @@
   bool keepdim,
   const Tensor& output_t,
   const Tensor& indices_t,
-  string reduction_type,
-  string func_name) {
+  MPSReductionType reduction_type,
+  const std::string& func_name) {
 
     namespace native_mps = at::native::mps;
 
@@ -1477,11 +1359,11 @@
 
               MPSGraphTensor* inputTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(input_t.scalar_type()));
               MPSGraphTensor* outputTensor = nil;
-              if(reduction_type == "max")
+              if(reduction_type == MPSReductionType::MAX)
                 outputTensor = [mpsGraph reductionMaximumWithTensor:inputTensor
                                                                axis:(NSInteger)dim_
                                                                name:nil];
-              else if(reduction_type == "min")
+              else if(reduction_type == MPSReductionType::MIN)
                 outputTensor = [mpsGraph reductionMinimumWithTensor:inputTensor
                                                                axis:(NSInteger)dim_
                                                                name:nil];
@@ -1498,11 +1380,11 @@
                 castInputTensor = inputTensor;
 
               MPSGraphTensor* argreduceOutTensor = nil;
-              if(reduction_type == "max")
+              if(reduction_type == MPSReductionType::MAX)
                 argreduceOutTensor = [mpsGraph reductionArgMaximumWithTensor:castInputTensor
                                                                         axis:(NSInteger)dim_
                                                                         name:@"argmax_out"];
-              else if(reduction_type == "min")
+              else if(reduction_type == MPSReductionType::MIN)
                 argreduceOutTensor = [mpsGraph reductionArgMinimumWithTensor:castInputTensor
                                                                         axis:(NSInteger)dim_
                                                                         name:@"argmax_out"];
@@ -1550,7 +1432,7 @@
     int64_t dim_ = maybe_wrap_dim(dim, input_t.dim());
     native::zero_numel_check_dims(input_t, dim_,  "max()");
 
-    min_max_out_mps(input_t, dim, keepdim, output_t, indices_t, "max", "max_out_mps");
+    min_max_out_mps(input_t, dim, keepdim, output_t, indices_t, MPSReductionType::MAX, "max_out_mps");
 }
 
 // Min out with dim
@@ -1564,16 +1446,170 @@
     int64_t dim_ = maybe_wrap_dim(dim, input_t.dim());
     native::zero_numel_check_dims(input_t, dim_, "min()");
 
-    min_max_out_mps(input_t, dim, keepdim, output_t, indices_t, "min", "min_out_mps");
+    min_max_out_mps(input_t, dim, keepdim, output_t, indices_t, MPSReductionType::MIN, "min_out_mps");
 }
 
+void argmax_argmin_out_mps
+   (const Tensor& input_t,
+    c10::optional<int64_t> dim,
+    bool keepdim,
+    const Tensor& output_t,
+    MPSReductionType reduction_type,
+    const std::string& func_name) {
+    namespace native_mps = at::native::mps;
+
+    // Derive from MPSCachedGraph
+    struct CachedGraph : public native_mps::MPSCachedGraph
+    {
+      CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
+      MPSGraphTensor *inputTensor_ = nil;
+      MPSGraphTensor *outputTensor_ = nil;
+    };
+
+    native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance();
+
+    int64_t dim_;
+
+    if (dim.has_value()) {
+        dim_ = maybe_wrap_dim(dim.value(), input_t.dim());
+        zero_numel_check_dims(input_t, dim_, reduction_type == MPSReductionType::MAX ? "argmax()" : "argmin()");
+    } else {
+        TORCH_CHECK_INDEX(
+        input_t.numel() != 0,
+        reduction_type == MPSReductionType::MAX ? "argmax()" : "argmin()" , ": Expected reduction dim to be specified for input.numel() == 0.");
+        // Since input will be flattened, take argmax or argmin along 0'th dimension
+        dim_ = 0;
+    }
+
+    // Calculate the output shape according to keepdim=True
+    // If there is no dim argument, the input shape is flattened
+    IntArrayRef input_shape = input_t.sizes();
+    int64_t num_input_dims = input_shape.size();
+    NSMutableArray<NSNumber*> *apparent_in_shape = nil;
+    NSMutableArray<NSNumber*> *apparent_out_shape = nil;
+
+    if(dim.has_value()) {
+        apparent_out_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
+        for(int i = 0; i < num_input_dims; i++) {
+            if(dim_ == i)
+                apparent_out_shape[i] = @1;
+            else
+                apparent_out_shape[i] = [NSNumber numberWithInt:input_shape[i]];
+        }
+    }
+    else {
+        apparent_in_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:1];
+        int64_t num_in_elements = 1;
+        for(int i = 0; i < num_input_dims; i++) {
+            num_in_elements *= input_shape[i];
+        }
+        apparent_in_shape[0] = [NSNumber numberWithInt:num_in_elements];
+
+        apparent_out_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:1];
+        apparent_out_shape[0] = @1;
+    }
+
+    if (output_t.numel() == 0) {
+        return;
+    }
+
+    auto stream = at::mps::getCurrentMPSStream();
+
+    @autoreleasepool {
+        string key = func_name + to_string(dim_) + ":" + native_mps::getTensorsStringKey(input_t);
+        CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
+
+        if(!cachedGraph) {
+          native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () {
+
+            CachedGraph *newCachedGraph = nil;
+
+            @autoreleasepool {
+              MPSGraph* mpsGraph = native_mps::make_mps_graph();
+              newCachedGraph = new CachedGraph(mpsGraph);
+
+              MPSGraphTensor* inputTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(input_t.scalar_type()));
+
+              MPSGraphTensor* castInputTensor = nil;
+              MPSGraphTensor* argreduceOutTensor = nil;
+
+              if(input_t.scalar_type() != ScalarType::Float &&
+                 input_t.scalar_type() != ScalarType::Int   &&
+                 input_t.scalar_type() != ScalarType::Half)
+                castInputTensor =  [mpsGraph castTensor:inputTensor
+                                                 toType:MPSDataTypeFloat32
+                                                   name:@"castInputTensor"];
+              else
+                castInputTensor = inputTensor;
+
+              if (reduction_type == MPSReductionType::MAX) {
+                argreduceOutTensor = [mpsGraph reductionArgMaximumWithTensor:castInputTensor
+                                                                        axis:(NSInteger)dim_
+                                                                        name:nil];
+              }
+              else {
+                argreduceOutTensor = [mpsGraph reductionArgMinimumWithTensor:castInputTensor
+                                                                        axis:(NSInteger)dim_
+                                                                        name:nil];
+              }
+              MPSGraphTensor* outputTensor = [mpsGraph castTensor:argreduceOutTensor
+                                                           toType:MPSDataTypeInt64
+                                                             name:@"castOutpuTensor"];
+
+              newCachedGraph->inputTensor_ = inputTensor;
+              newCachedGraph->outputTensor_ = outputTensor;
+            }
+            return newCachedGraph;
+          });
+          cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
+        }
+
+        native_mps::Placeholder inputPlaceholder = native_mps::Placeholder();
+        if(apparent_in_shape)
+            inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t, apparent_in_shape);
+        else
+            inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t);
+
+        auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output_t, apparent_out_shape);
+
+        NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
+          inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
+        };
+
+        NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
+          outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
+        };
+
+        native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
+    }
+}
+
+TORCH_IMPL_FUNC(argmax_out_mps)
+   (const Tensor& input_t,
+    c10::optional<int64_t> dim,
+    bool keepdim,
+    const Tensor& output_t) {
+
+    argmax_argmin_out_mps(input_t, dim, keepdim, output_t, MPSReductionType::MAX, "argmax_out_mps");
+}
+
+TORCH_IMPL_FUNC(argmin_out_mps)
+   (const Tensor& input_t,
+    c10::optional<int64_t> dim,
+    bool keepdim,
+    const Tensor& output_t) {
+
+    argmax_argmin_out_mps(input_t, dim, keepdim, output_t, MPSReductionType::MIN, "argmin_out_mps");
+}
+
+
 // Min/Max with dim
 std::tuple<Tensor, Tensor> min_max_mps
    (const Tensor& input_t,
     int64_t dim,
     bool keepdim,
-    string reduction_type,
-    string func_name) {
+    MPSReductionType reduction_type,
+    const std::string& func_name) {
 
     namespace native_mps = at::native::mps;
 
@@ -1661,7 +1697,7 @@
     int64_t dim,
     bool keepdim) {
 
-    return min_max_mps(input_t, dim, keepdim, "max", "max_mps");
+    return min_max_mps(input_t, dim, keepdim, MPSReductionType::MAX, "max_mps");
 }
 
 // Min with dim
@@ -1670,9 +1706,8 @@
     int64_t dim,
     bool keepdim) {
 
-    return min_max_mps(input_t, dim, keepdim, "min", "min_mps");
+    return min_max_mps(input_t, dim, keepdim, MPSReductionType::MIN, "min_mps");
 }
 
-}
-
-}
+} // native
+} // at
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index f22852f..4bca2bd 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -660,6 +660,7 @@
   structured: True
   dispatch:
     CPU, CUDA: argmin_out
+    MPS: argmin_out_mps
 
 - func: acosh(Tensor self) -> Tensor
   variants: function, method
diff --git a/test/test_mps.py b/test/test_mps.py
index f8614ea..dc857a3 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2187,9 +2187,14 @@
 
         helper((2, 3, 4, 5))
 
-    # Test forward argmax
-    def test_argmax(self):
-        def helper(n, c, h, w, dtype=torch.float32):
+    # Test forward argmin argmax
+    def test_argmin_argmax(self):
+        def helper(n, c, h, w, reduction_type, dtype=torch.float32):
+            if reduction_type == "max":
+                arg_reduction_fn = torch.argmax
+            else:
+                arg_reduction_fn = torch.argmin
+
             cpu_x = None
             x = None
             if(dtype not in [torch.float32, torch.bool]):
@@ -2202,46 +2207,50 @@
                 cpu_x = torch.randn(n, c, h, w, device='cpu', dtype=dtype, requires_grad=True)
                 x = cpu_x.detach().clone().to('mps').requires_grad_()
 
-            y = torch.argmax(x)
-            ref_y = torch.argmax(cpu_x)
+            y = arg_reduction_fn(x)
+            ref_y = arg_reduction_fn(cpu_x)
             self.assertEqual(y, ref_y)
 
-            y_0 = torch.argmax(x, dim=0)
-            refy_0 = torch.argmax(cpu_x, dim=0)
+            y_0 = arg_reduction_fn(x, dim=0)
+            refy_0 = arg_reduction_fn(cpu_x, dim=0)
             self.assertEqual(y_0, refy_0)
 
-            y_0dim = torch.argmax(x, dim=0, keepdim=True)
-            refy_0dim = torch.argmax(cpu_x, dim=0, keepdim=True)
+            y_0dim = arg_reduction_fn(x, dim=0, keepdim=True)
+            refy_0dim = arg_reduction_fn(cpu_x, dim=0, keepdim=True)
             self.assertEqual(y_0dim, refy_0dim)
 
-            y_1 = torch.argmax(x, dim=1)
-            refy_1 = torch.argmax(cpu_x, dim=1)
+            y_1 = arg_reduction_fn(x, dim=1)
+            refy_1 = arg_reduction_fn(cpu_x, dim=1)
             self.assertEqual(y_1, refy_1)
 
-            y_1dim = torch.argmax(x, dim=1, keepdim=True)
-            refy_1dim = torch.argmax(cpu_x, dim=1, keepdim=True)
+            y_1dim = arg_reduction_fn(x, dim=1, keepdim=True)
+            refy_1dim = arg_reduction_fn(cpu_x, dim=1, keepdim=True)
             self.assertEqual(y_1dim, refy_1dim)
 
-            y_2 = torch.argmax(x, dim=2)
-            refy_2 = torch.argmax(cpu_x, dim=2)
+            y_2 = arg_reduction_fn(x, dim=2)
+            refy_2 = arg_reduction_fn(cpu_x, dim=2)
             self.assertEqual(y_2, refy_2)
 
-            y_2dim = torch.argmax(x, dim=2, keepdim=True)
-            refy_2dim = torch.argmax(cpu_x, dim=2, keepdim=True)
+            y_2dim = arg_reduction_fn(x, dim=2, keepdim=True)
+            refy_2dim = arg_reduction_fn(cpu_x, dim=2, keepdim=True)
             self.assertEqual(y_2dim, refy_2dim)
 
-            y_3 = torch.argmax(x, dim=3)
-            refy_3 = torch.argmax(cpu_x, dim=3)
+            y_3 = arg_reduction_fn(x, dim=3)
+            refy_3 = arg_reduction_fn(cpu_x, dim=3)
             self.assertEqual(y_3, refy_3)
 
-            y_3dim = torch.argmax(x, dim=3, keepdim=True)
-            refy_3dim = torch.argmax(cpu_x, dim=3, keepdim=True)
+            y_3dim = arg_reduction_fn(x, dim=3, keepdim=True)
+            refy_3dim = arg_reduction_fn(cpu_x, dim=3, keepdim=True)
             self.assertEqual(y_3dim, refy_3dim)
 
-        helper(2, 8, 4, 4, torch.float32)
-        helper(2, 8, 4, 4, torch.int32)
-        helper(2, 8, 4, 4, torch.float16)
-        helper(2, 8, 4, 4, torch.int64)
+        helper(2, 8, 4, 4, "max", torch.float32)
+        helper(2, 8, 4, 4, "max", torch.int32)
+        helper(2, 8, 4, 4, "max", torch.float16)
+        helper(2, 8, 4, 4, "max", torch.int64)
+        helper(2, 8, 4, 4, "min", torch.float32)
+        helper(2, 8, 4, 4, "min", torch.int32)
+        helper(2, 8, 4, 4, "min", torch.float16)
+        helper(2, 8, 4, 4, "min", torch.int64)
 
     # Test forward max
     # Note - don't test grad now