| import torch |
| |
| from torch.fx.node import Node |
| from torch.fx._symbolic_trace import symbolic_trace |
| from torch.fx.passes.tools_common import legalize_graph |
| import itertools |
| import operator |
| |
| from typing import Dict, List |
| |
| |
| def split_result_tensors(result: torch.Tensor, inputs: List[torch.Tensor]) -> List[torch.Tensor]: |
| """ |
| A free function for use in the merge_matmul graph transformation below that |
| splits the output from a merged matmul into the individual results for each |
| input tensor. |
| |
| Arguments: |
| result: The merged matmul result tensor. |
| inputs: The list of inputs that were merged into one for the matmul. |
| |
| Returns: |
| List of matmul results for each input tensor. |
| """ |
| # When fx tracer is running, x.shape[0] will be torch.fx.Attribute but we |
| # need an int even when tracing |
| if isinstance(result, torch.fx.Proxy): |
| splits = [0] * len(inputs) |
| else: |
| splits = [x.shape[0] for x in inputs] |
| |
| return torch.split(result, splits) |
| |
| |
| def may_depend_on(a: Node, b: Node, search_depth: int = 6): |
| """ |
| Determine if one node depends on another in a torch.fx.Graph. |
| |
| Arguments: |
| a: The node that may have a dependency on b. |
| b: The node that a may have a dependency on. |
| search_depth: In the case of an indirect dependency, this function |
| searches upto this many nodes away in search of a |
| data dependency. If none is found, the function |
| makes the conservative assumption that there is a |
| dependency. |
| |
| Returns: |
| True if a may depend on b, False if it definitely does not. |
| """ |
| # Equivalence is defined as dependence. |
| if a == b: |
| return True |
| |
| # If a has no inputs, it cannot depend on b. |
| if len(a.all_input_nodes) == 0: |
| return False |
| |
| # If the search depth has been exhausted and no conclusion has been |
| # reached, assume that there is a data dependency. |
| if search_depth == 0: |
| return True |
| |
| # Recursively check all inputs of a. |
| for inp in a.all_input_nodes: |
| if may_depend_on(inp, b, search_depth - 1): |
| return True |
| |
| return False |
| |
| |
| def are_nodes_independent(nodes: List[Node]): |
| """ |
| Check if all of the given nodes are pairwise-data independent. |
| |
| Arguments: |
| nodes: The nodes to check for data dependencies. |
| |
| Returns: |
| True if any pair in nodes has a data dependency. |
| """ |
| # For each pair in nodes: |
| for i, j in itertools.combinations(nodes, 2): |
| if may_depend_on(i, j) or may_depend_on(j, i): |
| return False |
| |
| return True |
| |
| |
| def merge_matmul(in_mod: torch.nn.Module): |
| """ |
| A graph transformation that merges matrix multiplication operations that share the same right-hand |
| side operand into one large matrix multiplication. |
| ____ _________ _________ |
| ---- | | | | M| A * C | |
| M| A | T| B | * K| C | = |---------| |
| ---- , | | | | T| B * C | |
| K ---- --------- --------- |
| K R R |
| """ |
| gm = symbolic_trace(in_mod) |
| |
| rhs_users: Dict[Node, List[Node]] = {} |
| lhs_users: Dict[Node, List[Node]] = {} |
| |
| # Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to |
| # the matmul of which they are the LHS/RHS. |
| for node in gm.graph.nodes: |
| if node.op != "call_function" or node.target is not torch.matmul: |
| continue |
| |
| lhs, rhs = node.args |
| |
| # TODO: Properly handle aliasing caused by get_attr. For now, |
| # use the attribute name as the operand if the node is a |
| # get_attr. |
| lhs = lhs.target if lhs.op == "get_attr" else lhs |
| rhs = rhs.target if rhs.op == "get_attr" else rhs |
| |
| lhs_users.setdefault(lhs, []).append(node) |
| rhs_users.setdefault(rhs, []).append(node) |
| |
| for rhs, mms in rhs_users.items(): |
| # There must be at least matmuls for a merge to make sense. |
| if len(mms) < 2: |
| continue |
| |
| # All matmuls must not depend on each other directly or indirectly |
| # in order for the merge to be possible. |
| if not are_nodes_independent(mms): |
| continue |
| |
| lhs_vals = [mm.args[0] for mm in mms] |
| |
| # Merge the matmul. |
| # Collect a list of LHS operands and the single RHS operand. |
| lhs = [gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals] |
| rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs |
| |
| # Concatenate all the LHS operands. |
| merge_mm_cat = gm.graph.call_function(torch.cat, (lhs,), {}) |
| |
| # Multiply the concatenated LHS operands with the one RHS. This will produce |
| # the same results as all the individual matmuls involving rhs in the original graph, |
| # but they will all be concatenated together. |
| merge_mm = gm.graph.call_function(torch.matmul, (merge_mm_cat, rhs,), {}) |
| |
| # Split the result of the merged matmul using the shapes of the LHS operands |
| # to ascertain how large each chunk should be. |
| merge_mm_split = gm.graph.call_function( |
| split_result_tensors, (merge_mm, lhs), {} |
| ) |
| merge_mm_res = [ |
| gm.graph.call_function(operator.getitem, (merge_mm_split, out), {}) |
| for out in range(len(lhs)) |
| ] |
| |
| # Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul. |
| for old, new in zip(mms, merge_mm_res): |
| old.replace_all_uses_with(new) |
| gm.graph.erase_node(old) |
| |
| # All of the new nodes created above were inserted at the end, so we need to sort |
| # the nodes topologically to make sure all definitions precede uses. |
| legalize_graph(gm) |
| |
| gm.recompile() |
| gm.graph.lint() |
| return gm |