| import copy |
| import itertools |
| import operator |
| from functools import reduce |
| from typing import Optional |
| |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| |
| from torch._dynamo.utils import fake_mode_from_tensors |
| from torch.fx.experimental.optimization import ( |
| matches_module_pattern, |
| replace_node_module, |
| ) |
| from torch.fx.experimental.symbolic_shapes import guard_int |
| from torch.fx.passes.shape_prop import ShapeProp |
| from torch.nn.modules.utils import _pair |
| from . import config |
| |
| from .fx_utils import matches_module_function_pattern |
| |
| |
| class UnaryAttr: |
| def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None): |
| self.op_name = op_name |
| self.scalars_attr = scalars_attr if scalars_attr else [] |
| self.algorithm_attr = algorithm_attr if algorithm_attr else "" |
| super().__init__() |
| |
| def __call__(self, unary_module: nn.Module): |
| if type(unary_module) is nn.ReLU6: |
| unary_module = nn.Hardtanh(min_val=0, max_val=6) |
| assert all(hasattr(unary_module, item) for item in self.scalars_attr) |
| scalars = [getattr(unary_module, item) for item in self.scalars_attr] |
| |
| algorithm = "" |
| if self.algorithm_attr: |
| assert hasattr(unary_module, self.algorithm_attr) |
| algorithm = getattr(unary_module, self.algorithm_attr) |
| |
| return self.op_name, scalars, algorithm |
| |
| |
| def is_bfloat16_module(m): |
| weight_is_bf16 = m.weight.dtype == torch.bfloat16 |
| bias_is_bf16 = m.bias is None or m.bias.dtype == torch.bfloat16 |
| return weight_is_bf16 and bias_is_bf16 |
| |
| |
| def is_group_depthwise_conv_transpose(m): |
| return ( |
| type(m) in [nn.ConvTranspose2d] and m.groups > 1 and m.groups == m.in_channels |
| ) |
| |
| |
| def check_node_kind(current_node, modules, node_kind): |
| if not isinstance(current_node, torch.fx.Node): |
| return False |
| if current_node.op != "call_module": |
| return False |
| if not isinstance(current_node.target, str): |
| return False |
| if current_node.target not in modules: |
| return False |
| if type(modules[current_node.target]) is not node_kind: |
| return False |
| return True |
| |
| |
| def check_node_is_binary(node): |
| return ( |
| (node.op == "call_function" and node.target in [torch.add, torch.sub]) |
| or ( |
| node.op == "call_function" |
| and node.target |
| in [operator.add, operator.iadd, operator.sub, operator.isub] |
| ) |
| or (node.op == "call_method" and node.target in ["add", "add_", "sub", "sub_"]) |
| ) |
| |
| |
| def check_binary_op_kwargs_is_default(node): |
| # For binary op, we hope the kwargs values are the default value: |
| # torch.sub(add)(input, other, *, alpha=1, out=None). |
| if len(node.args) > 2: |
| return False |
| if len(node.kwargs) > 0: |
| if "out" in node.kwargs and node.kwargs["out"] is not None: |
| return False |
| if "alpha" in node.kwargs and node.kwargs["alpha"] != 1.0: |
| return False |
| return True |
| |
| |
| class ConvUnary2d(nn.Conv2d): |
| def __init__( |
| self, |
| conv: nn.Module, |
| unary: Optional[nn.Module], |
| input_size: list, |
| ): |
| super().__init__( |
| conv.in_channels, |
| conv.out_channels, |
| conv.kernel_size, |
| conv.stride, |
| conv.padding, |
| conv.dilation, |
| conv.groups, |
| conv.bias is not None, |
| conv.padding_mode, |
| conv.weight.device, |
| conv.weight.dtype, |
| ) |
| self._update_module_params(conv, unary, input_size) |
| |
| def _update_module_params(self, conv, unary, input_size): |
| self.__dict__ = copy.deepcopy(conv.__dict__) |
| self.attr = "none" |
| self.scalars = [] |
| self.algorithm = "" |
| if unary is not None: |
| self.attr, self.scalars, self.algorithm = unary_modules_map[ |
| unary.__class__ |
| ](unary) |
| self.weight = torch.nn.Parameter( |
| torch._C._nn.mkldnn_reorder_conv2d_weight( |
| self.weight.to_mkldnn(), |
| self.padding, |
| self.stride, |
| self.dilation, |
| self.groups, |
| tuple(guard_int(x) for x in input_size), |
| ), |
| requires_grad=self.weight.requires_grad, |
| ) |
| |
| def _conv_forward(self, input, weight, bias): |
| if self.padding_mode != "zeros": |
| return torch.ops.mkldnn._convolution_pointwise( |
| F.pad( |
| input, self._reversed_padding_repeated_twice, mode=self.padding_mode |
| ), |
| weight, |
| bias, |
| _pair(0), |
| self.stride, |
| self.dilation, |
| self.groups, |
| self.attr, |
| self.scalars, |
| self.algorithm, |
| ) |
| return torch.ops.mkldnn._convolution_pointwise( |
| input, |
| weight, |
| bias, |
| self.padding, |
| self.stride, |
| self.dilation, |
| self.groups, |
| self.attr, |
| self.scalars, |
| self.algorithm, |
| ) |
| |
| def forward(self, input): |
| return self._conv_forward(input, self.weight, self.bias) |
| |
| |
| class ConvBinary2d(nn.Conv2d): |
| def __init__( |
| self, |
| conv: nn.Module, |
| binary_op_name: str, |
| input_size: list, |
| ): |
| super().__init__( |
| conv.in_channels, |
| conv.out_channels, |
| conv.kernel_size, |
| conv.stride, |
| conv.padding, |
| conv.dilation, |
| conv.groups, |
| conv.bias is not None, |
| conv.padding_mode, |
| conv.weight.device, |
| conv.weight.dtype, |
| ) |
| self._update_module_params(conv, binary_op_name, input_size) |
| |
| def _update_module_params(self, conv, binary_op_name, input_size): |
| self.__dict__ = copy.deepcopy(conv.__dict__) |
| self.binary_attr = binary_op_name |
| self.binary_alpha = None |
| self.unary_attr = None |
| self.unary_scalars = [] |
| self.unary_algorithm = None |
| self.weight = torch.nn.Parameter( |
| torch._C._nn.mkldnn_reorder_conv2d_weight( |
| self.weight.to_mkldnn(), |
| self.padding, |
| self.stride, |
| self.dilation, |
| self.groups, |
| tuple(guard_int(x) for x in input_size), |
| ), |
| requires_grad=self.weight.requires_grad, |
| ) |
| |
| def _update_unary_params(self, unary): |
| self.unary_attr, self.unary_scalars, self.unary_algorithm = unary_modules_map[ |
| unary.__class__ |
| ](unary) |
| |
| def _conv_forward(self, input, other, weight, bias): |
| if self.padding_mode != "zeros": |
| return torch.ops.mkldnn._convolution_pointwise( |
| F.pad( |
| input, self._reversed_padding_repeated_twice, mode=self.padding_mode |
| ), |
| other, |
| weight, |
| bias, |
| _pair(0), |
| self.stride, |
| self.dilation, |
| self.groups, |
| self.binary_attr, |
| self.binary_alpha, |
| self.unary_attr, |
| self.unary_scalars, |
| self.unary_algorithm, |
| ) |
| return torch.ops.mkldnn._convolution_pointwise( |
| input, |
| other, |
| weight, |
| bias, |
| self.padding, |
| self.stride, |
| self.dilation, |
| self.groups, |
| self.binary_attr, |
| self.binary_alpha, |
| self.unary_attr, |
| self.unary_scalars, |
| self.unary_algorithm, |
| ) |
| |
| def forward(self, input, other): |
| return self._conv_forward(input, other, self.weight, self.bias) |
| |
| |
| class PackedLinear(nn.Linear): |
| def __init__(self, linear: nn.Module, input_size: list): |
| super().__init__( |
| linear.in_features, |
| linear.out_features, |
| linear.bias is not None, |
| linear.weight.device, |
| linear.weight.dtype, |
| ) |
| self._update_module_params(linear, input_size) |
| |
| def _update_module_params(self, linear, input_size): |
| self.__dict__ = copy.deepcopy(linear.__dict__) |
| self.batch_size = reduce(lambda x, y: x * y, input_size[:-1]) |
| self.packed_weight = torch.nn.Parameter( |
| torch.ops.mkl._mkl_reorder_linear_weight( |
| self.weight.to_mkldnn(), self.batch_size |
| ), |
| requires_grad=self.weight.requires_grad, |
| ) |
| |
| def forward(self, input): |
| y = torch.ops.mkl._mkl_linear( |
| input, self.packed_weight, self.weight, self.bias, self.batch_size |
| ) |
| return y |
| |
| |
| class LinearUnary(nn.Linear): |
| def __init__( |
| self, |
| linear: nn.Module, |
| unary: nn.Module, |
| ): |
| super().__init__( |
| linear.in_features, |
| linear.out_features, |
| linear.bias is not None, |
| linear.weight.device, |
| linear.weight.dtype, |
| ) |
| self._update_module_params(linear, unary) |
| |
| def _update_module_params(self, linear, unary): |
| self.__dict__ = copy.deepcopy(linear.__dict__) |
| self.attr, self.scalars, self.algorithm = unary_modules_map[unary.__class__]( |
| unary |
| ) |
| |
| def forward(self, input): |
| y = torch.ops.mkldnn._linear_pointwise( |
| input, self.weight, self.bias, self.attr, self.scalars, self.algorithm |
| ) |
| return y |
| |
| |
| class LinearBinary(nn.Linear): |
| def __init__(self, linear: nn.Module, binary_op_name: str): |
| super().__init__( |
| linear.in_features, |
| linear.out_features, |
| linear.bias is not None, |
| linear.weight.device, |
| linear.weight.dtype, |
| ) |
| self._update_module_params(linear, binary_op_name) |
| |
| def _update_module_params(self, linear, binary_op_name): |
| self.__dict__ = copy.deepcopy(linear.__dict__) |
| |
| self.attr = binary_op_name |
| |
| def forward(self, input, other): |
| y = torch.ops.mkldnn._linear_pointwise( |
| input, other, self.weight, self.bias, self.attr |
| ) |
| return y |
| |
| |
| class ConvTransposeUnary2d(nn.ConvTranspose2d): |
| def __init__( |
| self, |
| conv_transpose: nn.Module, |
| unary: Optional[nn.Module], |
| input_size: list, |
| ): |
| super().__init__( |
| conv_transpose.in_channels, |
| conv_transpose.out_channels, |
| conv_transpose.kernel_size, |
| conv_transpose.stride, |
| conv_transpose.padding, |
| conv_transpose.output_padding, |
| conv_transpose.groups, |
| conv_transpose.bias is not None, |
| conv_transpose.dilation, |
| conv_transpose.padding_mode, |
| conv_transpose.weight.device, |
| conv_transpose.weight.dtype, |
| ) |
| self._update_module_params(conv_transpose, unary, input_size) |
| |
| def _update_module_params(self, conv_transpose, unary, input_size): |
| self.__dict__ = copy.deepcopy(conv_transpose.__dict__) |
| self.attr, self.scalars, self.algorithm = ( |
| unary_modules_map[unary.__class__](unary) if unary else ("none", [], "") |
| ) |
| packed_weight = torch.ops.mkldnn._reorder_convolution_transpose_weight( |
| self.weight.to_mkldnn(), |
| self.padding, |
| self.output_padding, |
| self.stride, |
| self.dilation, |
| self.groups, |
| input_size, |
| ) |
| self.weight = torch.nn.Parameter( |
| packed_weight, |
| requires_grad=self.weight.requires_grad, |
| ) |
| |
| def _conv_transpose_forward(self, input, weight, bias): |
| if self.padding_mode != "zeros": |
| return torch.ops.mkldnn._convolution_transpose_pointwise( |
| F.pad( |
| input, self._reversed_padding_repeated_twice, mode=self.padding_mode |
| ), |
| weight, |
| bias, |
| _pair(0), |
| self.output_padding, |
| self.stride, |
| self.dilation, |
| self.groups, |
| self.attr, |
| self.scalars, |
| self.algorithm, |
| ) |
| return torch.ops.mkldnn._convolution_transpose_pointwise( |
| input, |
| weight, |
| bias, |
| self.padding, |
| self.output_padding, |
| self.stride, |
| self.dilation, |
| self.groups, |
| self.attr, |
| self.scalars, |
| self.algorithm, |
| ) |
| |
| def forward(self, input): |
| return self._conv_transpose_forward(input, self.weight, self.bias) |
| |
| |
| def packed_conv_eval(conv: nn.Module, input_size: list): |
| assert not (conv.training), "Fusion only for eval!" |
| return ConvUnary2d( |
| conv, |
| None, |
| input_size, |
| ) |
| |
| |
| def packed_conv_transpose_eval(conv_transpose: nn.Module, input_size: list): |
| assert not (conv_transpose.training), "Fusion only for eval!" |
| return ConvTransposeUnary2d( |
| conv_transpose, |
| None, |
| input_size, |
| ) |
| |
| |
| def fused_conv_unary_eval(conv: nn.Module, unary: nn.Module, input_size: list): |
| assert not (conv.training), "Fusion only for eval!" |
| return ConvUnary2d( |
| conv, |
| unary, |
| input_size, |
| ) |
| |
| |
| def fused_conv_binary_eval(conv: nn.Module, binary_op_name: str, input_size: list): |
| assert not (conv.training), "Fusion only for eval!" |
| return ConvBinary2d( |
| conv, |
| binary_op_name, |
| input_size, |
| ) |
| |
| |
| def fused_conv_binary_unary_eval( |
| conv_binary: nn.Module, unary: nn.Module, input_size: list |
| ): |
| assert not (conv_binary.training), "Fusion only for eval!" |
| # reuse origin conv module, and just update its' unary attr. |
| conv_binary._update_unary_params(unary) |
| return conv_binary |
| |
| |
| def packed_linear_eval(linear: nn.Module, input_size: list): |
| assert not (linear.training), "Fusion only for eval!" |
| return PackedLinear(linear, input_size) |
| |
| |
| def fused_linear_unary_eval(linear: nn.Module, unary: nn.Module, input_size: list): |
| assert not (linear.training), "Fusion only for eval!" |
| return LinearUnary( |
| linear, |
| unary, |
| ) |
| |
| |
| def fused_linear_binary_eval(linear: nn.Module, attr: str, input_size: list): |
| assert not (linear.training), "Fusion only for eval!" |
| linear_binary = LinearBinary( |
| linear, |
| attr, |
| ) |
| return linear_binary |
| |
| |
| def fused_conv_transpose_unary_eval( |
| conv_transpose: nn.Module, unary: nn.Module, input_size: list |
| ): |
| assert not (conv_transpose.training), "Fusion only for eval!" |
| return ConvTransposeUnary2d( |
| conv_transpose, |
| unary, |
| input_size, |
| ) |
| |
| |
| def mkldnn_fuse_fx(gm: torch.fx.GraphModule, example_inputs): |
| is_cpu = all( |
| example_input.device == torch.device("cpu") |
| for example_input in example_inputs |
| if isinstance(example_input, torch.Tensor) |
| ) |
| |
| # make sure the autograd is disabled. |
| if torch.is_grad_enabled(): |
| return gm |
| if not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()): |
| return gm |
| if not is_cpu: |
| return gm |
| # For binary fusion, we need to check inputs info to make sure |
| # the binary inputs have same tensor info(device, dtype, and layout). |
| |
| fake_mode = fake_mode_from_tensors(example_inputs) |
| ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs) |
| gm = fuse_unary(gm) |
| gm = fuse_binary(gm) |
| # why re-run fuse_unary? we want to enable conv+binary+unary fusion, |
| # such as conv+add+relu for vision model. |
| gm = fuse_unary(gm) |
| if config.cpp.weight_prepack: |
| gm = pack_module(gm) |
| return gm |
| |
| |
| def create_unary_module(node: torch.fx.node): |
| assert ( |
| node.op == "call_function" or node.op == "call_method" |
| ), "The current node should be a function/method node" |
| unary_map = { |
| F.relu: nn.ReLU, |
| F.sigmoid: nn.Sigmoid, |
| F.tanh: nn.Tanh, |
| F.hardswish: nn.Hardswish, |
| F.leaky_relu: nn.LeakyReLU, |
| F.hardtanh: nn.Hardtanh, |
| F.gelu: nn.GELU, |
| F.relu6: nn.ReLU6, |
| F.silu: nn.SiLU, |
| F.hardsigmoid: nn.Hardsigmoid, |
| torch.relu: nn.ReLU, |
| torch.sigmoid: nn.Sigmoid, |
| torch.tanh: nn.Tanh, |
| "relu": nn.ReLU, |
| "sigmoid": nn.Sigmoid, |
| "tanh": nn.Tanh, |
| } |
| return unary_map[node.target](*(node.args[1:]), **(node.kwargs)) |
| |
| |
| def fuse_unary(gm: torch.fx.GraphModule): |
| modules = dict(gm.named_modules()) |
| |
| for unary_op, ( |
| computation_module, |
| fuse_func, |
| ) in itertools.product(unary_ops, computation_op_unary_op_fusion_map.items()): |
| pattern = (computation_module, unary_op) |
| for node in gm.graph.nodes: |
| if matches_module_pattern( |
| pattern, node, modules |
| ) or matches_module_function_pattern(pattern, node, modules): |
| if ( |
| len(node.args[0].users) > 1 |
| ): # Output of computation_node is used by other nodes |
| continue |
| computation_node = modules[node.args[0].target] |
| if node.op == "call_function" or node.op == "call_method": |
| # make sure unary function's inputs only one fx.node(others should be constant value). |
| if any(isinstance(v, torch.fx.Node) for v in node.args[1:]) or any( |
| isinstance(v, torch.fx.Node) for _, v in node.kwargs.items() |
| ): |
| continue |
| unary_node = create_unary_module(node) |
| unary_node.eval() |
| else: |
| unary_node = modules[node.target] |
| eval_mode = all(not n.training for n in [computation_node, unary_node]) |
| if not eval_mode: |
| continue |
| # TODO: support padding str input("valid", "same"). |
| if type(computation_node) in [nn.Conv2d] and isinstance( |
| computation_node.padding, str |
| ): |
| continue |
| # TODO: support more conv+binary+unary fusion. |
| if type(computation_node) in [ConvBinary2d] and type( |
| unary_node |
| ) not in [nn.ReLU]: |
| continue |
| # only fuse for linear when the dtype is bf16 |
| if type(computation_node) in [nn.Linear] and not is_bfloat16_module( |
| computation_node |
| ): |
| continue |
| # TODO: remove this when group depthwise ConvTranspose is supported |
| if is_group_depthwise_conv_transpose(computation_node): |
| continue |
| computation_node_input_size = ( |
| node.args[0].args[0].meta.get("tensor_meta").shape |
| ) |
| fused_module = fuse_func( |
| computation_node, unary_node, computation_node_input_size |
| ) |
| replace_node_module(node.args[0], modules, fused_module) |
| |
| node.replace_all_uses_with(node.args[0]) |
| gm.graph.erase_node(node) |
| gm.graph.lint() |
| gm.recompile() |
| return gm |
| |
| |
| def replace_and_fuse_for_binary( |
| computation_node, node, fuse_func, attr, modules, index_node, index_pointwise |
| ): |
| computation_node_input_size = ( |
| node.args[index_node].args[0].meta.get("tensor_meta").shape |
| ) |
| fused_module = fuse_func(computation_node, attr, computation_node_input_size) |
| replace_node_module(node.args[index_node], modules, fused_module) |
| node.args[index_node].args = node.args[index_node].args + ( |
| node.args[index_pointwise], |
| ) |
| node.replace_all_uses_with(node.args[index_node]) |
| |
| |
| def binary_inputs_meta_is_same(binary_node): |
| tensor0_meta = binary_node.args[0].meta.get("tensor_meta") |
| tensor1_meta = binary_node.args[1].meta.get("tensor_meta") |
| if not tensor0_meta or not tensor1_meta: |
| return False |
| if ( |
| tensor0_meta.shape != tensor1_meta.shape |
| or tensor0_meta.stride != tensor1_meta.stride |
| or tensor0_meta.dtype != tensor1_meta.dtype |
| ): |
| return False |
| |
| return True |
| |
| |
| def fuse_binary(gm: torch.fx.GraphModule): |
| modules = dict(gm.named_modules()) |
| for node in gm.graph.nodes: |
| if check_node_is_binary(node) and check_binary_op_kwargs_is_default(node): |
| for node_kind, fuse_func in computation_op_binary_op_fusion_map.items(): |
| if not isinstance(node.args[0], torch.fx.Node) or not isinstance( |
| node.args[1], torch.fx.Node |
| ): |
| continue |
| if not binary_inputs_meta_is_same(node): |
| continue |
| attr = binary_attr[node.target] |
| index_list = supported_index_list[attr] |
| for index_dict in index_list: |
| index_node = index_dict["index_computation"] |
| index_pointwise = index_dict["index_pointwise"] |
| if check_node_kind(node.args[index_node], modules, node_kind): |
| if len(node.args[index_node].users) > 1: |
| continue |
| computation_node = modules[node.args[index_node].target] |
| if computation_node.training: |
| continue |
| |
| # TODO: support padding str input("valid", "same"). |
| if type(computation_node) in [nn.Conv2d] and isinstance( |
| computation_node.padding, str |
| ): |
| continue |
| # only fuse for linear when the dtype is bf16 |
| if type(computation_node) in [ |
| nn.Linear |
| ] and not is_bfloat16_module(computation_node): |
| continue |
| replace_and_fuse_for_binary( |
| computation_node, |
| node, |
| fuse_func, |
| attr if attr != "iadd" else "add", |
| modules, |
| index_node, |
| index_pointwise, |
| ) |
| # Make sure the fused node is post node of node's inputs nodes. |
| node.append(node.args[index_node]) |
| gm.graph.erase_node(node) |
| break |
| gm.graph.lint() |
| gm.recompile() |
| return gm |
| |
| |
| def convert_outplace_to_inplace(gm: torch.fx.GraphModule): |
| if not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()): |
| return gm |
| # This function is about replace outplace with inplace for better performance(external call), |
| # which happen after AOTAutograd. |
| for node in gm.graph.nodes: |
| if node.op == "call_function" and node.target in [ |
| torch.ops.mkldnn._convolution_pointwise.binary |
| ]: |
| # args[0] and args[1] is _convolution_pointwise.binary's input, |
| # need to check whether args[1] can be written or not. |
| if node.args[1].op in ["placeholder", "output"]: |
| continue |
| # TODO: node.args[1].users > 1, but node.args[1] never be used after current node. |
| if len(node.args[1].users) > 1: |
| continue |
| if node.args[1] == node.args[0]: |
| continue |
| binary_attr = node.args[8] |
| unary_attr = node.args[10] |
| if binary_attr != "add" or unary_attr not in ["", "relu"]: |
| continue |
| node.target = torch.ops.mkldnn._convolution_pointwise_.binary |
| gm.graph.lint() |
| gm.recompile() |
| return gm |
| |
| |
| def pack_module(gm: torch.fx.GraphModule): |
| modules = dict(gm.named_modules()) |
| for node in gm.graph.nodes: |
| if node.op == "call_module": |
| assert isinstance(node.target, str) |
| cur_module = modules[node.target] |
| if type(cur_module) in computation_op_packed_map: |
| if cur_module.training: |
| continue |
| computation_node_input_meta = node.args[0].meta.get("tensor_meta") |
| if computation_node_input_meta.dtype != torch.float32: |
| continue |
| if type(cur_module) in [torch.nn.Linear] and not torch._C.has_mkl: |
| continue |
| computation_node_input_size = computation_node_input_meta.shape |
| if ( |
| type(cur_module) in [torch.nn.Linear] |
| and len(computation_node_input_size) < 2 |
| ): |
| continue |
| if type(cur_module) in [nn.Conv2d] and isinstance( |
| cur_module.padding, str |
| ): |
| continue |
| # TODO: remove this when group depthwise ConvTranspose is supported |
| if is_group_depthwise_conv_transpose(cur_module): |
| continue |
| new_module = computation_op_packed_map[type(cur_module)]( |
| cur_module, computation_node_input_size |
| ) |
| assert isinstance(new_module, nn.Module) |
| replace_node_module(node, modules, new_module) |
| gm.graph.lint() |
| gm.recompile() |
| return gm |
| |
| |
| computation_op_unary_op_fusion_map = { |
| nn.Conv2d: fused_conv_unary_eval, |
| nn.Linear: fused_linear_unary_eval, |
| ConvBinary2d: fused_conv_binary_unary_eval, |
| nn.ConvTranspose2d: fused_conv_transpose_unary_eval, |
| } |
| |
| |
| unary_modules_map = { |
| nn.ReLU: UnaryAttr("relu"), |
| nn.Sigmoid: UnaryAttr("sigmoid"), |
| nn.Tanh: UnaryAttr("tanh"), |
| nn.Hardswish: UnaryAttr("hardswish"), |
| nn.LeakyReLU: UnaryAttr("leaky_relu", scalars_attr=["negative_slope"]), |
| nn.Hardtanh: UnaryAttr("hardtanh", scalars_attr=["min_val", "max_val"]), |
| nn.GELU: UnaryAttr("gelu", algorithm_attr="approximate"), |
| nn.ReLU6: UnaryAttr("hardtanh", scalars_attr=["min_val", "max_val"]), |
| nn.SiLU: UnaryAttr("swish"), |
| nn.Hardsigmoid: UnaryAttr("hardsigmoid"), |
| } |
| |
| unary_ops = [ |
| # modules |
| nn.ReLU, |
| nn.Sigmoid, |
| nn.Tanh, |
| nn.Hardswish, |
| nn.LeakyReLU, |
| nn.Hardtanh, |
| nn.GELU, |
| nn.ReLU6, |
| nn.SiLU, |
| nn.Hardsigmoid, |
| # functional |
| F.relu, |
| F.sigmoid, |
| F.tanh, |
| F.hardswish, |
| F.leaky_relu, |
| F.hardtanh, |
| F.gelu, |
| F.relu6, |
| F.silu, |
| F.hardsigmoid, |
| torch.relu, |
| torch.sigmoid, |
| torch.tanh, |
| # methods (torch.Tensor.xxx) |
| "relu", |
| "sigmoid", |
| "tanh", |
| ] |
| |
| |
| binary_attr = { |
| torch.add: "add", # node.op == "call_function" |
| "add": "add", # node.op == "call_method" |
| "add_": "iadd", # node.op == "call_method" |
| operator.add: "add", # node.op == "call_function" |
| operator.iadd: "iadd", # node.op == "call_function" |
| torch.sub: "sub", # node.op == "call_function" |
| "sub": "sub", # node.op == "call_method" |
| "sub_": "sub", # node.op == "call_method" |
| operator.sub: "sub", # node.op == "call_function" |
| operator.isub: "sub", # node.op == "call_function" |
| } |
| |
| |
| computation_op_binary_op_fusion_map = { |
| nn.Conv2d: fused_conv_binary_eval, |
| nn.Linear: fused_linear_binary_eval, |
| } |
| |
| |
| computation_op_packed_map = { |
| nn.Linear: packed_linear_eval, |
| nn.Conv2d: packed_conv_eval, |
| nn.ConvTranspose2d: packed_conv_transpose_eval, |
| } |
| |
| |
| # For add: we support conv/linear + other and other + conv |
| # For sub/add_/sub_, we only support conv/linear - other |
| # or conv/linear +(-)= other |
| supported_index_list = { |
| "add": [ |
| {"index_computation": 0, "index_pointwise": 1}, |
| {"index_computation": 1, "index_pointwise": 0}, |
| ], |
| "iadd": [{"index_computation": 0, "index_pointwise": 1}], |
| "sub": [{"index_computation": 0, "index_pointwise": 1}], |
| } |