| #include "caffe2/predictor/transforms.h" |
| #include "caffe2/onnx/onnx_exporter.h" |
| #include "caffe2/utils/proto_utils.h" |
| |
| #include <unordered_set> |
| |
| namespace caffe2 { |
| |
| namespace { |
| bool HasInput(const string& blob, const OperatorDef& op) { |
| for (const auto& in : op.input()) { |
| if (blob == in) { |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| bool HasOutput(const string& blob, const OperatorDef& op) { |
| for (const auto& out : op.output()) { |
| if (blob == out) { |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| void RewriteSubnetsForIfOp( |
| const string& from, |
| const string& to, |
| OperatorDef* op) { |
| ArgumentHelper helper(*op); |
| Argument *then_arg = nullptr, *else_arg = nullptr; |
| |
| std::map<std::string, std::string> oldname_to_newname; |
| oldname_to_newname[from] = to; |
| |
| if (helper.HasSingleArgumentOfType<NetDef>("then_net")) { |
| then_arg = GetMutableArgument("then_net", false, op); |
| onnx::rewriteSubnet(then_arg, oldname_to_newname); |
| } |
| if (helper.HasSingleArgumentOfType<NetDef>("else_net")) { |
| else_arg = GetMutableArgument("else_net", false, op); |
| onnx::rewriteSubnet(else_arg, oldname_to_newname); |
| } |
| } |
| |
| void RenameInputs( |
| const string& from, |
| const string& to, |
| OperatorDef* def, |
| int op_idx, |
| std::unordered_map<std::string, std::unordered_set<int>>& children) { |
| VLOG(2) << "RenameInputs (from=" << from << ", to=" << to << ", " |
| << def->DebugString() << ")"; |
| for (int i = 0; i < def->input_size(); i++) { |
| if (def->input(i) == from) { |
| *def->mutable_input(i) = to; |
| children[from].erase(op_idx); |
| children[to].insert(op_idx); |
| } |
| } |
| // Rename inputs in the subnets of If/AsyncIf op |
| if (def->type() == "If" || def->type() == "AsyncIf") { |
| RewriteSubnetsForIfOp(from, to, def); |
| } |
| } |
| |
| void RenameOutputs( |
| const string& from, |
| const string& to, |
| OperatorDef* def, |
| int op_idx, |
| std::unordered_map<std::string, std::unordered_set<int>>& parents) { |
| VLOG(2) << "RenameOutputs (from=" << from << ", to=" << to << ", " |
| << def->DebugString() << ")"; |
| for (string& output : *def->mutable_output()) { |
| if (output == from) { |
| output = to; |
| parents[from].erase(op_idx); |
| parents[to].insert(op_idx); |
| } |
| } |
| // Rename outputs in the subnets of If/AsyncIf op |
| if (def->type() == "If" || def->type() == "AsyncIf") { |
| RewriteSubnetsForIfOp(from, to, def); |
| } |
| } |
| |
| void RenameInputsInChildren( |
| const string& from, |
| const string& to, |
| caffe2::NetDef* net, |
| std::unordered_map<std::string, std::unordered_set<int>>& children) { |
| VLOG(2) << "RenameInputsInChildren (from=" << from << ", to=" << to << ")"; |
| if (children.count(from) == 0) { |
| return; |
| } |
| |
| // make an temporary copy here because we're going to modify children |
| for (int child : std::unordered_set<int>(children[from])) { |
| RenameInputs(from, to, net->mutable_op(child), child, children); |
| } |
| } |
| |
| void RenameOutputInParents( |
| const std::string& from, |
| const std::string& to, |
| caffe2::NetDef* net, |
| std::unordered_map<std::string, std::unordered_set<int>>& parents) { |
| VLOG(2) << "RenameOutputInParents (from=" << from << ", to=" << to << ")"; |
| if (parents.count(from) == 0) { |
| return; |
| } |
| |
| // make an temporary copy here because we're going to modify parents |
| for (int parent : std::unordered_set<int>(parents[from])) { |
| RenameOutputs(from, to, net->mutable_op(parent), parent, parents); |
| } |
| } |
| |
| bool FoundOpCandidate( |
| const OperatorDef* op, |
| int op_idx, |
| const std::string& op_type, |
| const std::unordered_set<std::string>& inputs, |
| const std::unordered_set<std::string>& outputs, |
| const std::unordered_map<std::string, std::unordered_set<int>>& parents, |
| const std::unordered_map<std::string, std::unordered_set<int>>& children) { |
| if (op->type() != op_type) { |
| VLOG(2) << "InplaceOps(" << op_type << ") skipping op: \n" |
| << op->DebugString(); |
| return false; |
| } |
| if (op->input_size() != 1 || op->output_size() != 1) { |
| VLOG(2) << "InplaceOps(" << op_type |
| << ") only supports ops with exactly 1 output " |
| << "and exactly 1 input. Skipping op: \n" |
| << op->DebugString(); |
| return false; |
| } |
| |
| // use actual copy because op->input/output may change |
| const std::string in = op->input(0); |
| const std::string out = op->output(0); |
| |
| if (in == out) { |
| // This case can still exist when in/out is in the predict_net's outputs. |
| // The op is an inplace op already. |
| return false; |
| } |
| |
| // The following is to handle the special cases of inputs being overwritten |
| // by ops in the net and then appear in outputs of the net |
| if (outputs.count(out) == 0) { |
| // Propagate input downwards |
| // Make sure that after input is propagated down, it doesn't have parents |
| // that comes after i but before the new child |
| int earliest_child = INT_MAX; |
| const auto& iter = children.find(out); |
| if (iter != children.end()) { |
| for (int child : iter->second) { |
| earliest_child = std::min(earliest_child, child); |
| } |
| } |
| if (earliest_child == INT_MAX) { |
| return true; |
| } |
| const auto& iter2 = parents.find(in); |
| if (iter2 != parents.end()) { |
| for (int parent : iter2->second) { |
| if (parent > op_idx && parent < earliest_child) { |
| VLOG(2) << "InplaceOps(" << op_type << ") skipping op: \n" |
| << op->DebugString(); |
| return false; |
| } |
| } |
| } |
| } else { |
| // Propagate output upwards |
| if (inputs.count(in) != 0 || outputs.count(in) != 0) { |
| // This is the case when the op is absolutely needed. It exists to serve |
| // one and only one purpose, to copy from in to out where in is one of |
| // the net's inputs or outputs and out is one of the net's outputs. |
| VLOG(2) << "InplaceOps(" << op_type << ") skipping op: \n" |
| << op->DebugString(); |
| return false; |
| } |
| // find latest parent of in |
| int latest_parent = -1; |
| const auto& iter = parents.find(in); |
| if (iter != parents.end()) { |
| for (int parent : iter->second) { |
| latest_parent = std::max(latest_parent, parent); |
| } |
| } |
| if (latest_parent == -1) { |
| return false; |
| } |
| // Make sure that after output is propagated, it doesn't have children that |
| // comes after its new parent, but before its previous parent |
| const auto& iter2 = children.find(out); |
| if (iter2 != children.end()) { |
| for (int child : iter2->second) { |
| if (child < op_idx && child > latest_parent) { |
| VLOG(2) << "InplaceOps(" << op_type << ") skipping op: \n" |
| << op->DebugString(); |
| return false; |
| } |
| } |
| } |
| } |
| |
| return true; |
| } |
| |
| } // namespace |
| |
| // Conceptually it's a pretty easy process and consists of 3 steps: |
| // 1) SSA rewrite; 2) propagate inputs forwards; 3) propagate outputs |
| // backwards and then forwards again. However, because of model outputs |
| // which can't be overwritten during the SSA process, and the fact that |
| // inputs could be overwritten by ops and also appear in outputs, it adds |
| // a lot of extra complexity to handle these special cases. A lot of this |
| // extra complexity is handled in FoundOpCandidate. |
| void RemoveOpsByType(InferenceGraph& graph, const std::string& op_type) { |
| int num_removed = 0; |
| NetDef* net = graph.predict_net_def.get(); |
| for (auto& op : net->op()) { |
| if (op.type() == "RecurrentNetwork") { |
| LOG(INFO) << "RemoveOpsByType does not support RecurrentNetwork yet"; |
| return; |
| } |
| } |
| |
| std::unordered_set<std::string> inputs( |
| graph.input_names.begin(), graph.input_names.end()); |
| std::unordered_set<std::string> outputs( |
| graph.output_names.begin(), graph.output_names.end()); |
| |
| if (!graph.predictor_net_ssa_rewritten) { |
| net->mutable_external_output()->Clear(); |
| // add external_outputs to net as they're necessary to correctly do ssa |
| // rewriting |
| for (const auto& o : graph.output_names) { |
| net->add_external_output(o); |
| } |
| onnx::SsaRewrite(nullptr, net); |
| // clear external_outputs |
| net->mutable_external_output()->Clear(); |
| graph.predictor_net_ssa_rewritten = true; |
| } |
| |
| // construct parents/children graphs to facilitate graph traversal |
| std::unordered_map<std::string, std::unordered_set<int>> parents, children; |
| for (int i = 0; i < net->op_size(); i++) { |
| OperatorDef* op = net->mutable_op(i); |
| for (auto& in : op->input()) { |
| children[in].insert(i); |
| } |
| for (auto& output : op->output()) { |
| parents[output].insert(i); |
| } |
| } |
| |
| // Inplace ops. Step 1: propagate inputs downward |
| for (int i = 0; i < net->op_size(); i++) { |
| OperatorDef* op = net->mutable_op(i); |
| if (!FoundOpCandidate(op, i, op_type, inputs, outputs, parents, children)) { |
| continue; |
| } |
| const std::string in = op->input(0); |
| const std::string out = op->output(0); |
| if (outputs.count(out) == 0) { |
| // Rename all apperances of out to in |
| VLOG(2) << "InplaceOps(" << op_type << ") inplacing op:\n" |
| << op->DebugString(); |
| RenameInputsInChildren(out, in, net, children); |
| RenameOutputs(out, in, op, i, parents); |
| } |
| } |
| |
| // Step 2: propagate outputs upward |
| for (int i = 0; i < net->op_size(); i++) { |
| OperatorDef* op = net->mutable_op(i); |
| if (!FoundOpCandidate(op, i, op_type, inputs, outputs, parents, children)) { |
| continue; |
| } |
| const std::string in = op->input(0); |
| const std::string out = op->output(0); |
| if (outputs.count(out) != 0) { |
| if (inputs.count(in) == 0 && outputs.count(in) == 0) { |
| // Rename all apperances (regardless of inputs/outputs) of in (if not |
| // in inputs) to out, when out is guaranteed to be produced a parent |
| // op. With the parents/children graph which remembers all apprerances |
| // of nodes (not just immediate parent/children), we don't need to |
| // propagate the outputs back down again because those cases are already |
| // handled by RenameOutputInParents and RenameInputsInChildren |
| if (parents.count(in) > 0 && !parents[in].empty()) { |
| RenameOutputInParents(in, out, net, parents); |
| VLOG(2) << "InplaceOps(" << op_type << ") inplacing op:\n" |
| << op->DebugString(); |
| RenameInputsInChildren(in, out, net, children); |
| RenameInputs(in, out, op, i, children); |
| } |
| } |
| } |
| } |
| |
| // Remove inplace ops |
| int i = 0; |
| while (i < net->op_size()) { |
| OperatorDef op = net->op(i); |
| if (op.type() == op_type && op.input_size() == 1 && op.output_size() == 1 && |
| op.input(0) == op.output(0)) { |
| net->mutable_op()->erase(net->mutable_op()->begin() + i); |
| num_removed++; |
| VLOG(2) << "RemoveOpsByType(" << op_type << ") deleting inplace op: \n" |
| << op.DebugString(); |
| } else { |
| i++; |
| VLOG(2) << "RemoveOpsByType(" << op_type << ") skipping op: \n" |
| << op.DebugString(); |
| } |
| } |
| VLOG(2) << "RemoveOpsByType(" << op_type << ") removed " << num_removed |
| << " ops"; |
| } |
| } // namespace caffe2 |