add Half support for layer_norm on CPU (#99590)
### Testing
Single socket (icx, 32cores):
| shape | fp32 forward (ms) | fp16 forward (ms) | mixed fp32 fp16 forward (ms) | fp32 backward (ms) | fp16 backward (ms) | mixed fp32 fp16 backward (ms) |
| -- | -- | -- | -- | -- | -- | -- |
| (1, 8, 16) | 0.012 | 0.011 | 0.011 | 0.051 | 0.051 | 0.050 |
| (8 ,8, 16) | 0.013 | 0.013 | 0.013 | 0.054 | 0.053 | 0.051 |
| (32, 8, 16) | 0.015 | 0.014 | 0.014 | 0.059 | 0.054 | 0.052 |
| (64, 128, 56, 56) | 1.875 | 0.790 | 1.016 | 12.845 | 7.151 | 6.985 |
| (64, 128, 256, 256) | 50.226 | 25.462 | 35.736 | 328.957 | 179.615 | 175.618 |
Single core (icx):
| shape | fp32 forward (ms) | fp16 forward (ms) | mixed fp32 fp16 forward (ms) | fp32 backward (ms) | fp16 backward (ms) | mixed fp32 fp16 backward (ms) |
| -- | -- | -- | -- | -- | -- | -- |
| (1, 8, 16) | 0.012 | 0.011 | 0.011 | 0.040 | 0.041 | 0.041 |
| (8 ,8, 16) | 0.012 | 0.012 | 0.012 | 0.042 | 0.042 | 0.042 |
| (32, 8, 16) | 0.027 | 0.014 | 0.014 | 0.048 | 0.048 | 0.046 |
| (64, 128, 56, 56) | 58.054 | 11.034 | 17.928 | 108.603 | 48.816 | 50.244 |
| (64, 128, 256, 256) | 1327.758 | 352.394 | 496.994 | 2846.182 | 1224.247 | 1218.422 |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99590
Approved by: https://github.com/mingfeima, https://github.com/jgong5, https://github.com/cpuhrsch
diff --git a/aten/src/ATen/cpu/vec/functional_bfloat16.h b/aten/src/ATen/cpu/vec/functional_bfloat16.h
index 74ca8b1..03cb017 100644
--- a/aten/src/ATen/cpu/vec/functional_bfloat16.h
+++ b/aten/src/ATen/cpu/vec/functional_bfloat16.h
@@ -69,7 +69,7 @@
//
template <typename scalar_t, typename Op,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
-inline scalar_t reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) {
+inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) {
using bVec = vec::Vectorized<scalar_t>;
using fVec = vec::Vectorized<float>;
if (size < bVec::size()) {
@@ -111,7 +111,7 @@
template <typename scalar_t, typename Op1, typename Op2,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
-inline std::pair<scalar_t, scalar_t> reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2,
+inline std::pair<float, float> reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2,
const scalar_t* data, int64_t size) {
using bVec = vec::Vectorized<scalar_t>;
using fVec = vec::Vectorized<float>;
@@ -169,7 +169,7 @@
template <typename scalar_t, typename MapOp, typename ReduceOp,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
-inline scalar_t map_reduce_all(
+inline float map_reduce_all(
const MapOp& map_fun,
const ReduceOp& red_fun,
const scalar_t* data,
@@ -225,7 +225,7 @@
template <typename scalar_t, typename MapOp, typename ReduceOp,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
-inline scalar_t map2_reduce_all(
+inline float map2_reduce_all(
const MapOp& map_fun,
const ReduceOp& red_fun,
const scalar_t* data,
@@ -294,7 +294,7 @@
template <typename scalar_t, typename MapOp, typename ReduceOp,
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
-inline scalar_t map3_reduce_all(
+inline float map3_reduce_all(
const MapOp& map_fun,
const ReduceOp& red_fun,
const scalar_t* data,
diff --git a/aten/src/ATen/native/cpu/layer_norm_kernel.cpp b/aten/src/ATen/native/cpu/layer_norm_kernel.cpp
index a0c3e09..a668305 100644
--- a/aten/src/ATen/native/cpu/layer_norm_kernel.cpp
+++ b/aten/src/ATen/native/cpu/layer_norm_kernel.cpp
@@ -23,14 +23,15 @@
namespace {
-template <typename T, typename T_ACC>
+template <typename T,
+ typename std::enable_if_t<!is_reduced_floating_point_v<T>, int> = 0>
void LayerNormKernelImplInternal(
const Tensor& X,
const Tensor& gamma,
const Tensor& beta,
int64_t M,
int64_t N,
- T_ACC eps,
+ T eps,
Tensor* Y,
Tensor* mean,
Tensor* rstd) {
@@ -83,7 +84,8 @@
});
}
-template <typename param_t>
+template <typename T, typename param_t,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
void layer_norm_kernel_mixed_type(
const Tensor& X,
const Tensor& gamma,
@@ -94,12 +96,12 @@
Tensor* Y,
Tensor* mean,
Tensor* rstd) {
- using bVec = Vectorized<BFloat16>;
+ using bVec = Vectorized<T>;
using fVec = Vectorized<float>;
- const BFloat16* X_data = X.data_ptr<BFloat16>();
+ const T* X_data = X.data_ptr<T>();
const param_t* gamma_data = gamma.defined() ? gamma.data_ptr<param_t>() : nullptr;
const param_t* beta_data = beta.defined() ? beta.data_ptr<param_t>() : nullptr;
- BFloat16* Y_data = Y->data_ptr<BFloat16>();
+ T* Y_data = Y->data_ptr<T>();
param_t* mean_data = mean ? mean->data_ptr<param_t>() : nullptr;
param_t* rstd_data = rstd ? rstd->data_ptr<param_t>() : nullptr;
@@ -109,38 +111,29 @@
const bool rstd_null = rstd_data == nullptr;
at::parallel_for(0, M, 1, [&](int64_t start, int64_t end) {
for (const auto i : c10::irange(start, end)) {
- const BFloat16* X_ptr = X_data + i * N;
- BFloat16* Y_ptr = Y_data + i * N;
+ const T* X_ptr = X_data + i * N;
+ T* Y_ptr = Y_data + i * N;
float mean_val;
float rstd_val;
std::tie(mean_val, rstd_val) = RowwiseMoments(X_ptr, N);
rstd_val = float(1) / std::sqrt(rstd_val + eps);
const float scale = rstd_val;
const float bias = -rstd_val * mean_val;
- if (gamma_null || beta_null) {
- for (const auto j : c10::irange(N)) {
- const param_t gamma_v = gamma_null ? param_t(1) : gamma_data[j];
- const param_t beta_v = beta_null ? param_t(0) : beta_data[j];
- Y_ptr[j] = (X_ptr[j] * scale + bias) * gamma_v + beta_v;
- }
- } else {
- int64_t d = 0;
- for (; d < N - (N % bVec::size()); d += bVec::size()) {
- bVec x_bvec = bVec::loadu(X_ptr + d);
- fVec x_fvec0, x_fvec1;
- std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
- fVec gamma_fvec0, gamma_fvec1;
- std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data + d);
- fVec beta_fvec0, beta_fvec1;
- std::tie(beta_fvec0, beta_fvec1) = load2f(beta_data + d);
- fVec y_fvec0 = (x_fvec0 * fVec(scale) + fVec(bias)) * gamma_fvec0 + beta_fvec0;
- fVec y_fvec1 = (x_fvec1 * fVec(scale) + fVec(bias)) * gamma_fvec1 + beta_fvec1;
- bVec y_bvec = convert_float_bfloat16(y_fvec0, y_fvec1);
- y_bvec.store(Y_ptr + d);
- }
- for (; d < N; d++) {
- Y_ptr[d] = (X_ptr[d] * scale + bias) * gamma_data[d] + beta_data[d];
- }
+ int64_t d = 0;
+ for (; d < N - (N % bVec::size()); d += bVec::size()) {
+ bVec x_bvec = bVec::loadu(X_ptr + d);
+ auto [x_fvec0, x_fvec1] = convert_to_float<T>(x_bvec);
+ auto [gamma_fvec0, gamma_fvec1] = gamma_null ? std::make_tuple(fVec(1), fVec(1)) : load2f(gamma_data + d);
+ auto [beta_fvec0, beta_fvec1] = beta_null ? std::make_tuple(fVec(0), fVec(0)) : load2f(beta_data + d);
+ fVec y_fvec0 = (x_fvec0 * fVec(scale) + fVec(bias)) * gamma_fvec0 + beta_fvec0;
+ fVec y_fvec1 = (x_fvec1 * fVec(scale) + fVec(bias)) * gamma_fvec1 + beta_fvec1;
+ bVec y_bvec = convert_from_float<T>(y_fvec0, y_fvec1);
+ y_bvec.store(Y_ptr + d);
+ }
+ for (; d < N; d++) {
+ const float gamma_v = gamma_null ? float(1) : float(gamma_data[d]);
+ const float beta_v = beta_null ? float(0) : float(beta_data[d]);
+ Y_ptr[d] = (float(X_ptr[d]) * scale + bias) * gamma_v + beta_v;
}
if (!mean_null) {
mean_data[i] = mean_val;
@@ -152,8 +145,9 @@
});
}
-template <>
-void LayerNormKernelImplInternal<BFloat16, float>(
+template <typename T,
+ typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0>
+void LayerNormKernelImplInternal(
const Tensor& X,
const Tensor& gamma,
const Tensor& beta,
@@ -165,9 +159,9 @@
Tensor* rstd) {
const bool mixed_type = is_mixed_type(X, gamma, beta);
if (mixed_type) {
- layer_norm_kernel_mixed_type<float>(X, gamma, beta, M, N, eps, Y, mean, rstd);
+ layer_norm_kernel_mixed_type<T, float>(X, gamma, beta, M, N, eps, Y, mean, rstd);
} else {
- layer_norm_kernel_mixed_type<BFloat16>(X, gamma, beta, M, N, eps, Y, mean, rstd);
+ layer_norm_kernel_mixed_type<T, T>(X, gamma, beta, M, N, eps, Y, mean, rstd);
}
}
@@ -184,15 +178,14 @@
TORCH_DCHECK_EQ(X.numel(), M * N);
DCHECK(!gamma.defined() || gamma.numel() == N);
DCHECK(!beta.defined() || beta.numel() == N);
- AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, X.scalar_type(),
+ AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, X.scalar_type(),
"LayerNormKernelImpl", [&]() {
- using acc_t = at::opmath_type<scalar_t>;
- LayerNormKernelImplInternal<scalar_t, acc_t>(
- X, gamma, beta, M, N, static_cast<acc_t>(eps), Y, mean, rstd);
+ LayerNormKernelImplInternal<scalar_t>(
+ X, gamma, beta, M, N, eps, Y, mean, rstd);
});
}
-template <typename T, typename T2, typename T_ACC>
+template <typename T, typename T2, typename opmath_t>
void layer_norm_backward_frame(
const T* dY_data,
const T* X_data,
@@ -202,19 +195,19 @@
T* dX_data,
T* dgamma_buffer_ptr,
T* dbeta_buffer_ptr,
- const T_ACC scale,
+ const opmath_t scale,
const bool gamma_null,
const bool dX_null,
const bool dgamma_null,
const bool dbeta_null,
int64_t N,
int64_t i) {
- using Vec = vec::Vectorized<T_ACC>;
+ using Vec = vec::Vectorized<opmath_t>;
const T* dY_ptr = dY_data + i * N;
const T* X_ptr = X_data + i * N;
if (!dgamma_null) {
- const T_ACC a = rstd_data[i];
- const T_ACC b = -a * mean_data[i];
+ const opmath_t a = rstd_data[i];
+ const opmath_t b = -a * mean_data[i];
// Scalar math:
// for (const auto j : c10::irange(N)) {
// dgamma_data[j] += dY_ptr[j] * (a * X_ptr[j] + b);
@@ -243,8 +236,8 @@
}
if (!dX_null) {
T* dX_ptr = dX_data + i * N;
- T_ACC ds = T_ACC(0);
- T_ACC db = T_ACC(0);
+ opmath_t ds = opmath_t(0);
+ opmath_t db = opmath_t(0);
// Scalar math:
// for (const auto j : c10::irange(N)) {
// const T gamma_v = gamma_null ? T(1) : gamma_data[j];
@@ -275,9 +268,9 @@
gamma_data,
N);
}
- const T_ACC a = rstd_data[i];
- const T_ACC b = (db * mean_data[i] - ds) * a * a * a * scale;
- const T_ACC c = -b * mean_data[i] - db * a * scale;
+ const opmath_t a = rstd_data[i];
+ const opmath_t b = (db * opmath_t(mean_data[i]) - ds) * a * a * a * scale;
+ const opmath_t c = -b * opmath_t(mean_data[i]) - db * a * scale;
// Scalar math:
// for (const auto j : c10::irange(N)) {
// const T gamma_v = gamma_null ? T(1) : gamma_data[j];
@@ -306,16 +299,17 @@
}
}
-template <>
-void layer_norm_backward_frame<BFloat16, float, float>(
- const BFloat16* dY_data,
- const BFloat16* X_data,
+template <typename T, typename T2, typename opmath_t,
+ typename std::enable_if_t<is_reduced_floating_point_v<T> && std::is_same<T2, float>::value, int> = 0>
+void layer_norm_backward_frame(
+ const T* dY_data,
+ const T* X_data,
const float* mean_data,
const float* rstd_data,
const float* gamma_data,
- BFloat16* dX_data,
- BFloat16* dgamma_buffer_ptr,
- BFloat16* dbeta_buffer_ptr,
+ T* dX_data,
+ T* dgamma_buffer_ptr,
+ T* dbeta_buffer_ptr,
const float scale,
const bool gamma_null,
const bool dX_null,
@@ -323,10 +317,10 @@
const bool dbeta_null,
int64_t N,
int64_t i) {
- using bVec = Vectorized<BFloat16>;
+ using bVec = Vectorized<T>;
using fVec = Vectorized<float>;
- const BFloat16* dY_ptr = dY_data + i * N;
- const BFloat16* X_ptr = X_data + i * N;
+ const T* dY_ptr = dY_data + i * N;
+ const T* X_ptr = X_data + i * N;
if (!dgamma_null) {
const float a = rstd_data[i];
const float b = -a * mean_data[i];
@@ -334,7 +328,7 @@
// for (const auto j : c10::irange(N)) {
// dgamma_data[j] += dY_ptr[j] * (a * X_ptr[j] + b);
// }
- vec::map3<BFloat16>(
+ vec::map3<T>(
[a, b](fVec dgamma, fVec dy, fVec x) {
return dgamma + dy * (fVec(a) * x + fVec(b));
},
@@ -349,7 +343,7 @@
// for (const auto j : c10::irange(N)) {
// dbeta_data[j] += dY_ptr[j];
// }
- vec::map2<BFloat16>(
+ vec::map2<T>(
[](fVec dbeta, fVec dy) { return dbeta + dy; },
dbeta_buffer_ptr,
dbeta_buffer_ptr,
@@ -357,7 +351,7 @@
N);
}
if (!dX_null) {
- BFloat16* dX_ptr = dX_data + i * N;
+ T* dX_ptr = dX_data + i * N;
float ds = float(0);
float db = float(0);
// Scalar math:
@@ -367,21 +361,21 @@
// db += dY_ptr[j] * gamma_v;
// }
if (gamma_null) {
- ds = vec::map2_reduce_all<BFloat16>(
+ ds = vec::map2_reduce_all<T>(
[](fVec x, fVec y) { return x * y; },
[](fVec x, fVec y) { return x + y; },
dY_ptr,
X_ptr,
N);
- db = vec::reduce_all<BFloat16>(
+ db = vec::reduce_all<T>(
[](fVec& x, fVec& y) { return x + y; }, dY_ptr, N);
} else {
if (N < bVec::size()) {
bVec x_bvec = bVec::loadu(X_ptr, N);
bVec dy_bvec = bVec::loadu(dY_ptr, N);
fVec x_fvec0, x_fvec1, dy_fvec0, dy_fvec1, gamma_fvec0, gamma_fvec1;
- std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
- std::tie(dy_fvec0, dy_fvec1) = convert_bfloat16_float(dy_bvec);
+ std::tie(x_fvec0, x_fvec1) = convert_to_float<T>(x_bvec);
+ std::tie(dy_fvec0, dy_fvec1) = convert_to_float<T>(dy_bvec);
std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data, N);
if (N > fVec::size()) {
fVec db_fvec0 = dy_fvec0 * gamma_fvec0;
@@ -404,8 +398,8 @@
bVec dy_bvec = bVec::loadu(dY_ptr);
fVec x_fvec0, x_fvec1, dy_fvec0, dy_fvec1, gamma_fvec0, gamma_fvec1;
fVec ds_fvec0, ds_fvec1, db_fvec0, db_fvec1, acc_ds_fvec0, acc_ds_fvec1, acc_db_fvec0, acc_db_fvec1;
- std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
- std::tie(dy_fvec0, dy_fvec1) = convert_bfloat16_float(dy_bvec);
+ std::tie(x_fvec0, x_fvec1) = convert_to_float<T>(x_bvec);
+ std::tie(dy_fvec0, dy_fvec1) = convert_to_float<T>(dy_bvec);
std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data);
acc_db_fvec0 = dy_fvec0 * gamma_fvec0;
acc_db_fvec1 = dy_fvec1 * gamma_fvec1;
@@ -414,8 +408,8 @@
for (; d < N - (N % bVec::size()); d += bVec::size()) {
x_bvec = bVec::loadu(X_ptr + d);
dy_bvec = bVec::loadu(dY_ptr + d);
- std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
- std::tie(dy_fvec0, dy_fvec1) = convert_bfloat16_float(dy_bvec);
+ std::tie(x_fvec0, x_fvec1) = convert_to_float<T>(x_bvec);
+ std::tie(dy_fvec0, dy_fvec1) = convert_to_float<T>(dy_bvec);
std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data + d);
db_fvec0 = dy_fvec0 * gamma_fvec0;
db_fvec1 = dy_fvec1 * gamma_fvec1;
@@ -429,8 +423,8 @@
if (N - d > 0) {
x_bvec = bVec::loadu(X_ptr + d, N - d);
dy_bvec = bVec::loadu(dY_ptr + d, N - d);
- std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
- std::tie(dy_fvec0, dy_fvec1) = convert_bfloat16_float(dy_bvec);
+ std::tie(x_fvec0, x_fvec1) = convert_to_float<T>(x_bvec);
+ std::tie(dy_fvec0, dy_fvec1) = convert_to_float<T>(dy_bvec);
std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data + d, N - d);
if (N - d > fVec::size()) {
db_fvec0 = dy_fvec0 * gamma_fvec0;
@@ -463,7 +457,7 @@
// dX_ptr[j] = a * dY_ptr[j] * gamma_v + b * X_ptr[j] + c;
// }
if (gamma_null) {
- vec::map2<BFloat16>(
+ vec::map2<T>(
[a, b, c](fVec dy, fVec x) {
return fVec(a) * dy + fVec(b) * x + fVec(c);
},
@@ -477,24 +471,24 @@
bVec x_bvec = bVec::loadu(X_ptr + d);
bVec dy_bvec = bVec::loadu(dY_ptr + d);
fVec x_fvec0, x_fvec1, dy_fvec0, dy_fvec1, gamma_fvec0, gamma_fvec1;
- std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
- std::tie(dy_fvec0, dy_fvec1) = convert_bfloat16_float(dy_bvec);
+ std::tie(x_fvec0, x_fvec1) = convert_to_float<T>(x_bvec);
+ std::tie(dy_fvec0, dy_fvec1) = convert_to_float<T>(dy_bvec);
std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data + d);
fVec r_fvec0 = fVec(a) * dy_fvec0 * gamma_fvec0 + fVec(b) * x_fvec0 + fVec(c);
fVec r_fvec1 = fVec(a) * dy_fvec1 * gamma_fvec1 + fVec(b) * x_fvec1 + fVec(c);
- bVec r_bvec = convert_float_bfloat16(r_fvec0, r_fvec1);
+ bVec r_bvec = convert_from_float<T>(r_fvec0, r_fvec1);
r_bvec.store(dX_ptr + d);
}
if (N - d > 0) {
bVec x_bvec = bVec::loadu(X_ptr + d, N - d);
bVec dy_bvec = bVec::loadu(dY_ptr + d, N - d);
fVec x_fvec0, x_fvec1, dy_fvec0, dy_fvec1, gamma_fvec0, gamma_fvec1;
- std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
- std::tie(dy_fvec0, dy_fvec1) = convert_bfloat16_float(dy_bvec);
+ std::tie(x_fvec0, x_fvec1) = convert_to_float<T>(x_bvec);
+ std::tie(dy_fvec0, dy_fvec1) = convert_to_float<T>(dy_bvec);
std::tie(gamma_fvec0, gamma_fvec1) = load2f(gamma_data + d, N - d);
fVec r_fvec0 = fVec(a) * dy_fvec0 * gamma_fvec0 + fVec(b) * x_fvec0 + fVec(c);
fVec r_fvec1 = fVec(a) * dy_fvec1 * gamma_fvec1 + fVec(b) * x_fvec1 + fVec(c);
- bVec r_bvec = convert_float_bfloat16(r_fvec0, r_fvec1);
+ bVec r_bvec = convert_from_float<T>(r_fvec0, r_fvec1);
r_bvec.store(dX_ptr + d, N - d);
}
}
@@ -513,7 +507,7 @@
Tensor* dX,
Tensor* dgamma,
Tensor* dbeta) {
- using T_ACC = at::opmath_type<T>;
+ using opmath_t = at::opmath_type<T>;
TORCH_DCHECK_EQ(dY.numel(), M * N);
TORCH_DCHECK_EQ(X.numel(), M * N);
TORCH_DCHECK_EQ(mean.numel(), M);
@@ -528,7 +522,7 @@
T* dX_data = dX->defined() ? dX->template data_ptr<T>() : nullptr;
T2* dgamma_data = dgamma->defined() ? dgamma->template data_ptr<T2>() : nullptr;
T2* dbeta_data = dbeta->defined() ? dbeta->template data_ptr<T2>() : nullptr;
- const T_ACC scale = T_ACC(1) / static_cast<T_ACC>(N);
+ const opmath_t scale = opmath_t(1) / static_cast<opmath_t>(N);
const bool gamma_null = gamma_data == nullptr;
const bool dX_null = dX_data == nullptr;
const bool dgamma_null = dgamma_data == nullptr;
@@ -565,7 +559,7 @@
T* dbeta_buffer_ptr =
dbeta_null ? nullptr : buffer_data + num_threads * N + tid * N;
for (const auto i : c10::irange(start, end)) {
- layer_norm_backward_frame<T, T2, T_ACC>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, dgamma_buffer_ptr, dbeta_buffer_ptr, scale, gamma_null, dX_null, dgamma_null, dbeta_null, N, i);
+ layer_norm_backward_frame<T, T2, opmath_t>(dY_data, X_data, mean_data, rstd_data, gamma_data, dX_data, dgamma_buffer_ptr, dbeta_buffer_ptr, scale, gamma_null, dX_null, dgamma_null, dbeta_null, N, i);
}
});
@@ -573,8 +567,8 @@
if (buffer_data != nullptr) {
parallel_for(0, N, 1, [&](int64_t start, int64_t end) {
for (const auto j : c10::irange(start, end)) {
- T_ACC dgamma_v = T_ACC(0);
- T_ACC dbeta_v = T_ACC(0);
+ opmath_t dgamma_v = opmath_t(0);
+ opmath_t dbeta_v = opmath_t(0);
for (const auto i : c10::irange(num_threads)) {
dgamma_v += buffer_data[i * N + j];
dbeta_v += buffer_data[num_threads * N + i * N + j];
@@ -603,16 +597,22 @@
Tensor* dX,
Tensor* dgamma,
Tensor* dbeta) {
- AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, X.scalar_type(),
- "LayerNormBackwardKernelImpl", [&]() {
- if (X.scalar_type() == at::kBFloat16 && gamma.scalar_type() == at::kFloat) {
- LayerNormBackwardKernelImplInternal<BFloat16, float>(
- dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
- } else {
+ if (at::isReducedFloatingType(X.scalar_type())) {
+ AT_DISPATCH_REDUCED_FLOATING_TYPES(X.scalar_type(), "LayerNormBackwardKernelImpl", [&]() {
+ if (gamma.scalar_type() == at::kFloat) {
+ LayerNormBackwardKernelImplInternal<scalar_t, float>(
+ dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
+ } else {
+ LayerNormBackwardKernelImplInternal<scalar_t, scalar_t>(
+ dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
+ }
+ });
+ } else {
+ AT_DISPATCH_FLOATING_TYPES(X.scalar_type(), "LayerNormBackwardKernelImpl", [&]() {
LayerNormBackwardKernelImplInternal<scalar_t, scalar_t>(
dY.contiguous(), X, mean, rstd, gamma, M, N, dX, dgamma, dbeta);
- }
- });
+ });
+ }
}
} // namespace
diff --git a/aten/src/ATen/native/cpu/moments_utils.h b/aten/src/ATen/native/cpu/moments_utils.h
index 3f60309..c89aa6b 100644
--- a/aten/src/ATen/native/cpu/moments_utils.h
+++ b/aten/src/ATen/native/cpu/moments_utils.h
@@ -76,7 +76,7 @@
AddMomentsVec(m0, m1_vec, m2_vec, m0_stk0, m1_stk0, m2_stk0);
}
-// each bfloat16 vector will be converted to two float vectors,
+// each bfloat16/half vector will be converted to two float vectors,
// and accumulated successively on m1_stk0/m2_stk0.
template <typename T>
inline typename std::enable_if<!std::is_same<T, at::opmath_type<T>>::value, void>::type
diff --git a/test/test_meta.py b/test/test_meta.py
index e856025..1967da3 100644
--- a/test/test_meta.py
+++ b/test/test_meta.py
@@ -710,7 +710,7 @@
meta_function_device_expected_failures['cpu'] = {
torch.native_batch_norm: {bf16, f16},
torch._native_batch_norm_legit: {bf16, f16},
- torch.native_layer_norm: {bf16},
+ torch.native_layer_norm: {bf16, f16},
}
meta_function_device_expected_failures['cuda'] = {
@@ -855,7 +855,7 @@
aten.native_batch_norm.default: {bf16, f16},
aten._native_batch_norm_legit.default: {bf16, f16},
aten._native_batch_norm_legit.no_stats: {bf16, f16},
- aten.native_layer_norm.default: {bf16},
+ aten.native_layer_norm.default: {bf16, f16},
aten.histc.default: {f16},
aten.histc.out: {f16},
}
diff --git a/test/test_mps.py b/test/test_mps.py
index 73db4c4..ae43dd5 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -11104,6 +11104,8 @@
'cross', 'linalg.cross',
'prod', 'masked.prod',
'nextafter',
+ 'native_layer_norm',
+ 'nn.functional.layer_norm',
# for macOS 12
'masked.normalize', 'masked.sum', 'masked.var',
diff --git a/test/test_nn.py b/test/test_nn.py
index 0876075..81fd93b 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -7948,7 +7948,7 @@
mean = out_reshaped.mean(-1)
var = out_reshaped.var(-1, unbiased=False)
- delta = 1e-1 if dtype == torch.bfloat16 else 1e-5
+ delta = 1e-1 if (dtype == torch.bfloat16 or dtype == torch.half) else 1e-5
self.assertEqual(torch.abs(mean.data).mean(), 0, atol=delta, rtol=0)
self.assertEqual(torch.abs(var.data).mean(), 1, atol=delta, rtol=0)
@@ -7982,12 +7982,12 @@
output.sum().backward()
self.assertEqualTypeString(output, input)
- def _test_LayerNorm_cpu_mixed_dtype(self, device):
+ def _test_LayerNorm_cpu_mixed_dtype(self, device, dtype):
for elementwise_affine in [True, False]:
# layer norm input shape is normalized to m x n, cpu vectorized on n,
# so make sure n exceeds vector length
- input = torch.empty(2, 3, 11, 3, device=device, dtype=torch.bfloat16).random_(1, 10)
- m = nn.LayerNorm([11, 3], elementwise_affine=elementwise_affine).to(device, torch.bfloat16)
+ input = torch.empty(2, 3, 11, 3, device=device, dtype=dtype).random_(1, 10)
+ m = nn.LayerNorm([11, 3], elementwise_affine=elementwise_affine).to(device, dtype)
# fp32
m_fp32 = deepcopy(m).to(device, torch.float)
@@ -7995,21 +7995,21 @@
out_fp32 = m_fp32(x_fp32)
out_fp32.sum().backward()
- # bf16
+ # bf16/half
m_bf16 = deepcopy(m)
x_bf16 = input.clone().detach().requires_grad_()
out_bf16 = m_bf16(x_bf16)
out_bf16.sum().backward()
- # bf16 mixed type
+ # bf16/half mixed type
m_mix = deepcopy(m).to(device, torch.float)
x_mix = input.clone().detach().requires_grad_()
out_mix = m_mix(x_mix)
out_mix.sum().backward()
- self.assertEqual(out_fp32.bfloat16(), out_bf16)
- self.assertEqual(out_fp32.bfloat16(), out_mix)
- self.assertEqual(x_fp32.grad.bfloat16(), x_bf16.grad, atol=1e-1, rtol=1e-1)
- self.assertEqual(x_fp32.grad.bfloat16(), x_mix.grad, atol=1e-1, rtol=1e-1)
+ self.assertEqual(out_fp32.to(dtype=dtype), out_bf16)
+ self.assertEqual(out_fp32.to(dtype=dtype), out_mix)
+ self.assertEqual(x_fp32.grad.to(dtype=dtype), x_bf16.grad, atol=1e-1, rtol=1e-1)
+ self.assertEqual(x_fp32.grad.to(dtype=dtype), x_mix.grad, atol=1e-1, rtol=1e-1)
def _test_GroupNorm_general(self, device, dtype=torch.float):
good_shape_g = {
@@ -8518,13 +8518,15 @@
self._test_LayerNorm_general(device)
if self.device_type == 'cuda' or self.device_type == 'cpu':
- self._test_LayerNorm_general(device, dtype=torch.bfloat16)
+ for dtype in [torch.half, torch.bfloat16]:
+ self._test_LayerNorm_general(device, dtype=dtype)
if self.device_type == 'cuda':
self._test_LayerNorm_cuda_half(device)
if self.device_type == 'cpu':
- self._test_LayerNorm_cpu_mixed_dtype(device)
+ for dtype in [torch.half, torch.bfloat16]:
+ self._test_LayerNorm_cpu_mixed_dtype(device, dtype=dtype)
@onlyNativeDeviceTypes
def test_LayerNorm_numeric(self, device):
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index eb02724..cafcd84 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -12844,8 +12844,7 @@
OpInfo('native_layer_norm',
aten_name='native_layer_norm',
ref=reference_native_layer_norm,
- dtypes=floating_types_and(torch.bfloat16),
- dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+ dtypes=floating_types_and(torch.half, torch.bfloat16),
supports_out=False,
assert_jit_shape_analysis=True,
supports_fwgrad_bwgrad=True,
@@ -13378,8 +13377,7 @@
aten_backward_name='layer_norm_backward',
aliases=('layer_norm',),
ref=reference_layer_norm,
- dtypes=floating_types_and(torch.bfloat16),
- dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
+ dtypes=floating_types_and(torch.half, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,