Refactors LSTM Eval for reuse in Uni(Bi)directionalSequenceLSTM.
Bug: 113563341
Test: Covered by existing test NeuralNetworksTest_static
--gtest_filter=GeneratedTests.lstm*
Change-Id: Iff47304515c7e6a436757c2e513a954487070a76
Merged-In: Iff47304515c7e6a436757c2e513a954487070a76
(cherry picked from commit 91a87797ce004e0890d75e347d63ae443c84f062)
diff --git a/common/operations/LSTM.cpp b/common/operations/LSTM.cpp
index 5159952..bbadfa6 100644
--- a/common/operations/LSTM.cpp
+++ b/common/operations/LSTM.cpp
@@ -37,6 +37,11 @@
return reinterpret_cast<const T*>(operand->buffer);
}
+template <typename T>
+inline const T* GetOptionalBuffer(const RunTimeOperandInfo* operand) {
+ return !IsNullInput(operand) ? reinterpret_cast<const T*>(operand->buffer) : nullptr;
+}
+
} // anonymous namespace
LSTMCell::LSTMCell(const Operation& operation, std::vector<RunTimeOperandInfo>& operands) {
@@ -334,6 +339,227 @@
}
// static
+bool LSTMCell::LSTMEvalFloat32(
+ const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
+ const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
+ const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
+ const Shape& input_to_output_weights_shape, const float* recurrent_to_input_weights_buffer,
+ const float* recurrent_to_forget_weights_buffer,
+ const float* recurrent_to_cell_weights_buffer,
+ const float* recurrent_to_output_weights_buffer,
+ const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer,
+ const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer,
+ const float* input_gate_bias_buffer, const float* forget_gate_bias_buffer,
+ const float* cell_bias_buffer, const float* output_gate_bias_buffer,
+ const float* projection_weights_buffer, const float* projection_bias_buffer,
+ const float* output_state_in_buffer, const float* cell_state_in_buffer,
+ const float* input_layer_norm_weights_buffer, const float* forget_layer_norm_weights_buffer,
+ const float* cell_layer_norm_weights_buffer, const float* output_layer_norm_weights_buffer,
+ float* output_state_out_buffer, float* cell_state_out_buffer, float* output_buffer,
+ float* scratch_buffer_buffer, bool timeMajor) {
+ NNTRACE_COMP("LSTMCell::LSTMEvalFloat32");
+
+ const uint32_t inputRank = getNumberOfDimensions(input_shape);
+ NN_CHECK(inputRank == 2 || inputRank == 3);
+
+ const uint32_t maxTime =
+ (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 0 : 1) : 1;
+ const uint32_t batchSize = (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 1 : 0)
+ : getSizeOfDimension(input_shape, 0);
+ const uint32_t inputSize = getSizeOfDimension(input_shape, inputRank - 1);
+ const uint32_t outputSize = getSizeOfDimension(recurrent_to_output_weights_shape, 1);
+
+ const uint32_t batchInputSize = batchSize * inputSize;
+ const uint32_t batchOutputSize = batchSize * outputSize;
+
+ const float* inputCurrentTimeStep = input_buffer;
+ float* outputCurrentTimeStep = output_buffer;
+ for (int t = 0; t < maxTime; ++t) {
+ LSTMStep(params, inputCurrentTimeStep, input_shape, input_to_input_weights_buffer,
+ input_to_forget_weights_buffer, input_to_cell_weights_buffer,
+ input_to_output_weights_buffer, input_to_output_weights_shape,
+ recurrent_to_input_weights_buffer, recurrent_to_forget_weights_buffer,
+ recurrent_to_cell_weights_buffer, recurrent_to_output_weights_buffer,
+ recurrent_to_output_weights_shape, cell_to_input_weights_buffer,
+ cell_to_forget_weights_buffer, cell_to_output_weights_buffer,
+ input_gate_bias_buffer, forget_gate_bias_buffer, cell_bias_buffer,
+ output_gate_bias_buffer, projection_weights_buffer, projection_bias_buffer,
+ output_state_in_buffer, cell_state_in_buffer, input_layer_norm_weights_buffer,
+ forget_layer_norm_weights_buffer, cell_layer_norm_weights_buffer,
+ output_layer_norm_weights_buffer, output_state_out_buffer, cell_state_out_buffer,
+ outputCurrentTimeStep, scratch_buffer_buffer);
+ inputCurrentTimeStep += batchInputSize;
+ outputCurrentTimeStep += batchOutputSize;
+ }
+ return true;
+}
+
+// static
+bool LSTMCell::LSTMEvalFloat16(
+ const LSTMParams& params, const _Float16* input_buffer, const Shape& input_shape,
+ const _Float16* input_to_input_weights_buffer,
+ const _Float16* input_to_forget_weights_buffer,
+ const _Float16* input_to_cell_weights_buffer,
+ const _Float16* input_to_output_weights_buffer, const Shape& input_to_output_weights_shape,
+ const _Float16* recurrent_to_input_weights_buffer,
+ const _Float16* recurrent_to_forget_weights_buffer,
+ const _Float16* recurrent_to_cell_weights_buffer,
+ const _Float16* recurrent_to_output_weights_buffer,
+ const Shape& recurrent_to_output_weights_shape,
+ const _Float16* cell_to_input_weights_buffer, const _Float16* cell_to_forget_weights_buffer,
+ const _Float16* cell_to_output_weights_buffer, const _Float16* input_gate_bias_buffer,
+ const _Float16* forget_gate_bias_buffer, const _Float16* cell_bias_buffer,
+ const _Float16* output_gate_bias_buffer, const _Float16* projection_weights_buffer,
+ const _Float16* projection_bias_buffer, const _Float16* output_state_in_buffer,
+ const _Float16* cell_state_in_buffer, const _Float16* input_layer_norm_weights_buffer,
+ const _Float16* forget_layer_norm_weights_buffer,
+ const _Float16* cell_layer_norm_weights_buffer,
+ const _Float16* output_layer_norm_weights_buffer, _Float16* output_state_out_buffer,
+ _Float16* cell_state_out_buffer, _Float16* output_buffer, _Float16* scratch_buffer_buffer,
+ bool timeMajor) {
+ NNTRACE_COMP("LSTMCell::LSTMEvalFloat16");
+
+ const uint32_t inputRank = getNumberOfDimensions(input_shape);
+ NN_CHECK(inputRank == 2 || inputRank == 3);
+
+ const uint32_t maxTime =
+ (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 0 : 1) : 1;
+ const uint32_t batchSize = (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 1 : 0)
+ : getSizeOfDimension(input_shape, 0);
+ const uint32_t inputSize = getSizeOfDimension(input_shape, inputRank - 1);
+ const uint32_t numCells = getSizeOfDimension(input_to_output_weights_shape, 0);
+ const uint32_t outputSize = getSizeOfDimension(recurrent_to_output_weights_shape, 1);
+
+ const uint32_t batchInputSize = batchSize * inputSize;
+ const uint32_t batchOutputSize = batchSize * outputSize;
+
+ std::vector<float> input_float32(maxTime * batchInputSize);
+ convertFloat16ToFloat32(input_buffer, &input_float32);
+ std::vector<float> input_to_input_weights_float32(numCells * inputSize);
+ if (input_to_input_weights_buffer != nullptr) {
+ convertFloat16ToFloat32(input_to_input_weights_buffer, &input_to_input_weights_float32);
+ }
+ std::vector<float> input_to_forget_weights_float32(numCells * inputSize);
+ convertFloat16ToFloat32(input_to_forget_weights_buffer, &input_to_forget_weights_float32);
+ std::vector<float> input_to_cell_weights_float32(numCells * inputSize);
+ convertFloat16ToFloat32(input_to_cell_weights_buffer, &input_to_cell_weights_float32);
+ std::vector<float> input_to_output_weights_float32(numCells * inputSize);
+ convertFloat16ToFloat32(input_to_output_weights_buffer, &input_to_output_weights_float32);
+
+ std::vector<float> recurrent_to_input_weights_float32(numCells * outputSize);
+ if (recurrent_to_input_weights_buffer != nullptr) {
+ convertFloat16ToFloat32(recurrent_to_input_weights_buffer,
+ &recurrent_to_input_weights_float32);
+ }
+ std::vector<float> recurrent_to_forget_weights_float32(numCells * outputSize);
+ convertFloat16ToFloat32(recurrent_to_forget_weights_buffer,
+ &recurrent_to_forget_weights_float32);
+ std::vector<float> recurrent_to_cell_weights_float32(numCells * outputSize);
+ convertFloat16ToFloat32(recurrent_to_cell_weights_buffer, &recurrent_to_cell_weights_float32);
+ std::vector<float> recurrent_to_output_weights_float32(numCells * outputSize);
+ convertFloat16ToFloat32(recurrent_to_output_weights_buffer,
+ &recurrent_to_output_weights_float32);
+
+ std::vector<float> cell_to_input_weights_float32(numCells);
+ if (cell_to_input_weights_buffer != nullptr) {
+ convertFloat16ToFloat32(cell_to_input_weights_buffer, &cell_to_input_weights_float32);
+ }
+ std::vector<float> cell_to_forget_weights_float32(numCells);
+ if (cell_to_forget_weights_buffer != nullptr) {
+ convertFloat16ToFloat32(cell_to_forget_weights_buffer, &cell_to_forget_weights_float32);
+ }
+ std::vector<float> cell_to_output_weights_float32(numCells);
+ if (cell_to_output_weights_buffer != nullptr) {
+ convertFloat16ToFloat32(cell_to_output_weights_buffer, &cell_to_output_weights_float32);
+ }
+
+ std::vector<float> input_gate_bias_float32(numCells);
+ if (input_gate_bias_buffer != nullptr) {
+ convertFloat16ToFloat32(input_gate_bias_buffer, &input_gate_bias_float32);
+ }
+ std::vector<float> forget_gate_bias_float32(numCells);
+ convertFloat16ToFloat32(forget_gate_bias_buffer, &forget_gate_bias_float32);
+ std::vector<float> cell_bias_float32(numCells);
+ convertFloat16ToFloat32(cell_bias_buffer, &cell_bias_float32);
+ std::vector<float> output_gate_bias_float32(numCells);
+ convertFloat16ToFloat32(output_gate_bias_buffer, &output_gate_bias_float32);
+
+ std::vector<float> projection_weights_float32(numCells * outputSize);
+ if (projection_weights_buffer != nullptr) {
+ convertFloat16ToFloat32(projection_weights_buffer, &projection_weights_float32);
+ }
+ std::vector<float> projection_bias_float32(outputSize);
+ if (projection_bias_buffer != nullptr) {
+ convertFloat16ToFloat32(projection_bias_buffer, &projection_bias_float32);
+ }
+
+ std::vector<float> output_state_in_float32(batchSize * outputSize);
+ convertFloat16ToFloat32(output_state_in_buffer, &output_state_in_float32);
+ std::vector<float> cell_state_in_float32(batchSize * numCells);
+ convertFloat16ToFloat32(cell_state_in_buffer, &cell_state_in_float32);
+
+ std::vector<float> input_layer_norm_weights_float32(numCells);
+ if (input_layer_norm_weights_buffer != nullptr) {
+ convertFloat16ToFloat32(input_layer_norm_weights_buffer, &input_layer_norm_weights_float32);
+ }
+ std::vector<float> forget_layer_norm_weights_float32(numCells);
+ if (forget_layer_norm_weights_buffer != nullptr) {
+ convertFloat16ToFloat32(forget_layer_norm_weights_buffer,
+ &forget_layer_norm_weights_float32);
+ }
+ std::vector<float> cell_layer_norm_weights_float32(numCells);
+ if (cell_layer_norm_weights_buffer != nullptr) {
+ convertFloat16ToFloat32(cell_layer_norm_weights_buffer, &cell_layer_norm_weights_float32);
+ }
+ std::vector<float> output_layer_norm_weights_float32(numCells);
+ if (output_layer_norm_weights_buffer != nullptr) {
+ convertFloat16ToFloat32(output_layer_norm_weights_buffer,
+ &output_layer_norm_weights_float32);
+ }
+
+ std::vector<float> output_state_out_float32(batchOutputSize);
+ convertFloat16ToFloat32(output_state_out_buffer, &output_state_out_float32);
+ std::vector<float> cell_state_out_float32(batchSize * numCells);
+ convertFloat16ToFloat32(cell_state_out_buffer, &cell_state_out_float32);
+
+ std::vector<float> output_float32(maxTime * batchOutputSize);
+ convertFloat16ToFloat32(output_buffer, &output_float32);
+ std::vector<float> scratch_buffer_float32(params.use_cifg ? 3 * batchSize * numCells
+ : 4 * batchSize * numCells);
+ convertFloat16ToFloat32(scratch_buffer_buffer, &scratch_buffer_float32);
+
+ const float* inputCurrentTimeStep = input_float32.data();
+ float* outputCurrentTimeStep = output_float32.data();
+ for (int t = 0; t < maxTime; ++t) {
+ LSTMStep(params, inputCurrentTimeStep, input_shape, input_to_input_weights_float32.data(),
+ input_to_forget_weights_float32.data(), input_to_cell_weights_float32.data(),
+ input_to_output_weights_float32.data(), input_to_output_weights_shape,
+ recurrent_to_input_weights_float32.data(),
+ recurrent_to_forget_weights_float32.data(),
+ recurrent_to_cell_weights_float32.data(),
+ recurrent_to_output_weights_float32.data(), recurrent_to_output_weights_shape,
+ cell_to_input_weights_float32.data(), cell_to_forget_weights_float32.data(),
+ cell_to_output_weights_float32.data(), input_gate_bias_float32.data(),
+ forget_gate_bias_float32.data(), cell_bias_float32.data(),
+ output_gate_bias_float32.data(), projection_weights_float32.data(),
+ projection_bias_float32.data(), output_state_in_float32.data(),
+ cell_state_in_float32.data(), input_layer_norm_weights_float32.data(),
+ forget_layer_norm_weights_float32.data(), cell_layer_norm_weights_float32.data(),
+ output_layer_norm_weights_float32.data(), output_state_out_float32.data(),
+ cell_state_out_float32.data(), outputCurrentTimeStep,
+ scratch_buffer_float32.data());
+ inputCurrentTimeStep += batchInputSize;
+ outputCurrentTimeStep += batchOutputSize;
+ }
+
+ convertFloat32ToFloat16(output_state_out_float32, output_state_out_buffer);
+ convertFloat32ToFloat16(cell_state_out_float32, cell_state_out_buffer);
+ convertFloat32ToFloat16(output_float32, output_buffer);
+ convertFloat32ToFloat16(scratch_buffer_float32, scratch_buffer_buffer);
+ return true;
+}
+
+// static
bool LSTMCell::LSTMStep(
const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
@@ -543,208 +769,65 @@
bool LSTMCell::Eval() {
switch (input_->type) {
case OperandType::TENSOR_FLOAT32: {
- LSTMStep(params_, GetBuffer<const float>(input_), input_->shape(),
- GetBuffer<const float>(input_to_input_weights_),
- GetBuffer<const float>(input_to_forget_weights_),
- GetBuffer<const float>(input_to_cell_weights_),
- GetBuffer<const float>(input_to_output_weights_),
- input_to_output_weights_->shape(),
- GetBuffer<const float>(recurrent_to_input_weights_),
- GetBuffer<const float>(recurrent_to_forget_weights_),
- GetBuffer<const float>(recurrent_to_cell_weights_),
- GetBuffer<const float>(recurrent_to_output_weights_),
- recurrent_to_output_weights_->shape(),
- GetBuffer<const float>(cell_to_input_weights_),
- GetBuffer<const float>(cell_to_forget_weights_),
- GetBuffer<const float>(cell_to_output_weights_),
- GetBuffer<const float>(input_gate_bias_),
- GetBuffer<const float>(forget_gate_bias_), GetBuffer<const float>(cell_bias_),
- GetBuffer<const float>(output_gate_bias_),
- GetBuffer<const float>(projection_weights_),
- GetBuffer<const float>(projection_bias_),
- GetBuffer<const float>(output_state_in_),
- GetBuffer<const float>(cell_state_in_),
- GetBuffer<const float>(input_layer_norm_weights_),
- GetBuffer<const float>(forget_layer_norm_weights_),
- GetBuffer<const float>(cell_layer_norm_weights_),
- GetBuffer<const float>(output_layer_norm_weights_),
- GetBuffer<float>(output_state_out_), GetBuffer<float>(cell_state_out_),
- GetBuffer<float>(output_), GetBuffer<float>(scratch_buffer_));
+ LSTMEvalFloat32(params_, GetBuffer<const float>(input_), input_->shape(),
+ GetBuffer<const float>(input_to_input_weights_),
+ GetBuffer<const float>(input_to_forget_weights_),
+ GetBuffer<const float>(input_to_cell_weights_),
+ GetBuffer<const float>(input_to_output_weights_),
+ input_to_output_weights_->shape(),
+ GetBuffer<const float>(recurrent_to_input_weights_),
+ GetBuffer<const float>(recurrent_to_forget_weights_),
+ GetBuffer<const float>(recurrent_to_cell_weights_),
+ GetBuffer<const float>(recurrent_to_output_weights_),
+ recurrent_to_output_weights_->shape(),
+ GetBuffer<const float>(cell_to_input_weights_),
+ GetBuffer<const float>(cell_to_forget_weights_),
+ GetBuffer<const float>(cell_to_output_weights_),
+ GetBuffer<const float>(input_gate_bias_),
+ GetBuffer<const float>(forget_gate_bias_),
+ GetBuffer<const float>(cell_bias_),
+ GetBuffer<const float>(output_gate_bias_),
+ GetBuffer<const float>(projection_weights_),
+ GetBuffer<const float>(projection_bias_),
+ GetBuffer<const float>(output_state_in_),
+ GetBuffer<const float>(cell_state_in_),
+ GetBuffer<const float>(input_layer_norm_weights_),
+ GetBuffer<const float>(forget_layer_norm_weights_),
+ GetBuffer<const float>(cell_layer_norm_weights_),
+ GetBuffer<const float>(output_layer_norm_weights_),
+ GetBuffer<float>(output_state_out_), GetBuffer<float>(cell_state_out_),
+ GetBuffer<float>(output_), GetBuffer<float>(scratch_buffer_));
} break;
case OperandType::TENSOR_FLOAT16: {
- std::vector<float> input_float32(getNumberOfElements(input_->shape()));
- convertFloat16ToFloat32(GetBuffer<_Float16>(input_), &input_float32);
- const float* input_to_input_weights_buffer = nullptr;
- std::vector<float> input_to_input_weights_float32(
- getNumberOfElements(input_to_input_weights_->shape()));
- if (!IsNullInput(input_to_input_weights_)) {
- convertFloat16ToFloat32(GetBuffer<_Float16>(input_to_input_weights_),
- &input_to_input_weights_float32);
- input_to_input_weights_buffer = input_to_input_weights_float32.data();
- }
- std::vector<float> input_to_forget_weights_float32(
- getNumberOfElements(input_to_forget_weights_->shape()));
- convertFloat16ToFloat32(GetBuffer<_Float16>(input_to_forget_weights_),
- &input_to_forget_weights_float32);
- std::vector<float> input_to_cell_weights_float32(
- getNumberOfElements(input_to_cell_weights_->shape()));
- convertFloat16ToFloat32(GetBuffer<_Float16>(input_to_cell_weights_),
- &input_to_cell_weights_float32);
- std::vector<float> input_to_output_weights_float32(
- getNumberOfElements(input_to_output_weights_->shape()));
- convertFloat16ToFloat32(GetBuffer<_Float16>(input_to_output_weights_),
- &input_to_output_weights_float32);
- const float* recurrent_to_input_weights_buffer = nullptr;
- std::vector<float> recurrent_to_input_weights_float32(
- getNumberOfElements(recurrent_to_input_weights_->shape()));
- if (!IsNullInput(recurrent_to_input_weights_)) {
- convertFloat16ToFloat32(GetBuffer<_Float16>(recurrent_to_input_weights_),
- &recurrent_to_input_weights_float32);
- recurrent_to_input_weights_buffer = recurrent_to_input_weights_float32.data();
- }
- std::vector<float> recurrent_to_forget_weights_float32(
- getNumberOfElements(recurrent_to_forget_weights_->shape()));
- convertFloat16ToFloat32(GetBuffer<_Float16>(recurrent_to_forget_weights_),
- &recurrent_to_forget_weights_float32);
- std::vector<float> recurrent_to_cell_weights_float32(
- getNumberOfElements(recurrent_to_cell_weights_->shape()));
- convertFloat16ToFloat32(GetBuffer<_Float16>(recurrent_to_cell_weights_),
- &recurrent_to_cell_weights_float32);
- std::vector<float> recurrent_to_output_weights_float32(
- getNumberOfElements(recurrent_to_output_weights_->shape()));
- convertFloat16ToFloat32(GetBuffer<_Float16>(recurrent_to_output_weights_),
- &recurrent_to_output_weights_float32);
- const float* cell_to_input_weights_buffer = nullptr;
- std::vector<float> cell_to_input_weights_float32(
- getNumberOfElements(cell_to_input_weights_->shape()));
- if (!IsNullInput(cell_to_input_weights_)) {
- convertFloat16ToFloat32(GetBuffer<_Float16>(cell_to_input_weights_),
- &cell_to_input_weights_float32);
- cell_to_input_weights_buffer = cell_to_input_weights_float32.data();
- }
- const float* cell_to_forget_weights_buffer = nullptr;
- std::vector<float> cell_to_forget_weights_float32(
- getNumberOfElements(cell_to_forget_weights_->shape()));
- if (!IsNullInput(cell_to_forget_weights_)) {
- convertFloat16ToFloat32(GetBuffer<_Float16>(cell_to_forget_weights_),
- &cell_to_forget_weights_float32);
- cell_to_forget_weights_buffer = cell_to_forget_weights_float32.data();
- }
- const float* cell_to_output_weights_buffer = nullptr;
- std::vector<float> cell_to_output_weights_float32(
- getNumberOfElements(cell_to_output_weights_->shape()));
- if (!IsNullInput(cell_to_output_weights_)) {
- convertFloat16ToFloat32(GetBuffer<_Float16>(cell_to_output_weights_),
- &cell_to_output_weights_float32);
- cell_to_output_weights_buffer = cell_to_output_weights_float32.data();
- }
- const float* input_gate_bias_buffer = nullptr;
- std::vector<float> input_gate_bias_float32(
- getNumberOfElements(input_gate_bias_->shape()));
- if (!IsNullInput(input_gate_bias_)) {
- convertFloat16ToFloat32(GetBuffer<_Float16>(input_gate_bias_),
- &input_gate_bias_float32);
- input_gate_bias_buffer = input_gate_bias_float32.data();
- }
- std::vector<float> forget_gate_bias_float32(
- getNumberOfElements(forget_gate_bias_->shape()));
- convertFloat16ToFloat32(GetBuffer<_Float16>(forget_gate_bias_),
- &forget_gate_bias_float32);
- std::vector<float> cell_bias_float32(getNumberOfElements(cell_bias_->shape()));
- convertFloat16ToFloat32(GetBuffer<_Float16>(cell_bias_), &cell_bias_float32);
- std::vector<float> output_gate_bias_float32(
- getNumberOfElements(output_gate_bias_->shape()));
- convertFloat16ToFloat32(GetBuffer<_Float16>(output_gate_bias_),
- &output_gate_bias_float32);
- const float* projection_weights_buffer = nullptr;
- std::vector<float> projection_weights_float32(
- getNumberOfElements(projection_weights_->shape()));
- if (!IsNullInput(projection_weights_)) {
- convertFloat16ToFloat32(GetBuffer<_Float16>(projection_weights_),
- &projection_weights_float32);
- projection_weights_buffer = projection_weights_float32.data();
- }
- const float* projection_bias_buffer = nullptr;
- std::vector<float> projection_bias_float32(
- getNumberOfElements(projection_bias_->shape()));
- if (!IsNullInput(projection_bias_)) {
- convertFloat16ToFloat32(GetBuffer<_Float16>(projection_bias_),
- &projection_bias_float32);
- projection_bias_buffer = projection_bias_float32.data();
- }
- std::vector<float> output_state_in_float32(
- getNumberOfElements(output_state_in_->shape()));
- convertFloat16ToFloat32(GetBuffer<_Float16>(output_state_in_),
- &output_state_in_float32);
- std::vector<float> cell_state_in_float32(getNumberOfElements(cell_state_in_->shape()));
- convertFloat16ToFloat32(GetBuffer<_Float16>(cell_state_in_), &cell_state_in_float32);
- const float* input_layer_norm_weights_buffer = nullptr;
- std::vector<float> input_layer_norm_weights_float32(
- getNumberOfElements(input_layer_norm_weights_->shape()));
- if (!IsNullInput(input_layer_norm_weights_)) {
- convertFloat16ToFloat32(GetBuffer<_Float16>(input_layer_norm_weights_),
- &input_layer_norm_weights_float32);
- input_layer_norm_weights_buffer = input_layer_norm_weights_float32.data();
- }
- const float* forget_layer_norm_weights_buffer = nullptr;
- std::vector<float> forget_layer_norm_weights_float32(
- getNumberOfElements(forget_layer_norm_weights_->shape()));
- if (!IsNullInput(forget_layer_norm_weights_)) {
- convertFloat16ToFloat32(GetBuffer<_Float16>(forget_layer_norm_weights_),
- &forget_layer_norm_weights_float32);
- forget_layer_norm_weights_buffer = forget_layer_norm_weights_float32.data();
- }
- const float* cell_layer_norm_weights_buffer = nullptr;
- std::vector<float> cell_layer_norm_weights_float32(
- getNumberOfElements(cell_layer_norm_weights_->shape()));
- if (!IsNullInput(cell_layer_norm_weights_)) {
- convertFloat16ToFloat32(GetBuffer<_Float16>(cell_layer_norm_weights_),
- &cell_layer_norm_weights_float32);
- cell_layer_norm_weights_buffer = cell_layer_norm_weights_float32.data();
- }
- const float* output_layer_norm_weights_buffer = nullptr;
- std::vector<float> output_layer_norm_weights_float32(
- getNumberOfElements(output_layer_norm_weights_->shape()));
- if (!IsNullInput(output_layer_norm_weights_)) {
- convertFloat16ToFloat32(GetBuffer<_Float16>(output_layer_norm_weights_),
- &output_layer_norm_weights_float32);
- output_layer_norm_weights_buffer = output_layer_norm_weights_float32.data();
- }
- std::vector<float> output_state_out_float32(
- getNumberOfElements(output_state_out_->shape()));
- convertFloat16ToFloat32(GetBuffer<_Float16>(output_state_out_),
- &output_state_out_float32);
- std::vector<float> cell_state_out_float32(
- getNumberOfElements(cell_state_out_->shape()));
- convertFloat16ToFloat32(GetBuffer<_Float16>(cell_state_out_), &cell_state_out_float32);
- std::vector<float> output_float32(getNumberOfElements(output_->shape()));
- convertFloat16ToFloat32(GetBuffer<_Float16>(output_), &output_float32);
- std::vector<float> scratch_buffer_float32(
- getNumberOfElements(scratch_buffer_->shape()));
- convertFloat16ToFloat32(GetBuffer<_Float16>(scratch_buffer_), &scratch_buffer_float32);
-
- LSTMStep(params_, input_float32.data(), input_->shape(), input_to_input_weights_buffer,
- input_to_forget_weights_float32.data(), input_to_cell_weights_float32.data(),
- input_to_output_weights_float32.data(), input_to_output_weights_->shape(),
- recurrent_to_input_weights_buffer, recurrent_to_forget_weights_float32.data(),
- recurrent_to_cell_weights_float32.data(),
- recurrent_to_output_weights_float32.data(),
- recurrent_to_output_weights_->shape(), cell_to_input_weights_buffer,
- cell_to_forget_weights_buffer, cell_to_output_weights_buffer,
- input_gate_bias_buffer, forget_gate_bias_float32.data(),
- cell_bias_float32.data(), output_gate_bias_float32.data(),
- projection_weights_buffer, projection_bias_buffer,
- output_state_in_float32.data(), cell_state_in_float32.data(),
- input_layer_norm_weights_buffer, forget_layer_norm_weights_buffer,
- cell_layer_norm_weights_buffer, output_layer_norm_weights_buffer,
- output_state_out_float32.data(), cell_state_out_float32.data(),
- output_float32.data(), scratch_buffer_float32.data());
-
- convertFloat32ToFloat16(output_state_out_float32,
- GetBuffer<_Float16>(output_state_out_));
- convertFloat32ToFloat16(cell_state_out_float32, GetBuffer<_Float16>(cell_state_out_));
- convertFloat32ToFloat16(output_float32, GetBuffer<_Float16>(output_));
- convertFloat32ToFloat16(scratch_buffer_float32, GetBuffer<_Float16>(scratch_buffer_));
+ LSTMEvalFloat16(params_, GetBuffer<const _Float16>(input_), input_->shape(),
+ GetOptionalBuffer<const _Float16>(input_to_input_weights_),
+ GetBuffer<const _Float16>(input_to_forget_weights_),
+ GetBuffer<const _Float16>(input_to_cell_weights_),
+ GetBuffer<const _Float16>(input_to_output_weights_),
+ input_to_output_weights_->shape(),
+ GetOptionalBuffer<const _Float16>(recurrent_to_input_weights_),
+ GetBuffer<const _Float16>(recurrent_to_forget_weights_),
+ GetBuffer<const _Float16>(recurrent_to_cell_weights_),
+ GetBuffer<const _Float16>(recurrent_to_output_weights_),
+ recurrent_to_output_weights_->shape(),
+ GetOptionalBuffer<const _Float16>(cell_to_input_weights_),
+ GetOptionalBuffer<const _Float16>(cell_to_forget_weights_),
+ GetOptionalBuffer<const _Float16>(cell_to_output_weights_),
+ GetOptionalBuffer<const _Float16>(input_gate_bias_),
+ GetBuffer<const _Float16>(forget_gate_bias_),
+ GetBuffer<const _Float16>(cell_bias_),
+ GetBuffer<const _Float16>(output_gate_bias_),
+ GetOptionalBuffer<const _Float16>(projection_weights_),
+ GetOptionalBuffer<const _Float16>(projection_bias_),
+ GetBuffer<const _Float16>(output_state_in_),
+ GetBuffer<const _Float16>(cell_state_in_),
+ GetOptionalBuffer<const _Float16>(input_layer_norm_weights_),
+ GetOptionalBuffer<const _Float16>(forget_layer_norm_weights_),
+ GetOptionalBuffer<const _Float16>(cell_layer_norm_weights_),
+ GetOptionalBuffer<const _Float16>(output_layer_norm_weights_),
+ GetBuffer<_Float16>(output_state_out_),
+ GetBuffer<_Float16>(cell_state_out_), GetBuffer<_Float16>(output_),
+ GetBuffer<_Float16>(scratch_buffer_));
} break;
default: {
LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);