| #include "caffe2/core/graph.h" |
| |
| #include "caffe2/core/common.h" |
| #include "caffe2/core/logging.h" |
| #include "caffe2/core/net.h" |
| #include "caffe2/proto/caffe2_pb.h" |
| |
| namespace caffe2 { |
| |
| namespace transform { |
| |
| Graph::Graph(const NetDef& net) : netdef_(net) { |
| nodes_.clear(); |
| nodes_.resize(net.op_size()); |
| |
| // Copy over operators |
| for (int x = 0; x < net.op_size(); x++) { |
| node(x).op = net.op(x); |
| } |
| |
| // For any blob, which operator was the last to write to it? |
| // In python, this is known as "versions". |
| std::unordered_map<string, int> edge_parent; |
| |
| for (int i = 0; i < (int)nodes_.size(); i++) { |
| for (const string& blob : node(i).op.input()) { |
| auto it = edge_parent.find(blob); |
| if (it != edge_parent.end()) { |
| int j = it->second; |
| node(i).parents[j].push_back(blob); |
| node(j).children[i].push_back(blob); |
| } else { |
| external_input_.insert(blob); |
| } |
| } |
| for (const string& blob : node(i).op.output()) { |
| edge_parent[blob] = i; |
| } |
| } |
| |
| // Traverse opposite direction to find external outputs |
| |
| // For any blob, which operator was the last to read to from it? |
| std::unordered_map<string, int> edge_child; |
| |
| for (int i = (int)nodes_.size() - 1; i >= 0; i--) { |
| for (const string& blob : node(i).op.output()) { |
| auto it = edge_child.find(blob); |
| if (it == edge_child.end()) { |
| external_output_.insert(blob); |
| } |
| } |
| for (const string& blob : node(i).op.input()) { |
| edge_child[blob] = i; |
| } |
| } |
| } |
| |
| const std::vector<std::pair<string, int>> Graph::GetSubgraphInput( |
| const std::vector<int>& match) { |
| return GetSubgraphPerimeterHelper(true, match); |
| } |
| |
| const std::vector<std::pair<string, int>> Graph::GetSubgraphOutput( |
| const std::vector<int>& match) { |
| return GetSubgraphPerimeterHelper(false, match); |
| } |
| |
| // This helper function will either get: |
| // 1) a list for the blobs that write INTO a subgraph |
| // 2) a list of for the blobs that are written FROM a subgraph. |
| // |
| // The "from_children" flag determines if it is case 1 (true) or case 2 (false). |
| const std::vector<std::pair<string, int>> Graph::GetSubgraphPerimeterHelper( |
| bool from_children, |
| const std::vector<int>& match) { |
| std::vector<std::pair<string, int>> edge_list; |
| std::unordered_set<int> match_set(match.begin(), match.end()); |
| for (int x = 0; x < (int)nodes_.size(); x++) { |
| if (!is_node_active(x)) { |
| continue; |
| } |
| if (!match_set.count(x)) { // x is not in subgraph |
| const auto& list = from_children ? node(x).children : node(x).parents; |
| for (const auto& edge : list) { |
| int parent = edge.first; |
| const auto& blobs = edge.second; |
| if (match_set.count(parent)) { // but has a parent that is in subgraph |
| for (const string& blob : blobs) { |
| // NOLINTNEXTLINE(modernize-use-emplace) |
| edge_list.push_back({blob, x}); |
| } |
| } |
| } |
| } |
| } |
| // return the list in sorted order, to allow binary searching |
| std::sort(edge_list.begin(), edge_list.end()); |
| return edge_list; |
| } |
| |
| NetDef Graph::GetNetDef() { |
| std::vector<bool> visited(nodes_.size(), false); |
| |
| // Copy over all the properties of the netdef we're based on |
| NetDef netdef = netdef_; |
| |
| // But we're going to put in our own operators. |
| netdef.clear_op(); |
| |
| // Keeps track of the number of parents yet to be processed. |
| std::vector<int> unchecked_parent_count; |
| |
| // We will perform a topological traversal on the nodes, but we will prefer |
| // nodes that come earlier in the execution order. |
| |
| // This is a min-heap, which stores its elements in ascending order. |
| // This stores the nodes in the order we process them to be in. |
| // This guarantees the lowest lexicographical topological ordering. |
| |
| // This also means the original nodes will be kept in their execution order. |
| // NOLINTNEXTLINE(modernize-use-transparent-functors) |
| std::priority_queue<int, std::vector<int>, std::greater<int>> q; |
| |
| // In our graph, G, the nodes don't have a strict ordering. But in the netdef, |
| // they must (since nets are operators executed in some order). |
| // How do we make sure that the order of operators in our generated netdef |
| // is valid? |
| // 1) The ordering of the netdef must be topologically sorted, respect to G. |
| // If A -> B is an edge in the graph G, then A must come before B in the |
| // netdef's ordering. |
| // 2) No blob conflicts: If A -> B is an edge in the graph G, and A writes to |
| // blob X and B reads from blob X, then there cannot be an op that writes |
| // to blob X between A and B in the ordering. |
| // |
| // Perform a Topological Sort, to find an order for the Operators to be in. |
| // We will keep track of the number of parents each node has. |
| // We begin with an empty queue, and push in all nodes that do not have any |
| // parents. Then, we keep track of all unprocessed parents for each node. |
| // When a node has no more unprocessed parents, we can push it into the queue |
| // to be processed. This guarantees condition 1 is satisfied. |
| |
| // TODO(benz): Currently, condition 2 is not guaranteed to be satisified. |
| // However, giving each blob unique names via SSA will satisfy this condition. |
| // Then, the resulting graph can be optimized with memonger. |
| |
| for (int i = 0; i < (int)nodes_.size(); i++) { |
| unchecked_parent_count.push_back(node(i).parents.size()); |
| if (node(i).parents.size() == 0 && is_node_active(i)) { |
| q.push(i); |
| visited[i] = true; |
| } |
| } |
| |
| while (!q.empty()) { |
| int idx = q.top(); |
| q.pop(); |
| if (!is_node_active(idx)) { |
| continue; |
| } |
| // Creates a new OperatorDef in NetDef |
| auto& op = *(netdef.add_op()); |
| // Sets it equal to the OperatorDef at node(idx) |
| op = node(idx).op; |
| for (const auto& edge : node(idx).children) { |
| int child = edge.first; |
| if (!visited[child] && is_node_active(child)) { |
| unchecked_parent_count[child]--; |
| if (unchecked_parent_count[child] == 0) { |
| q.push(child); |
| visited[child] = true; |
| } |
| } |
| } |
| } |
| return netdef; |
| } |
| |
| void Graph::DeactivateSubgraph(std::vector<int> subgraph) { |
| for (int idx : subgraph) { |
| // remove all edges connected to inactive node |
| for (const auto& edge : node(idx).parents) { |
| int parent = edge.first; |
| node(parent).children.erase(idx); |
| } |
| for (const auto& edge : node(idx).children) { |
| int child = edge.first; |
| node(child).parents.erase(idx); |
| } |
| // actually mark flags as false |
| node(idx).active = false; |
| } |
| } |
| |
| } // namespace transform |
| |
| OperatorDef* AddOp( |
| NetDef* netdef_ptr, |
| string op_type, |
| std::vector<string> inputs, |
| std::vector<string> outputs) { |
| CHECK(netdef_ptr); |
| auto& netdef = *netdef_ptr; |
| auto op_ptr = netdef.add_op(); |
| auto& op = *op_ptr; |
| op.set_type(op_type); |
| for (const string& inp : inputs) { |
| op.add_input(inp); |
| } |
| for (const string& outp : outputs) { |
| op.add_output(outp); |
| } |
| return op_ptr; |
| } |
| |
| bool MatchStrings(string p, string s) { |
| if (p == "*") { // star accepts anything |
| return true; |
| } |
| // TODO(benz): memoize this. (high constant factor boost in performance) |
| vector<string> choices = split('|', p); |
| for (const string& candidate : choices) { |
| if (candidate == s) { |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| bool MatchArguments(const OperatorDef& p_op, const OperatorDef& g_op) { |
| for (const auto& p_arg : p_op.arg()) { |
| if (!p_arg.has_name()) { |
| continue; |
| } |
| bool found = false; |
| for (const auto& g_arg : g_op.arg()) { |
| if (p_arg.name() == g_arg.name()) { |
| found = true; |
| if (p_arg.has_f()) { |
| if (!g_arg.has_f() || p_arg.f() != g_arg.f()) { |
| return false; |
| } |
| } |
| if (p_arg.has_i()) { |
| if (!g_arg.has_i() || p_arg.i() != g_arg.i()) { |
| return false; |
| } |
| } |
| if (p_arg.has_s()) { |
| if (!g_arg.has_s() || !MatchStrings(p_arg.s(), g_arg.s())) { |
| return false; |
| } |
| } |
| if (p_arg.floats_size() != g_arg.floats_size()) { |
| return false; |
| } |
| for (int i = 0; i < p_arg.floats_size(); i++) { |
| if (p_arg.floats(i) != g_arg.floats(i)) { |
| return false; |
| } |
| } |
| if (p_arg.ints_size() != g_arg.ints_size()) { |
| return false; |
| } |
| for (int i = 0; i < p_arg.ints_size(); i++) { |
| if (p_arg.ints(i) != g_arg.ints(i)) { |
| return false; |
| } |
| } |
| if (p_arg.strings_size() != g_arg.strings_size()) { |
| return false; |
| } |
| for (int i = 0; i < p_arg.strings_size(); i++) { |
| if (!MatchStrings(p_arg.strings(i), g_arg.strings(i))) { |
| return false; |
| } |
| } |
| } |
| } |
| if (!found) { |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| } // namespace caffe2 |