| #ifndef CAFFE2_OPERATORS_DO_OP_H_ |
| #define CAFFE2_OPERATORS_DO_OP_H_ |
| |
| #include <string> |
| #include <unordered_map> |
| #include <unordered_set> |
| #include <vector> |
| |
| #include "caffe2/core/context.h" |
| #include "caffe2/core/logging.h" |
| #include "caffe2/core/operator.h" |
| #include "caffe2/operators/create_scope_op.h" |
| #include "caffe2/proto/caffe2_pb.h" |
| #include "c10/util/irange.h" |
| |
| namespace caffe2 { |
| |
| template <class Context> |
| class DoOp final : public Operator<Context> { |
| public: |
| explicit DoOp(const OperatorDef& operator_def, Workspace* ws) |
| : Operator<Context>(operator_def, ws), parent_ws_(ws) { |
| CAFFE_ENFORCE( |
| this->template HasSingleArgumentOfType<NetDef>("net"), |
| "net must be specified in Do operator"); |
| net_def_ = this->template GetSingleArgument<NetDef>("net", NetDef()); |
| is_gradient_op_ = operator_def.is_gradient_op(); |
| copy_external_blobs_ = |
| this->template GetSingleArgument<bool>("copy_external_blobs", false); |
| reuse_workspace_ = |
| this->template GetSingleArgument<bool>("reuse_workspace", false); |
| CAFFE_ENFORCE( |
| !(is_gradient_op_ && reuse_workspace_), |
| "Gradient Do op requires use of stacked workspaces"); |
| CAFFE_ENFORCE( |
| !(copy_external_blobs_ && reuse_workspace_), |
| "Reuse workspace and copy external blobs simultaneously in Do op"); |
| |
| const auto& inner_blobs = |
| this->template GetRepeatedArgument<std::string>("inner_blobs"); |
| const auto& outer_blobs_idx = |
| this->template GetRepeatedArgument<int>("outer_blobs_idx"); |
| CAFFE_ENFORCE_EQ( |
| inner_blobs.size(), |
| outer_blobs_idx.size(), |
| "Invalid blob bindings: different inner/outer blobs lengths"); |
| |
| const auto& outer_blob_names = checkAndGetOuterNames(operator_def); |
| std::unordered_set<std::string> used_outer_names; |
| for (const auto blob_idx : c10::irange(inner_blobs.size())) { |
| CAFFE_ENFORCE( |
| !blob_bindings_.count(inner_blobs[blob_idx]), |
| "Invalid blob bindings: redefinition of inner blob " + |
| inner_blobs[blob_idx]); |
| CAFFE_ENFORCE( |
| outer_blobs_idx[blob_idx] >= 0 && |
| // NOLINTNEXTLINE(clang-diagnostic-sign-compare) |
| outer_blobs_idx[blob_idx] < outer_blob_names.size(), |
| "Invalid blob bindings: outer blob index (" + |
| c10::to_string(outer_blobs_idx[blob_idx]) + ", inner name: " + |
| inner_blobs[blob_idx] + ") is out of bounds [0, " + |
| c10::to_string(outer_blob_names.size() - 1) + "]"); |
| const auto& outer_name = outer_blob_names[outer_blobs_idx[blob_idx]]; |
| CAFFE_ENFORCE( |
| !used_outer_names.count(outer_name), |
| "Reusage of outer name: " + outer_name); |
| used_outer_names.insert(outer_name); |
| blob_bindings_[inner_blobs[blob_idx]] = outer_name; |
| forwarded_inner_blobs_.insert(inner_blobs[blob_idx]); |
| } |
| std::unordered_set<std::string> all_outer_names( |
| outer_blob_names.begin(), outer_blob_names.end()); |
| CAFFE_ENFORCE_EQ( |
| used_outer_names.size(), |
| all_outer_names.size(), |
| "Not all outer names are used in blob bindings"); |
| } |
| |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| |
| bool RunOnDevice() override { |
| auto* ws_stack = |
| this->template Output<detail::WorkspaceStack>(OutputSize() - 1); |
| std::shared_ptr<Workspace> net_workspace; |
| if (is_gradient_op_) { |
| net_workspace = |
| ws_stack->popGradientWorkspace(parent_ws_, blob_bindings_); |
| } else { |
| if (reuse_workspace_ && !ws_stack->empty()) { |
| net_workspace = |
| ws_stack->reuseLastForwardWorkspace(parent_ws_, blob_bindings_); |
| } else { |
| net_workspace = |
| ws_stack->pushForwardWorkspace(parent_ws_, blob_bindings_); |
| } |
| } |
| CAFFE_ENFORCE(net_workspace, "Failed to initialize Do op workspace"); |
| |
| // TODO(iliacher): figure how to reuse existing net with a new workspace |
| auto* net = net_workspace->GetNet(net_def_.name()); |
| if (!net) { |
| net = net_workspace->CreateNet(net_def_, true); |
| } |
| CAFFE_ENFORCE(net, "Failed to initialize subnet"); |
| auto success = net->Run(); |
| if (!is_gradient_op_ && copy_external_blobs_) { |
| net_workspace->template CopyForwardedTensors<Context>( |
| forwarded_inner_blobs_); |
| } |
| return success; |
| } |
| |
| private: |
| // returns vector of input blob names followed by output blob names in |
| // operator definition order; ensures that input (output) names are unique, |
| // checks number of input (output) blobs |
| std::vector<std::string> checkAndGetOuterNames( |
| const OperatorDef& operator_def) const { |
| auto input_names = getInputBlobNames(operator_def); |
| CAFFE_ENFORCE(!input_names.empty(), "Expected at least one input blob"); |
| std::string input_ws_blob = input_names.back(); // copy |
| // removing blob that holds pointer op workspace |
| input_names.pop_back(); |
| |
| std::unordered_set<std::string> all_input_names( |
| input_names.begin(), input_names.end()); |
| CAFFE_ENFORCE_EQ( |
| input_names.size(), all_input_names.size(), "Duplicate input blobs"); |
| |
| auto output_names = getOutputBlobNames(operator_def); |
| CAFFE_ENFORCE(!output_names.empty(), "Expected at least one output blob"); |
| const auto& output_ws_blob = output_names.back(); |
| CAFFE_ENFORCE_EQ( |
| input_ws_blob, |
| output_ws_blob, |
| "Expected same input/output workspace blob"); |
| // remove blob that holds pointer to op workspace |
| output_names.pop_back(); |
| |
| std::unordered_set<std::string> all_output_names( |
| output_names.begin(), output_names.end()); |
| CAFFE_ENFORCE_EQ( |
| output_names.size(), all_output_names.size(), "Duplicate output blobs"); |
| |
| std::vector<std::string> outer_blob_names; |
| outer_blob_names.reserve(input_names.size() + output_names.size()); |
| outer_blob_names.insert( |
| outer_blob_names.end(), input_names.begin(), input_names.end()); |
| outer_blob_names.insert( |
| outer_blob_names.end(), output_names.begin(), output_names.end()); |
| return outer_blob_names; |
| } |
| |
| std::vector<std::string> getInputBlobNames( |
| const OperatorDef& operator_def) const { |
| std::vector<std::string> names; |
| names.reserve(operator_def.input_size()); |
| for (const auto idx : c10::irange(operator_def.input_size())) { |
| names.push_back(operator_def.input(idx)); |
| } |
| return names; |
| } |
| |
| std::vector<std::string> getOutputBlobNames( |
| const OperatorDef& operator_def) const { |
| std::vector<std::string> names; |
| names.reserve(operator_def.output_size()); |
| for (const auto idx : c10::irange(operator_def.output_size())) { |
| names.push_back(operator_def.output(idx)); |
| } |
| return names; |
| } |
| |
| std::unordered_map<std::string, std::string> blob_bindings_; |
| std::unordered_set<std::string> forwarded_inner_blobs_; |
| bool is_gradient_op_; |
| bool copy_external_blobs_; |
| bool reuse_workspace_; |
| NetDef net_def_; |
| Workspace* parent_ws_; |
| }; |
| |
| } // namespace caffe2 |
| |
| #endif // CAFFE2_OPERATORS_DO_OP_H_ |