blob: 18ed9aa3d9fc6bbd144a470b5821050e2ab19b7e [file] [log] [blame]
import copy
from functools import reduce
from typing import Optional
import torch
import torch._dynamo.config as dynamo_config
import torch.nn as nn
import torch.nn.functional as F
from torch._dynamo.utils import detect_fake_mode
from torch.fx.experimental.optimization import replace_node_module
from torch.fx.passes.shape_prop import ShapeProp
from torch.nn.modules.utils import _pair
from . import config
def is_group_depthwise_conv_transpose(m):
return (
type(m) in [nn.ConvTranspose2d] and m.groups > 1 and m.groups == m.in_channels
)
class PackedConv2d(nn.Conv2d):
def __init__(
self,
conv: nn.Module,
input_size: Optional[list],
):
super().__init__(
conv.in_channels,
conv.out_channels,
conv.kernel_size,
conv.stride,
conv.padding,
conv.dilation,
conv.groups,
conv.bias is not None,
conv.padding_mode,
conv.weight.device,
conv.weight.dtype,
)
self._update_module_params(conv, input_size)
def _update_module_params(self, conv, input_size):
self.__dict__ = copy.deepcopy(conv.__dict__)
self.weight = torch.nn.Parameter(
torch._C._nn.mkldnn_reorder_conv2d_weight(
self.weight.to_mkldnn(),
self.padding,
self.stride,
self.dilation,
self.groups,
input_size,
)
if input_size is not None
else self.weight.to_mkldnn(),
requires_grad=self.weight.requires_grad,
)
def _conv_forward(self, input, weight, bias):
if self.padding_mode != "zeros":
return torch.ops.mkldnn._convolution_pointwise(
F.pad(
input, self._reversed_padding_repeated_twice, mode=self.padding_mode
),
weight,
bias,
_pair(0),
self.stride,
self.dilation,
self.groups,
"none",
[],
"",
)
return torch.ops.mkldnn._convolution_pointwise(
input,
weight,
bias,
self.padding,
self.stride,
self.dilation,
self.groups,
"none",
[],
"",
)
def forward(self, input):
return self._conv_forward(input, self.weight, self.bias)
class PackedLinearFP32(nn.Linear):
def __init__(self, linear: nn.Module, input_size: Optional[list]):
super().__init__(
linear.in_features,
linear.out_features,
linear.bias is not None,
linear.weight.device,
linear.weight.dtype,
)
self._update_module_params(linear, input_size)
def _update_module_params(self, linear, input_size):
self.__dict__ = copy.deepcopy(linear.__dict__)
self.batch_size = reduce(lambda x, y: x * y, input_size[:-1])
self.packed_weight = torch.nn.Parameter(
torch.ops.mkl._mkl_reorder_linear_weight(
self.weight.to_mkldnn(), self.batch_size
),
requires_grad=self.weight.requires_grad,
)
def forward(self, input):
y = torch.ops.mkl._mkl_linear(
input, self.packed_weight, self.weight, self.bias, self.batch_size
)
return y
class PackedLinearBF16(nn.Linear):
def __init__(self, linear: nn.Module, input_size: Optional[list]):
super().__init__(
linear.in_features,
linear.out_features,
linear.bias is not None,
linear.weight.device,
linear.weight.dtype,
)
self._update_module_params(linear, input_size)
def _update_module_params(self, linear, input_size):
self.__dict__ = copy.deepcopy(linear.__dict__)
self.batch_size = (
reduce(lambda x, y: x * y, input_size[:-1])
if input_size is not None
else None
)
self.packed_weight = torch.nn.Parameter(
torch.ops.mkldnn._reorder_linear_weight(
self.weight.to_mkldnn(),
self.batch_size,
),
requires_grad=self.weight.requires_grad,
)
def forward(self, input):
y = torch.ops.mkldnn._linear_pointwise(
input,
self.packed_weight,
self.bias,
"none",
[],
"",
)
return y
class PackedConvTranspose2d(nn.ConvTranspose2d):
def __init__(
self,
conv_transpose: nn.Module,
input_size: Optional[list],
):
super().__init__(
conv_transpose.in_channels,
conv_transpose.out_channels,
conv_transpose.kernel_size,
conv_transpose.stride,
conv_transpose.padding,
conv_transpose.output_padding,
conv_transpose.groups,
conv_transpose.bias is not None,
conv_transpose.dilation,
conv_transpose.padding_mode,
conv_transpose.weight.device,
conv_transpose.weight.dtype,
)
self._update_module_params(conv_transpose, input_size)
def _update_module_params(self, conv_transpose, input_size):
self.__dict__ = copy.deepcopy(conv_transpose.__dict__)
packed_weight = (
torch.ops.mkldnn._reorder_convolution_transpose_weight(
self.weight.to_mkldnn(),
self.padding,
self.output_padding,
self.stride,
self.dilation,
self.groups,
input_size,
)
if input_size is not None
else self.weight.transpose(0, 1).to_mkldnn()
)
self.weight = torch.nn.Parameter(
packed_weight,
requires_grad=self.weight.requires_grad,
)
def _conv_transpose_forward(self, input, weight, bias):
if self.padding_mode != "zeros":
raise ValueError(
"Only `zeros` padding mode is supported for PackedConvTranspose2d"
)
return torch.ops.mkldnn._convolution_transpose_pointwise(
input,
weight,
bias,
self.padding,
self.output_padding,
self.stride,
self.dilation,
self.groups,
"none",
[],
"",
)
def forward(self, input):
return self._conv_transpose_forward(input, self.weight, self.bias)
def packed_conv_eval(conv: nn.Module, input_size: Optional[list]):
assert not (conv.training), "Fusion only for eval!"
return PackedConv2d(
conv,
input_size,
)
def packed_conv_transpose_eval(conv_transpose: nn.Module, input_size: Optional[list]):
assert not (conv_transpose.training), "Fusion only for eval!"
return PackedConvTranspose2d(conv_transpose, input_size)
def packed_linear_eval(linear: nn.Module, input_size: Optional[list]):
assert not (linear.training), "Fusion only for eval!"
if linear.weight.dtype == torch.bfloat16:
return PackedLinearBF16(linear, input_size)
return PackedLinearFP32(linear, input_size)
def mkldnn_fuse_fx(gm: torch.fx.GraphModule, example_inputs):
is_cpu = all(
example_input.device == torch.device("cpu")
for example_input in example_inputs
if isinstance(example_input, torch.Tensor)
)
# make sure the autograd and autocast are disabled.
if torch.is_grad_enabled() or torch.is_autocast_cpu_enabled():
return gm
if not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()):
return gm
if not is_cpu:
return gm
fake_mode = detect_fake_mode(example_inputs)
if config.cpp.weight_prepack:
if not dynamo_config.dynamic_shapes:
ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs)
gm = pack_module(gm)
return gm
def pack_module(gm: torch.fx.GraphModule):
modules = dict(gm.named_modules())
for node in gm.graph.nodes:
if node.op == "call_module":
assert isinstance(node.target, str)
cur_module = modules[node.target]
if type(cur_module) in computation_op_packed_map:
if cur_module.weight.device != torch.device(
"cpu"
) or cur_module.weight.dtype not in [torch.bfloat16, torch.float32]:
continue
if cur_module.training:
continue
if (
cur_module.weight.dtype == torch.bfloat16
and not torch.ops.mkldnn._is_mkldnn_bf16_supported()
):
continue
if dynamo_config.dynamic_shapes:
computation_node_input_size = None
# Conv2d and ConvTranspose2d weight format are dependent on input size,
# but ShapeProp may be failed to get the input size, so we skip them.
if not (
type(cur_module) in [torch.nn.Linear]
and cur_module.weight.dtype == torch.bfloat16
):
continue
else:
computation_node_input_size = (
node.args[0].meta.get("tensor_meta").shape
)
if type(cur_module) in [torch.nn.Linear]:
# for fp32 linear, only packed when has mkl.
if (
cur_module.weight.dtype == torch.float32
and (not torch._C.has_mkl)
) or len(computation_node_input_size) < 2:
continue
else:
if len(computation_node_input_size) != 4:
continue
if type(cur_module) in [nn.Conv2d] and isinstance(
cur_module.padding, str
):
continue
# TODO: remove this when group depthwise ConvTranspose is supported
if type(cur_module) in [nn.ConvTranspose2d] and (
is_group_depthwise_conv_transpose(cur_module)
or dynamo_config.dynamic_shapes
or len(node.args) > 1
or len(node.kwargs) > 0
):
continue
new_module = computation_op_packed_map[type(cur_module)](
cur_module, computation_node_input_size
)
assert isinstance(new_module, nn.Module)
replace_node_module(node, modules, new_module)
gm.graph.lint()
gm.recompile()
return gm
computation_op_packed_map = {
nn.Linear: packed_linear_eval,
nn.Conv2d: packed_conv_eval,
nn.ConvTranspose2d: packed_conv_transpose_eval,
}