LSTM: require input layer norm weights to be omitted in case CIFG is used.
In case of CIFG LSTM, input layer norm weights are not used in the
computation.
Bug: 129126572
Test: NeuralNetworksTest_static
Change-Id: Id835e7a8f7fa80354eb280cfc0154d9e235695aa
Merged-In: Id835e7a8f7fa80354eb280cfc0154d9e235695aa
(cherry picked from commit 5a9edc3c1290494f0a339e4b2e0038bfe59df416)
diff --git a/common/operations/LSTM.cpp b/common/operations/LSTM.cpp
index 477ff54..3bc4072 100644
--- a/common/operations/LSTM.cpp
+++ b/common/operations/LSTM.cpp
@@ -22,6 +22,7 @@
#include "OperationsUtils.h"
#include "Tracing.h"
+#include "Utils.h"
namespace android {
namespace nn {
@@ -92,10 +93,14 @@
// We check the version of LSTM by checking the number of the inputs to the
// op. For LSTM version 1.0 there were 23 inputs and for 1.2 there are 27.
if (operation.inputs.size() == 27) {
- input_layer_norm_weights_ = GetInput(operation, operands, kInputLayerNormWeightsTensor);
- forget_layer_norm_weights_ = GetInput(operation, operands, kForgetLayerNormWeightsTensor);
- cell_layer_norm_weights_ = GetInput(operation, operands, kCellLayerNormWeightsTensor);
- output_layer_norm_weights_ = GetInput(operation, operands, kOutputLayerNormWeightsTensor);
+ input_layer_norm_weights_ =
+ GetInput(operation, operands, kInputLayerNormWeightsTensor); // optional
+ forget_layer_norm_weights_ =
+ GetInput(operation, operands, kForgetLayerNormWeightsTensor); // optional
+ cell_layer_norm_weights_ =
+ GetInput(operation, operands, kCellLayerNormWeightsTensor); // optional
+ output_layer_norm_weights_ =
+ GetInput(operation, operands, kOutputLayerNormWeightsTensor); // optional
} else {
// For LSTM from HAL v1.0 assign operands with no values
static RunTimeOperandInfo no_value;
@@ -203,7 +208,9 @@
// Since we have already checked that weights are all there or none, we can
// check the existence of only one to the get the condition.
params->use_peephole = !IsNullInput(cell_to_output_weights);
- params->use_layer_norm = !IsNullInput(input_layer_norm_weights);
+ // Checking output instead of input layer norm weights because input can be
+ // omitted ones can be omited in case CIFG LSTM is used.
+ params->use_layer_norm = !IsNullInput(output_layer_norm_weights);
params->use_projection_weight = (projection_weights->lifetime != OperandLifeTime::NO_VALUE);
params->use_projection_bias = (projection_bias->lifetime != OperandLifeTime::NO_VALUE);
@@ -262,12 +269,24 @@
NN_CHECK_EQ(SizeOfDimension(output_layer_norm_weights, 0), n_cell);
}
- const bool layer_norm_weights_all_or_none =
- (IsNullInput(input_layer_norm_weights) && IsNullInput(forget_layer_norm_weights) &&
- IsNullInput(cell_layer_norm_weights) && IsNullInput(input_layer_norm_weights)) ||
- (!IsNullInput(input_layer_norm_weights) && !IsNullInput(forget_layer_norm_weights) &&
- !IsNullInput(cell_layer_norm_weights) && !IsNullInput(input_layer_norm_weights));
- NN_CHECK(layer_norm_weights_all_or_none);
+ if (params->use_cifg) {
+ NN_RET_CHECK(IsNullInput(input_layer_norm_weights))
+ << "input_layer_norm_weights are provided while CIFG is used";
+ const bool layer_norm_weights_all_or_none_cifg =
+ (IsNullInput(forget_layer_norm_weights) && IsNullInput(cell_layer_norm_weights) &&
+ IsNullInput(output_layer_norm_weights)) ||
+ (!IsNullInput(forget_layer_norm_weights) && !IsNullInput(cell_layer_norm_weights) &&
+ !IsNullInput(output_layer_norm_weights));
+ NN_RET_CHECK(layer_norm_weights_all_or_none_cifg);
+ } else {
+ const bool layer_norm_weights_all_or_none =
+ (IsNullInput(input_layer_norm_weights) && IsNullInput(forget_layer_norm_weights) &&
+ IsNullInput(cell_layer_norm_weights) && IsNullInput(output_layer_norm_weights)) ||
+ (!IsNullInput(input_layer_norm_weights) &&
+ !IsNullInput(forget_layer_norm_weights) && !IsNullInput(cell_layer_norm_weights) &&
+ !IsNullInput(output_layer_norm_weights));
+ NN_RET_CHECK(layer_norm_weights_all_or_none);
+ }
return true;
}