| #pragma once |
| |
| #include "caffe2/core/common.h" |
| #include "caffe2/proto/caffe2_pb.h" |
| #include "caffe2/utils/proto_utils.h" |
| #include "caffe2/utils/string_utils.h" |
| |
| #include <algorithm> |
| #include <unordered_map> |
| #include <unordered_set> |
| |
| namespace caffe2 { |
| |
| namespace transform { |
| |
| /** |
| * Graph representation of an operator. |
| */ |
| struct TORCH_API Node { |
| public: |
| // Empty constructor for resize |
| Node() {} |
| |
| // Alternate constructor |
| Node( |
| const OperatorDef& op, |
| bool active, |
| std::map<int, std::vector<string>> parents, |
| std::map<int, std::vector<string>> children) |
| : op(op), active(active), parents(parents), children(children) {} |
| |
| // The OperatorDef which this node represents. |
| OperatorDef op; |
| |
| // Keeps track of if an operator has been deleted through a transformation. |
| bool active = true; |
| |
| // Stores a pair (idx, blob_list), |
| // idx = index of the child |
| // blob_list = a list of strings, containing the blobs that connect the nodes |
| std::map<int, std::vector<string>> parents; |
| std::map<int, std::vector<string>> children; |
| }; |
| |
| /** |
| * Graph representation of a Netdef. |
| */ |
| struct TORCH_API Graph { |
| public: |
| /** |
| * Given a subgraph, gets all of the parents of the subgraph, as well as |
| * their associated blob names. Sorted by blob names. |
| * |
| * <string, int> := (name of blob writing into subgraph, |
| * index of node that writes into subgraph using that blob) |
| */ |
| const std::vector<std::pair<string, int>> GetSubgraphInput( |
| const std::vector<int>& subgraph); |
| |
| /** |
| * Given a subgraph, gets all of the children of the subgraph, as well as |
| * their associated blob names. Sorted by blob names. |
| * |
| * <string, int> := (name of blob reading from subgraph, |
| * index of node that reads from subgraph using that blob) |
| */ |
| const std::vector<std::pair<string, int>> GetSubgraphOutput( |
| const std::vector<int>& subgraph); |
| |
| /** |
| * Graph generation. |
| * Given a netdef, returns a Graph. |
| * |
| * Each node represents an operator. |
| * An edge exists between two nodes if the parent op writes to a blob, which |
| * is the input of the child blob, with no other op writing to the blob in |
| * between the execution order. |
| * |
| * Time Complexity: O(E), where E is the number of blobs |
| */ |
| explicit Graph(const NetDef& net_def); |
| |
| /** |
| * Generates a NetDef Representation for the current graph. |
| * Nodes are visited in topological order, which is proper Opdef ordering. |
| * TODO(benz): |
| * There exists conflicts with repeated blob names, where topological sorting |
| * is not sufficient for correct netdef representation, unless blobs are |
| * renamed. |
| * For example, if after a transformation, We have operator ancestry: |
| * A --> B --> C, and also A --> D --> E, where B -> C and D -> E uses the |
| * same blob name, then A, B, D, E, C is a correct topological ordering, |
| * but D will write to the blob that C reads from, instead of B. |
| * Currently believe that there will always be ambiguity unless blobs are |
| * renamed. |
| * This is solved by performing SSA on all transformed blob names. |
| */ |
| NetDef GetNetDef(); |
| |
| /** |
| * Deactivate a subgraph, and get rid of all edges into this subgraph. |
| */ |
| void DeactivateSubgraph(std::vector<int> subgraph); |
| |
| size_t size() const { |
| return nodes_.size(); |
| } |
| |
| void push_node(const Node& new_node) { |
| return nodes_.push_back(new_node); |
| } |
| |
| void resize_nodes(size_t new_size) { |
| nodes_.resize(new_size); |
| } |
| |
| // Index safe, less verbose way to access nodes |
| inline const Node& node(size_t idx) const { |
| return nodes_.at(idx); |
| } |
| |
| inline Node& node(size_t idx) { |
| return nodes_.at(idx); |
| } |
| |
| inline bool is_node_active(size_t idx) { |
| return node(idx).active; |
| } |
| |
| inline const std::set<string>& external_input() const { |
| return external_input_; |
| } |
| |
| inline const std::set<string>& external_output() const { |
| return external_output_; |
| } |
| |
| private: |
| const std::vector<std::pair<string, int>> GetSubgraphPerimeterHelper( |
| bool from_children, |
| const std::vector<int>& match); |
| |
| // Stores the netdef representation. Is updated upon calls to GetNetDef. |
| NetDef netdef_; |
| |
| // Stores which blobs the graph reads from, and writes to. |
| std::set<string> external_input_; |
| std::set<string> external_output_; |
| |
| // Keeps track of all the Operators currently within graph, even if inactive. |
| std::vector<Node> nodes_; |
| }; |
| |
| } // namespace transform |
| |
| // Adds an operator def to a netdef. |
| // Returns the ptr, if you want to add anything extra (such as device_option) |
| TORCH_API OperatorDef* AddOp( |
| NetDef* netdef_ptr, |
| string op_type, |
| std::vector<string> inputs, |
| std::vector<string> outputs); |
| |
| /** |
| * This allows for the use of * and | to match operator types, |
| * engines, or any other property that is represented by strings. |
| * |
| * For example, if we wanted to match an operator to Conv or FC, we can give: |
| * "Conv|FC" as the type() of that op. |
| */ |
| TORCH_API bool MatchStrings(string p, string s); |
| |
| /** |
| * This ensures that each named arg that exists in the pattern exists in g_op, |
| * is equal in value. |
| */ |
| TORCH_API bool MatchArguments(const OperatorDef& p_op, const OperatorDef& g_op); |
| |
| } // namespace caffe2 |