| #include "caffe2/operators/load_save_op.h" |
| |
| #if CAFFE2_HAVE_RE2 |
| #include <re2/re2.h> |
| #else |
| #include <regex> |
| #endif |
| |
| namespace caffe2 { |
| |
| template <> |
| void LoadOp<CPUContext>::SetCurrentDevice(BlobProto* proto) { |
| if (proto->has_tensor()) { |
| proto->mutable_tensor()->clear_device_detail(); |
| proto->mutable_tensor()->mutable_device_detail()->set_device_type( |
| PROTO_CPU); |
| } |
| } |
| |
| template <int VALUE_TYPE = TensorProto_DataType_FLOAT> |
| std::vector<TensorShape> LoadTensorInference( |
| const OperatorDef& def, |
| const vector<TensorShape>& /* unused */) { |
| ArgumentHelper helper(def); |
| auto shape = helper.GetRepeatedArgument<int64_t>("shape"); |
| vector<TensorShape> out; |
| // Currently load op supports only shape. |
| // TODO: We have to extend it to support shapes vector. |
| // Since it support just one shape, we return |
| // the right shape information only when there is just one blob loaded. |
| // Otherwise, we return unknown TensorShapes. |
| if (def.output_size() == 1 && shape.size() > 0) { |
| TensorShape ts; |
| ts.set_data_type(static_cast<TensorProto_DataType>( |
| helper.GetSingleArgument<int>("dtype", VALUE_TYPE))); |
| for (auto d : shape) { |
| ts.add_dims(d); |
| } |
| out.push_back(ts); |
| } else { |
| for (int i = 0; i < def.output_size(); i++) { |
| TensorShape ts; |
| ts.set_unknown_shape(true); |
| out.push_back(ts); |
| } |
| } |
| return out; |
| } |
| |
| namespace internal { |
| |
| SaveOpImpl::SaveOpImpl( |
| OperatorBase* op, |
| const OperatorDef& operator_def, |
| Workspace* ws) |
| : operator_(op), |
| strip_prefix_(op->template GetSingleArgument<string>("strip_prefix", "")), |
| db_type_(op->template GetSingleArgument<string>("db_type", "")), |
| db_options_(op->template GetSingleArgument<string>("db_options", "")), |
| blob_names_( |
| op->template GetRepeatedArgument<string>("blob_name_overrides")) { |
| CAFFE_ENFORCE_GT(db_type_.size(), 0, "Must specify a db type."); |
| CAFFE_ENFORCE( |
| blob_names_.empty() || blob_names_.size() == op->Inputs().size(), |
| "Number of blobs and blob_name_overrides mismatch."); |
| CAFFE_ENFORCE( |
| blob_names_.empty() || strip_prefix_.empty(), |
| "strip_prefix and blob_name_overrides are mutually exclusive."); |
| |
| auto absolute_path = |
| op->template GetSingleArgument<int>("absolute_path", false); |
| auto db_name = op->template GetSingleArgument<string>("db", ""); |
| CAFFE_ENFORCE_GT(db_name.size(), 0, "Must specify a db name."); |
| full_db_name_ = absolute_path ? db_name : (ws->RootFolder() + "/" + db_name); |
| |
| auto options_data = op->template GetSingleArgument<string>("options", ""); |
| if (!options_data.empty()) { |
| if (!options_.ParseFromString(options_data)) { |
| CAFFE_ENFORCE(false, "unable to parse serialization options"); |
| } |
| } |
| if (op->template HasSingleArgumentOfType<int>("chunk_size")) { |
| // The chunk size argument pre-dates the options argument. |
| // If it was passed in, add it to the options list as a final default |
| // setting. |
| auto chunk_size_argument = |
| op->template GetSingleArgument<int>("chunk_size", kDefaultChunkSize); |
| // The chunk_size argument used 0 to mean "no chunking", and -1 to mean |
| // "default chunk size". This is backwards from the behavior of the |
| // chunk_size field in the BlobSerializationOptions, so swap these values if |
| // we see them. (BlobSerializationOptions uses 0 to mean "default chunk |
| // size" since protobuf v3 does not support custom default values, and so we |
| // need to use 0 to mean the default behavior.) |
| constexpr int kOldDefaultChunkSize = -1; |
| constexpr int kOldNoChunking = 0; |
| if (chunk_size_argument == kOldDefaultChunkSize) { |
| chunk_size_argument = kDefaultChunkSize; |
| } else if (chunk_size_argument == kOldNoChunking) { |
| chunk_size_argument = kNoChunking; |
| } |
| options_.mutable_options()->Add()->set_chunk_size(chunk_size_argument); |
| } |
| |
| if (blob_names_.empty()) { |
| std::set<std::string> input_names; |
| blob_names_.resize(op->Inputs().size()); |
| // NOLINTNEXTLINE(clang-diagnostic-sign-compare) |
| for (int i = 0; i < blob_names_.size(); ++i) { |
| std::string name; |
| if (strip_prefix_.empty()) { |
| name = operator_def.input(i); |
| } else { |
| auto match_pos = operator_def.input(i).find(strip_prefix_); |
| if (match_pos == string::npos) { |
| name = operator_def.input(i); |
| } else { |
| name = operator_def.input(i).substr( |
| match_pos + strip_prefix_.size(), string::npos); |
| } |
| } |
| CAFFE_ENFORCE( |
| input_names.insert(name).second, "Duplicated input: ", name); |
| blob_names_[i] = name; |
| } |
| } |
| } |
| |
| namespace { |
| const BlobSerializationOptions& GetBlobOptions( |
| c10::string_view blob_name, |
| const SerializationOptions& options_list, |
| const BlobSerializationOptions& default_options) { |
| for (const auto& options : options_list.options()) { |
| const auto& name_regex = options.blob_name_regex(); |
| if (name_regex.empty()) { |
| return options; |
| } |
| |
| #if CAFFE2_HAVE_RE2 |
| // If we have re2, prefer it over std::regex. |
| re2::RE2 regex(name_regex); |
| if (re2::RE2::FullMatch( |
| re2::StringPiece(blob_name.data(), blob_name.size()), regex)) { |
| return options; |
| } |
| #else |
| // std::regex should be avoided if at all possible, but use it as a fallback |
| // if we don't have re2 (e.g., for some issues with it see |
| // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=61582) |
| if (std::regex_match( |
| blob_name.begin(), blob_name.end(), std::regex(name_regex))) { |
| return options; |
| } |
| #endif |
| } |
| return default_options; |
| } |
| } // namespace |
| |
| bool SaveOpImpl::RunOnDevice() { |
| std::unique_ptr<DB> out_db( |
| caffe2::db::CreateDB(db_type_, full_db_name_, caffe2::db::NEW)); |
| CAFFE_ENFORCE( |
| out_db.get(), |
| "Cannot find db implementation of type ", |
| db_type_, |
| " (while trying to open ", |
| full_db_name_, |
| ")"); |
| if (!db_options_.empty()) { |
| out_db->SetOptions(db_options_); |
| } |
| |
| BlobSerializerBase::SerializationAcceptor acceptor = |
| [&](const std::string& blobName, std::string&& data) { |
| // transaction should take care of locking |
| VLOG(2) << "Sending " << blobName << " blob's data of size " |
| << data.size() << " to db"; |
| auto transaction = out_db->NewTransaction(); |
| transaction->Put(blobName, std::move(data)); |
| transaction->Commit(); |
| }; |
| |
| const vector<const Blob*>& inputs = operator_->OperatorBase::Inputs(); |
| VLOG(0) << "Saving " << inputs.size() << " inputs to " << db_type_ << ": " |
| << full_db_name_; |
| BlobSerializationOptions default_options; |
| // NOLINTNEXTLINE(clang-diagnostic-sign-compare) |
| for (int i = 0; i < inputs.size(); ++i) { |
| SerializeBlob( |
| *inputs[i], |
| blob_names_[i], |
| acceptor, |
| GetBlobOptions(blob_names_[i], options_, default_options)); |
| } |
| out_db->Close(); |
| return true; |
| } |
| |
| } // namespace internal |
| |
| namespace { |
| class EstimateAllBlobSizesOp final : public Operator<CPUContext> { |
| public: |
| explicit EstimateAllBlobSizesOp( |
| const OperatorDef& operator_def, |
| Workspace* ws) |
| : Operator<CPUContext>(operator_def, ws), |
| include_shared_(GetSingleArgument<int>("include_shared", true)), |
| ws_(ws) { |
| auto options_data = GetSingleArgument<string>("options", ""); |
| if (!options_data.empty()) { |
| if (!options_.ParseFromString(options_data)) { |
| CAFFE_ENFORCE(false, "unable to parse serialization options"); |
| } |
| } |
| } |
| |
| bool RunOnDevice() override { |
| const auto& blob_names = include_shared_ ? ws_->Blobs() : ws_->LocalBlobs(); |
| auto* names_out = Output(0, {static_cast<int64_t>(blob_names.size())}, at::dtype<std::string>()); |
| auto* sizes_out = Output(1, {static_cast<int64_t>(blob_names.size())}, at::dtype<int64_t>()); |
| BlobSerializationOptions default_options; |
| for (size_t idx = 0; idx < blob_names.size(); ++idx) { |
| const auto& name = blob_names[idx]; |
| auto* blob = ws_->GetBlob(name); |
| if (!blob) { |
| LOG(ERROR) << "unable to find blob " << name |
| << " when estimating serialization size"; |
| continue; |
| } |
| |
| names_out->template mutable_data<std::string>()[idx] = name; |
| const auto& blob_serialization_options = |
| internal::GetBlobOptions(name, options_, default_options); |
| sizes_out->template mutable_data<int64_t>()[idx] = |
| EstimateSerializedBlobSize(*blob, name, blob_serialization_options); |
| } |
| return true; |
| } |
| |
| private: |
| bool include_shared_{true}; |
| Workspace* ws_{nullptr}; |
| SerializationOptions options_; |
| }; |
| } // namespace |
| |
| REGISTER_CPU_OPERATOR(DBExists, DBExistsOp<CPUContext>); |
| REGISTER_CPU_OPERATOR(Load, LoadOp<CPUContext>); |
| REGISTER_CPU_OPERATOR(Save, SaveOp<CPUContext>); |
| REGISTER_CPU_OPERATOR(Checkpoint, CheckpointOp<CPUContext>); |
| // CPU Operator old name: do NOT use, we may deprecate this later. |
| REGISTER_CPU_OPERATOR(Snapshot, CheckpointOp<CPUContext>); |
| REGISTER_CPU_OPERATOR(EstimateAllBlobSizes, EstimateAllBlobSizesOp); |
| |
| OPERATOR_SCHEMA(DBExists) |
| .NumInputs(0) |
| .NumOutputs(1) |
| .SetDoc(R"DOC( |
| Checks if the db described by the arguments exists. |
| |
| Github Links: |
| |
| - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/load_save_op.cc |
| |
| <details> |
| |
| <summary> <b>Example</b> </summary> |
| |
| **Code** |
| |
| ``` |
| workspace.ResetWorkspace() |
| |
| op = core.CreateOperator( |
| "DBExists", |
| [], |
| ["exists"], |
| db_name="test_db", |
| db_type="leveldb", |
| ) |
| |
| workspace.RunOperatorOnce(op) |
| print("exists:", workspace.FetchBlob("exists")) |
| |
| ``` |
| |
| </details> |
| |
| )DOC") |
| .Output(0, "exists", "*(type: Tensor`<bool>`)* Scalar boolean output " |
| "tensor. True if the db exists, else false.") |
| .Arg( |
| "absolute_path", |
| "*(type: int; default: 0)* If set to non-zero, save the db directly to " |
| "the path specified by the `db` arg. If not set (default), prepend the " |
| "path of the current root folder of the workspace to the path specified " |
| "by the `db` arg.") |
| .Arg("db_name", "*(type: string)* Path to the db in question; see the " |
| "`absolute_path` arg details for options regarding the current root folder " |
| "of the workspace.") |
| .Arg("db_type", "*(type: string)* Type of db to save (options: \"lmdb\", " |
| "\"leveldb\", \"minidb\")."); |
| |
| OPERATOR_SCHEMA(Load) |
| .NumInputs(0, INT_MAX) |
| .NumOutputs(0, INT_MAX) |
| .TensorInferenceFunction(LoadTensorInference<>) |
| .SetDoc(R"DOC( |
| The Load operator loads a set of serialized blobs from a db or multiple dbs. It |
| takes $[0, \infty)$ number of inputs and $[0, \infty)$ number of outputs, using |
| the db keys to match the db entries with the outputs. |
| |
| If at least one input is passed, then it is assumed that that input blobs are a |
| set of DBReaders to load from. Otherwise the `db` or `dbs` argument is used to load |
| blobs from one single db or multiple dbs respectively. `db_type` argument is used |
| to specify the type of the input db/dbs. |
| |
| Github Links: |
| |
| - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/load_save_op.cc |
| |
| <details> |
| |
| <summary> <b>Example</b> </summary> |
| |
| **Code** |
| |
| ``` |
| workspace.ResetWorkspace() |
| |
| op = core.CreateOperator( |
| "Load", |
| [], |
| ["X", "Y"], |
| db="test_db", |
| db_type="lmdb" |
| ) |
| |
| workspace.RunOperatorOnce(op) |
| print("X:", workspace.FetchBlob("X")) |
| print("Y:", workspace.FetchBlob("Y")) |
| |
| ``` |
| |
| </details> |
| |
| )DOC") |
| .Input( |
| 0, |
| "X, Y, ...", |
| "*(type: List(DBReader))* [OPTIONAL] List of DBReaders to load from. Can " |
| "use this instead of the `db`/`dbs` args.") |
| .Arg( |
| "absolute_path", |
| "*(type: int; default: 0)* If set to non-zero, save the db directly to " |
| "the path specified by the `db` arg. If not set (default), prepend the " |
| "path of the current root folder of the workspace to the path specified " |
| "by the `db` arg.") |
| .Arg( |
| "add_prefix", |
| "*(type: string, default: \"\")* Blobs will be prefixed with this when " |
| "loading. Useful for avoiding collisions with blobs existing in the " |
| "workspace. The output blob names specified to this op should include " |
| "this prefix.") |
| .Arg( |
| "strip_prefix", |
| "*(type: string, default: \"\")* Characters in the provided blob names " |
| "that match `strip_prefix` will be removed prior to saving. Also, " |
| "characters that precede `strip_prefix` will be removed. Useful for " |
| "removing device scope from blob names.") |
| .Arg("db", "*(type: string)* The output path of the db. See the " |
| "`absolute_path` arg details for options regarding the current root folder " |
| "of the workspace.") |
| .Arg( |
| "dbs", |
| "*(type: List(string))* List of paths to dbs to load blobs from. See " |
| "the `absolute_path` arg details for options regarding the current " |
| "root folder of the workspace.") |
| .Arg("db_type", "(type: string)* Type of db to save (options: \"lmdb\", " |
| "\"leveldb\", \"minidb\").") |
| .Arg( |
| "keep_device", |
| "*(type: int; default: 0)* If nonzero, the blobs are loaded into the " |
| "device that is specified in the serialized `BlobProto`. Otherwise, " |
| "the device will be set as the one that the `Load` operator is being " |
| "run under.") |
| .Arg( |
| "load_all", |
| "*(type: int; default: 0)* If nonzero, will load all blobs pointed to " |
| "by the db to the workspace overwriting/creating blobs as needed.") |
| .Arg( |
| "allow_incomplete", |
| "*(type: bool; default: False)* If True, will allow not loading all " |
| "the output blobs specified in the outputs.") |
| .Arg( |
| "source_blob_names", |
| "*(type: List(string))* If set, used instead of output blob names to " |
| "specify which blobs in the db shall be loaded. Must be the same " |
| "length as number of output blobs."); |
| |
| OPERATOR_SCHEMA(Save) |
| .NumInputs(1, INT_MAX) |
| .NumOutputs(0) |
| .SetDoc(R"DOC( |
| Saves a set of blobs to a db. It takes $[1, \infty)$ number of inputs and has |
| no output. The contents of the inputs are written into the db using the |
| settings specified by the arguments. |
| |
| Github Links: |
| |
| - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/load_save_op.cc |
| |
| <details> |
| |
| <summary> <b>Example</b> </summary> |
| |
| **Code** |
| |
| ``` |
| workspace.ResetWorkspace() |
| |
| op = core.CreateOperator( |
| "Save", |
| ["X", "Y", "Z"], |
| [], |
| db="test_db2", |
| db_type="leveldb", |
| blob_name_overrides=["x_scores", "y_scores", "z_scores"] |
| ) |
| |
| workspace.FeedBlob("X", np.random.randint(20, size=(5,5))) |
| workspace.FeedBlob("Y", np.random.randint(20, size=(5,5))) |
| workspace.FeedBlob("Z", np.random.randint(20, size=(5,5))) |
| workspace.RunOperatorOnce(op) |
| |
| ``` |
| |
| </details> |
| |
| )DOC") |
| .Arg( |
| "absolute_path", |
| "*(type: int; default: 0)* If set to non-zero, save the db directly to " |
| "the path specified by the `db` arg. If not set (default), prepend the " |
| "path of the current root folder of the workspace to the path specified " |
| "by the `db` arg.") |
| .Arg( |
| "strip_prefix", |
| "*(type: string, default: \"\")* Characters in the provided blob names " |
| "that match `strip_prefix` will be removed prior to saving. Also, " |
| "characters that precede `strip_prefix` will be removed. Useful for " |
| "removing device scope from blob names.") |
| .Arg( |
| "blob_name_overrides", |
| "*(List(string))* If set, used as blob names instead of original blob " |
| "names. Must be same length as number of blobs.") |
| .Arg("db", "*(type: string)* The output path of the db. See the " |
| "`absolute_path` arg details for options regarding the current root folder " |
| "of the workspace.") |
| .Arg("db_type", "*(type: string)* Type of db to save (options: \"lmdb\", " |
| "\"leveldb\", \"minidb\").") |
| .Arg("chunk_size", "*(type: string; default: kDefaultChunkSize)* The chunk " |
| "size to split tensor data into. If not set, caffe2_tensor_chunk_size will " |
| "be used") |
| .Input(0, "X", "*(type: Tensor)* Input tensor(s)."); |
| |
| OPERATOR_SCHEMA(Checkpoint) |
| .NumInputs(1, INT_MAX) |
| .NumOutputs(0) |
| .SetDoc(R"DOC( |
| The Checkpoint operator is similar to the Save operator, but allows one to save |
| to db every few iterations, with a db name that is appended with the iteration |
| count. It takes [1, infinity) number of inputs and has no output. The first |
| input has to be a TensorCPU of type int and has size 1 (i.e. the iteration |
| counter). This is determined whether we need to do checkpointing. |
| )DOC") |
| .Arg( |
| "absolute_path", |
| "(int, default 0) if set, use the db path directly and do not prepend " |
| "the current root folder of the workspace.") |
| .Arg( |
| "db", |
| "(string) a template string that one can combine with the " |
| "iteration to create the final db name. For example, " |
| "\"/home/lonestarr/checkpoint_%08d.db\"") |
| .Arg("db_type", "(string) the type of the db.") |
| .Arg( |
| "every", |
| "(int, default 1) the checkpointing is carried out when " |
| "(iter mod every) is zero."); |
| |
| OPERATOR_SCHEMA(Snapshot); |
| |
| OPERATOR_SCHEMA(EstimateAllBlobSizes) |
| .NumInputs(0) |
| .NumOutputs(2) |
| .SetDoc(R"DOC( |
| Returns two outputs: a 1D tensor of strings containing the names |
| of each blob in the active workspace, and a 1D tensor of integers containing the |
| estimated serialized size of each blob (in bytes). |
| )DOC") |
| .Arg( |
| "include_shared", |
| "(bool, default true) Whether to include blobs " |
| "inherited from parent workspaces.") |
| .Arg( |
| "options", |
| "(string, default empty) A BlobSerializationOptions message specifying " |
| "options for how specific blobs should be serialized.") |
| .Output(0, "blob_names", "1D tensor of strings containing blob names.") |
| .Output(1, "blob_sizes", "1D tensor of int64_t containing blob sizes."); |
| |
| NO_GRADIENT(Load); |
| SHOULD_NOT_DO_GRADIENT(DBExists); |
| SHOULD_NOT_DO_GRADIENT(Save); |
| SHOULD_NOT_DO_GRADIENT(Checkpoint); |
| SHOULD_NOT_DO_GRADIENT(Snapshot); |
| SHOULD_NOT_DO_GRADIENT(EstimateAllBlobSizesOp); |
| |
| } // namespace caffe2 |