blob: 8864ff6bdf6d94b476a24b9abdd5a8fc4b065395 [file] [log] [blame]
#include "caffe2/predictor/predictor_utils.h"
#include "caffe2/core/blob.h"
#include "caffe2/core/logging.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/proto/predictor_consts.pb.h"
#include "caffe2/utils/proto_utils.h"
namespace caffe2 {
namespace predictor_utils {
TORCH_API const NetDef& getNet(const MetaNetDef& def, const std::string& name) {
std::string net_names;
bool is_first = true;
for (const auto& n : def.nets()) {
if (!is_first) {
net_names += ", ";
}
is_first = false;
net_names += n.key();
if (n.key() == name) {
return n.value();
}
}
CAFFE_THROW("Net not found: ",
name,
"; available nets: ",
net_names);
}
std::unique_ptr<MetaNetDef> extractMetaNetDef(
db::Cursor* cursor,
const std::string& key) {
CAFFE_ENFORCE(cursor);
if (cursor->SupportsSeek()) {
cursor->Seek(key);
}
for (; cursor->Valid(); cursor->Next()) {
if (cursor->key() != key) {
continue;
}
// We've found a match. Parse it out.
BlobProto proto;
CAFFE_ENFORCE(proto.ParseFromString(cursor->value()));
Blob blob;
DeserializeBlob(proto, &blob);
CAFFE_ENFORCE(blob.template IsType<string>());
auto def = std::make_unique<MetaNetDef>();
CAFFE_ENFORCE(def->ParseFromString(blob.template Get<string>()));
return def;
}
CAFFE_THROW("Failed to find in db the key: ", key);
}
std::unique_ptr<MetaNetDef> runGlobalInitialization(
std::unique_ptr<db::DBReader> db,
Workspace* master) {
CAFFE_ENFORCE(db.get());
auto* cursor = db->cursor();
auto metaNetDef = extractMetaNetDef(
cursor, PredictorConsts::default_instance().meta_net_def());
if (metaNetDef->has_modelinfo()) {
CAFFE_ENFORCE(
metaNetDef->modelinfo().predictortype() ==
PredictorConsts::default_instance().single_predictor(),
"Can only load single predictor");
}
VLOG(1) << "Extracted meta net def";
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
const auto globalInitNet = getNet(
*metaNetDef, PredictorConsts::default_instance().global_init_net_type());
VLOG(1) << "Global init net: " << ProtoDebugString(globalInitNet);
// Now, pass away ownership of the DB into the master workspace for
// use by the globalInitNet.
master->CreateBlob(PredictorConsts::default_instance().predictor_dbreader())
->Reset(db.release());
// Now, with the DBReader set, we can run globalInitNet.
CAFFE_ENFORCE(
master->RunNetOnce(globalInitNet),
"Failed running the globalInitNet: ",
ProtoDebugString(globalInitNet));
return metaNetDef;
}
} // namespace predictor_utils
} // namespace caffe2