add lerp cpu support for half (#105607)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105607
Approved by: https://github.com/albanD
diff --git a/aten/src/ATen/native/cpu/LerpKernel.cpp b/aten/src/ATen/native/cpu/LerpKernel.cpp
index afff853..c9bf752 100644
--- a/aten/src/ATen/native/cpu/LerpKernel.cpp
+++ b/aten/src/ATen/native/cpu/LerpKernel.cpp
@@ -79,6 +79,24 @@
auto result1 = lerp_vec(self_vec1, end_vec1, weight_vec);
return convert_float_bfloat16(result0, result1);
});
+ } else if (iter.common_dtype() == kHalf) {
+ using hVec = Vectorized<Half>;
+ using fVec = Vectorized<float>;
+ float weight_val = weight.to<float>();
+ auto weight_vec = fVec(weight_val);
+ at::native::cpu_kernel_vec(
+ iter,
+ [weight_val](Half self_val, Half end_val) -> Half {
+ return lerp(self_val, end_val, weight_val);
+ },
+ [=](hVec self_vec, hVec end_vec) -> hVec {
+ fVec self_vec0, self_vec1, end_vec0, end_vec1;
+ std::tie(self_vec0, self_vec1) = convert_half_float(self_vec);
+ std::tie(end_vec0, end_vec1) = convert_half_float(end_vec);
+ auto result0 = lerp_vec(self_vec0, end_vec0, weight_vec);
+ auto result1 = lerp_vec(self_vec1, end_vec1, weight_vec);
+ return convert_float_half(result0, result1);
+ });
} else {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.common_dtype(), "lerp_kernel_scalar", [&] {
auto weight_val = weight.to<scalar_t>();
@@ -113,6 +131,23 @@
auto result1 = lerp_vec(self_vec1, end_vec1, weight_vec1);
return convert_float_bfloat16(result0, result1);
});
+ } else if (iter.common_dtype() == kHalf) {
+ using hVec = Vectorized<Half>;
+ using fVec = Vectorized<float>;
+ at::native::cpu_kernel_vec(
+ iter,
+ [=](Half self_val, Half end_val, Half weight_val) -> Half {
+ return lerp(self_val, end_val, weight_val);
+ },
+ [=](hVec self_vec, hVec end_vec, hVec weight_vec) -> hVec {
+ fVec self_vec0, self_vec1, end_vec0, end_vec1, weight_vec0, weight_vec1;
+ std::tie(self_vec0, self_vec1) = convert_half_float(self_vec);
+ std::tie(end_vec0, end_vec1) = convert_half_float(end_vec);
+ std::tie(weight_vec0, weight_vec1) = convert_half_float(weight_vec);
+ auto result0 = lerp_vec(self_vec0, end_vec0, weight_vec0);
+ auto result1 = lerp_vec(self_vec1, end_vec1, weight_vec1);
+ return convert_float_half(result0, result1);
+ });
} else {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.common_dtype(), "lerp_kernel_tensor", [&] {
at::native::cpu_kernel_vec(
diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py
index 0785666..6c044cc 100644
--- a/test/test_binary_ufuncs.py
+++ b/test/test_binary_ufuncs.py
@@ -3410,7 +3410,6 @@
@onlyCUDA
@dtypes(torch.half, torch.bfloat16)
def test_lerp_lowp(self, device, dtype):
- ref_dtype = torch.float
xvals = (0.0, -30000.0)
yvals = (0.1, -20000.0)
xs = [torch.full((4,), xval, device=device, dtype=dtype) for xval in xvals]
@@ -3425,7 +3424,7 @@
self.assertEqual(actual, expected, atol=0.0, rtol=0.0)
@onlyCPU
- @dtypes(torch.bfloat16)
+ @dtypes(torch.half, torch.bfloat16)
def test_lerp_lowp_cpu(self, device, dtype):
xvals = (0.0, -30000.0)
yvals = (0.1, -20000.0)
diff --git a/test/test_foreach.py b/test/test_foreach.py
index 8ba6925..f76d16b 100644
--- a/test/test_foreach.py
+++ b/test/test_foreach.py
@@ -809,7 +809,7 @@
kwargs["weight"] = args[1]
ref_kwargs["weight"] = args[1]
- if dtype in integral_types() or dtype == torch.bool or (not self.is_cuda and dtype == torch.half):
+ if dtype in integral_types() or dtype == torch.bool:
with self.assertRaises(RuntimeError):
wrapped_op(inputs, self.is_cuda, is_fastpath, **kwargs)
return
diff --git a/test/test_mps.py b/test/test_mps.py
index ee8500b..0da3f7b 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -10471,7 +10471,7 @@
'nn.functional.huber_loss',
'true_divide', 'kron',
'gradient', 'var', 'std', 'ldexp',
- 'linalg.vector_norm',
+ 'linalg.vector_norm', 'lerp',
'addr', 'var_mean',
'var_mean_unbiased',
'acosh', 'asinh', 'asin',
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index a722aef..f7f8ffe 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -8840,6 +8840,7 @@
foreach_lerp_op_db: List[ForeachFuncInfo] = [
ForeachFuncInfo(
"lerp",
+ dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16),
sample_inputs_func=foreach_lerp_sample_func(3, True, False),
@@ -14914,7 +14915,7 @@
dtypes=[torch.bfloat16]),
),),
OpInfo('lerp',
- dtypes=floating_and_complex_types_and(torch.bfloat16),
+ dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half),
dtypesIfCUDA=floating_and_complex_types_and(torch.chalf, torch.half, torch.bfloat16),
dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_lerp,