blob: 0b29040d2b6de241490ccc5bbeaeec96c195a5a4 [file] [log] [blame]
#include "store_ops.h"
#include "caffe2/core/blob_serialization.h"
namespace caffe2 {
constexpr auto kBlobName = "blob_name";
constexpr auto kAddValue = "add_value";
StoreSetOp::StoreSetOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<CPUContext>(operator_def, ws),
blobName_(
GetSingleArgument<std::string>(kBlobName, operator_def.input(DATA))) {
}
bool StoreSetOp::RunOnDevice() {
// Serialize and pass to store
auto* handler =
OperatorBase::Input<std::unique_ptr<StoreHandler>>(HANDLER).get();
handler->set(blobName_, SerializeBlob(InputBlob(DATA), blobName_));
return true;
}
REGISTER_CPU_OPERATOR(StoreSet, StoreSetOp);
OPERATOR_SCHEMA(StoreSet)
.NumInputs(2)
.NumOutputs(0)
.SetDoc(R"DOC(
Set a blob in a store. The key is the input blob's name and the value
is the data in that blob. The key can be overridden by specifying the
'blob_name' argument.
)DOC")
.Arg("blob_name", "alternative key for the blob (optional)")
.Input(0, "handler", "unique_ptr<StoreHandler>")
.Input(1, "data", "data blob");
StoreGetOp::StoreGetOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<CPUContext>(operator_def, ws),
blobName_(GetSingleArgument<std::string>(
kBlobName,
operator_def.output(DATA))) {}
bool StoreGetOp::RunOnDevice() {
// Get from store and deserialize
auto* handler =
OperatorBase::Input<std::unique_ptr<StoreHandler>>(HANDLER).get();
DeserializeBlob(handler->get(blobName_), OperatorBase::Outputs()[DATA]);
return true;
}
REGISTER_CPU_OPERATOR(StoreGet, StoreGetOp);
OPERATOR_SCHEMA(StoreGet)
.NumInputs(1)
.NumOutputs(1)
.SetDoc(R"DOC(
Get a blob from a store. The key is the output blob's name. The key
can be overridden by specifying the 'blob_name' argument.
)DOC")
.Arg("blob_name", "alternative key for the blob (optional)")
.Input(0, "handler", "unique_ptr<StoreHandler>")
.Output(0, "data", "data blob");
StoreAddOp::StoreAddOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<CPUContext>(operator_def, ws),
blobName_(GetSingleArgument<std::string>(kBlobName, "")),
addValue_(GetSingleArgument<int64_t>(kAddValue, 1)) {
CAFFE_ENFORCE(HasArgument(kBlobName));
}
bool StoreAddOp::RunOnDevice() {
auto* handler =
OperatorBase::Input<std::unique_ptr<StoreHandler>>(HANDLER).get();
Output(VALUE)->Resize(1);
Output(VALUE)->mutable_data<int64_t>()[0] =
handler->add(blobName_, addValue_);
return true;
}
REGISTER_CPU_OPERATOR(StoreAdd, StoreAddOp);
OPERATOR_SCHEMA(StoreAdd)
.NumInputs(1)
.NumOutputs(1)
.SetDoc(R"DOC(
Add a value to a remote counter. If the key is not set, the store
initializes it to 0 and then performs the add operation. The operation
returns the resulting counter value.
)DOC")
.Arg("blob_name", "key of the counter (required)")
.Arg("add_value", "value that is added (optional, default: 1)")
.Input(0, "handler", "unique_ptr<StoreHandler>")
.Output(0, "value", "the current value of the counter");
StoreWaitOp::StoreWaitOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<CPUContext>(operator_def, ws),
blobNames_(GetRepeatedArgument<std::string>(kBlobName)) {}
bool StoreWaitOp::RunOnDevice() {
auto* handler =
OperatorBase::Input<std::unique_ptr<StoreHandler>>(HANDLER).get();
if (InputSize() == 2 && Input(1).IsType<std::string>()) {
CAFFE_ENFORCE(
blobNames_.empty(), "cannot specify both argument and input blob");
std::vector<std::string> blobNames;
auto* namesPtr = Input(1).data<std::string>();
for (int i = 0; i < Input(1).size(); ++i) {
// NOLINTNEXTLINE(performance-inefficient-vector-operation)
blobNames.push_back(namesPtr[i]);
}
handler->wait(blobNames);
} else {
handler->wait(blobNames_);
}
return true;
}
REGISTER_CPU_OPERATOR(StoreWait, StoreWaitOp);
OPERATOR_SCHEMA(StoreWait)
.NumInputs(1, 2)
.NumOutputs(0)
.SetDoc(R"DOC(
Wait for the specified blob names to be set. The blob names can be passed
either as an input blob with blob names or as an argument.
)DOC")
.Arg("blob_names", "names of the blobs to wait for (optional)")
.Input(0, "handler", "unique_ptr<StoreHandler>")
.Input(1, "names", "names of the blobs to wait for (optional)");
} // namespace caffe2