Refactors LSTM Eval for reuse in Uni(Bi)directionalSequenceLSTM.

Bug: 113563341

Test: Covered by existing test NeuralNetworksTest_static
--gtest_filter=GeneratedTests.lstm*

Change-Id: Iff47304515c7e6a436757c2e513a954487070a76
Merged-In: Iff47304515c7e6a436757c2e513a954487070a76
(cherry picked from commit 91a87797ce004e0890d75e347d63ae443c84f062)
diff --git a/common/operations/LSTM.cpp b/common/operations/LSTM.cpp
index 5159952..bbadfa6 100644
--- a/common/operations/LSTM.cpp
+++ b/common/operations/LSTM.cpp
@@ -37,6 +37,11 @@
     return reinterpret_cast<const T*>(operand->buffer);
 }
 
+template <typename T>
+inline const T* GetOptionalBuffer(const RunTimeOperandInfo* operand) {
+    return !IsNullInput(operand) ? reinterpret_cast<const T*>(operand->buffer) : nullptr;
+}
+
 }  // anonymous namespace
 
 LSTMCell::LSTMCell(const Operation& operation, std::vector<RunTimeOperandInfo>& operands) {
@@ -334,6 +339,227 @@
 }
 
 // static
+bool LSTMCell::LSTMEvalFloat32(
+        const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
+        const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
+        const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
+        const Shape& input_to_output_weights_shape, const float* recurrent_to_input_weights_buffer,
+        const float* recurrent_to_forget_weights_buffer,
+        const float* recurrent_to_cell_weights_buffer,
+        const float* recurrent_to_output_weights_buffer,
+        const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer,
+        const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer,
+        const float* input_gate_bias_buffer, const float* forget_gate_bias_buffer,
+        const float* cell_bias_buffer, const float* output_gate_bias_buffer,
+        const float* projection_weights_buffer, const float* projection_bias_buffer,
+        const float* output_state_in_buffer, const float* cell_state_in_buffer,
+        const float* input_layer_norm_weights_buffer, const float* forget_layer_norm_weights_buffer,
+        const float* cell_layer_norm_weights_buffer, const float* output_layer_norm_weights_buffer,
+        float* output_state_out_buffer, float* cell_state_out_buffer, float* output_buffer,
+        float* scratch_buffer_buffer, bool timeMajor) {
+    NNTRACE_COMP("LSTMCell::LSTMEvalFloat32");
+
+    const uint32_t inputRank = getNumberOfDimensions(input_shape);
+    NN_CHECK(inputRank == 2 || inputRank == 3);
+
+    const uint32_t maxTime =
+            (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 0 : 1) : 1;
+    const uint32_t batchSize = (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 1 : 0)
+                                                : getSizeOfDimension(input_shape, 0);
+    const uint32_t inputSize = getSizeOfDimension(input_shape, inputRank - 1);
+    const uint32_t outputSize = getSizeOfDimension(recurrent_to_output_weights_shape, 1);
+
+    const uint32_t batchInputSize = batchSize * inputSize;
+    const uint32_t batchOutputSize = batchSize * outputSize;
+
+    const float* inputCurrentTimeStep = input_buffer;
+    float* outputCurrentTimeStep = output_buffer;
+    for (int t = 0; t < maxTime; ++t) {
+        LSTMStep(params, inputCurrentTimeStep, input_shape, input_to_input_weights_buffer,
+                 input_to_forget_weights_buffer, input_to_cell_weights_buffer,
+                 input_to_output_weights_buffer, input_to_output_weights_shape,
+                 recurrent_to_input_weights_buffer, recurrent_to_forget_weights_buffer,
+                 recurrent_to_cell_weights_buffer, recurrent_to_output_weights_buffer,
+                 recurrent_to_output_weights_shape, cell_to_input_weights_buffer,
+                 cell_to_forget_weights_buffer, cell_to_output_weights_buffer,
+                 input_gate_bias_buffer, forget_gate_bias_buffer, cell_bias_buffer,
+                 output_gate_bias_buffer, projection_weights_buffer, projection_bias_buffer,
+                 output_state_in_buffer, cell_state_in_buffer, input_layer_norm_weights_buffer,
+                 forget_layer_norm_weights_buffer, cell_layer_norm_weights_buffer,
+                 output_layer_norm_weights_buffer, output_state_out_buffer, cell_state_out_buffer,
+                 outputCurrentTimeStep, scratch_buffer_buffer);
+        inputCurrentTimeStep += batchInputSize;
+        outputCurrentTimeStep += batchOutputSize;
+    }
+    return true;
+}
+
+// static
+bool LSTMCell::LSTMEvalFloat16(
+        const LSTMParams& params, const _Float16* input_buffer, const Shape& input_shape,
+        const _Float16* input_to_input_weights_buffer,
+        const _Float16* input_to_forget_weights_buffer,
+        const _Float16* input_to_cell_weights_buffer,
+        const _Float16* input_to_output_weights_buffer, const Shape& input_to_output_weights_shape,
+        const _Float16* recurrent_to_input_weights_buffer,
+        const _Float16* recurrent_to_forget_weights_buffer,
+        const _Float16* recurrent_to_cell_weights_buffer,
+        const _Float16* recurrent_to_output_weights_buffer,
+        const Shape& recurrent_to_output_weights_shape,
+        const _Float16* cell_to_input_weights_buffer, const _Float16* cell_to_forget_weights_buffer,
+        const _Float16* cell_to_output_weights_buffer, const _Float16* input_gate_bias_buffer,
+        const _Float16* forget_gate_bias_buffer, const _Float16* cell_bias_buffer,
+        const _Float16* output_gate_bias_buffer, const _Float16* projection_weights_buffer,
+        const _Float16* projection_bias_buffer, const _Float16* output_state_in_buffer,
+        const _Float16* cell_state_in_buffer, const _Float16* input_layer_norm_weights_buffer,
+        const _Float16* forget_layer_norm_weights_buffer,
+        const _Float16* cell_layer_norm_weights_buffer,
+        const _Float16* output_layer_norm_weights_buffer, _Float16* output_state_out_buffer,
+        _Float16* cell_state_out_buffer, _Float16* output_buffer, _Float16* scratch_buffer_buffer,
+        bool timeMajor) {
+    NNTRACE_COMP("LSTMCell::LSTMEvalFloat16");
+
+    const uint32_t inputRank = getNumberOfDimensions(input_shape);
+    NN_CHECK(inputRank == 2 || inputRank == 3);
+
+    const uint32_t maxTime =
+            (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 0 : 1) : 1;
+    const uint32_t batchSize = (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 1 : 0)
+                                                : getSizeOfDimension(input_shape, 0);
+    const uint32_t inputSize = getSizeOfDimension(input_shape, inputRank - 1);
+    const uint32_t numCells = getSizeOfDimension(input_to_output_weights_shape, 0);
+    const uint32_t outputSize = getSizeOfDimension(recurrent_to_output_weights_shape, 1);
+
+    const uint32_t batchInputSize = batchSize * inputSize;
+    const uint32_t batchOutputSize = batchSize * outputSize;
+
+    std::vector<float> input_float32(maxTime * batchInputSize);
+    convertFloat16ToFloat32(input_buffer, &input_float32);
+    std::vector<float> input_to_input_weights_float32(numCells * inputSize);
+    if (input_to_input_weights_buffer != nullptr) {
+        convertFloat16ToFloat32(input_to_input_weights_buffer, &input_to_input_weights_float32);
+    }
+    std::vector<float> input_to_forget_weights_float32(numCells * inputSize);
+    convertFloat16ToFloat32(input_to_forget_weights_buffer, &input_to_forget_weights_float32);
+    std::vector<float> input_to_cell_weights_float32(numCells * inputSize);
+    convertFloat16ToFloat32(input_to_cell_weights_buffer, &input_to_cell_weights_float32);
+    std::vector<float> input_to_output_weights_float32(numCells * inputSize);
+    convertFloat16ToFloat32(input_to_output_weights_buffer, &input_to_output_weights_float32);
+
+    std::vector<float> recurrent_to_input_weights_float32(numCells * outputSize);
+    if (recurrent_to_input_weights_buffer != nullptr) {
+        convertFloat16ToFloat32(recurrent_to_input_weights_buffer,
+                                &recurrent_to_input_weights_float32);
+    }
+    std::vector<float> recurrent_to_forget_weights_float32(numCells * outputSize);
+    convertFloat16ToFloat32(recurrent_to_forget_weights_buffer,
+                            &recurrent_to_forget_weights_float32);
+    std::vector<float> recurrent_to_cell_weights_float32(numCells * outputSize);
+    convertFloat16ToFloat32(recurrent_to_cell_weights_buffer, &recurrent_to_cell_weights_float32);
+    std::vector<float> recurrent_to_output_weights_float32(numCells * outputSize);
+    convertFloat16ToFloat32(recurrent_to_output_weights_buffer,
+                            &recurrent_to_output_weights_float32);
+
+    std::vector<float> cell_to_input_weights_float32(numCells);
+    if (cell_to_input_weights_buffer != nullptr) {
+        convertFloat16ToFloat32(cell_to_input_weights_buffer, &cell_to_input_weights_float32);
+    }
+    std::vector<float> cell_to_forget_weights_float32(numCells);
+    if (cell_to_forget_weights_buffer != nullptr) {
+        convertFloat16ToFloat32(cell_to_forget_weights_buffer, &cell_to_forget_weights_float32);
+    }
+    std::vector<float> cell_to_output_weights_float32(numCells);
+    if (cell_to_output_weights_buffer != nullptr) {
+        convertFloat16ToFloat32(cell_to_output_weights_buffer, &cell_to_output_weights_float32);
+    }
+
+    std::vector<float> input_gate_bias_float32(numCells);
+    if (input_gate_bias_buffer != nullptr) {
+        convertFloat16ToFloat32(input_gate_bias_buffer, &input_gate_bias_float32);
+    }
+    std::vector<float> forget_gate_bias_float32(numCells);
+    convertFloat16ToFloat32(forget_gate_bias_buffer, &forget_gate_bias_float32);
+    std::vector<float> cell_bias_float32(numCells);
+    convertFloat16ToFloat32(cell_bias_buffer, &cell_bias_float32);
+    std::vector<float> output_gate_bias_float32(numCells);
+    convertFloat16ToFloat32(output_gate_bias_buffer, &output_gate_bias_float32);
+
+    std::vector<float> projection_weights_float32(numCells * outputSize);
+    if (projection_weights_buffer != nullptr) {
+        convertFloat16ToFloat32(projection_weights_buffer, &projection_weights_float32);
+    }
+    std::vector<float> projection_bias_float32(outputSize);
+    if (projection_bias_buffer != nullptr) {
+        convertFloat16ToFloat32(projection_bias_buffer, &projection_bias_float32);
+    }
+
+    std::vector<float> output_state_in_float32(batchSize * outputSize);
+    convertFloat16ToFloat32(output_state_in_buffer, &output_state_in_float32);
+    std::vector<float> cell_state_in_float32(batchSize * numCells);
+    convertFloat16ToFloat32(cell_state_in_buffer, &cell_state_in_float32);
+
+    std::vector<float> input_layer_norm_weights_float32(numCells);
+    if (input_layer_norm_weights_buffer != nullptr) {
+        convertFloat16ToFloat32(input_layer_norm_weights_buffer, &input_layer_norm_weights_float32);
+    }
+    std::vector<float> forget_layer_norm_weights_float32(numCells);
+    if (forget_layer_norm_weights_buffer != nullptr) {
+        convertFloat16ToFloat32(forget_layer_norm_weights_buffer,
+                                &forget_layer_norm_weights_float32);
+    }
+    std::vector<float> cell_layer_norm_weights_float32(numCells);
+    if (cell_layer_norm_weights_buffer != nullptr) {
+        convertFloat16ToFloat32(cell_layer_norm_weights_buffer, &cell_layer_norm_weights_float32);
+    }
+    std::vector<float> output_layer_norm_weights_float32(numCells);
+    if (output_layer_norm_weights_buffer != nullptr) {
+        convertFloat16ToFloat32(output_layer_norm_weights_buffer,
+                                &output_layer_norm_weights_float32);
+    }
+
+    std::vector<float> output_state_out_float32(batchOutputSize);
+    convertFloat16ToFloat32(output_state_out_buffer, &output_state_out_float32);
+    std::vector<float> cell_state_out_float32(batchSize * numCells);
+    convertFloat16ToFloat32(cell_state_out_buffer, &cell_state_out_float32);
+
+    std::vector<float> output_float32(maxTime * batchOutputSize);
+    convertFloat16ToFloat32(output_buffer, &output_float32);
+    std::vector<float> scratch_buffer_float32(params.use_cifg ? 3 * batchSize * numCells
+                                                              : 4 * batchSize * numCells);
+    convertFloat16ToFloat32(scratch_buffer_buffer, &scratch_buffer_float32);
+
+    const float* inputCurrentTimeStep = input_float32.data();
+    float* outputCurrentTimeStep = output_float32.data();
+    for (int t = 0; t < maxTime; ++t) {
+        LSTMStep(params, inputCurrentTimeStep, input_shape, input_to_input_weights_float32.data(),
+                 input_to_forget_weights_float32.data(), input_to_cell_weights_float32.data(),
+                 input_to_output_weights_float32.data(), input_to_output_weights_shape,
+                 recurrent_to_input_weights_float32.data(),
+                 recurrent_to_forget_weights_float32.data(),
+                 recurrent_to_cell_weights_float32.data(),
+                 recurrent_to_output_weights_float32.data(), recurrent_to_output_weights_shape,
+                 cell_to_input_weights_float32.data(), cell_to_forget_weights_float32.data(),
+                 cell_to_output_weights_float32.data(), input_gate_bias_float32.data(),
+                 forget_gate_bias_float32.data(), cell_bias_float32.data(),
+                 output_gate_bias_float32.data(), projection_weights_float32.data(),
+                 projection_bias_float32.data(), output_state_in_float32.data(),
+                 cell_state_in_float32.data(), input_layer_norm_weights_float32.data(),
+                 forget_layer_norm_weights_float32.data(), cell_layer_norm_weights_float32.data(),
+                 output_layer_norm_weights_float32.data(), output_state_out_float32.data(),
+                 cell_state_out_float32.data(), outputCurrentTimeStep,
+                 scratch_buffer_float32.data());
+        inputCurrentTimeStep += batchInputSize;
+        outputCurrentTimeStep += batchOutputSize;
+    }
+
+    convertFloat32ToFloat16(output_state_out_float32, output_state_out_buffer);
+    convertFloat32ToFloat16(cell_state_out_float32, cell_state_out_buffer);
+    convertFloat32ToFloat16(output_float32, output_buffer);
+    convertFloat32ToFloat16(scratch_buffer_float32, scratch_buffer_buffer);
+    return true;
+}
+
+// static
 bool LSTMCell::LSTMStep(
         const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
         const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
@@ -543,208 +769,65 @@
 bool LSTMCell::Eval() {
     switch (input_->type) {
         case OperandType::TENSOR_FLOAT32: {
-            LSTMStep(params_, GetBuffer<const float>(input_), input_->shape(),
-                     GetBuffer<const float>(input_to_input_weights_),
-                     GetBuffer<const float>(input_to_forget_weights_),
-                     GetBuffer<const float>(input_to_cell_weights_),
-                     GetBuffer<const float>(input_to_output_weights_),
-                     input_to_output_weights_->shape(),
-                     GetBuffer<const float>(recurrent_to_input_weights_),
-                     GetBuffer<const float>(recurrent_to_forget_weights_),
-                     GetBuffer<const float>(recurrent_to_cell_weights_),
-                     GetBuffer<const float>(recurrent_to_output_weights_),
-                     recurrent_to_output_weights_->shape(),
-                     GetBuffer<const float>(cell_to_input_weights_),
-                     GetBuffer<const float>(cell_to_forget_weights_),
-                     GetBuffer<const float>(cell_to_output_weights_),
-                     GetBuffer<const float>(input_gate_bias_),
-                     GetBuffer<const float>(forget_gate_bias_), GetBuffer<const float>(cell_bias_),
-                     GetBuffer<const float>(output_gate_bias_),
-                     GetBuffer<const float>(projection_weights_),
-                     GetBuffer<const float>(projection_bias_),
-                     GetBuffer<const float>(output_state_in_),
-                     GetBuffer<const float>(cell_state_in_),
-                     GetBuffer<const float>(input_layer_norm_weights_),
-                     GetBuffer<const float>(forget_layer_norm_weights_),
-                     GetBuffer<const float>(cell_layer_norm_weights_),
-                     GetBuffer<const float>(output_layer_norm_weights_),
-                     GetBuffer<float>(output_state_out_), GetBuffer<float>(cell_state_out_),
-                     GetBuffer<float>(output_), GetBuffer<float>(scratch_buffer_));
+            LSTMEvalFloat32(params_, GetBuffer<const float>(input_), input_->shape(),
+                            GetBuffer<const float>(input_to_input_weights_),
+                            GetBuffer<const float>(input_to_forget_weights_),
+                            GetBuffer<const float>(input_to_cell_weights_),
+                            GetBuffer<const float>(input_to_output_weights_),
+                            input_to_output_weights_->shape(),
+                            GetBuffer<const float>(recurrent_to_input_weights_),
+                            GetBuffer<const float>(recurrent_to_forget_weights_),
+                            GetBuffer<const float>(recurrent_to_cell_weights_),
+                            GetBuffer<const float>(recurrent_to_output_weights_),
+                            recurrent_to_output_weights_->shape(),
+                            GetBuffer<const float>(cell_to_input_weights_),
+                            GetBuffer<const float>(cell_to_forget_weights_),
+                            GetBuffer<const float>(cell_to_output_weights_),
+                            GetBuffer<const float>(input_gate_bias_),
+                            GetBuffer<const float>(forget_gate_bias_),
+                            GetBuffer<const float>(cell_bias_),
+                            GetBuffer<const float>(output_gate_bias_),
+                            GetBuffer<const float>(projection_weights_),
+                            GetBuffer<const float>(projection_bias_),
+                            GetBuffer<const float>(output_state_in_),
+                            GetBuffer<const float>(cell_state_in_),
+                            GetBuffer<const float>(input_layer_norm_weights_),
+                            GetBuffer<const float>(forget_layer_norm_weights_),
+                            GetBuffer<const float>(cell_layer_norm_weights_),
+                            GetBuffer<const float>(output_layer_norm_weights_),
+                            GetBuffer<float>(output_state_out_), GetBuffer<float>(cell_state_out_),
+                            GetBuffer<float>(output_), GetBuffer<float>(scratch_buffer_));
         } break;
         case OperandType::TENSOR_FLOAT16: {
-            std::vector<float> input_float32(getNumberOfElements(input_->shape()));
-            convertFloat16ToFloat32(GetBuffer<_Float16>(input_), &input_float32);
-            const float* input_to_input_weights_buffer = nullptr;
-            std::vector<float> input_to_input_weights_float32(
-                    getNumberOfElements(input_to_input_weights_->shape()));
-            if (!IsNullInput(input_to_input_weights_)) {
-                convertFloat16ToFloat32(GetBuffer<_Float16>(input_to_input_weights_),
-                                        &input_to_input_weights_float32);
-                input_to_input_weights_buffer = input_to_input_weights_float32.data();
-            }
-            std::vector<float> input_to_forget_weights_float32(
-                    getNumberOfElements(input_to_forget_weights_->shape()));
-            convertFloat16ToFloat32(GetBuffer<_Float16>(input_to_forget_weights_),
-                                    &input_to_forget_weights_float32);
-            std::vector<float> input_to_cell_weights_float32(
-                    getNumberOfElements(input_to_cell_weights_->shape()));
-            convertFloat16ToFloat32(GetBuffer<_Float16>(input_to_cell_weights_),
-                                    &input_to_cell_weights_float32);
-            std::vector<float> input_to_output_weights_float32(
-                    getNumberOfElements(input_to_output_weights_->shape()));
-            convertFloat16ToFloat32(GetBuffer<_Float16>(input_to_output_weights_),
-                                    &input_to_output_weights_float32);
-            const float* recurrent_to_input_weights_buffer = nullptr;
-            std::vector<float> recurrent_to_input_weights_float32(
-                    getNumberOfElements(recurrent_to_input_weights_->shape()));
-            if (!IsNullInput(recurrent_to_input_weights_)) {
-                convertFloat16ToFloat32(GetBuffer<_Float16>(recurrent_to_input_weights_),
-                                        &recurrent_to_input_weights_float32);
-                recurrent_to_input_weights_buffer = recurrent_to_input_weights_float32.data();
-            }
-            std::vector<float> recurrent_to_forget_weights_float32(
-                    getNumberOfElements(recurrent_to_forget_weights_->shape()));
-            convertFloat16ToFloat32(GetBuffer<_Float16>(recurrent_to_forget_weights_),
-                                    &recurrent_to_forget_weights_float32);
-            std::vector<float> recurrent_to_cell_weights_float32(
-                    getNumberOfElements(recurrent_to_cell_weights_->shape()));
-            convertFloat16ToFloat32(GetBuffer<_Float16>(recurrent_to_cell_weights_),
-                                    &recurrent_to_cell_weights_float32);
-            std::vector<float> recurrent_to_output_weights_float32(
-                    getNumberOfElements(recurrent_to_output_weights_->shape()));
-            convertFloat16ToFloat32(GetBuffer<_Float16>(recurrent_to_output_weights_),
-                                    &recurrent_to_output_weights_float32);
-            const float* cell_to_input_weights_buffer = nullptr;
-            std::vector<float> cell_to_input_weights_float32(
-                    getNumberOfElements(cell_to_input_weights_->shape()));
-            if (!IsNullInput(cell_to_input_weights_)) {
-                convertFloat16ToFloat32(GetBuffer<_Float16>(cell_to_input_weights_),
-                                        &cell_to_input_weights_float32);
-                cell_to_input_weights_buffer = cell_to_input_weights_float32.data();
-            }
-            const float* cell_to_forget_weights_buffer = nullptr;
-            std::vector<float> cell_to_forget_weights_float32(
-                    getNumberOfElements(cell_to_forget_weights_->shape()));
-            if (!IsNullInput(cell_to_forget_weights_)) {
-                convertFloat16ToFloat32(GetBuffer<_Float16>(cell_to_forget_weights_),
-                                        &cell_to_forget_weights_float32);
-                cell_to_forget_weights_buffer = cell_to_forget_weights_float32.data();
-            }
-            const float* cell_to_output_weights_buffer = nullptr;
-            std::vector<float> cell_to_output_weights_float32(
-                    getNumberOfElements(cell_to_output_weights_->shape()));
-            if (!IsNullInput(cell_to_output_weights_)) {
-                convertFloat16ToFloat32(GetBuffer<_Float16>(cell_to_output_weights_),
-                                        &cell_to_output_weights_float32);
-                cell_to_output_weights_buffer = cell_to_output_weights_float32.data();
-            }
-            const float* input_gate_bias_buffer = nullptr;
-            std::vector<float> input_gate_bias_float32(
-                    getNumberOfElements(input_gate_bias_->shape()));
-            if (!IsNullInput(input_gate_bias_)) {
-                convertFloat16ToFloat32(GetBuffer<_Float16>(input_gate_bias_),
-                                        &input_gate_bias_float32);
-                input_gate_bias_buffer = input_gate_bias_float32.data();
-            }
-            std::vector<float> forget_gate_bias_float32(
-                    getNumberOfElements(forget_gate_bias_->shape()));
-            convertFloat16ToFloat32(GetBuffer<_Float16>(forget_gate_bias_),
-                                    &forget_gate_bias_float32);
-            std::vector<float> cell_bias_float32(getNumberOfElements(cell_bias_->shape()));
-            convertFloat16ToFloat32(GetBuffer<_Float16>(cell_bias_), &cell_bias_float32);
-            std::vector<float> output_gate_bias_float32(
-                    getNumberOfElements(output_gate_bias_->shape()));
-            convertFloat16ToFloat32(GetBuffer<_Float16>(output_gate_bias_),
-                                    &output_gate_bias_float32);
-            const float* projection_weights_buffer = nullptr;
-            std::vector<float> projection_weights_float32(
-                    getNumberOfElements(projection_weights_->shape()));
-            if (!IsNullInput(projection_weights_)) {
-                convertFloat16ToFloat32(GetBuffer<_Float16>(projection_weights_),
-                                        &projection_weights_float32);
-                projection_weights_buffer = projection_weights_float32.data();
-            }
-            const float* projection_bias_buffer = nullptr;
-            std::vector<float> projection_bias_float32(
-                    getNumberOfElements(projection_bias_->shape()));
-            if (!IsNullInput(projection_bias_)) {
-                convertFloat16ToFloat32(GetBuffer<_Float16>(projection_bias_),
-                                        &projection_bias_float32);
-                projection_bias_buffer = projection_bias_float32.data();
-            }
-            std::vector<float> output_state_in_float32(
-                    getNumberOfElements(output_state_in_->shape()));
-            convertFloat16ToFloat32(GetBuffer<_Float16>(output_state_in_),
-                                    &output_state_in_float32);
-            std::vector<float> cell_state_in_float32(getNumberOfElements(cell_state_in_->shape()));
-            convertFloat16ToFloat32(GetBuffer<_Float16>(cell_state_in_), &cell_state_in_float32);
-            const float* input_layer_norm_weights_buffer = nullptr;
-            std::vector<float> input_layer_norm_weights_float32(
-                    getNumberOfElements(input_layer_norm_weights_->shape()));
-            if (!IsNullInput(input_layer_norm_weights_)) {
-                convertFloat16ToFloat32(GetBuffer<_Float16>(input_layer_norm_weights_),
-                                        &input_layer_norm_weights_float32);
-                input_layer_norm_weights_buffer = input_layer_norm_weights_float32.data();
-            }
-            const float* forget_layer_norm_weights_buffer = nullptr;
-            std::vector<float> forget_layer_norm_weights_float32(
-                    getNumberOfElements(forget_layer_norm_weights_->shape()));
-            if (!IsNullInput(forget_layer_norm_weights_)) {
-                convertFloat16ToFloat32(GetBuffer<_Float16>(forget_layer_norm_weights_),
-                                        &forget_layer_norm_weights_float32);
-                forget_layer_norm_weights_buffer = forget_layer_norm_weights_float32.data();
-            }
-            const float* cell_layer_norm_weights_buffer = nullptr;
-            std::vector<float> cell_layer_norm_weights_float32(
-                    getNumberOfElements(cell_layer_norm_weights_->shape()));
-            if (!IsNullInput(cell_layer_norm_weights_)) {
-                convertFloat16ToFloat32(GetBuffer<_Float16>(cell_layer_norm_weights_),
-                                        &cell_layer_norm_weights_float32);
-                cell_layer_norm_weights_buffer = cell_layer_norm_weights_float32.data();
-            }
-            const float* output_layer_norm_weights_buffer = nullptr;
-            std::vector<float> output_layer_norm_weights_float32(
-                    getNumberOfElements(output_layer_norm_weights_->shape()));
-            if (!IsNullInput(output_layer_norm_weights_)) {
-                convertFloat16ToFloat32(GetBuffer<_Float16>(output_layer_norm_weights_),
-                                        &output_layer_norm_weights_float32);
-                output_layer_norm_weights_buffer = output_layer_norm_weights_float32.data();
-            }
-            std::vector<float> output_state_out_float32(
-                    getNumberOfElements(output_state_out_->shape()));
-            convertFloat16ToFloat32(GetBuffer<_Float16>(output_state_out_),
-                                    &output_state_out_float32);
-            std::vector<float> cell_state_out_float32(
-                    getNumberOfElements(cell_state_out_->shape()));
-            convertFloat16ToFloat32(GetBuffer<_Float16>(cell_state_out_), &cell_state_out_float32);
-            std::vector<float> output_float32(getNumberOfElements(output_->shape()));
-            convertFloat16ToFloat32(GetBuffer<_Float16>(output_), &output_float32);
-            std::vector<float> scratch_buffer_float32(
-                    getNumberOfElements(scratch_buffer_->shape()));
-            convertFloat16ToFloat32(GetBuffer<_Float16>(scratch_buffer_), &scratch_buffer_float32);
-
-            LSTMStep(params_, input_float32.data(), input_->shape(), input_to_input_weights_buffer,
-                     input_to_forget_weights_float32.data(), input_to_cell_weights_float32.data(),
-                     input_to_output_weights_float32.data(), input_to_output_weights_->shape(),
-                     recurrent_to_input_weights_buffer, recurrent_to_forget_weights_float32.data(),
-                     recurrent_to_cell_weights_float32.data(),
-                     recurrent_to_output_weights_float32.data(),
-                     recurrent_to_output_weights_->shape(), cell_to_input_weights_buffer,
-                     cell_to_forget_weights_buffer, cell_to_output_weights_buffer,
-                     input_gate_bias_buffer, forget_gate_bias_float32.data(),
-                     cell_bias_float32.data(), output_gate_bias_float32.data(),
-                     projection_weights_buffer, projection_bias_buffer,
-                     output_state_in_float32.data(), cell_state_in_float32.data(),
-                     input_layer_norm_weights_buffer, forget_layer_norm_weights_buffer,
-                     cell_layer_norm_weights_buffer, output_layer_norm_weights_buffer,
-                     output_state_out_float32.data(), cell_state_out_float32.data(),
-                     output_float32.data(), scratch_buffer_float32.data());
-
-            convertFloat32ToFloat16(output_state_out_float32,
-                                    GetBuffer<_Float16>(output_state_out_));
-            convertFloat32ToFloat16(cell_state_out_float32, GetBuffer<_Float16>(cell_state_out_));
-            convertFloat32ToFloat16(output_float32, GetBuffer<_Float16>(output_));
-            convertFloat32ToFloat16(scratch_buffer_float32, GetBuffer<_Float16>(scratch_buffer_));
+            LSTMEvalFloat16(params_, GetBuffer<const _Float16>(input_), input_->shape(),
+                            GetOptionalBuffer<const _Float16>(input_to_input_weights_),
+                            GetBuffer<const _Float16>(input_to_forget_weights_),
+                            GetBuffer<const _Float16>(input_to_cell_weights_),
+                            GetBuffer<const _Float16>(input_to_output_weights_),
+                            input_to_output_weights_->shape(),
+                            GetOptionalBuffer<const _Float16>(recurrent_to_input_weights_),
+                            GetBuffer<const _Float16>(recurrent_to_forget_weights_),
+                            GetBuffer<const _Float16>(recurrent_to_cell_weights_),
+                            GetBuffer<const _Float16>(recurrent_to_output_weights_),
+                            recurrent_to_output_weights_->shape(),
+                            GetOptionalBuffer<const _Float16>(cell_to_input_weights_),
+                            GetOptionalBuffer<const _Float16>(cell_to_forget_weights_),
+                            GetOptionalBuffer<const _Float16>(cell_to_output_weights_),
+                            GetOptionalBuffer<const _Float16>(input_gate_bias_),
+                            GetBuffer<const _Float16>(forget_gate_bias_),
+                            GetBuffer<const _Float16>(cell_bias_),
+                            GetBuffer<const _Float16>(output_gate_bias_),
+                            GetOptionalBuffer<const _Float16>(projection_weights_),
+                            GetOptionalBuffer<const _Float16>(projection_bias_),
+                            GetBuffer<const _Float16>(output_state_in_),
+                            GetBuffer<const _Float16>(cell_state_in_),
+                            GetOptionalBuffer<const _Float16>(input_layer_norm_weights_),
+                            GetOptionalBuffer<const _Float16>(forget_layer_norm_weights_),
+                            GetOptionalBuffer<const _Float16>(cell_layer_norm_weights_),
+                            GetOptionalBuffer<const _Float16>(output_layer_norm_weights_),
+                            GetBuffer<_Float16>(output_state_out_),
+                            GetBuffer<_Float16>(cell_state_out_), GetBuffer<_Float16>(output_),
+                            GetBuffer<_Float16>(scratch_buffer_));
         } break;
         default: {
             LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);