| #include "store_ops.h" |
| |
| #include "caffe2/core/blob_serialization.h" |
| |
| namespace caffe2 { |
| |
| constexpr auto kBlobName = "blob_name"; |
| constexpr auto kAddValue = "add_value"; |
| |
| StoreSetOp::StoreSetOp(const OperatorDef& operator_def, Workspace* ws) |
| : Operator<CPUContext>(operator_def, ws), |
| blobName_( |
| GetSingleArgument<std::string>(kBlobName, operator_def.input(DATA))) { |
| } |
| |
| bool StoreSetOp::RunOnDevice() { |
| // Serialize and pass to store |
| auto* handler = |
| OperatorBase::Input<std::unique_ptr<StoreHandler>>(HANDLER).get(); |
| handler->set(blobName_, SerializeBlob(InputBlob(DATA), blobName_)); |
| return true; |
| } |
| |
| REGISTER_CPU_OPERATOR(StoreSet, StoreSetOp); |
| OPERATOR_SCHEMA(StoreSet) |
| .NumInputs(2) |
| .NumOutputs(0) |
| .SetDoc(R"DOC( |
| Set a blob in a store. The key is the input blob's name and the value |
| is the data in that blob. The key can be overridden by specifying the |
| 'blob_name' argument. |
| )DOC") |
| .Arg("blob_name", "alternative key for the blob (optional)") |
| .Input(0, "handler", "unique_ptr<StoreHandler>") |
| .Input(1, "data", "data blob"); |
| |
| StoreGetOp::StoreGetOp(const OperatorDef& operator_def, Workspace* ws) |
| : Operator<CPUContext>(operator_def, ws), |
| blobName_(GetSingleArgument<std::string>( |
| kBlobName, |
| operator_def.output(DATA))) {} |
| |
| bool StoreGetOp::RunOnDevice() { |
| // Get from store and deserialize |
| auto* handler = |
| OperatorBase::Input<std::unique_ptr<StoreHandler>>(HANDLER).get(); |
| DeserializeBlob(handler->get(blobName_), OperatorBase::Outputs()[DATA]); |
| return true; |
| } |
| |
| REGISTER_CPU_OPERATOR(StoreGet, StoreGetOp); |
| OPERATOR_SCHEMA(StoreGet) |
| .NumInputs(1) |
| .NumOutputs(1) |
| .SetDoc(R"DOC( |
| Get a blob from a store. The key is the output blob's name. The key |
| can be overridden by specifying the 'blob_name' argument. |
| )DOC") |
| .Arg("blob_name", "alternative key for the blob (optional)") |
| .Input(0, "handler", "unique_ptr<StoreHandler>") |
| .Output(0, "data", "data blob"); |
| |
| StoreAddOp::StoreAddOp(const OperatorDef& operator_def, Workspace* ws) |
| : Operator<CPUContext>(operator_def, ws), |
| blobName_(GetSingleArgument<std::string>(kBlobName, "")), |
| addValue_(GetSingleArgument<int64_t>(kAddValue, 1)) { |
| CAFFE_ENFORCE(HasArgument(kBlobName)); |
| } |
| |
| bool StoreAddOp::RunOnDevice() { |
| auto* handler = |
| OperatorBase::Input<std::unique_ptr<StoreHandler>>(HANDLER).get(); |
| Output(VALUE)->Resize(1); |
| Output(VALUE)->mutable_data<int64_t>()[0] = |
| handler->add(blobName_, addValue_); |
| return true; |
| } |
| |
| REGISTER_CPU_OPERATOR(StoreAdd, StoreAddOp); |
| OPERATOR_SCHEMA(StoreAdd) |
| .NumInputs(1) |
| .NumOutputs(1) |
| .SetDoc(R"DOC( |
| Add a value to a remote counter. If the key is not set, the store |
| initializes it to 0 and then performs the add operation. The operation |
| returns the resulting counter value. |
| )DOC") |
| .Arg("blob_name", "key of the counter (required)") |
| .Arg("add_value", "value that is added (optional, default: 1)") |
| .Input(0, "handler", "unique_ptr<StoreHandler>") |
| .Output(0, "value", "the current value of the counter"); |
| |
| StoreWaitOp::StoreWaitOp(const OperatorDef& operator_def, Workspace* ws) |
| : Operator<CPUContext>(operator_def, ws), |
| blobNames_(GetRepeatedArgument<std::string>(kBlobName)) {} |
| |
| bool StoreWaitOp::RunOnDevice() { |
| auto* handler = |
| OperatorBase::Input<std::unique_ptr<StoreHandler>>(HANDLER).get(); |
| if (InputSize() == 2 && Input(1).IsType<std::string>()) { |
| CAFFE_ENFORCE( |
| blobNames_.empty(), "cannot specify both argument and input blob"); |
| std::vector<std::string> blobNames; |
| auto* namesPtr = Input(1).data<std::string>(); |
| for (int i = 0; i < Input(1).size(); ++i) { |
| // NOLINTNEXTLINE(performance-inefficient-vector-operation) |
| blobNames.push_back(namesPtr[i]); |
| } |
| handler->wait(blobNames); |
| } else { |
| handler->wait(blobNames_); |
| } |
| return true; |
| } |
| |
| REGISTER_CPU_OPERATOR(StoreWait, StoreWaitOp); |
| OPERATOR_SCHEMA(StoreWait) |
| .NumInputs(1, 2) |
| .NumOutputs(0) |
| .SetDoc(R"DOC( |
| Wait for the specified blob names to be set. The blob names can be passed |
| either as an input blob with blob names or as an argument. |
| )DOC") |
| .Arg("blob_names", "names of the blobs to wait for (optional)") |
| .Input(0, "handler", "unique_ptr<StoreHandler>") |
| .Input(1, "names", "names of the blobs to wait for (optional)"); |
| } // namespace caffe2 |