| #include "caffe2/predictor/predictor_utils.h" |
| |
| #include "caffe2/core/blob.h" |
| #include "caffe2/core/logging.h" |
| #include "caffe2/proto/caffe2_pb.h" |
| #include "caffe2/proto/predictor_consts.pb.h" |
| #include "caffe2/utils/proto_utils.h" |
| |
| namespace caffe2 { |
| namespace predictor_utils { |
| |
| TORCH_API const NetDef& getNet(const MetaNetDef& def, const std::string& name) { |
| std::string net_names; |
| bool is_first = true; |
| for (const auto& n : def.nets()) { |
| if (!is_first) { |
| net_names += ", "; |
| } |
| is_first = false; |
| net_names += n.key(); |
| if (n.key() == name) { |
| return n.value(); |
| } |
| } |
| CAFFE_THROW("Net not found: ", |
| name, |
| "; available nets: ", |
| net_names); |
| } |
| |
| std::unique_ptr<MetaNetDef> extractMetaNetDef( |
| db::Cursor* cursor, |
| const std::string& key) { |
| CAFFE_ENFORCE(cursor); |
| if (cursor->SupportsSeek()) { |
| cursor->Seek(key); |
| } |
| for (; cursor->Valid(); cursor->Next()) { |
| if (cursor->key() != key) { |
| continue; |
| } |
| // We've found a match. Parse it out. |
| BlobProto proto; |
| CAFFE_ENFORCE(proto.ParseFromString(cursor->value())); |
| Blob blob; |
| DeserializeBlob(proto, &blob); |
| CAFFE_ENFORCE(blob.template IsType<string>()); |
| auto def = std::make_unique<MetaNetDef>(); |
| CAFFE_ENFORCE(def->ParseFromString(blob.template Get<string>())); |
| return def; |
| } |
| CAFFE_THROW("Failed to find in db the key: ", key); |
| } |
| |
| std::unique_ptr<MetaNetDef> runGlobalInitialization( |
| std::unique_ptr<db::DBReader> db, |
| Workspace* master) { |
| CAFFE_ENFORCE(db.get()); |
| auto* cursor = db->cursor(); |
| |
| auto metaNetDef = extractMetaNetDef( |
| cursor, PredictorConsts::default_instance().meta_net_def()); |
| if (metaNetDef->has_modelinfo()) { |
| CAFFE_ENFORCE( |
| metaNetDef->modelinfo().predictortype() == |
| PredictorConsts::default_instance().single_predictor(), |
| "Can only load single predictor"); |
| } |
| VLOG(1) << "Extracted meta net def"; |
| |
| // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
| const auto globalInitNet = getNet( |
| *metaNetDef, PredictorConsts::default_instance().global_init_net_type()); |
| VLOG(1) << "Global init net: " << ProtoDebugString(globalInitNet); |
| |
| // Now, pass away ownership of the DB into the master workspace for |
| // use by the globalInitNet. |
| master->CreateBlob(PredictorConsts::default_instance().predictor_dbreader()) |
| ->Reset(db.release()); |
| |
| // Now, with the DBReader set, we can run globalInitNet. |
| CAFFE_ENFORCE( |
| master->RunNetOnce(globalInitNet), |
| "Failed running the globalInitNet: ", |
| ProtoDebugString(globalInitNet)); |
| |
| return metaNetDef; |
| } |
| |
| } // namespace predictor_utils |
| } // namespace caffe2 |