add Half support for layer_norm on CPU (#99590)

### Testing
Single socket (icx, 32cores):
| shape | fp32 forward (ms) | fp16 forward (ms) | mixed fp32 fp16 forward (ms) | fp32 backward (ms) | fp16 backward (ms) | mixed fp32 fp16 backward (ms) |
| -- | -- | -- | -- | -- | -- | -- |
| (1, 8, 16) | 0.012 | 0.011 | 0.011 | 0.051 | 0.051 | 0.050 |
| (8 ,8, 16) | 0.013 | 0.013 | 0.013 | 0.054 | 0.053 | 0.051 |
| (32, 8, 16) | 0.015 | 0.014 | 0.014 | 0.059 | 0.054 | 0.052 |
| (64, 128, 56, 56) | 1.875 | 0.790 | 1.016 | 12.845 | 7.151 | 6.985 |
| (64, 128, 256, 256) | 50.226 | 25.462 | 35.736 | 328.957 | 179.615 | 175.618 |

Single core (icx):

| shape | fp32 forward (ms) | fp16 forward (ms) | mixed fp32 fp16 forward (ms) | fp32 backward (ms) | fp16 backward (ms) | mixed fp32 fp16 backward (ms) |
| -- | -- | -- | -- | -- | -- | -- |
| (1, 8, 16) | 0.012 | 0.011 | 0.011 | 0.040 | 0.041 | 0.041 |
| (8 ,8, 16) | 0.012 | 0.012 | 0.012 | 0.042 | 0.042 | 0.042 |
| (32, 8, 16) | 0.027 | 0.014 | 0.014 | 0.048 | 0.048 | 0.046 |
| (64, 128, 56, 56) | 58.054 | 11.034 | 17.928 | 108.603 | 48.816 | 50.244 |
| (64, 128, 256, 256) | 1327.758 | 352.394 | 496.994 | 2846.182 | 1224.247 | 1218.422 |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99590
Approved by: https://github.com/mingfeima, https://github.com/jgong5, https://github.com/cpuhrsch
diff --git a/aten/src/ATen/cpu/vec/functional_bfloat16.h b/aten/src/ATen/cpu/vec/functional_bfloat16.h
index 74ca8b1..03cb017 100644
--- a/aten/src/ATen/cpu/vec/functional_bfloat16.h
+++ b/aten/src/ATen/cpu/vec/functional_bfloat16.h
@@ -69,7 +69,7 @@
 //
 template <typename scalar_t, typename Op,
           typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
-inline scalar_t reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) {
+inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) {
   using bVec = vec::Vectorized<scalar_t>;
   using fVec = vec::Vectorized<float>;
   if (size < bVec::size()) {
@@ -111,7 +111,7 @@
 
 template <typename scalar_t, typename Op1, typename Op2,
           typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
-inline std::pair<scalar_t, scalar_t> reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2,
+inline std::pair<float, float> reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2,
     const scalar_t* data, int64_t size) {
   using bVec = vec::Vectorized<scalar_t>;
   using fVec = vec::Vectorized<float>;
@@ -169,7 +169,7 @@
 
 template <typename scalar_t, typename MapOp, typename ReduceOp,
           typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
-inline scalar_t map_reduce_all(
+inline float map_reduce_all(
     const MapOp& map_fun,
     const ReduceOp& red_fun,
     const scalar_t* data,
@@ -225,7 +225,7 @@
 
 template <typename scalar_t, typename MapOp, typename ReduceOp,
           typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
-inline scalar_t map2_reduce_all(
+inline float map2_reduce_all(
     const MapOp& map_fun,
     const ReduceOp& red_fun,
     const scalar_t* data,
@@ -294,7 +294,7 @@
 
 template <typename scalar_t, typename MapOp, typename ReduceOp,
           typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
-inline scalar_t map3_reduce_all(
+inline float map3_reduce_all(
     const MapOp& map_fun,
     const ReduceOp& red_fun,
     const scalar_t* data,
diff --git a/aten/src/ATen/native/cpu/layer_norm_kernel.cpp b/aten/src/ATen/native/cpu/layer_norm_kernel.cpp
index a0c3e09..a668305 100644
--- a/aten/src/ATen/native/cpu/layer_norm_kernel.cpp
+++ b/aten/src/ATen/native/cpu/layer_norm_kernel.cpp
@@ -23,14 +23,15 @@
 
 namespace {
 
-template <typename T, typename T_ACC>
+template <typename T,
+          typename std::enable_if_t<!is_reduced_floating_point_v<T>, int> = 0>
 void LayerNormKernelImplInternal(
     const Tensor& X,
     const Tensor& gamma,
     const Tensor& beta,
     int64_t M,
     int64_t N,
-    T_ACC eps,
+    T eps,
     Tensor* Y,
     Tensor* mean,
     Tensor* rstd) {
@@ -83,7 +84,8 @@
   });
 }
 
-template <typename param_t>
+template <typename T, typename param_t,
+          typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
 void layer_norm_kernel_mixed_type(
     const Tensor& X,
     const Tensor& gamma,
@@ -94,12 +96,12 @@
     Tensor* Y,
     Tensor* mean,
     Tensor* rstd) {
-  using bVec = Vectorized<BFloat16>;
+  using bVec = Vectorized<T>;
   using fVec = Vectorized<float>;
-  const BFloat16* X_data = X.data_ptr<BFloat16>();
+  const T* X_data = X.data_ptr<T>();
   const param_t* gamma_data = gamma.defined() ? gamma.data_ptr<param_t>() : nullptr;
   const param_t* beta_data = beta.defined() ? beta.data_ptr<param_t>() : nullptr;
-  BFloat16* Y_data = Y->data_ptr<BFloat16>();
+  T* Y_data = Y->data_ptr<T>();
   param_t* mean_data = mean ? mean->data_ptr<param_t>() : nullptr;
   param_t* rstd_data = rstd ? rstd->data_ptr<param_t>() : nullptr;
 
@@ -109,38 +111,29 @@
   const bool rstd_null = rstd_data == nullptr;
   at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
     for (const auto i : c10::irange(start, end)) {
-      const BFloat16* X_ptr = X_data + i * N;
-      BFloat16* Y_ptr = Y_data + i * N;
+      const T* X_ptr = X_data + i * N;
+      T* Y_ptr = Y_data + i * N;
       float mean_val;
       float rstd_val;
       std::tie(mean_val, rstd_val) = RowwiseMoments(X_ptr, N);
       rstd_val = float(1) / std::sqrt(rstd_val + eps);
       const float scale = rstd_val;
       const float bias = -rstd_val * mean_val;
-      if (gamma_null || beta_null) {
-        for (const auto j : c10::irange(N)) {
-          const param_t gamma_v = gamma_null ? param_t(1) : gamma_data[j];
-          const param_t beta_v = beta_null ? param_t(0) : beta_data[j];
-          Y_ptr[j] = (X_ptr[j] * scale + bias) * gamma_v + beta_v;
-        }
-      } else {
-        int64_t d = 0;
-        for (; d < N - (N % bVec::size()); d += bVec::size()) {
-          bVec x_bvec = bVec::loadu(X_ptr + d);
-          fVec x_fvec0, x_fvec1;
-          std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
-          fVec gamma_fvec0, gamma_fvec1;
-          std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data + d);
-          fVec beta_fvec0, beta_fvec1;
-          std::tie(beta_fvec0, beta_fvec1) = load2f(beta_data + d);
-          fVec y_fvec0 = (x_fvec0 * fVec(scale) + fVec(bias)) * gamma_fvec0 + beta_fvec0;
-          fVec y_fvec1 = (x_fvec1 * fVec(scale) + fVec(bias)) * gamma_fvec1 + beta_fvec1;
-          bVec y_bvec = convert_float_bfloat16(y_fvec0, y_fvec1);
-          y_bvec.store(Y_ptr + d);
-        }
-        for (; d < N; d++) {
-          Y_ptr[d] = (X_ptr[d] * scale + bias) * gamma_data[d] + beta_data[d];
-        }
+      int64_t d = 0;
+      for (; d < N - (N % bVec::size()); d += bVec::size()) {
+        bVec x_bvec = bVec::loadu(X_ptr + d);
+        auto [x_fvec0, x_fvec1] = convert_to_float<T>(x_bvec);
+        auto [gamma_fvec0, gamma_fvec1] = gamma_null ? std::make_tuple(fVec(1), fVec(1)) : load2f(gamma_data + d);
+        auto [beta_fvec0, beta_fvec1] = beta_null ? std::make_tuple(fVec(0), fVec(0)) : load2f(beta_data + d);
+        fVec y_fvec0 = (x_fvec0 * fVec(scale) + fVec(bias)) * gamma_fvec0 + beta_fvec0;
+        fVec y_fvec1 = (x_fvec1 * fVec(scale) + fVec(bias)) * gamma_fvec1 + beta_fvec1;
+        bVec y_bvec = convert_from_float<T>(y_fvec0, y_fvec1);
+        y_bvec.store(Y_ptr + d);
+      }
+      for (; d < N; d++) {
+        const float gamma_v = gamma_null ? float(1) : float(gamma_data[d]);
+        const float beta_v = beta_null ? float(0) : float(beta_data[d]);
+        Y_ptr[d] = (float(X_ptr[d]) * scale + bias) * gamma_v + beta_v;
       }
       if (!mean_null) {
         mean_data[i] = mean_val;
@@ -152,8 +145,9 @@
   });
 }
 
-template <>
-void LayerNormKernelImplInternal<BFloat16, float>(
+template <typename T,
+          typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+void LayerNormKernelImplInternal(
     const Tensor& X,
     const Tensor& gamma,
     const Tensor& beta,
@@ -165,9 +159,9 @@
     Tensor* rstd) {
   const bool mixed_type = is_mixed_type(X, gamma, beta);
   if (mixed_type) {
-    layer_norm_kernel_mixed_type<float>(X, gamma, beta, M, N, eps, Y, mean, rstd);
+    layer_norm_kernel_mixed_type<T, float>(X, gamma, beta, M, N, eps, Y, mean, rstd);
   } else {
-    layer_norm_kernel_mixed_type<BFloat16>(X, gamma, beta, M, N, eps, Y, mean, rstd);
+    layer_norm_kernel_mixed_type<T, T>(X, gamma, beta, M, N, eps, Y, mean, rstd);
   }
 }
 
@@ -184,15 +178,14 @@
   TORCH_DCHECK_EQ(X.numel(), M * N);
   DCHECK(!gamma.defined() || gamma.numel() == N);
   DCHECK(!beta.defined() || beta.numel() == N);
-  AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, X.scalar_type(),
+  AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, X.scalar_type(),
       "LayerNormKernelImpl", [&]() {
-    using acc_t = at::opmath_type<scalar_t>;
-    LayerNormKernelImplInternal<scalar_t, acc_t>(
-        X, gamma, beta, M, N, static_cast<acc_t>(eps), Y, mean, rstd);
+    LayerNormKernelImplInternal<scalar_t>(
+        X, gamma, beta, M, N, eps, Y, mean, rstd);
   });
 }
 
-template <typename T, typename T2, typename T_ACC>
+template <typename T, typename T2, typename opmath_t>
 void layer_norm_backward_frame(
     const T* dY_data,
     const T* X_data,
@@ -202,19 +195,19 @@
     T* dX_data,
     T* dgamma_buffer_ptr,
     T* dbeta_buffer_ptr,
-    const T_ACC scale,
+    const opmath_t scale,
     const bool gamma_null,
     const bool dX_null,
     const bool dgamma_null,
     const bool dbeta_null,
     int64_t N,
     int64_t i) {
-  using Vec = vec::Vectorized<T_ACC>;
+  using Vec = vec::Vectorized<opmath_t>;
   const T* dY_ptr = dY_data + i * N;
   const T* X_ptr = X_data + i * N;
   if (!dgamma_null) {
-    const T_ACC a = rstd_data[i];
-    const T_ACC b = -a * mean_data[i];
+    const opmath_t a = rstd_data[i];
+    const opmath_t b = -a * mean_data[i];
     // Scalar math:
     // for (const auto j : c10::irange(N)) {
     //   dgamma_data[j] += dY_ptr[j] * (a * X_ptr[j] + b);
@@ -243,8 +236,8 @@
   }
   if (!dX_null) {
     T* dX_ptr = dX_data + i * N;
-    T_ACC ds = T_ACC(0);
-    T_ACC db = T_ACC(0);
+    opmath_t ds = opmath_t(0);
+    opmath_t db = opmath_t(0);
     // Scalar math:
     // for (const auto j : c10::irange(N)) {
     //   const T gamma_v = gamma_null ? T(1) : gamma_data[j];
@@ -275,9 +268,9 @@
           gamma_data,
           N);
     }
-    const T_ACC a = rstd_data[i];
-    const T_ACC b = (db * mean_data[i] - ds) * a * a * a * scale;
-    const T_ACC c = -b * mean_data[i] - db * a * scale;
+    const opmath_t a = rstd_data[i];
+    const opmath_t b = (db * opmath_t(mean_data[i]) - ds) * a * a * a * scale;
+    const opmath_t c = -b * opmath_t(mean_data[i]) - db * a * scale;
     // Scalar math:
     // for (const auto j : c10::irange(N)) {
     //   const T gamma_v = gamma_null ? T(1) : gamma_data[j];
@@ -306,16 +299,17 @@
   }
 }
 
-template <>
-void layer_norm_backward_frame<BFloat16, float, float>(
-    const BFloat16* dY_data,
-    const BFloat16* X_data,
+template <typename T, typename T2, typename opmath_t,
+          typename std::enable_if_t<is_reduced_floating_point_v<T> && std::is_same<T2, float>::value, int> = 0>
+void layer_norm_backward_frame(
+    const T* dY_data,
+    const T* X_data,
     const float* mean_data,
     const float* rstd_data,
     const float* gamma_data,
-    BFloat16* dX_data,
-    BFloat16* dgamma_buffer_ptr,
-    BFloat16* dbeta_buffer_ptr,
+    T* dX_data,
+    T* dgamma_buffer_ptr,
+    T* dbeta_buffer_ptr,
     const float scale,
     const bool gamma_null,
     const bool dX_null,
@@ -323,10 +317,10 @@
     const bool dbeta_null,
     int64_t N,
     int64_t i) {
-  using bVec = Vectorized<BFloat16>;
+  using bVec = Vectorized<T>;
   using fVec = Vectorized<float>;
-  const BFloat16* dY_ptr = dY_data + i * N;
-  const BFloat16* X_ptr = X_data + i * N;
+  const T* dY_ptr = dY_data + i * N;
+  const T* X_ptr = X_data + i * N;
   if (!dgamma_null) {
     const float a = rstd_data[i];
     const float b = -a * mean_data[i];
@@ -334,7 +328,7 @@
     // for (const auto j : c10::irange(N)) {
     //   dgamma_data[j] += dY_ptr[j] * (a * X_ptr[j] + b);
     // }
-    vec::map3<BFloat16>(
+    vec::map3<T>(
         [a, b](fVec dgamma, fVec dy, fVec x) {
           return dgamma + dy * (fVec(a) * x + fVec(b));
         },
@@ -349,7 +343,7 @@
     // for (const auto j : c10::irange(N)) {
     //   dbeta_data[j] += dY_ptr[j];
     // }
-    vec::map2<BFloat16>(
+    vec::map2<T>(
         [](fVec dbeta, fVec dy) { return dbeta + dy; },
         dbeta_buffer_ptr,
         dbeta_buffer_ptr,
@@ -357,7 +351,7 @@
         N);
   }
   if (!dX_null) {
-    BFloat16* dX_ptr = dX_data + i * N;
+    T* dX_ptr = dX_data + i * N;
     float ds = float(0);
     float db = float(0);
     // Scalar math:
@@ -367,21 +361,21 @@
     //   db += dY_ptr[j] * gamma_v;
     // }
     if (gamma_null) {
-      ds = vec::map2_reduce_all<BFloat16>(
+      ds = vec::map2_reduce_all<T>(
           [](fVec x, fVec y) { return x * y; },
           [](fVec x, fVec y) { return x + y; },
           dY_ptr,
           X_ptr,
           N);
-      db = vec::reduce_all<BFloat16>(
+      db = vec::reduce_all<T>(
           [](fVec& x, fVec& y) { return x + y; }, dY_ptr, N);
     } else {
       if (N < bVec::size()) {
         bVec x_bvec = bVec::loadu(X_ptr, N);
         bVec dy_bvec = bVec::loadu(dY_ptr, N);
         fVec x_fvec0, x_fvec1, dy_fvec0, dy_fvec1, gamma_fvec0, gamma_fvec1;
-        std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
-        std::tie(dy_fvec0, dy_fvec1) = convert_bfloat16_float(dy_bvec);
+        std::tie(x_fvec0, x_fvec1) = convert_to_float<T>(x_bvec);
+        std::tie(dy_fvec0, dy_fvec1) = convert_to_float<T>(dy_bvec);
         std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data, N);
         if (N > fVec::size()) {
           fVec db_fvec0 = dy_fvec0 * gamma_fvec0;
@@ -404,8 +398,8 @@
         bVec dy_bvec = bVec::loadu(dY_ptr);
         fVec x_fvec0, x_fvec1, dy_fvec0, dy_fvec1, gamma_fvec0, gamma_fvec1;
         fVec ds_fvec0, ds_fvec1, db_fvec0, db_fvec1, acc_ds_fvec0, acc_ds_fvec1, acc_db_fvec0, acc_db_fvec1;
-        std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
-        std::tie(dy_fvec0, dy_fvec1) = convert_bfloat16_float(dy_bvec);
+        std::tie(x_fvec0, x_fvec1) = convert_to_float<T>(x_bvec);
+        std::tie(dy_fvec0, dy_fvec1) = convert_to_float<T>(dy_bvec);
         std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data);
         acc_db_fvec0 = dy_fvec0 * gamma_fvec0;
         acc_db_fvec1 = dy_fvec1 * gamma_fvec1;
@@ -414,8 +408,8 @@
         for (; d < N - (N % bVec::size()); d += bVec::size()) {
           x_bvec = bVec::loadu(X_ptr + d);
           dy_bvec = bVec::loadu(dY_ptr + d);
-          std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
-          std::tie(dy_fvec0, dy_fvec1) = convert_bfloat16_float(dy_bvec);
+          std::tie(x_fvec0, x_fvec1) = convert_to_float<T>(x_bvec);
+          std::tie(dy_fvec0, dy_fvec1) = convert_to_float<T>(dy_bvec);
           std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data + d);
           db_fvec0 = dy_fvec0 * gamma_fvec0;
           db_fvec1 = dy_fvec1 * gamma_fvec1;
@@ -429,8 +423,8 @@
         if (N - d > 0) {
           x_bvec = bVec::loadu(X_ptr + d, N - d);
           dy_bvec = bVec::loadu(dY_ptr + d, N - d);
-          std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
-          std::tie(dy_fvec0, dy_fvec1) = convert_bfloat16_float(dy_bvec);
+          std::tie(x_fvec0, x_fvec1) = convert_to_float<T>(x_bvec);
+          std::tie(dy_fvec0, dy_fvec1) = convert_to_float<T>(dy_bvec);
           std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data + d, N - d);
           if (N - d > fVec::size()) {
             db_fvec0 = dy_fvec0 * gamma_fvec0;
@@ -463,7 +457,7 @@
     //   dX_ptr[j] = a * dY_ptr[j] * gamma_v + b * X_ptr[j] + c;
     // }
     if (gamma_null) {
-      vec::map2<BFloat16>(
+      vec::map2<T>(
           [a, b, c](fVec dy, fVec x) {
             return fVec(a) * dy + fVec(b) * x + fVec(c);
           },
@@ -477,24 +471,24 @@
         bVec x_bvec = bVec::loadu(X_ptr + d);
         bVec dy_bvec = bVec::loadu(dY_ptr + d);
         fVec x_fvec0, x_fvec1, dy_fvec0, dy_fvec1, gamma_fvec0, gamma_fvec1;
-        std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
-        std::tie(dy_fvec0, dy_fvec1) = convert_bfloat16_float(dy_bvec);
+        std::tie(x_fvec0, x_fvec1) = convert_to_float<T>(x_bvec);
+        std::tie(dy_fvec0, dy_fvec1) = convert_to_float<T>(dy_bvec);
         std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data + d);
         fVec r_fvec0 = fVec(a) * dy_fvec0 * gamma_fvec0 + fVec(b) * x_fvec0 + fVec(c);
         fVec r_fvec1 = fVec(a) * dy_fvec1 * gamma_fvec1 + fVec(b) * x_fvec1 + fVec(c);
-        bVec r_bvec = convert_float_bfloat16(r_fvec0, r_fvec1);
+        bVec r_bvec = convert_from_float<T>(r_fvec0, r_fvec1);
         r_bvec.store(dX_ptr + d);
       }
       if (N - d > 0) {
         bVec x_bvec = bVec::loadu(X_ptr + d, N - d);
         bVec dy_bvec = bVec::loadu(dY_ptr + d, N - d);
         fVec x_fvec0, x_fvec1, dy_fvec0, dy_fvec1, gamma_fvec0, gamma_fvec1;
-        std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
-        std::tie(dy_fvec0, dy_fvec1) = convert_bfloat16_float(dy_bvec);
+        std::tie(x_fvec0, x_fvec1) = convert_to_float<T>(x_bvec);
+        std::tie(dy_fvec0, dy_fvec1) = convert_to_float<T>(dy_bvec);
         std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data + d, N - d);
         fVec r_fvec0 = fVec(a) * dy_fvec0 * gamma_fvec0 + fVec(b) * x_fvec0 + fVec(c);
         fVec r_fvec1 = fVec(a) * dy_fvec1 * gamma_fvec1 + fVec(b) * x_fvec1 + fVec(c);
-        bVec r_bvec = convert_float_bfloat16(r_fvec0, r_fvec1);
+        bVec r_bvec = convert_from_float<T>(r_fvec0, r_fvec1);
         r_bvec.store(dX_ptr + d, N - d);
       }
     }
@@ -513,7 +507,7 @@
     Tensor* dX,
     Tensor* dgamma,
     Tensor* dbeta) {
-  using T_ACC = at::opmath_type<T>;
+  using opmath_t = at::opmath_type<T>;
   TORCH_DCHECK_EQ(dY.numel(), M * N);
   TORCH_DCHECK_EQ(X.numel(), M * N);
   TORCH_DCHECK_EQ(mean.numel(), M);
@@ -528,7 +522,7 @@
   T* dX_data = dX->defined() ? dX->template data_ptr<T>() : nullptr;
   T2* dgamma_data = dgamma->defined() ? dgamma->template data_ptr<T2>() : nullptr;
   T2* dbeta_data = dbeta->defined() ? dbeta->template data_ptr<T2>() : nullptr;
-  const T_ACC scale = T_ACC(1) / static_cast<T_ACC>(N);
+  const opmath_t scale = opmath_t(1) / static_cast<opmath_t>(N);
   const bool gamma_null = gamma_data == nullptr;
   const bool dX_null = dX_data == nullptr;
   const bool dgamma_null = dgamma_data == nullptr;
@@ -565,7 +559,7 @@
     T* dbeta_buffer_ptr =
         dbeta_null ? nullptr : buffer_data + num_threads * N + tid * N;
     for (const auto i : c10::irange(start, end)) {
-      layer_norm_backward_frame<T, T2, T_ACC>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, dgamma_buffer_ptr, dbeta_buffer_ptr, scale, gamma_null, dX_null, dgamma_null, dbeta_null, N, i);
+      layer_norm_backward_frame<T, T2, opmath_t>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, dgamma_buffer_ptr, dbeta_buffer_ptr, scale, gamma_null, dX_null, dgamma_null, dbeta_null, N, i);
     }
   });
 
@@ -573,8 +567,8 @@
   if (buffer_data != nullptr) {
     parallel_for(0, N, 1, [&](int64_t start, int64_t end) {
       for (const auto j : c10::irange(start, end)) {
-        T_ACC dgamma_v = T_ACC(0);
-        T_ACC dbeta_v = T_ACC(0);
+        opmath_t dgamma_v = opmath_t(0);
+        opmath_t dbeta_v = opmath_t(0);
         for (const auto i : c10::irange(num_threads)) {
           dgamma_v += buffer_data[i * N + j];
           dbeta_v += buffer_data[num_threads * N + i * N + j];
@@ -603,16 +597,22 @@
     Tensor* dX,
     Tensor* dgamma,
     Tensor* dbeta) {
-  AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, X.scalar_type(),
-      "LayerNormBackwardKernelImpl", [&]() {
-    if (X.scalar_type() == at::kBFloat16 && gamma.scalar_type() == at::kFloat) {
-      LayerNormBackwardKernelImplInternal<BFloat16, float>(
-          dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
-    } else {
+  if (at::isReducedFloatingType(X.scalar_type())) {
+    AT_DISPATCH_REDUCED_FLOATING_TYPES(X.scalar_type(), "LayerNormBackwardKernelImpl", [&]() {
+      if (gamma.scalar_type() == at::kFloat) {
+        LayerNormBackwardKernelImplInternal<scalar_t, float>(
+            dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
+      } else {
+        LayerNormBackwardKernelImplInternal<scalar_t, scalar_t>(
+            dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
+      }
+      });
+  } else {
+    AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "LayerNormBackwardKernelImpl", [&]() {
       LayerNormBackwardKernelImplInternal<scalar_t, scalar_t>(
           dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
-    }
-  });
+    });
+  }
 }
 
 } // namespace
diff --git a/aten/src/ATen/native/cpu/moments_utils.h b/aten/src/ATen/native/cpu/moments_utils.h
index 3f60309..c89aa6b 100644
--- a/aten/src/ATen/native/cpu/moments_utils.h
+++ b/aten/src/ATen/native/cpu/moments_utils.h
@@ -76,7 +76,7 @@
   AddMomentsVec(m0, m1_vec, m2_vec, m0_stk0, m1_stk0, m2_stk0);
 }
 
-// each bfloat16 vector will be converted to two float vectors,
+// each bfloat16/half vector will be converted to two float vectors,
 // and accumulated successively on m1_stk0/m2_stk0.
 template <typename T>
 inline typename std::enable_if<!std::is_same<T, at::opmath_type<T>>::value, void>::type
diff --git a/test/test_meta.py b/test/test_meta.py
index e856025..1967da3 100644
--- a/test/test_meta.py
+++ b/test/test_meta.py
@@ -710,7 +710,7 @@
 meta_function_device_expected_failures['cpu'] = {
     torch.native_batch_norm: {bf16, f16},
     torch._native_batch_norm_legit: {bf16, f16},
-    torch.native_layer_norm: {bf16},
+    torch.native_layer_norm: {bf16, f16},
 }
 
 meta_function_device_expected_failures['cuda'] = {
@@ -855,7 +855,7 @@
     aten.native_batch_norm.default: {bf16, f16},
     aten._native_batch_norm_legit.default: {bf16, f16},
     aten._native_batch_norm_legit.no_stats: {bf16, f16},
-    aten.native_layer_norm.default: {bf16},
+    aten.native_layer_norm.default: {bf16, f16},
     aten.histc.default: {f16},
     aten.histc.out: {f16},
 }
diff --git a/test/test_mps.py b/test/test_mps.py
index 73db4c4..ae43dd5 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -11104,6 +11104,8 @@
         'cross', 'linalg.cross',
         'prod', 'masked.prod',
         'nextafter',
+        'native_layer_norm',
+        'nn.functional.layer_norm',
 
         # for macOS 12
         'masked.normalize', 'masked.sum', 'masked.var',
diff --git a/test/test_nn.py b/test/test_nn.py
index 0876075..81fd93b 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -7948,7 +7948,7 @@
             mean = out_reshaped.mean(-1)
             var = out_reshaped.var(-1, unbiased=False)
 
-            delta = 1e-1 if dtype == torch.bfloat16 else 1e-5
+            delta = 1e-1 if (dtype == torch.bfloat16 or dtype == torch.half) else 1e-5
             self.assertEqual(torch.abs(mean.data).mean(), 0, atol=delta, rtol=0)
             self.assertEqual(torch.abs(var.data).mean(), 1, atol=delta, rtol=0)
 
@@ -7982,12 +7982,12 @@
         output.sum().backward()
         self.assertEqualTypeString(output, input)
 
-    def _test_LayerNorm_cpu_mixed_dtype(self, device):
+    def _test_LayerNorm_cpu_mixed_dtype(self, device, dtype):
         for elementwise_affine in [True, False]:
             # layer norm input shape is normalized to m x n, cpu vectorized on n,
             # so make sure n exceeds vector length
-            input = torch.empty(2, 3, 11, 3, device=device, dtype=torch.bfloat16).random_(1, 10)
-            m = nn.LayerNorm([11, 3], elementwise_affine=elementwise_affine).to(device, torch.bfloat16)
+            input = torch.empty(2, 3, 11, 3, device=device, dtype=dtype).random_(1, 10)
+            m = nn.LayerNorm([11, 3], elementwise_affine=elementwise_affine).to(device, dtype)
 
             # fp32
             m_fp32 = deepcopy(m).to(device, torch.float)
@@ -7995,21 +7995,21 @@
             out_fp32 = m_fp32(x_fp32)
             out_fp32.sum().backward()
 
-            # bf16
+            # bf16/half
             m_bf16 = deepcopy(m)
             x_bf16 = input.clone().detach().requires_grad_()
             out_bf16 = m_bf16(x_bf16)
             out_bf16.sum().backward()
 
-            # bf16 mixed type
+            # bf16/half mixed type
             m_mix = deepcopy(m).to(device, torch.float)
             x_mix = input.clone().detach().requires_grad_()
             out_mix = m_mix(x_mix)
             out_mix.sum().backward()
-            self.assertEqual(out_fp32.bfloat16(), out_bf16)
-            self.assertEqual(out_fp32.bfloat16(), out_mix)
-            self.assertEqual(x_fp32.grad.bfloat16(), x_bf16.grad, atol=1e-1, rtol=1e-1)
-            self.assertEqual(x_fp32.grad.bfloat16(), x_mix.grad, atol=1e-1, rtol=1e-1)
+            self.assertEqual(out_fp32.to(dtype=dtype), out_bf16)
+            self.assertEqual(out_fp32.to(dtype=dtype), out_mix)
+            self.assertEqual(x_fp32.grad.to(dtype=dtype), x_bf16.grad, atol=1e-1, rtol=1e-1)
+            self.assertEqual(x_fp32.grad.to(dtype=dtype), x_mix.grad, atol=1e-1, rtol=1e-1)
 
     def _test_GroupNorm_general(self, device, dtype=torch.float):
         good_shape_g = {
@@ -8518,13 +8518,15 @@
         self._test_LayerNorm_general(device)
 
         if self.device_type == 'cuda' or self.device_type == 'cpu':
-            self._test_LayerNorm_general(device, dtype=torch.bfloat16)
+            for dtype in [torch.half, torch.bfloat16]:
+                self._test_LayerNorm_general(device, dtype=dtype)
 
         if self.device_type == 'cuda':
             self._test_LayerNorm_cuda_half(device)
 
         if self.device_type == 'cpu':
-            self._test_LayerNorm_cpu_mixed_dtype(device)
+            for dtype in [torch.half, torch.bfloat16]:
+                self._test_LayerNorm_cpu_mixed_dtype(device, dtype=dtype)
 
     @onlyNativeDeviceTypes
     def test_LayerNorm_numeric(self, device):
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index eb02724..cafcd84 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -12844,8 +12844,7 @@
     OpInfo('native_layer_norm',
            aten_name='native_layer_norm',
            ref=reference_native_layer_norm,
-           dtypes=floating_types_and(torch.bfloat16),
-           dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+           dtypes=floating_types_and(torch.half, torch.bfloat16),
            supports_out=False,
            assert_jit_shape_analysis=True,
            supports_fwgrad_bwgrad=True,
@@ -13378,8 +13377,7 @@
            aten_backward_name='layer_norm_backward',
            aliases=('layer_norm',),
            ref=reference_layer_norm,
-           dtypes=floating_types_and(torch.bfloat16),
-           dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+           dtypes=floating_types_and(torch.half, torch.bfloat16),
            supports_out=False,
            supports_forward_ad=True,
            supports_fwgrad_bwgrad=True,