blob: be3faaa82e526d8e800a903db25c195a07c28e09 [file] [log] [blame]
/*
* Copyright (C) 2022 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define LOG_TAG "FlatbufferModelBuilder"
#include "FlatbufferModelBuilder.h"
#include <LegacyUtils.h>
#include "FlatbufferModelBuilderUtils.h"
#include "operation_converters/OperationConverterResolver.h"
namespace android {
namespace nn {
void FlatbufferModelBuilder::verifyModel(const tflite::Model* model) {
flatbuffers::Verifier verifier(mBuilder.GetBufferPointer(), mBuilder.GetSize());
CHECK(model != nullptr);
CHECK(model->Verify(verifier));
}
void FlatbufferModelBuilder::initializeBufferVector() {
mBufferVector.clear();
std::vector<uint8_t> emptyData;
auto emptyBuffer = tflite::CreateBufferDirect(mBuilder, &emptyData);
mBufferVector.push_back(emptyBuffer);
}
void FlatbufferModelBuilder::initializeOpCodeIndexForOperationType() {
mOpCodeIndexForOperationType.clear();
mOpCodeIndexForOperationType.resize(kNumberOfOperationTypes, -1);
}
std::vector<MetadataFlatbuffer> FlatbufferModelBuilder::createMetadataVector() {
std::vector<MetadataFlatbuffer> metadataVector;
for (uint32_t i = 0; i < mBufferVector.size(); i++) {
auto metadata = tflite::CreateMetadataDirect(mBuilder, std::to_string(i).c_str() /* name */,
i /* buffer */);
metadataVector.push_back(metadata);
}
return metadataVector;
}
Result<const tflite::Model*> FlatbufferModelBuilder::createTfliteModel() {
mModel = makeModel();
// Initialize and clear data structures
initializeBufferVector();
mOpCodesVector.clear();
initializeOpCodeIndexForOperationType();
// Generate subgraphs
auto subgraphsVector = NN_TRY(createSubGraphs());
auto metadataVector = createMetadataVector();
ModelFlatbuffer flatbufferModel = tflite::CreateModelDirect(
mBuilder, 3 /* version*/, &mOpCodesVector /* operator_codes */,
&subgraphsVector /* subgraphs */, nullptr /* description */,
&mBufferVector /* buffers */, nullptr /* metadata_buffer */,
&metadataVector /* metadata */);
mBuilder.Finish(flatbufferModel);
const tflite::Model* tfliteModel = tflite::GetModel(mBuilder.GetBufferPointer());
verifyModel(tfliteModel);
return tfliteModel;
}
Result<SubGraphFlatbuffer> FlatbufferModelBuilder::createSubGraphFlatbuffer(
const Model::Subgraph& subgraph) {
// TFLite does not support unspecified ranks in Operands
NN_TRY(checkAllTensorOperandsHaveSpecifiedRank(subgraph.operands));
// TFLite does not support dynamic shapes for subgrah output Operands
NN_TRY(checkNoSubgraphOutputOperandsHaveDynamicShape(subgraph.operands));
SubGraphContext context(&mModel, &subgraph, &mBuilder, &mOpCodesVector,
&mOpCodeIndexForOperationType, &mBufferVector);
for (const Operation& operation : subgraph.operations) {
const IOperationConverter* converter =
OperationConverterResolver::get()->findOperationConverter(operation.type);
NN_RET_CHECK(converter != nullptr)
<< "IOperationConverter not implemented for OperationType: " << operation.type;
NN_TRY(converter->convert(operation, &context));
}
for (uint32_t idx : subgraph.inputIndexes) {
context.addSubGraphInput(idx);
}
for (uint32_t idx : subgraph.outputIndexes) {
context.addSubGraphOutput(idx);
}
return context.finish();
}
Result<std::vector<SubGraphFlatbuffer>> FlatbufferModelBuilder::createSubGraphs() {
// We do not support control flow yet
NN_RET_CHECK(mModel.referenced.empty()) << "Control flow for multiple subgraphs not supported";
std::vector<SubGraphFlatbuffer> subGraphVector;
auto mainSubGraph = NN_TRY(createSubGraphFlatbuffer(mModel.main));
subGraphVector.push_back(mainSubGraph);
return subGraphVector;
}
} // namespace nn
} // namespace android