Add Half support for cummax, cummin, cumprod, logcumsumexp, and prod on CPU (#112132)
Add Half support for cummax, cummin, cumprod, logcumsumexp, and prod on CPU.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112132
Approved by: https://github.com/cpuhrsch
diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp
index 04d0e0c..7a47490 100644
--- a/aten/src/ATen/native/ReduceOps.cpp
+++ b/aten/src/ATen/native/ReduceOps.cpp
@@ -794,7 +794,7 @@
}
void cummax_helper_cpu(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) {
- AT_DISPATCH_ALL_TYPES_AND2(kBool, kBFloat16,
+ AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf,
self.scalar_type(), "cummax_cpu",
[&] {
at::native::tensor_dim_apply3<scalar_t, int64_t>(self, values, indices, dim, cummax_cummin_helper<scalar_t, int64_t, std::greater_equal<scalar_t>>);
@@ -829,7 +829,7 @@
}
void cummin_helper_cpu(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) {
- AT_DISPATCH_ALL_TYPES_AND2(kBool, kBFloat16,
+ AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf,
self.scalar_type(), "cummin_cpu",
[&] {
at::native::tensor_dim_apply3<scalar_t, int64_t>(self, values, indices, dim, cummax_cummin_helper<scalar_t, int64_t, std::less_equal<scalar_t>>);
diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
index f4f73f4..405fda4 100644
--- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
+++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
@@ -100,7 +100,7 @@
auto wrap_dim = maybe_wrap_dim(dim, self.dim());
int64_t self_dim_size = ensure_nonempty_size(self, wrap_dim);
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, self.scalar_type(), "cumprod_out_cpu", [&] {
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, self.scalar_type(), "cumprod_out_cpu", [&] {
cpu_cum_base_kernel<scalar_t>(result, self, wrap_dim, [&] (
scalar_t* result_data, auto result_dim_stride,
const scalar_t* self_data, auto self_dim_stride, scalar_t init_val) {
@@ -119,7 +119,7 @@
auto wrap_dim = maybe_wrap_dim(dim, self.dim());
int64_t self_dim_size = ensure_nonempty_size(self, wrap_dim);
- AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(kBFloat16, self.scalar_type(), "logcumsumexp_out_cpu", [&] {
+ AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kBFloat16, kHalf, self.scalar_type(), "logcumsumexp_out_cpu", [&] {
cpu_cum_base_kernel<scalar_t>(result, self, wrap_dim, [&] (
scalar_t* result_data, auto result_dim_stride,
const scalar_t* self_data, auto self_dim_stride, scalar_t init_val) {
@@ -176,7 +176,7 @@
// NOLINTNEXTLINE(bugprone-argument-comment)
/*identity=*/1);
} else {
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, iter.dtype(), "prod_out_cpu", [&] {
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "prod_out_cpu", [&] {
binary_kernel_reduce_vec(
iter,
[=](scalar_t a, scalar_t b)
diff --git a/test/test_mps.py b/test/test_mps.py
index b359a4c..817e11e 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -164,7 +164,9 @@
'__rpow__': [torch.float32],
# See https://github.com/pytorch/pytorch/issues/106112 for more information
- 'cumprod': [torch.float32],
+ 'cumprod': [torch.float32, torch.float16],
+ # See https://github.com/pytorch/pytorch/issues/109166 for more information
+ 'masked.cumprod': [torch.float16],
}
SKIPLIST_GRAD = {
@@ -10943,6 +10945,7 @@
'nn.functional.kl_div',
'nn.functional.softmin',
'cross', 'linalg.cross',
+ 'prod', 'masked.prod',
# for macOS 12
'masked.normalize', 'masked.sum', 'masked.var',
diff --git a/test/test_reductions.py b/test/test_reductions.py
index f755267..2727698 100644
--- a/test/test_reductions.py
+++ b/test/test_reductions.py
@@ -1435,6 +1435,18 @@
torch.prod(x, 1, out=res2)
self.assertEqual(res1, res2)
+ @onlyCPU
+ @dtypes(torch.float16, torch.bfloat16)
+ def test_prod_lowp(self, device, dtype):
+ x = torch.rand(100, 100, dtype=dtype, device=device)
+ x_ref = x.float()
+ res1 = torch.prod(x, 1)
+ res2 = torch.prod(x_ref, 1)
+ self.assertEqual(res1, res2.to(dtype=dtype))
+ res1 = torch.prod(x, 0)
+ res2 = torch.prod(x_ref, 0)
+ self.assertEqual(res1, res2.to(dtype=dtype))
+
def test_prod_bool(self, device):
vals = [[True, True], [True, False], [False, False], []]
for val in vals:
diff --git a/test/test_sparse.py b/test/test_sparse.py
index e335fa9..fad4db9 100644
--- a/test/test_sparse.py
+++ b/test/test_sparse.py
@@ -4225,6 +4225,10 @@
class TestSparseMaskedReductions(TestCase):
exact_dtype = True
+ fp16_low_precision_list = {
+ 'masked.prod',
+ }
+
@ops(sparse_masked_reduction_ops)
def test_future_empty_dim(self, device, dtype, op):
"""Currently, `dim=()` in reductions operations means "reduce over
@@ -4263,7 +4267,12 @@
self.assertEqual(actual.layout, torch.sparse_coo)
expected = op(t, *sample_input.args, **sample_input_kwargs).to_sparse()
- self.assertEqual(actual, expected)
+ atol = None
+ rtol = None
+ if op.name in self.fp16_low_precision_list and dtype == torch.half:
+ atol = 1e-5
+ rtol = 2e-3
+ self.assertEqual(actual, expected, atol=atol, rtol=rtol)
class TestSparseMeta(TestCase):
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 1667fa0..bf5fcdb 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -10821,8 +10821,7 @@
),
sample_inputs_func=sample_inputs_cumulative_ops),
OpInfo('cumprod',
- dtypes=all_types_and_complex_and(torch.bfloat16),
- dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(
@@ -10833,8 +10832,7 @@
sample_inputs_func=sample_inputs_cumprod,
gradcheck_fast_mode=False),
OpInfo('cummax',
- dtypes=all_types_and(torch.bool, torch.bfloat16),
- dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
+ dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
sample_inputs_func=partial(sample_inputs_cumulative_ops, supports_dtype_kwargs=False),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
@@ -10842,8 +10840,7 @@
),
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL),
OpInfo('cummin',
- dtypes=all_types_and(torch.bool, torch.bfloat16),
- dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
+ dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
sample_inputs_func=partial(sample_inputs_cumulative_ops, supports_dtype_kwargs=False),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
@@ -17294,8 +17291,7 @@
)
),
OpInfo('logcumsumexp',
- dtypes=floating_and_complex_types_and(torch.bfloat16),
- dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
+ dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half),
backward_dtypes=floating_and_complex_types_and(torch.bfloat16),
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.bfloat16),
supports_forward_ad=True,
@@ -18371,7 +18367,7 @@
supports_fwgrad_bwgrad=True,
promotes_int_to_int64=True,
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
- dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
sample_inputs_func=sample_inputs_prod,
ref=prod_numpy,
diff --git a/torch/testing/_internal/opinfo/definitions/_masked.py b/torch/testing/_internal/opinfo/definitions/_masked.py
index d9b44a3..98fef72 100644
--- a/torch/testing/_internal/opinfo/definitions/_masked.py
+++ b/torch/testing/_internal/opinfo/definitions/_masked.py
@@ -504,11 +504,7 @@
supports_sparse=True,
supports_sparse_csr=True,
promotes_int_to_int64=True,
- # FIXME: "prod_cpu" not implemented for 'Half'
- dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
- dtypesIfCUDA=all_types_and_complex_and(
- torch.bool, torch.float16, torch.bfloat16
- ),
+ dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
skips=(
DecorateInfo(
unittest.expectedFailure,
@@ -554,6 +550,12 @@
"TestReductions",
"test_ref_small_input",
),
+ DecorateInfo(
+ toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1.5e-03)}),
+ "TestMasked",
+ "test_mask_layout",
+ device_type="cpu",
+ ),
],
sample_inputs_func=sample_inputs_masked_reduction,
sample_inputs_sparse_coo_func=sample_inputs_sparse_coo_masked_reduction,
@@ -585,8 +587,7 @@
),
OpInfo(
"masked.cumprod",
- dtypes=all_types_and_complex_and(torch.bfloat16),
- dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
+ dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
method_variant=None,
# Runs very slowly on slow gradcheck - alternatively reduce input sizes
gradcheck_fast_mode=True,