Min and max NaN propagation fix in MPS backend (#130445)

Partial fix to issue #130295

Moves min and max ops to use the NaN propagating API in MPS to align with the pytorch convention. Adds a regression test to validate the fix achieves parity with cpu backend.
Co-authored-by: Nikita Shulga <[email protected]>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130445
Approved by: https://github.com/malfet
diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm
index 3bc34d2..571cb03 100644
--- a/aten/src/ATen/native/mps/operations/ReduceOps.mm
+++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm
@@ -245,9 +245,9 @@
 
         castOutputTensor = [mpsGraph reductionSumWithTensor:nonZeros axes:wrappedAxes name:nil];
       } else if (reduction_type == MPSReductionType::AMAX) {
-        castOutputTensor = [mpsGraph reductionMaximumWithTensor:castInputTensor axes:wrappedAxes name:nil];
+        castOutputTensor = [mpsGraph reductionMaximumPropagateNaNWithTensor:castInputTensor axes:wrappedAxes name:nil];
       } else if (reduction_type == MPSReductionType::AMIN) {
-        castOutputTensor = [mpsGraph reductionMinimumWithTensor:castInputTensor axes:wrappedAxes name:nil];
+        castOutputTensor = [mpsGraph reductionMinimumPropagateNaNWithTensor:castInputTensor axes:wrappedAxes name:nil];
       } else if (reduction_type == MPSReductionType::TRACE) {
         MPSGraphTensor* bandPartWithTensor = [mpsGraph bandPartWithTensor:castInputTensor
                                                                  numLower:0
@@ -630,9 +630,9 @@
 
       NSArray<NSNumber*>* axes = getTensorAxes(input_t);
       if (reduction_type == MPSReductionType::MAX) {
-        castOutputTensor = [mpsGraph reductionMaximumWithTensor:castInputTensor axes:axes name:nil];
+        castOutputTensor = [mpsGraph reductionMaximumPropagateNaNWithTensor:castInputTensor axes:axes name:nil];
       } else if (reduction_type == MPSReductionType::MIN) {
-        castOutputTensor = [mpsGraph reductionMinimumWithTensor:castInputTensor axes:axes name:nil];
+        castOutputTensor = [mpsGraph reductionMinimumPropagateNaNWithTensor:castInputTensor axes:axes name:nil];
       }
 
       MPSGraphTensor* outputTensor = castOutputTensor;
@@ -705,9 +705,9 @@
           castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
 
       if (reduction_type == MPSReductionType::MAX) {
-        outputTensor = [mpsGraph reductionMaximumWithTensor:castInputTensor axis:(NSInteger)dim_ name:nil];
+        outputTensor = [mpsGraph reductionMaximumPropagateNaNWithTensor:castInputTensor axis:(NSInteger)dim_ name:nil];
       } else if (reduction_type == MPSReductionType::MIN) {
-        outputTensor = [mpsGraph reductionMinimumWithTensor:castInputTensor axis:(NSInteger)dim_ name:nil];
+        outputTensor = [mpsGraph reductionMinimumPropagateNaNWithTensor:castInputTensor axis:(NSInteger)dim_ name:nil];
       }
 
       MPSGraphTensor* argreduceOutTensor = nil;
diff --git a/test/test_mps.py b/test/test_mps.py
index 5540837..74bf819 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -8299,6 +8299,29 @@
 
         [helper(dtype) for dtype in [torch.float32, torch.float16, torch.int32, torch.int16, torch.uint8, torch.int8, torch.bool]]
 
+    def test_min_max_nan_propagation(self):
+        def helper(dtype):
+            cpu_x = torch.tensor([1.0, float("nan"), 3.0], device="cpu")
+            mps_x = cpu_x.detach().clone().to('mps')
+
+            cpu_max = torch.max(cpu_x)
+            mps_max = torch.max(mps_x).to('cpu')
+
+            cpu_amax = torch.amax(cpu_x)
+            mps_amax = torch.amax(mps_x).to('cpu')
+
+            cpu_min = torch.min(cpu_x)
+            mps_min = torch.min(mps_x).to('cpu')
+
+            cpu_amin = torch.amin(cpu_x)
+            mps_amin = torch.amin(mps_x).to('cpu')
+
+            self.assertEqual(cpu_max, mps_max)
+            self.assertEqual(cpu_amax, mps_amax)
+            self.assertEqual(cpu_min, mps_min)
+            self.assertEqual(cpu_amin, mps_amin)
+        [helper(dtype) for dtype in [torch.float32, torch.float16, torch.bfloat16]]
+
     def test_isin(self):
         def helper(dtype):
             shapes = [([2, 5], [3, 5, 2]), ([10, 3, 5], [20, 1, 3]),