blob: d78bb601b9c44e1dd5902378c32a2a57e7e8a231 [file] [log] [blame]
#include "caffe2/predictor/transforms.h"
#include "caffe2/onnx/onnx_exporter.h"
#include "caffe2/utils/proto_utils.h"
#include <unordered_set>
namespace caffe2 {
namespace {
bool HasInput(const string& blob, const OperatorDef& op) {
for (const auto& in : op.input()) {
if (blob == in) {
return true;
}
}
return false;
}
bool HasOutput(const string& blob, const OperatorDef& op) {
for (const auto& out : op.output()) {
if (blob == out) {
return true;
}
}
return false;
}
void RewriteSubnetsForIfOp(
const string& from,
const string& to,
OperatorDef* op) {
ArgumentHelper helper(*op);
Argument *then_arg = nullptr, *else_arg = nullptr;
std::map<std::string, std::string> oldname_to_newname;
oldname_to_newname[from] = to;
if (helper.HasSingleArgumentOfType<NetDef>("then_net")) {
then_arg = GetMutableArgument("then_net", false, op);
onnx::rewriteSubnet(then_arg, oldname_to_newname);
}
if (helper.HasSingleArgumentOfType<NetDef>("else_net")) {
else_arg = GetMutableArgument("else_net", false, op);
onnx::rewriteSubnet(else_arg, oldname_to_newname);
}
}
void RenameInputs(
const string& from,
const string& to,
OperatorDef* def,
int op_idx,
std::unordered_map<std::string, std::unordered_set<int>>& children) {
VLOG(2) << "RenameInputs (from=" << from << ", to=" << to << ", "
<< def->DebugString() << ")";
for (int i = 0; i < def->input_size(); i++) {
if (def->input(i) == from) {
*def->mutable_input(i) = to;
children[from].erase(op_idx);
children[to].insert(op_idx);
}
}
// Rename inputs in the subnets of If/AsyncIf op
if (def->type() == "If" || def->type() == "AsyncIf") {
RewriteSubnetsForIfOp(from, to, def);
}
}
void RenameOutputs(
const string& from,
const string& to,
OperatorDef* def,
int op_idx,
std::unordered_map<std::string, std::unordered_set<int>>& parents) {
VLOG(2) << "RenameOutputs (from=" << from << ", to=" << to << ", "
<< def->DebugString() << ")";
for (string& output : *def->mutable_output()) {
if (output == from) {
output = to;
parents[from].erase(op_idx);
parents[to].insert(op_idx);
}
}
// Rename outputs in the subnets of If/AsyncIf op
if (def->type() == "If" || def->type() == "AsyncIf") {
RewriteSubnetsForIfOp(from, to, def);
}
}
void RenameInputsInChildren(
const string& from,
const string& to,
caffe2::NetDef* net,
std::unordered_map<std::string, std::unordered_set<int>>& children) {
VLOG(2) << "RenameInputsInChildren (from=" << from << ", to=" << to << ")";
if (children.count(from) == 0) {
return;
}
// make an temporary copy here because we're going to modify children
for (int child : std::unordered_set<int>(children[from])) {
RenameInputs(from, to, net->mutable_op(child), child, children);
}
}
void RenameOutputInParents(
const std::string& from,
const std::string& to,
caffe2::NetDef* net,
std::unordered_map<std::string, std::unordered_set<int>>& parents) {
VLOG(2) << "RenameOutputInParents (from=" << from << ", to=" << to << ")";
if (parents.count(from) == 0) {
return;
}
// make an temporary copy here because we're going to modify parents
for (int parent : std::unordered_set<int>(parents[from])) {
RenameOutputs(from, to, net->mutable_op(parent), parent, parents);
}
}
bool FoundOpCandidate(
const OperatorDef* op,
int op_idx,
const std::string& op_type,
const std::unordered_set<std::string>& inputs,
const std::unordered_set<std::string>& outputs,
const std::unordered_map<std::string, std::unordered_set<int>>& parents,
const std::unordered_map<std::string, std::unordered_set<int>>& children) {
if (op->type() != op_type) {
VLOG(2) << "InplaceOps(" << op_type << ") skipping op: \n"
<< op->DebugString();
return false;
}
if (op->input_size() != 1 || op->output_size() != 1) {
VLOG(2) << "InplaceOps(" << op_type
<< ") only supports ops with exactly 1 output "
<< "and exactly 1 input. Skipping op: \n"
<< op->DebugString();
return false;
}
// use actual copy because op->input/output may change
const std::string in = op->input(0);
const std::string out = op->output(0);
if (in == out) {
// This case can still exist when in/out is in the predict_net's outputs.
// The op is an inplace op already.
return false;
}
// The following is to handle the special cases of inputs being overwritten
// by ops in the net and then appear in outputs of the net
if (outputs.count(out) == 0) {
// Propagate input downwards
// Make sure that after input is propagated down, it doesn't have parents
// that comes after i but before the new child
int earliest_child = INT_MAX;
const auto& iter = children.find(out);
if (iter != children.end()) {
for (int child : iter->second) {
earliest_child = std::min(earliest_child, child);
}
}
if (earliest_child == INT_MAX) {
return true;
}
const auto& iter2 = parents.find(in);
if (iter2 != parents.end()) {
for (int parent : iter2->second) {
if (parent > op_idx && parent < earliest_child) {
VLOG(2) << "InplaceOps(" << op_type << ") skipping op: \n"
<< op->DebugString();
return false;
}
}
}
} else {
// Propagate output upwards
if (inputs.count(in) != 0 || outputs.count(in) != 0) {
// This is the case when the op is absolutely needed. It exists to serve
// one and only one purpose, to copy from in to out where in is one of
// the net's inputs or outputs and out is one of the net's outputs.
VLOG(2) << "InplaceOps(" << op_type << ") skipping op: \n"
<< op->DebugString();
return false;
}
// find latest parent of in
int latest_parent = -1;
const auto& iter = parents.find(in);
if (iter != parents.end()) {
for (int parent : iter->second) {
latest_parent = std::max(latest_parent, parent);
}
}
if (latest_parent == -1) {
return false;
}
// Make sure that after output is propagated, it doesn't have children that
// comes after its new parent, but before its previous parent
const auto& iter2 = children.find(out);
if (iter2 != children.end()) {
for (int child : iter2->second) {
if (child < op_idx && child > latest_parent) {
VLOG(2) << "InplaceOps(" << op_type << ") skipping op: \n"
<< op->DebugString();
return false;
}
}
}
}
return true;
}
} // namespace
// Conceptually it's a pretty easy process and consists of 3 steps:
// 1) SSA rewrite; 2) propagate inputs forwards; 3) propagate outputs
// backwards and then forwards again. However, because of model outputs
// which can't be overwritten during the SSA process, and the fact that
// inputs could be overwritten by ops and also appear in outputs, it adds
// a lot of extra complexity to handle these special cases. A lot of this
// extra complexity is handled in FoundOpCandidate.
void RemoveOpsByType(InferenceGraph& graph, const std::string& op_type) {
int num_removed = 0;
NetDef* net = graph.predict_net_def.get();
for (auto& op : net->op()) {
if (op.type() == "RecurrentNetwork") {
LOG(INFO) << "RemoveOpsByType does not support RecurrentNetwork yet";
return;
}
}
std::unordered_set<std::string> inputs(
graph.input_names.begin(), graph.input_names.end());
std::unordered_set<std::string> outputs(
graph.output_names.begin(), graph.output_names.end());
if (!graph.predictor_net_ssa_rewritten) {
net->mutable_external_output()->Clear();
// add external_outputs to net as they're necessary to correctly do ssa
// rewriting
for (const auto& o : graph.output_names) {
net->add_external_output(o);
}
onnx::SsaRewrite(nullptr, net);
// clear external_outputs
net->mutable_external_output()->Clear();
graph.predictor_net_ssa_rewritten = true;
}
// construct parents/children graphs to facilitate graph traversal
std::unordered_map<std::string, std::unordered_set<int>> parents, children;
for (int i = 0; i < net->op_size(); i++) {
OperatorDef* op = net->mutable_op(i);
for (auto& in : op->input()) {
children[in].insert(i);
}
for (auto& output : op->output()) {
parents[output].insert(i);
}
}
// Inplace ops. Step 1: propagate inputs downward
for (int i = 0; i < net->op_size(); i++) {
OperatorDef* op = net->mutable_op(i);
if (!FoundOpCandidate(op, i, op_type, inputs, outputs, parents, children)) {
continue;
}
const std::string in = op->input(0);
const std::string out = op->output(0);
if (outputs.count(out) == 0) {
// Rename all apperances of out to in
VLOG(2) << "InplaceOps(" << op_type << ") inplacing op:\n"
<< op->DebugString();
RenameInputsInChildren(out, in, net, children);
RenameOutputs(out, in, op, i, parents);
}
}
// Step 2: propagate outputs upward
for (int i = 0; i < net->op_size(); i++) {
OperatorDef* op = net->mutable_op(i);
if (!FoundOpCandidate(op, i, op_type, inputs, outputs, parents, children)) {
continue;
}
const std::string in = op->input(0);
const std::string out = op->output(0);
if (outputs.count(out) != 0) {
if (inputs.count(in) == 0 && outputs.count(in) == 0) {
// Rename all apperances (regardless of inputs/outputs) of in (if not
// in inputs) to out, when out is guaranteed to be produced a parent
// op. With the parents/children graph which remembers all apprerances
// of nodes (not just immediate parent/children), we don't need to
// propagate the outputs back down again because those cases are already
// handled by RenameOutputInParents and RenameInputsInChildren
if (parents.count(in) > 0 && !parents[in].empty()) {
RenameOutputInParents(in, out, net, parents);
VLOG(2) << "InplaceOps(" << op_type << ") inplacing op:\n"
<< op->DebugString();
RenameInputsInChildren(in, out, net, children);
RenameInputs(in, out, op, i, children);
}
}
}
}
// Remove inplace ops
int i = 0;
while (i < net->op_size()) {
OperatorDef op = net->op(i);
if (op.type() == op_type && op.input_size() == 1 && op.output_size() == 1 &&
op.input(0) == op.output(0)) {
net->mutable_op()->erase(net->mutable_op()->begin() + i);
num_removed++;
VLOG(2) << "RemoveOpsByType(" << op_type << ") deleting inplace op: \n"
<< op.DebugString();
} else {
i++;
VLOG(2) << "RemoveOpsByType(" << op_type << ") skipping op: \n"
<< op.DebugString();
}
}
VLOG(2) << "RemoveOpsByType(" << op_type << ") removed " << num_removed
<< " ops";
}
} // namespace caffe2