add lerp cpu support for half (#105607)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105607
Approved by: https://github.com/albanD
diff --git a/aten/src/ATen/native/cpu/LerpKernel.cpp b/aten/src/ATen/native/cpu/LerpKernel.cpp
index afff853..c9bf752 100644
--- a/aten/src/ATen/native/cpu/LerpKernel.cpp
+++ b/aten/src/ATen/native/cpu/LerpKernel.cpp
@@ -79,6 +79,24 @@
           auto result1 = lerp_vec(self_vec1, end_vec1, weight_vec);
           return convert_float_bfloat16(result0, result1);
       });
+  } else if (iter.common_dtype() == kHalf) {
+    using hVec = Vectorized<Half>;
+    using fVec = Vectorized<float>;
+    float weight_val = weight.to<float>();
+    auto weight_vec = fVec(weight_val);
+    at::native::cpu_kernel_vec(
+      iter,
+      [weight_val](Half self_val, Half end_val) -> Half {
+        return lerp(self_val, end_val, weight_val);
+      },
+      [=](hVec self_vec, hVec end_vec) -> hVec {
+          fVec self_vec0, self_vec1, end_vec0, end_vec1;
+          std::tie(self_vec0, self_vec1) = convert_half_float(self_vec);
+          std::tie(end_vec0, end_vec1) = convert_half_float(end_vec);
+          auto result0 = lerp_vec(self_vec0, end_vec0, weight_vec);
+          auto result1 = lerp_vec(self_vec1, end_vec1, weight_vec);
+          return convert_float_half(result0, result1);
+      });
   } else {
     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.common_dtype(), "lerp_kernel_scalar", [&] {
       auto weight_val = weight.to<scalar_t>();
@@ -113,6 +131,23 @@
           auto result1 = lerp_vec(self_vec1, end_vec1, weight_vec1);
           return convert_float_bfloat16(result0, result1);
       });
+  } else if (iter.common_dtype() == kHalf) {
+    using hVec = Vectorized<Half>;
+    using fVec = Vectorized<float>;
+    at::native::cpu_kernel_vec(
+      iter,
+      [=](Half self_val, Half end_val, Half weight_val) -> Half {
+        return lerp(self_val, end_val, weight_val);
+      },
+      [=](hVec self_vec, hVec end_vec, hVec weight_vec) -> hVec {
+          fVec self_vec0, self_vec1, end_vec0, end_vec1, weight_vec0, weight_vec1;
+          std::tie(self_vec0, self_vec1) = convert_half_float(self_vec);
+          std::tie(end_vec0, end_vec1) = convert_half_float(end_vec);
+          std::tie(weight_vec0, weight_vec1) = convert_half_float(weight_vec);
+          auto result0 = lerp_vec(self_vec0, end_vec0, weight_vec0);
+          auto result1 = lerp_vec(self_vec1, end_vec1, weight_vec1);
+          return convert_float_half(result0, result1);
+      });
   } else {
     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.common_dtype(), "lerp_kernel_tensor", [&] {
       at::native::cpu_kernel_vec(
diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py
index 0785666..6c044cc 100644
--- a/test/test_binary_ufuncs.py
+++ b/test/test_binary_ufuncs.py
@@ -3410,7 +3410,6 @@
     @onlyCUDA
     @dtypes(torch.half, torch.bfloat16)
     def test_lerp_lowp(self, device, dtype):
-        ref_dtype = torch.float
         xvals = (0.0, -30000.0)
         yvals = (0.1, -20000.0)
         xs = [torch.full((4,), xval, device=device, dtype=dtype) for xval in xvals]
@@ -3425,7 +3424,7 @@
             self.assertEqual(actual, expected, atol=0.0, rtol=0.0)
 
     @onlyCPU
-    @dtypes(torch.bfloat16)
+    @dtypes(torch.half, torch.bfloat16)
     def test_lerp_lowp_cpu(self, device, dtype):
         xvals = (0.0, -30000.0)
         yvals = (0.1, -20000.0)
diff --git a/test/test_foreach.py b/test/test_foreach.py
index 8ba6925..f76d16b 100644
--- a/test/test_foreach.py
+++ b/test/test_foreach.py
@@ -809,7 +809,7 @@
                 kwargs["weight"] = args[1]
                 ref_kwargs["weight"] = args[1]
 
-            if dtype in integral_types() or dtype == torch.bool or (not self.is_cuda and dtype == torch.half):
+            if dtype in integral_types() or dtype == torch.bool:
                 with self.assertRaises(RuntimeError):
                     wrapped_op(inputs, self.is_cuda, is_fastpath, **kwargs)
                 return
diff --git a/test/test_mps.py b/test/test_mps.py
index ee8500b..0da3f7b 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -10471,7 +10471,7 @@
         'nn.functional.huber_loss',
         'true_divide', 'kron',
         'gradient', 'var', 'std', 'ldexp',
-        'linalg.vector_norm',
+        'linalg.vector_norm', 'lerp',
         'addr', 'var_mean',
         'var_mean_unbiased',
         'acosh', 'asinh', 'asin',
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index a722aef..f7f8ffe 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -8840,6 +8840,7 @@
 foreach_lerp_op_db: List[ForeachFuncInfo] = [
     ForeachFuncInfo(
         "lerp",
+        dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
         dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
         dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16),
         sample_inputs_func=foreach_lerp_sample_func(3, True, False),
@@ -14914,7 +14915,7 @@
                                     dtypes=[torch.bfloat16]),
                    ),),
     OpInfo('lerp',
-           dtypes=floating_and_complex_types_and(torch.bfloat16),
+           dtypes=floating_and_complex_types_and(torch.bfloat16, torch.half),
            dtypesIfCUDA=floating_and_complex_types_and(torch.chalf, torch.half, torch.bfloat16),
            dtypesIfROCM=floating_and_complex_types_and(torch.half, torch.bfloat16),
            sample_inputs_func=sample_inputs_lerp,