| #ifndef CAFFE2_OPERATORS_ONNX_WHILE_OP_H_ |
| #define CAFFE2_OPERATORS_ONNX_WHILE_OP_H_ |
| |
| #include "caffe2/core/context.h" |
| #include "caffe2/core/logging.h" |
| #include "caffe2/core/operator.h" |
| #include "caffe2/operators/create_scope_op.h" |
| #include "c10/util/irange.h" |
| |
| namespace caffe2 { |
| |
| template <class Context> |
| class ONNXWhileOp final : public Operator<Context> { |
| public: |
| explicit ONNXWhileOp(const OperatorDef& operator_def, Workspace* ws) |
| : Operator<Context>(operator_def, ws), |
| parent_ws_(ws), |
| has_trip_count_( |
| this->template GetSingleArgument<int64_t>("has_trip_count", 0)), |
| has_cond_(this->template GetSingleArgument<int64_t>("has_cond", 0)), |
| save_scopes_( |
| this->template GetSingleArgument<int64_t>("save_scopes", 0)), |
| disable_scopes_( |
| this->template GetSingleArgument<int64_t>("disable_scopes", 0)), |
| num_loop_carried_deps_(this->template GetSingleArgument<int64_t>( |
| "num_loop_carried_deps", |
| -1)) { |
| CAFFE_ENFORCE( |
| this->template HasSingleArgumentOfType<NetDef>("body"), |
| "body net must be specified in ONNXWhile operator"); |
| if (disable_scopes_) { |
| CAFFE_ENFORCE( |
| !save_scopes_, "Cannot save scopes when disable_scopes=True"); |
| } |
| body_net_def_ = this->template GetSingleArgument<NetDef>("body", NetDef()); |
| static int64_t counter = -1; |
| if (!body_net_def_.has_name()) { |
| if (counter == -1) { |
| ++counter; |
| body_net_def_.set_name("loop_net"); |
| } else { |
| ++counter; |
| body_net_def_.set_name("loop_net." + c10::to_string(counter)); |
| } |
| } |
| } |
| |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| |
| bool RunOnDevice() { |
| return DispatchHelper<TensorTypes<int, bool, long>>::call(this, Input(1)); |
| } |
| |
| // Operator |
| // Inputs: max trip count, condition, initial loop-carried dependencies |
| // Outputs: Final loop-carried dependencies, scan_outputs |
| // Body |
| // Inputs: iteration number, condition, loop-carried dependencies |
| // Outputs: condition, loop-carried dependencies, scan_outputs |
| template <typename CondVarType> |
| bool DoRunWithType() { |
| // Clear workspaces from the previous invocations of the loop |
| // and setup a local scope for the first iteration |
| ws_stack_.clear(); |
| auto loop_ws = !disable_scopes_ |
| ? ws_stack_.pushForwardWorkspace(parent_ws_).get() |
| : parent_ws_; |
| |
| constexpr int64_t num_inputs_before_lcds = 2; |
| // First input is the maximumt trip count. Second input is the condition |
| // variable (for the first iteration). The rest of the inputs are |
| // loop-carried dependencies. |
| int64_t num_loop_carried_deps; |
| if (num_loop_carried_deps_ != -1) { |
| num_loop_carried_deps = num_loop_carried_deps_; |
| } else { |
| num_loop_carried_deps = InputSize() - num_inputs_before_lcds; |
| } |
| int64_t max_trip_count = *Input(0).template data<int64_t>(); |
| const bool first_iter_condition = *Input(1).template data<CondVarType>(); |
| |
| scope_ = std::make_shared<LocalScope>( |
| loop_ws, body_net_def_, num_loop_carried_deps); |
| |
| // Body graph has 1+N+K outputs: recalculated condition variable, N |
| // loop-carried dependencies, and K scan_outputs |
| int num_scan_outputs = |
| scope_->net()->external_output().size() - num_loop_carried_deps - 1; |
| |
| CAFFE_ENFORCE_GE( |
| num_scan_outputs, |
| 0, |
| "Body graph must have N+K outputs, where N is the number " |
| "of loop-carried dependencies and K is the number of scan " |
| "outputs"); |
| |
| // Copy initial loop-carried dependencies |
| for (const auto i : c10::irange(num_loop_carried_deps)) { |
| scope_->lcd_tensor(i)->CopyFrom(Input(i + num_inputs_before_lcds)); |
| } |
| |
| // Initialize iteration variable |
| scope_->set_iteration(0ll); |
| |
| // Initialize input condition variable |
| scope_->template set_input_condition<CondVarType>(first_iter_condition); |
| |
| auto valid_iter_num = [this, max_trip_count](int64_t i) { |
| if (has_trip_count_) { |
| return i < max_trip_count; |
| } else { |
| return true; |
| } |
| }; |
| |
| auto condition_true = [this, first_iter_condition]( |
| int64_t i, bool cond_value) { |
| if (has_cond_) { |
| if (i == 0) { |
| return (bool)first_iter_condition; |
| } else { |
| return cond_value; |
| } |
| } else { |
| return true; |
| } |
| }; |
| |
| // Allocate scan_outputs for zero-iteration case |
| for (const auto i : c10::irange(num_scan_outputs)) { |
| Output(i + num_loop_carried_deps)->Resize(0); |
| Output(i + num_loop_carried_deps)->template mutable_data<int32_t>(); |
| } |
| |
| // Use this to keep track of the sizes of the scan outputs and validate |
| // they're the same across iterations. |
| std::vector<std::vector<int64_t>> scan_outputs_sizes; |
| |
| Workspace* cur_ws = nullptr; |
| bool cur_output_condition = false; |
| |
| while (true) { |
| int64_t itr = scope_->iteration(); |
| if (valid_iter_num(itr) && condition_true(itr, cur_output_condition)) { |
| if (!scope_->net()->Run()) { |
| return false; |
| } |
| |
| cur_ws = scope_->workspace(); |
| cur_output_condition = scope_->template output_condition<CondVarType>(); |
| if (save_scopes_) { |
| loop_ws = ws_stack_.pushForwardWorkspace(parent_ws_).get(); |
| scope_ = std::make_shared<LocalScope>( |
| loop_ws, body_net_def_, num_loop_carried_deps); |
| } |
| |
| // Copy forward loop-carried dependencies |
| for (const auto i : c10::irange(num_loop_carried_deps)) { |
| Blob* b = cur_ws->GetBlob(scope_->net()->external_output()[i + 1]); |
| const Tensor& t = b->template Get<Tensor>(); |
| scope_->lcd_tensor(i)->CopyFrom(t); |
| } |
| // Copy out scan_outputs |
| for (const auto i : c10::irange(num_scan_outputs)) { |
| int net_output_idx = i + 1 + num_loop_carried_deps; |
| const Tensor& scan_output = |
| cur_ws->GetBlob(scope_->net()->external_output()[net_output_idx]) |
| ->template Get<Tensor>(); |
| auto* scan_output_target = Output(i + num_loop_carried_deps); |
| if (itr == 0) { |
| auto dims = scan_output.sizes().vec(); |
| scan_outputs_sizes.push_back(dims); |
| dims.insert(dims.begin(), 1); |
| scan_output_target->Resize(dims); |
| scan_output_target->CopyFrom(scan_output); |
| } else { |
| auto dims = scan_output.sizes().vec(); |
| CAFFE_ENFORCE_EQ( |
| dims, |
| scan_outputs_sizes[i], |
| "Size of scan output changed across iterations"); |
| dims.insert(dims.begin(), itr); |
| scan_output_target->Extend(1, 100); |
| |
| int64_t timestep_size = 1; |
| for (const int64_t t : scan_outputs_sizes[i]) { |
| timestep_size *= t; |
| } |
| |
| const void* src_data = scan_output.raw_data(); |
| auto& sot_meta = scan_output_target->dtype(); |
| void* dst_data = |
| (char*)scan_output_target->raw_mutable_data(sot_meta) + |
| timestep_size * scan_output.itemsize() * itr; |
| memcpy(dst_data, src_data, timestep_size * scan_output.itemsize()); |
| } |
| } |
| scope_->set_iteration(itr + 1ll); |
| scope_->template set_input_condition<CondVarType>(cur_output_condition); |
| } else { |
| break; |
| } |
| } |
| |
| // Copy out final loop-carried dependencies |
| for (const auto i : c10::irange(num_loop_carried_deps)) { |
| Output(i)->CopyFrom(*scope_->lcd_tensor(i)); |
| } |
| |
| return true; |
| } |
| |
| private: |
| class LocalScope { |
| public: |
| LocalScope(Workspace* loop_ws, const NetDef& body_net_def, size_t num_lcds) |
| : loop_ws_(loop_ws) { |
| CAFFE_ENFORCE(loop_ws_, "Failed to initialize local loop workspace"); |
| |
| // Create loop-carried deps in Workspace |
| lcd_tensors_.clear(); |
| // NOLINTNEXTLINE(clang-diagnostic-sign-compare) |
| for (int i = 2; i < num_lcds + 2; ++i) { |
| Blob* b = loop_ws_->CreateBlob(body_net_def.external_input(i)); |
| Tensor* t = BlobGetMutableTensor(b, Context::GetDeviceType()); |
| lcd_tensors_.push_back(t); |
| } |
| // First output is the iteration variable |
| auto* iteration_var_blob = |
| loop_ws_->CreateBlob(body_net_def.external_input(0)); |
| iteration_var_ = |
| BlobGetMutableTensor(iteration_var_blob, Context::GetDeviceType()); |
| |
| input_condition_var_ = BlobGetMutableTensor( |
| loop_ws_->CreateBlob(body_net_def.external_input(1)), |
| Context::GetDeviceType()); |
| |
| auto* condition_var_blob = |
| loop_ws_->CreateBlob(body_net_def.external_output(0)); |
| condition_var_ = |
| BlobGetMutableTensor(condition_var_blob, Context::GetDeviceType()); |
| condition_var_->Resize(1); |
| condition_var_->template mutable_data<bool>(); |
| |
| body_net_ = loop_ws_->GetNet(body_net_def.name()); |
| if (!body_net_) { |
| body_net_ = loop_ws_->CreateNet(body_net_def, true); |
| } |
| CAFFE_ENFORCE(body_net_, "Failed to initialize loop subnet"); |
| } |
| |
| NetBase* net() const { |
| return body_net_; |
| } |
| |
| Workspace* workspace() const { |
| return loop_ws_; |
| } |
| |
| int64_t iteration() const { |
| auto* iteration_var_ptr = |
| iteration_var_->template mutable_data<int64_t>(); |
| return *iteration_var_ptr; |
| } |
| |
| Tensor* lcd_tensor(int idx) { |
| return lcd_tensors_[idx]; |
| } |
| |
| void set_iteration(int64_t itr) { |
| iteration_var_->Resize(); |
| auto* iteration_var_ptr = |
| iteration_var_->template mutable_data<int64_t>(); |
| *iteration_var_ptr = itr; |
| } |
| |
| template <typename CondVarType> |
| void set_input_condition(bool cond_value) { |
| input_condition_var_->Resize(1); |
| auto* input_condition_var_ptr = |
| input_condition_var_->template mutable_data<CondVarType>(); |
| *input_condition_var_ptr = cond_value; |
| } |
| |
| template <typename CondVarType> |
| bool output_condition() const { |
| auto* condition_var_ptr = |
| condition_var_->template mutable_data<CondVarType>(); |
| return *condition_var_ptr; |
| } |
| |
| private: |
| Workspace* loop_ws_; |
| |
| NetBase* body_net_; // owned by a workspace |
| Tensor* iteration_var_; |
| Tensor* input_condition_var_; |
| Tensor* condition_var_; |
| |
| std::vector<Tensor*> lcd_tensors_; |
| }; |
| |
| NetDef body_net_def_; |
| Workspace* parent_ws_; |
| detail::WorkspaceStack ws_stack_; |
| |
| bool has_trip_count_; |
| bool has_cond_; |
| bool save_scopes_; |
| bool disable_scopes_; |
| int64_t num_loop_carried_deps_; |
| |
| std::shared_ptr<LocalScope> scope_; |
| }; |
| |
| } // namespace caffe2 |
| |
| #endif // CAFFE2_OPERATORS_ONNX_WHILE_OP_H |