Add Half support for addcmul, addcdiv, cumsum, and topk on CPU (#103319)
Add Half support for addcmul, addcdiv, cumsum, and topk on CPU.
Note: This PR will introduce the issue https://github.com/pytorch/pytorch/issues/111454.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103319
Approved by: https://github.com/jgong5, https://github.com/cpuhrsch
diff --git a/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp b/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp
index 74098a3..25243b2 100644
--- a/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp
+++ b/aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp
@@ -5,34 +5,34 @@
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cpu/Loops.h>
#include <c10/core/Scalar.h>
-
+#include <ATen/cpu/vec/functional.h>
namespace at::native {
namespace {
static void addcmul_cpu_kernel(TensorIteratorBase& iter, const Scalar& value) {
ScalarType dtype = iter.common_dtype();
- if (dtype == kBFloat16) {
- float float_val = value.to<float>();
- auto float_vec = Vectorized<float>(float_val);
- cpu_kernel_vec(
- iter,
- [=](BFloat16 self_val, BFloat16 t1_val, BFloat16 t2_val) -> BFloat16 {
- return float(self_val) + float_val * float(t1_val) * float(t2_val);
- },
- [=](Vectorized<BFloat16> self_vec,
- Vectorized<BFloat16> t1_vec,
- Vectorized<BFloat16> t2_vec) {
- Vectorized<float> self_vec0, self_vec1;
- std::tie(self_vec0, self_vec1) = convert_bfloat16_float(self_vec);
- Vectorized<float> t1_vec0, t1_vec1, t2_vec0, t2_vec1;
- std::tie(t1_vec0, t1_vec1) = convert_bfloat16_float(t1_vec);
- std::tie(t2_vec0, t2_vec1) = convert_bfloat16_float(t2_vec);
- self_vec0 = self_vec0 + float_vec * t1_vec0 * t2_vec0;
- self_vec1 = self_vec1 + float_vec * t1_vec1 * t2_vec1;
- return convert_float_bfloat16(self_vec0, self_vec1);
- });
+ if (at::isReducedFloatingType(dtype)) {
+ AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "addcmul_cpu_out", [&]() {
+ float float_val = value.to<float>();
+ auto float_vec = Vectorized<float>(float_val);
+ cpu_kernel_vec(
+ iter,
+ [=](scalar_t self_val, scalar_t t1_val, scalar_t t2_val) -> scalar_t {
+ return float(self_val) + float_val * float(t1_val) * float(t2_val);
+ },
+ [=](Vectorized<scalar_t> self_vec,
+ Vectorized<scalar_t> t1_vec,
+ Vectorized<scalar_t> t2_vec) -> Vectorized<scalar_t> {
+ auto [self_vec0, self_vec1] = convert_to_float<scalar_t>(self_vec);
+ auto [t1_vec0, t1_vec1] = convert_to_float<scalar_t>(t1_vec);
+ auto [t2_vec0, t2_vec1] = convert_to_float<scalar_t>(t2_vec);
+ self_vec0 = self_vec0 + float_vec * t1_vec0 * t2_vec0;
+ self_vec1 = self_vec1 + float_vec * t1_vec1 * t2_vec1;
+ return convert_from_float<scalar_t>(self_vec0, self_vec1);
+ });
+ });
} else {
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::ComplexHalf, at::ScalarType::Half,
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::ComplexHalf,
dtype, "addcmul_cpu_out", [&] {
scalar_t scalar_val = value.to<scalar_t>();
auto scalar_vec = Vectorized<scalar_t>(scalar_val);
@@ -52,26 +52,26 @@
static void addcdiv_cpu_kernel(TensorIteratorBase& iter, const Scalar& value) {
ScalarType dtype = iter.common_dtype();
- if (dtype == kBFloat16) {
- float float_val = value.to<float>();
- auto float_vec = Vectorized<float>(float_val);
- cpu_kernel_vec(
- iter,
- [=](BFloat16 self_val, BFloat16 t1_val, BFloat16 t2_val) -> BFloat16 {
- return float(self_val) + float_val * float(t1_val) / float(t2_val);
- },
- [=](Vectorized<BFloat16> self_vec,
- Vectorized<BFloat16> t1_vec,
- Vectorized<BFloat16> t2_vec) {
- Vectorized<float> self_vec0, self_vec1;
- std::tie(self_vec0, self_vec1) = convert_bfloat16_float(self_vec);
- Vectorized<float> t1_vec0, t1_vec1, t2_vec0, t2_vec1;
- std::tie(t1_vec0, t1_vec1) = convert_bfloat16_float(t1_vec);
- std::tie(t2_vec0, t2_vec1) = convert_bfloat16_float(t2_vec);
- self_vec0 = self_vec0 + float_vec * t1_vec0 / t2_vec0;
- self_vec1 = self_vec1 + float_vec * t1_vec1 / t2_vec1;
- return convert_float_bfloat16(self_vec0, self_vec1);
- });
+ if (at::isReducedFloatingType(dtype)) {
+ AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "addcdiv_cpu_out", [&]() {
+ float float_val = value.to<float>();
+ auto float_vec = Vectorized<float>(float_val);
+ cpu_kernel_vec(
+ iter,
+ [=](scalar_t self_val, scalar_t t1_val, scalar_t t2_val) -> scalar_t {
+ return float(self_val) + float_val * float(t1_val) / float(t2_val);
+ },
+ [=](Vectorized<scalar_t> self_vec,
+ Vectorized<scalar_t> t1_vec,
+ Vectorized<scalar_t> t2_vec) -> Vectorized<scalar_t> {
+ auto [self_vec0, self_vec1] = convert_to_float<scalar_t>(self_vec);
+ auto [t1_vec0, t1_vec1] = convert_to_float<scalar_t>(t1_vec);
+ auto [t2_vec0, t2_vec1] = convert_to_float<scalar_t>(t2_vec);
+ self_vec0 = self_vec0 + float_vec * t1_vec0 / t2_vec0;
+ self_vec1 = self_vec1 + float_vec * t1_vec1 / t2_vec1;
+ return convert_from_float<scalar_t>(self_vec0, self_vec1);
+ });
+ });
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(dtype, "addcdiv_cpu_out", [&] {
scalar_t scalar_val = value.to<scalar_t>();
diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
index 5780c78..f4f73f4 100644
--- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
+++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
@@ -81,7 +81,7 @@
auto wrap_dim = maybe_wrap_dim(dim, self.dim());
int64_t self_dim_size = ensure_nonempty_size(self, wrap_dim);
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, self.scalar_type(), "cumsum_out_cpu", [&] {
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, self.scalar_type(), "cumsum_out_cpu", [&] {
cpu_cum_base_kernel<scalar_t>(result, self, wrap_dim, [&] (
scalar_t* result_data, auto result_dim_stride,
const scalar_t* self_data, auto self_dim_stride, scalar_t init_val) {
diff --git a/aten/src/ATen/native/cpu/SortingKernel.cpp b/aten/src/ATen/native/cpu/SortingKernel.cpp
index 9833c67..8975690 100644
--- a/aten/src/ATen/native/cpu/SortingKernel.cpp
+++ b/aten/src/ATen/native/cpu/SortingKernel.cpp
@@ -223,7 +223,7 @@
auto mode_indices_stride = indices.strides()[dim];
auto tmp_values_stride = self.strides()[dim];
- AT_DISPATCH_ALL_TYPES_AND(ScalarType::BFloat16, self.scalar_type(), "topk_cpu", [&] {
+ AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "topk_cpu", [&] {
auto loop = [&](char** data, const int64_t* strides, int64_t n) {
if (self.scalar_type() == ScalarType::BFloat16) {
return topk_impl_loop<scalar_t, float>(
diff --git a/test/onnx/test_fx_op_consistency.py b/test/onnx/test_fx_op_consistency.py
index 4760b11..4bd56a7 100644
--- a/test/onnx/test_fx_op_consistency.py
+++ b/test/onnx/test_fx_op_consistency.py
@@ -371,6 +371,15 @@
"cumsum", dtypes=onnx_test_common.BOOL_TYPES + (torch.uint8, torch.int8, torch.int16,),
reason=onnx_test_common.reason_onnx_does_not_support("Cumsum", "bool, uint8, int8, int16")
),
+ # See https://github.com/pytorch/pytorch/issues/111454
+ xfail(
+ "cumsum", dtypes=(torch.float16,),
+ reason=onnx_test_common.reason_onnx_runtime_does_not_support("RUNTIME_EXCEPTION : \
+ Exception during initialization: /onnxruntime_src/onnxruntime/core/framework/\
+ allocation_planner.cc:230 int& onnxruntime::PlannerImpl::\
+ UseCount(onnxruntime::OrtValueIndex) n >= 0 && static_cast<size_t>(n) \
+ < ort_value_info_.size() was false.")
+ ),
xfail(
"cross",
reason=onnx_test_common.reason_onnx_script_does_not_support("linalg_cross"),
diff --git a/test/test_mps.py b/test/test_mps.py
index 133d10a..efa6daf 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -10898,7 +10898,7 @@
# You most likely do NOT want to modify this manually
FP16_LOW_PRECISION_LIST = {
- 'add', 'sub', 'div',
+ 'add', 'sub', 'div', 'addcdiv',
'__rdiv__', '__rmul__',
'nn.functional.huber_loss',
'true_divide', 'kron',
diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py
index 41f8ee9..08b62cc 100644
--- a/test/test_sort_and_select.py
+++ b/test/test_sort_and_select.py
@@ -754,9 +754,8 @@
for curr_size in (small, large, verylarge):
self._test_topk_dtype(device, dtype, True, curr_size)
- @onlyCUDA
- @dtypes(torch.bfloat16)
- def test_topk_bfloat16(self, device, dtype):
+ @dtypes(torch.bfloat16, torch.half)
+ def test_topk_lower_precision(self, device, dtype):
small = 10
large = 4096
@@ -765,7 +764,7 @@
self._test_topk_dtype(device, dtype, False, curr_size)
@dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
- @dtypes(torch.float, torch.double, torch.bfloat16)
+ @dtypes(torch.float, torch.double, torch.bfloat16, torch.half)
def test_topk_nonfinite(self, device, dtype):
x = torch.tensor([float('nan'), float('inf'), 1e4, 0, -1e4, -float('inf')], device=device, dtype=dtype)
val, idx = x.topk(4)
@@ -796,7 +795,7 @@
@onlyNativeDeviceTypes
@dtypesIfCUDA(*all_types_and(torch.bfloat16))
- @dtypes(*all_types())
+ @dtypes(*all_types_and(torch.bfloat16, torch.half))
def test_topk_zero(self, device, dtype):
# https://github.com/pytorch/pytorch/issues/49205
t = torch.rand(2, 2, device=device).to(dtype=dtype)
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 46e8996..d011c83 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -10164,8 +10164,7 @@
reference_inputs_func=partial(
reference_inputs_elementwise_ternary, sample_inputs_func=reference_inputs_addcmul_addcdiv)),
OpInfo('addcdiv',
- dtypes=floating_and_complex_types_and(torch.bfloat16),
- dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
+ dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(
@@ -10788,8 +10787,7 @@
supports_out=True,
supports_forward_ad=True),
OpInfo('cumsum',
- dtypes=all_types_and_complex_and(torch.bfloat16),
- dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(
@@ -14022,8 +14020,7 @@
),
),
OpInfo('topk',
- dtypes=all_types_and(torch.bfloat16),
- dtypesIfCUDA=all_types_and(torch.bfloat16, torch.float16),
+ dtypes=all_types_and(torch.bfloat16, torch.float16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
assert_jit_shape_analysis=True,
@@ -17033,8 +17030,7 @@
check_batched_forward_grad=False,
sample_inputs_func=sample_trapezoid),
OpInfo('cumulative_trapezoid',
- dtypes=all_types_and_complex_and(torch.bfloat16),
- dtypesIfCUDA=all_types_and_complex_and(torch.bfloat16, torch.float16),
+ dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
# See https://github.com/pytorch/pytorch/pull/78358
diff --git a/torch/testing/_internal/opinfo/definitions/_masked.py b/torch/testing/_internal/opinfo/definitions/_masked.py
index 890180f..31b3796 100644
--- a/torch/testing/_internal/opinfo/definitions/_masked.py
+++ b/torch/testing/_internal/opinfo/definitions/_masked.py
@@ -561,8 +561,7 @@
),
OpInfo(
"masked.cumsum",
- dtypes=all_types_and_complex_and(torch.bfloat16),
- dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
method_variant=None,
# Runs very slowly on slow gradcheck - alternatively reduce input sizes
gradcheck_fast_mode=True,