| # Copyright (c) 2016-present, Facebook, Inc. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| ############################################################################## |
| |
| ## @package SparseTransformer |
| # Module caffe2.experiments.python.SparseTransformer |
| |
| |
| |
| |
| from caffe2.python import workspace |
| import scipy.sparse |
| |
| |
| class NetDefNode(): |
| |
| def __init__(self, name, optype, p=None, op=None): |
| self.name = name |
| self.optype = optype |
| self.ops = {} |
| self.prev = {} |
| self.insertInput(p) |
| self.visited = False |
| self.op = op |
| |
| def insertInput(self, p): |
| """ |
| Insert input of this op |
| also maintain the output of previous op |
| p: a node or a list of node |
| """ |
| if isinstance(p, list): |
| for i in p: |
| self.prev[i.name] = i |
| i.ops[self.name] = self |
| elif isinstance(p, NetDefNode): |
| self.prev[p.name] = p |
| p.ops[self.name] = self |
| |
| def deleteInput(self, p): |
| if isinstance(p, NetDefNode): |
| del self.prev[p.name] |
| del p.ops[self.name] |
| |
| |
| def maskNallocate(weight_name): |
| """ |
| Combine mask and weights |
| create wcsr, iw, jw, return their names |
| """ |
| w = workspace.FetchBlob(weight_name) |
| w_csr = scipy.sparse.csr_matrix(w) |
| wcsr = w_csr.data |
| iw = w_csr.indptr |
| jw = w_csr.indices |
| workspace.FeedBlob(weight_name + "wcsr", wcsr) |
| workspace.FeedBlob(weight_name + "iw", iw) |
| workspace.FeedBlob(weight_name + "jw", jw) |
| return weight_name + "wcsr", weight_name + "iw", weight_name + "jw" |
| |
| |
| def transFCRelu(cur, id2node, name2id, ops, model): |
| """ |
| Add trans before and after this FC_Prune->(Relu)->FC_Prune chain. |
| """ |
| # 1. add trans before the start of this chain |
| # assuming that cur is a FC_Prune, and it has only one input |
| pre = cur.prev.itervalues().next() |
| # Create a node /op and insert it. |
| # TODO(wyiming): check whether it is correct here |
| current_blob = model.Transpose(cur.op.input[0], cur.op.input[0] + "_trans") |
| # print model.net.Proto() |
| trans_op = model.net.Proto().op[-1] |
| trans_node = NetDefNode(trans_op.output[0], "Transpose", pre, trans_op) |
| trans_node.visited = True |
| pre_new = trans_node |
| |
| # 2. use while loop to visit the chain |
| while True: |
| # breakup with the parent |
| cur.deleteInput(pre) |
| if not (cur.optype == "FC_Prune" or cur.optype == "Relu"): |
| print("Reaching the end of the chain") |
| break |
| if len(cur.ops) > 1: |
| print("A FC/Relu giving more than 1 useful outputs") |
| if cur.optype == "FC_Prune": |
| op = cur.op |
| wcsr, iw, jw = maskNallocate(op.input[1]) |
| bias_name = op.input[3] |
| # TODO(wyiming): create a new Op here |
| current_blob = model.FC_Sparse(current_blob, |
| cur.op.output[0] + "_Sparse", |
| wcsr, iw, jw, bias_name) |
| sps_op = model.net.Proto().op[-1] |
| sps_node = NetDefNode(cur.op.output[0] + "_Sparse", |
| "FC_Sparse", |
| pre_new, sps_op) |
| sps_node.visited = True |
| pre_new = sps_node |
| if cur.optype == "Relu": |
| op = cur.op |
| current_blob = model.Relu(current_blob, current_blob) |
| rel_op = model.net.Proto().op[-1] |
| rel_node = NetDefNode(str(current_blob), "Relu", |
| pre_new, rel_op) |
| rel_node.visited = True |
| pre_new = rel_node |
| |
| cur.visited = True |
| pre = cur |
| flag = False |
| for _, temp in cur.ops.iteritems(): |
| if temp.optype == "Relu" or temp.optype == "FC_Prune": |
| flag = True |
| cur = temp |
| if not flag: |
| # assume that there is only 1 output that is not PrintOP |
| cur = cur.ops.itervalues().next() |
| cur.deleteInput(pre) |
| print("No FC/RElu children") |
| print(cur.op.type) |
| break |
| # 3. add trans after this chain like 1. |
| current_blob = model.Transpose(current_blob, pre.op.output[0]) |
| trans_op = model.net.Proto().op[-1] |
| trans_node = NetDefNode(str(current_blob), "Transpose", pre_new, trans_op) |
| trans_node.visited = True |
| cur.insertInput(trans_node) |
| print(cur.prev) |
| print(trans_node.ops) |
| |
| |
| def Prune2Sparse(cur, id2node, name2id, ops, model): |
| # Assume that FC and Relu takes in only 1 input; |
| # If not raise warning |
| if not cur.visited and cur.optype == "FC_Prune": |
| transFCRelu(cur, id2node, name2id, ops, model) |
| |
| cur.visited = True |
| for name, n in cur.ops.iteritems(): |
| Prune2Sparse(n, id2node, name2id, ops, model) |
| |
| |
| def net2list(net_root): |
| """ |
| Use topological order(BFS) to print the op of a net in a list |
| """ |
| bfs_queue = [] |
| op_list = [] |
| cur = net_root |
| for _, n in cur.ops.iteritems(): |
| bfs_queue.append(n) |
| while bfs_queue: |
| node = bfs_queue[0] |
| bfs_queue = bfs_queue[1:] |
| op_list.append(node.op) |
| for _, n in node.ops.iteritems(): |
| bfs_queue.append(n) |
| |
| return op_list |
| |
| |
| def netbuilder(model): |
| print("Welcome to model checker") |
| proto = model.net.Proto() |
| net_name2id = {} |
| net_id2node = {} |
| net_root = NetDefNode("net_root", "root", None) |
| |
| for op_id, op in enumerate(proto.op): |
| if op.type == "Print": |
| continue |
| op_name = '%s/%s (op#%d)' % (op.name, op.type, op_id) \ |
| if op.name else '%s (op#%d)' % (op.type, op_id) |
| # print(op_name) |
| op_node = NetDefNode(op_name, op.type, op=op) |
| net_id2node[op_id] = op_node |
| |
| if_has_layer_input = False |
| for input_name in op.input: |
| if input_name not in net_name2id: |
| # assume that un_occured name are non_layers |
| # TODO: write a non-layer checker and log it |
| continue |
| op_node.insertInput(net_id2node[net_name2id[input_name]]) |
| if_has_layer_input = True |
| |
| if not if_has_layer_input: |
| op_node.insertInput(net_root) |
| |
| for output_name in op.output: |
| net_name2id[output_name] = op_id |
| |
| return net_root, net_name2id, net_id2node |