Add Half support for aminmax on CPU (#106853)
Add Half support for aminmax on CPU.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106853
Approved by: https://github.com/cpuhrsch
diff --git a/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp
index a883b15..125f3ce 100644
--- a/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp
+++ b/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp
@@ -199,7 +199,7 @@
}
);
} else {
- AT_DISPATCH_ALL_TYPES_AND(kBFloat16, input.scalar_type(), "aminmax_cpu", [&] {
+ AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "aminmax_cpu", [&] {
using Vec = Vectorized<opmath_type<scalar_t>>;
using scalar_t_pair = std::pair<scalar_t, scalar_t>;
reduce_all_impl_vec_two_outputs<scalar_t>(
diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
index 28ebb7e..f014c34 100644
--- a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
+++ b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
@@ -185,7 +185,7 @@
return;
}
- AT_DISPATCH_ALL_TYPES_AND(ScalarType::Bool, self.scalar_type(), "aminmax_cpu", [&] {
+ AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "aminmax_cpu", [&] {
compare_base_kernel<scalar_t, scalar_t>(min_result, max_result, self, wrap_dim, keepdim, [&] (
scalar_t* min_result_data, scalar_t* max_result_data,
const scalar_t* self_data, auto self_dim_stride) {
diff --git a/test/test_mps.py b/test/test_mps.py
index e6da181..5d3a19e 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -79,7 +79,7 @@
'cdist': [torch.float32],
'masked.scatter': [torch.float16, torch.float32],
'index_fill': [torch.float16, torch.float32], # missing `aten::_unique`.
- 'aminmax': [torch.float32],
+ 'aminmax': [torch.float32, torch.float16],
'polar': [torch.float32],
# Correctness issues
diff --git a/test/test_reductions.py b/test/test_reductions.py
index 14c3698..1206ab8 100644
--- a/test/test_reductions.py
+++ b/test/test_reductions.py
@@ -1207,7 +1207,7 @@
self._test_minmax_helper(torch.amax, np.amax, device, dtype)
@onlyNativeDeviceTypes
- @dtypes(torch.float, torch.double)
+ @dtypes(torch.float, torch.double, torch.bfloat16, torch.half)
@dtypesIfCUDA(torch.half, torch.float, torch.bfloat16)
def test_aminmax(self, device, dtype):
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 460619b..ce7fc78 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -12268,8 +12268,7 @@
supports_fwgrad_bwgrad=True),
OpInfo('aminmax',
ref=lambda x, dim=None, keepdim=False: (np.amin(x, axis=dim, keepdims=keepdim), np.amax(x, axis=dim, keepdims=keepdim)),
- dtypes=all_types_and(torch.bool),
- dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16),
+ dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
decorators=(onlyNativeDeviceTypes,),
supports_autograd=False,
sample_inputs_func=sample_inputs_aminmax,