| #pragma once |
| #include <string.h> |
| #include <cmath> |
| #include <cstdint> |
| #include "c10/util/irange.h" |
| #include "caffe2/utils/conversions.h" |
| |
| #include "vectorizer.h" |
| |
| namespace caffe2 { |
| namespace perfkernels { |
| namespace { |
| template <typename T> |
| inline T sigmoid(T x) { |
| return 1 / (1 + std::exp(-x)); |
| } |
| |
| template <typename T> |
| inline T host_tanh(T x) { |
| return 2 * sigmoid(2 * x) - 1; |
| } |
| |
| template <typename T> |
| inline void LstmUnitImpl( |
| const int N, |
| const int D, |
| const int t, |
| const T* H_prev, |
| const T* C_prev, |
| const T* X, |
| const int32_t* seqLengths, |
| const bool drop_states, |
| T* C, |
| T* H, |
| const float forget_bias) { |
| const T forgetBias = convert::To<float, T>(forget_bias); |
| for (const auto n : c10::irange(N)) { |
| const bool valid = seqLengths == nullptr || t < seqLengths[n]; |
| if (!valid) { |
| if (drop_states) { |
| memset(H, 0, sizeof(T) * D); |
| memset(C, 0, sizeof(T) * D); |
| } else { |
| memcpy(H, H_prev, sizeof(T) * D); |
| memcpy(C, C_prev, sizeof(T) * D); |
| } |
| } else { |
| const T* X_D = &X[D]; |
| const T* X_2D = &X[2 * D]; |
| const T* X_3D = &X[3 * D]; |
| VECTOR_LOOP for (const auto d : c10::irange(D)) { |
| const T i = sigmoid(X[d]); |
| const T f = sigmoid(X_D[d] + forgetBias); |
| const T o = sigmoid(X_2D[d]); |
| const T g = host_tanh(X_3D[d]); |
| const T c_prev = C_prev[d]; |
| const T c = f * c_prev + i * g; |
| C[d] = c; |
| const T host_tanh_c = host_tanh(c); |
| H[d] = o * host_tanh_c; |
| } |
| } |
| H_prev += D; |
| C_prev += D; |
| X += 4 * D; |
| C += D; |
| H += D; |
| } |
| } |
| |
| template <typename T> |
| inline void LstmUnitGradientImpl( |
| int N, |
| int D, |
| int t, |
| const T* C_prev, |
| const T* X, |
| const int32_t* seqLengths, |
| const T* C, |
| const T* H, |
| const T* C_diff, |
| const T* H_diff, |
| bool drop_states, |
| T* H_prev_diff, |
| T* C_prev_diff, |
| T* X_diff, |
| const float forget_bias) { |
| const T localForgetBias = convert::To<float, T>(forget_bias); |
| for (const auto n : c10::irange(N)) { |
| const bool valid = seqLengths == nullptr || t < seqLengths[n]; |
| |
| if (!valid) { |
| if (drop_states) { |
| memset(C_prev_diff, 0, sizeof(T) * D); |
| memset(H_prev_diff, 0, sizeof(T) * D); |
| } else { |
| memcpy(H_prev_diff, H_diff, sizeof(T) * D); |
| memcpy(C_prev_diff, C_diff, sizeof(T) * D); |
| } |
| memset(X_diff, 0, 4 * sizeof(T) * D); |
| } else { |
| VECTOR_LOOP for (const auto d : c10::irange(D)) { |
| T* c_prev_diff = C_prev_diff + d; |
| T* h_prev_diff = H_prev_diff + d; |
| T* i_diff = X_diff + d; |
| T* f_diff = X_diff + 1 * D + d; |
| T* o_diff = X_diff + 2 * D + d; |
| T* g_diff = X_diff + 3 * D + d; |
| |
| const T i = sigmoid(X[d]); |
| const T f = sigmoid(X[1 * D + d] + localForgetBias); |
| const T o = sigmoid(X[2 * D + d]); |
| const T g = host_tanh(X[3 * D + d]); |
| const T c_prev = C_prev[d]; |
| const T c = C[d]; |
| const T host_tanh_c = host_tanh(c); |
| const T c_term_diff = |
| C_diff[d] + H_diff[d] * o * (1 - host_tanh_c * host_tanh_c); |
| *c_prev_diff = c_term_diff * f; |
| *h_prev_diff = 0; // not used in 'valid' case |
| *i_diff = c_term_diff * g * i * (1 - i); |
| *f_diff = c_term_diff * c_prev * f * (1 - f); |
| *o_diff = H_diff[d] * host_tanh_c * o * (1 - o); |
| *g_diff = c_term_diff * i * (1 - g * g); |
| } |
| } |
| C_prev += D; |
| X += 4 * D; |
| C += D; |
| H += D; |
| C_diff += D; |
| H_diff += D; |
| X_diff += 4 * D; |
| H_prev_diff += D; |
| C_prev_diff += D; |
| } |
| } |
| |
| } // namespace |
| } // namespace perfkernels |
| } // namespace caffe2 |