Min and max NaN propagation fix in MPS backend (#130445)
Partial fix to issue #130295
Moves min and max ops to use the NaN propagating API in MPS to align with the pytorch convention. Adds a regression test to validate the fix achieves parity with cpu backend.
Co-authored-by: Nikita Shulga <[email protected]>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130445
Approved by: https://github.com/malfet
diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm
index 3bc34d2..571cb03 100644
--- a/aten/src/ATen/native/mps/operations/ReduceOps.mm
+++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm
@@ -245,9 +245,9 @@
castOutputTensor = [mpsGraph reductionSumWithTensor:nonZeros axes:wrappedAxes name:nil];
} else if (reduction_type == MPSReductionType::AMAX) {
- castOutputTensor = [mpsGraph reductionMaximumWithTensor:castInputTensor axes:wrappedAxes name:nil];
+ castOutputTensor = [mpsGraph reductionMaximumPropagateNaNWithTensor:castInputTensor axes:wrappedAxes name:nil];
} else if (reduction_type == MPSReductionType::AMIN) {
- castOutputTensor = [mpsGraph reductionMinimumWithTensor:castInputTensor axes:wrappedAxes name:nil];
+ castOutputTensor = [mpsGraph reductionMinimumPropagateNaNWithTensor:castInputTensor axes:wrappedAxes name:nil];
} else if (reduction_type == MPSReductionType::TRACE) {
MPSGraphTensor* bandPartWithTensor = [mpsGraph bandPartWithTensor:castInputTensor
numLower:0
@@ -630,9 +630,9 @@
NSArray<NSNumber*>* axes = getTensorAxes(input_t);
if (reduction_type == MPSReductionType::MAX) {
- castOutputTensor = [mpsGraph reductionMaximumWithTensor:castInputTensor axes:axes name:nil];
+ castOutputTensor = [mpsGraph reductionMaximumPropagateNaNWithTensor:castInputTensor axes:axes name:nil];
} else if (reduction_type == MPSReductionType::MIN) {
- castOutputTensor = [mpsGraph reductionMinimumWithTensor:castInputTensor axes:axes name:nil];
+ castOutputTensor = [mpsGraph reductionMinimumPropagateNaNWithTensor:castInputTensor axes:axes name:nil];
}
MPSGraphTensor* outputTensor = castOutputTensor;
@@ -705,9 +705,9 @@
castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
if (reduction_type == MPSReductionType::MAX) {
- outputTensor = [mpsGraph reductionMaximumWithTensor:castInputTensor axis:(NSInteger)dim_ name:nil];
+ outputTensor = [mpsGraph reductionMaximumPropagateNaNWithTensor:castInputTensor axis:(NSInteger)dim_ name:nil];
} else if (reduction_type == MPSReductionType::MIN) {
- outputTensor = [mpsGraph reductionMinimumWithTensor:castInputTensor axis:(NSInteger)dim_ name:nil];
+ outputTensor = [mpsGraph reductionMinimumPropagateNaNWithTensor:castInputTensor axis:(NSInteger)dim_ name:nil];
}
MPSGraphTensor* argreduceOutTensor = nil;
diff --git a/test/test_mps.py b/test/test_mps.py
index 5540837..74bf819 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -8299,6 +8299,29 @@
[helper(dtype) for dtype in [torch.float32, torch.float16, torch.int32, torch.int16, torch.uint8, torch.int8, torch.bool]]
+ def test_min_max_nan_propagation(self):
+ def helper(dtype):
+ cpu_x = torch.tensor([1.0, float("nan"), 3.0], device="cpu")
+ mps_x = cpu_x.detach().clone().to('mps')
+
+ cpu_max = torch.max(cpu_x)
+ mps_max = torch.max(mps_x).to('cpu')
+
+ cpu_amax = torch.amax(cpu_x)
+ mps_amax = torch.amax(mps_x).to('cpu')
+
+ cpu_min = torch.min(cpu_x)
+ mps_min = torch.min(mps_x).to('cpu')
+
+ cpu_amin = torch.amin(cpu_x)
+ mps_amin = torch.amin(mps_x).to('cpu')
+
+ self.assertEqual(cpu_max, mps_max)
+ self.assertEqual(cpu_amax, mps_amax)
+ self.assertEqual(cpu_min, mps_min)
+ self.assertEqual(cpu_amin, mps_amin)
+ [helper(dtype) for dtype in [torch.float32, torch.float16, torch.bfloat16]]
+
def test_isin(self):
def helper(dtype):
shapes = [([2, 5], [3, 5, 2]), ([10, 3, 5], [20, 1, 3]),