| #include <unordered_set> |
| |
| #include "caffe2/core/db.h" |
| #include "caffe2/core/logging.h" |
| #include "caffe2/utils/proto_utils.h" |
| |
| namespace caffe2 { |
| namespace db { |
| |
| class ProtoDBCursor : public Cursor { |
| public: |
| explicit ProtoDBCursor(const TensorProtos* proto) : proto_(proto), iter_(0) {} |
| // NOLINTNEXTLINE(modernize-use-equals-default) |
| ~ProtoDBCursor() override {} |
| |
| void Seek(const string& /*str*/) override { |
| CAFFE_THROW("ProtoDB is not designed to support seeking."); |
| } |
| |
| void SeekToFirst() override { |
| iter_ = 0; |
| } |
| void Next() override { |
| ++iter_; |
| } |
| string key() override { |
| return proto_->protos(iter_).name(); |
| } |
| string value() override { |
| return SerializeAsString_EnforceCheck( |
| proto_->protos(iter_), "ProtoDBCursor"); |
| } |
| bool Valid() override { |
| return iter_ < proto_->protos_size(); |
| } |
| |
| private: |
| const TensorProtos* proto_; |
| int iter_; |
| }; |
| |
| class ProtoDBTransaction : public Transaction { |
| public: |
| explicit ProtoDBTransaction(TensorProtos* proto) |
| : proto_(proto), existing_names_() { |
| for (const auto& tensor : proto_->protos()) { |
| existing_names_.insert(tensor.name()); |
| } |
| } |
| ~ProtoDBTransaction() override { |
| // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) |
| Commit(); |
| } |
| void Put(const string& key, string&& value) override { |
| if (existing_names_.count(key)) { |
| CAFFE_THROW("An item with key ", key, " already exists."); |
| } |
| auto* tensor = proto_->add_protos(); |
| CAFFE_ENFORCE( |
| tensor->ParseFromString(value), |
| "Cannot parse content from the value string."); |
| CAFFE_ENFORCE( |
| tensor->name() == key, |
| "Passed in key ", |
| key, |
| " does not equal to the tensor name ", |
| tensor->name()); |
| } |
| // Commit does nothing. The protocol buffer will be written at destruction |
| // of ProtoDB. |
| void Commit() override {} |
| |
| private: |
| TensorProtos* proto_; |
| std::unordered_set<string> existing_names_; |
| |
| C10_DISABLE_COPY_AND_ASSIGN(ProtoDBTransaction); |
| }; |
| |
| class ProtoDB : public DB { |
| public: |
| ProtoDB(const string& source, Mode mode) |
| : DB(source, mode), proto_(), source_(source) { |
| if (mode == READ || mode == WRITE) { |
| // Read the current protobuffer. |
| CAFFE_ENFORCE( |
| ReadProtoFromFile(source, &proto_), "Cannot read protobuffer."); |
| } |
| LOG(INFO) << "Opened protodb " << source; |
| } |
| ~ProtoDB() override { |
| // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) |
| Close(); |
| } |
| |
| void Close() override { |
| if (mode_ == NEW || mode_ == WRITE) { |
| WriteProtoToBinaryFile(proto_, source_); |
| } |
| } |
| |
| unique_ptr<Cursor> NewCursor() override { |
| return make_unique<ProtoDBCursor>(&proto_); |
| } |
| unique_ptr<Transaction> NewTransaction() override { |
| return make_unique<ProtoDBTransaction>(&proto_); |
| } |
| |
| private: |
| TensorProtos proto_; |
| string source_; |
| }; |
| |
| REGISTER_CAFFE2_DB(ProtoDB, ProtoDB); |
| // For lazy-minded, one can also call with lower-case name. |
| REGISTER_CAFFE2_DB(protodb, ProtoDB); |
| |
| } // namespace db |
| } // namespace caffe2 |