| #ifndef CAFFE2_OPERATORS_CREATE_SCOPE_OP_H_ |
| #define CAFFE2_OPERATORS_CREATE_SCOPE_OP_H_ |
| |
| #include <string> |
| #include <unordered_map> |
| #include <unordered_set> |
| #include <vector> |
| |
| #include "caffe2/core/context.h" |
| #include "caffe2/core/logging.h" |
| #include "caffe2/core/operator.h" |
| #include "caffe2/proto/caffe2_pb.h" |
| |
| C10_DECLARE_bool(caffe2_workspace_stack_debug); |
| |
| namespace caffe2 { |
| namespace detail { |
| |
| /* |
| * Keeps track of forward and backward gradient workspaces in stack, |
| * reuses previously created workspaces, non-thread safe |
| */ |
| class TORCH_API WorkspaceStack { |
| public: |
| explicit WorkspaceStack() : parent_ws_(nullptr), top_(-1) {} |
| |
| std::shared_ptr<Workspace> pushForwardWorkspace(Workspace* parent_ws) { |
| return pushForwardWorkspace( |
| parent_ws, std::unordered_map<std::string, std::string>()); |
| } |
| |
| std::shared_ptr<Workspace> pushForwardWorkspace( |
| Workspace* parent_ws, |
| const std::unordered_map<std::string, std::string>& blob_bindings) { |
| checkStack(); |
| if (FLAGS_caffe2_workspace_stack_debug) { |
| if (parent_ws_) { |
| CAFFE_ENFORCE_EQ(parent_ws_, parent_ws, "Parent workspace mismatch"); |
| } else { |
| parent_ws_ = parent_ws; |
| } |
| if (!blob_bindings_.empty()) { |
| checkBindingsMatch(blob_bindings_, blob_bindings); |
| } else { |
| blob_bindings_ = blob_bindings; |
| } |
| } |
| |
| if (top_ == workspaces_.size() - 1) { |
| workspaces_.push_back( |
| std::make_shared<Workspace>(parent_ws, blob_bindings)); |
| } else { |
| // when reusing workspace, make sure copies of external blobs are |
| // removed and blob bindings are set |
| auto& workspace = workspaces_[top_ + 1]; |
| const auto& local_blobs = workspace->LocalBlobs(); |
| std::unordered_set<std::string> local_blobs_set; |
| local_blobs_set.insert(local_blobs.begin(), local_blobs.end()); |
| bool found_local_copy = false; |
| for (const auto& blob_pair : blob_bindings) { |
| if (local_blobs_set.count(blob_pair.first)) { |
| workspace->RemoveBlob(blob_pair.first); |
| found_local_copy = true; |
| } |
| } |
| if (found_local_copy) { |
| workspace->AddBlobMapping(parent_ws, blob_bindings); |
| } |
| } |
| |
| return workspaces_[++top_]; |
| } |
| |
| std::shared_ptr<Workspace> popGradientWorkspace( |
| Workspace* parent_ws, |
| const std::unordered_map<std::string, std::string>& grad_blob_bindings) { |
| checkStack(); |
| if (FLAGS_caffe2_workspace_stack_debug) { |
| if (parent_ws_) { |
| CAFFE_ENFORCE_EQ(parent_ws_, parent_ws, "Parent workspace mismatch"); |
| } else { |
| parent_ws_ = parent_ws; |
| } |
| if (!grad_blob_bindings_.empty()) { |
| checkBindingsMatch(grad_blob_bindings_, grad_blob_bindings); |
| } else { |
| grad_blob_bindings_ = grad_blob_bindings; |
| } |
| } |
| |
| if (top_ < 0) { |
| return nullptr; |
| } |
| auto& grad_workspace = workspaces_[top_]; |
| grad_workspace->AddBlobMapping(parent_ws, grad_blob_bindings, true); |
| --top_; |
| return grad_workspace; |
| } |
| |
| std::shared_ptr<Workspace> reuseLastForwardWorkspace(Workspace* parent_ws) { |
| return reuseLastForwardWorkspace( |
| parent_ws, std::unordered_map<std::string, std::string>()); |
| } |
| |
| std::shared_ptr<Workspace> reuseLastForwardWorkspace( |
| Workspace* parent_ws, |
| const std::unordered_map<std::string, std::string>& blob_bindings) { |
| checkStack(); |
| if (top_ < 0) { |
| return nullptr; |
| } |
| workspaces_[top_]->AddBlobMapping(parent_ws, blob_bindings); |
| return workspaces_[top_]; |
| } |
| |
| void clear() { |
| checkStack(); |
| top_ = -1; |
| } |
| |
| bool empty() const { |
| return top_ < 0; |
| } |
| |
| private: |
| void checkStack() const { |
| CAFFE_ENFORCE_GT( |
| (int)workspaces_.size(), top_, "Corrupted workspaces stack"); |
| } |
| |
| void checkBindingsMatch( |
| const std::unordered_map<std::string, std::string>& bindings, |
| const std::unordered_map<std::string, std::string>& test_bindings) const { |
| CAFFE_ENFORCE_EQ( |
| bindings.size(), test_bindings.size(), "Blob bindings mismatch"); |
| for (const auto& blob_binding : bindings) { |
| CAFFE_ENFORCE( |
| test_bindings.count(blob_binding.first), "Blob bindings mismatch"); |
| CAFFE_ENFORCE_EQ( |
| test_bindings.at(blob_binding.first), |
| blob_binding.second, |
| "Blob bindings mismatch"); |
| } |
| } |
| |
| std::unordered_map<std::string, std::string> blob_bindings_; |
| std::unordered_map<std::string, std::string> grad_blob_bindings_; |
| Workspace* parent_ws_; |
| int top_; |
| std::vector<std::shared_ptr<Workspace>> workspaces_; |
| }; |
| } // namespace detail |
| |
| template <class Context> |
| class CreateScopeOp final : public Operator<Context> { |
| public: |
| template <class... Args> |
| explicit CreateScopeOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...) {} |
| |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| bool RunOnDevice() override; |
| }; |
| |
| template <class Context> |
| class HasScopeOp final : public Operator<Context> { |
| public: |
| template <class... Args> |
| explicit HasScopeOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...) {} |
| |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| bool RunOnDevice() override; |
| }; |
| |
| } // namespace caffe2 |
| |
| #endif // CAFFE2_OPERATORS_CREATE_SCOPE_OP_H_ |