Support zero-sized tensor and check for omitted input in CpuExecutor.
Zero-sized tensors are only supported internally to the driver, or
reported in output shapes. It is illegal to pre-specify a zero-sized
tensor as model input or output.
To summarize the meanings of dimension = 0:
- Dimension = 0 for model input: dynamic input and must be fully-specified
at execution time
- Dimension = 0 for internal operand / model output: unknown, to be
deduced from execution
- Dimension = 0 from getOutputOperandDimensions:
* If NO_ERROR, it is a zero-sized output
* If OUTPUT_INSUFFICIENT_SIZE, it is unknown
Add two additional fields in OperationRegistration:
- allowOmittedOperand, if false, CpuExecutor will enforce inputs/outputs
are not null
- allowZeroSizedInput, if false, CpuExecutor will enforce every dimension
of inputs not being 0
The current implementation is assuming that none of the operations with
the old switch statement path support zero-sized input. Only the
operations with OperationResolver can support zero-sized input tensor.
All the operations with OperationResolver are reporting false for
allowZeroSizedInput. Will enable this on a small subset of operations with
separate CLs.
Bug: 126737477
Test: NeuralNetworksTest_static
Change-Id: Ia94d67e4c8c6a49b543d29ebb3b31d509ece0970
Merged-In: Ia94d67e4c8c6a49b543d29ebb3b31d509ece0970
(cherry picked from commit d305bbd09f750145e5d56b14ae268c2919a7cd3c)
diff --git a/common/operations/BidirectionalSequenceRNN.cpp b/common/operations/BidirectionalSequenceRNN.cpp
index 8b23b94..b1bbe25 100644
--- a/common/operations/BidirectionalSequenceRNN.cpp
+++ b/common/operations/BidirectionalSequenceRNN.cpp
@@ -96,7 +96,7 @@
const T* auxInput = nullptr;
const T* fwAuxWeights = nullptr;
const T* bwAuxWeights = nullptr;
- const bool hasAuxInputs = !context->isNullInput(kAuxInputTensor);
+ const bool hasAuxInputs = !context->isOmittedInput(kAuxInputTensor);
if (hasAuxInputs) {
auxInput = context->getInputBuffer<T>(kAuxInputTensor);
fwAuxWeights = context->getInputBuffer<T>(kFwAuxWeightsTensor);
@@ -285,13 +285,14 @@
Shape fwAuxWeights = context->getInputShape(kFwAuxWeightsTensor);
Shape bwAuxWeights = context->getInputShape(kBwAuxWeightsTensor);
- const bool auxInputsAllOrNone =
- (context->isNullInput(kAuxInputTensor) && context->isNullInput(kFwAuxWeightsTensor) &&
- context->isNullInput(kBwAuxWeightsTensor)) ||
- (!context->isNullInput(kAuxInputTensor) && !context->isNullInput(kFwAuxWeightsTensor) &&
- !context->isNullInput(kBwAuxWeightsTensor));
+ const bool auxInputsAllOrNone = (context->isOmittedInput(kAuxInputTensor) &&
+ context->isOmittedInput(kFwAuxWeightsTensor) &&
+ context->isOmittedInput(kBwAuxWeightsTensor)) ||
+ (!context->isOmittedInput(kAuxInputTensor) &&
+ !context->isOmittedInput(kFwAuxWeightsTensor) &&
+ !context->isOmittedInput(kBwAuxWeightsTensor));
NN_RET_CHECK(auxInputsAllOrNone);
- const bool hasAuxInputs = !context->isNullInput(kAuxInputTensor);
+ const bool hasAuxInputs = !context->isOmittedInput(kAuxInputTensor);
int32_t timeMajor = context->getInputValue<bool>(kTimeMajorParam);
const uint32_t batchSize =
@@ -370,7 +371,7 @@
NN_REGISTER_OPERATION(BIDIRECTIONAL_SEQUENCE_RNN, "BIDIRECTIONAL_SEQUENCE_RNN",
bidirectional_sequence_rnn::validate, bidirectional_sequence_rnn::prepare,
- bidirectional_sequence_rnn::execute);
+ bidirectional_sequence_rnn::execute, .allowOmittedOperand = true);
} // namespace nn
} // namespace android