| r''' |
| **This file is EXPERIMENTAL and is mostly used for testing purposes! Do not |
| rely on it for anything!** |
| ''' |
| from torch.fx import Graph, GraphModule |
| from torch.fx.graph import map_arg |
| from torch.fx.proxy import Proxy |
| import sys |
| import torch |
| from torch.nn.utils import fuse_conv_bn_weights |
| import operator |
| |
| # can be a |
| # module type, a builtin function, or a string to match target |
| |
| def _minmax_scale_zeropoint(min_val, max_val, qmin=-127, qmax=128, eps=torch.finfo(torch.float32).eps): |
| min_val = min(0.0, min_val) |
| max_val = max(0.0, max_val) |
| if max_val == min_val: |
| return 1.0, 0 |
| else: |
| scale = (max_val - min_val) / float(qmax - qmin) |
| scale = max(scale, eps) |
| zero_point = qmin - round(min_val / scale) |
| zero_point = max(qmin, zero_point) |
| zero_point = min(qmax, zero_point) |
| zero_point = int(zero_point) |
| return scale, zero_point |
| |
| class MinMaxObserver: |
| def __init__(self, quantizer, node): |
| self.min, self.max = float('inf'), float('-inf') |
| self.all_tensors = True |
| |
| def observe(self, node, env): |
| v = env[node.name] |
| if not isinstance(v, torch.Tensor): |
| self.all_tensors = False |
| return |
| self.max = max(self.max, float(v.max())) |
| self.min = min(self.min, float(v.min())) |
| |
| def scale_zeropoint(self): |
| return _minmax_scale_zeropoint(self.min, self.max, qmin=0, qmax=255) |
| |
| class NoObserver: |
| def __init__(self, quantizer, node): |
| pass |
| |
| def observe(self, node, env): |
| pass |
| |
| _DEFAULT_QUANTIZATION_PATTERNS = {} |
| def register_pattern(pattern): |
| def insert(fn): |
| _DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn |
| return fn |
| return insert |
| |
| |
| @register_pattern(operator.add) |
| class Add(MinMaxObserver): |
| def quantize(self, quantizer, node, load_arg): |
| if not self.all_tensors: |
| return NotImplemented |
| scale, zeropoint = self.scale_zeropoint() |
| return quantizer.quantized_graph.create_node( |
| 'call_function', torch.ops.quantized.add, load_arg(node.args), {'scale': scale, 'zero_point': zeropoint}) |
| |
| |
| class Relu(NoObserver): |
| def quantize(self, quantizer, node, load_arg): |
| return torch.relu(load_arg(node.args[0])) # torch.relu works directly on quantized tensors? |
| |
| # these ops have quantized equivalents that do not need any extra information |
| @register_pattern(torch.nn.ReLU) |
| @register_pattern(torch.nn.AvgPool2d) |
| @register_pattern(torch.nn.MaxPool2d) |
| @register_pattern(torch.nn.AdaptiveAvgPool2d) |
| class CopyNode(NoObserver): |
| def quantize(self, quantizer, node, load_arg): |
| return quantizer.quantized_graph.node_copy(node, load_arg) |
| |
| class IdentityModule(torch.nn.Module): |
| def forward(self, x): |
| return x |
| |
| # handle conv, maybe followed by bn, maybe followed by relu |
| @register_pattern(torch.nn.modules.conv.Conv2d) |
| @register_pattern((torch.nn.ReLU, torch.nn.modules.conv.Conv2d)) |
| @register_pattern((torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.conv.Conv2d)) |
| @register_pattern((torch.nn.ReLU, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.conv.Conv2d))) |
| class ConvNormRelu(MinMaxObserver): |
| def __init__(self, quantizer, node): |
| super().__init__(quantizer, node) |
| self.relu_node, self.bn_node = None, None |
| if isinstance(quantizer.modules[node.target], torch.nn.ReLU): |
| self.relu_node = node |
| node = node.args[0] |
| if isinstance(quantizer.modules[node.target], torch.nn.BatchNorm2d): |
| self.bn_node = node |
| self.bn = quantizer.modules[self.bn_node.target] |
| node = node.args[0] |
| assert isinstance(quantizer.modules[node.target], torch.nn.modules.Conv2d) |
| self.conv_node = node |
| self.conv = quantizer.modules[self.conv_node.target] |
| |
| def quantize(self, quantizer, node, load_arg): |
| mod = self.conv |
| weight, bias = mod.weight, mod.bias |
| |
| if self.bn_node is not None: |
| weight, bias = fuse_conv_bn_weights( |
| weight, bias, self.bn.running_mean, self.bn.running_var, |
| self.bn.eps, self.bn.weight, self.bn.bias) |
| |
| min_val, max_val = float(weight.min()), float(weight.max()) |
| |
| act_scale, act_zp = self.scale_zeropoint() |
| |
| weight_scale, weight_zp = _minmax_scale_zeropoint(min_val, max_val) |
| qweight = torch.quantize_per_tensor(weight, weight_scale, weight_zp, torch.qint8) |
| |
| ctor = torch.ao.nn.intrinsic.quantized.ConvReLU2d if self.relu_node is not None else torch.ao.nn.quantized.Conv2d |
| |
| qconv = ctor(mod.in_channels, mod.out_channels, mod.kernel_size, |
| mod.stride, mod.padding, mod.dilation, mod.groups, |
| mod.bias is not None, mod.padding_mode) |
| |
| qconv.set_weight_bias(qweight, bias) |
| qconv.scale = float(act_scale) |
| qconv.zero_point = int(act_zp) |
| parent_name, name = _parent_name(self.conv_node.target) |
| setattr(quantizer.modules[parent_name], name, qconv) |
| if self.bn_node is not None: |
| parent_bn, bn_name = _parent_name(self.bn_node.target) |
| # we can't just delete this because submodules's forwards (which are not longer use) |
| # try to call it, so replace with something that does nothing. |
| setattr(quantizer.modules[parent_name], bn_name, IdentityModule()) |
| |
| return quantizer.quantized_graph.create_node('call_module', self.conv_node.target, (load_arg(self.conv_node.args[0]),), {}) |
| |
| |
| # turn foo.bar -> ['foo', 'bar'] |
| def _parent_name(target): |
| r = target.rsplit('.', 1) |
| if len(r) == 1: |
| return '', r[0] |
| else: |
| return r[0], r[1] |
| |
| |
| |
| class DefaultQuant(MinMaxObserver): |
| def quantize(self, input): |
| assert self.all_tensors |
| scale, zeropoint = self.scale_zeropoint() |
| return torch.quantize_per_tensor(Proxy(input), scale, zeropoint, torch.quint8).node |
| |
| def matches(modules, node, pattern, max_uses=sys.maxsize): |
| if isinstance(pattern, tuple): |
| self_match, *arg_matches = pattern |
| else: |
| self_match = pattern |
| arg_matches = None |
| |
| if len(node.users) > max_uses: |
| return False |
| |
| if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module): |
| if node.op != 'call_module': |
| return False |
| if not isinstance(modules[node.target], self_match): |
| return False |
| elif callable(self_match): |
| if node.op != 'call_function' or node.target is not self_match: |
| return False |
| elif node.target != self_match: |
| return False |
| |
| if not arg_matches: |
| return True |
| |
| if len(arg_matches) != len(node.args): |
| return False |
| |
| return all(matches(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches)) |
| |
| |
| class Quantizer: |
| def __init__(self, mod, patterns=_DEFAULT_QUANTIZATION_PATTERNS, quant_ctor=DefaultQuant): |
| self.root = mod |
| self.graph = mod.graph |
| self.quant_ctor = quant_ctor |
| |
| # cached information for observe |
| self.state_dict = self.root.state_dict() |
| self.modules = dict(self.root.named_modules()) |
| |
| # match the patterns that will get quantized |
| self.matches = self._find_matches(patterns) |
| # find _inputs_ to matched nodes that are not quantized, these |
| # have to be quantized, which requires measuring stats, |
| # initialize an quant_ctor object for each |
| self.quants = self._find_quants(quant_ctor) |
| |
| |
| |
| def observe(self, args): |
| # most of this function is just an interpreter for the graph |
| # it would be possible to put this in some abstraction, but |
| # it is pretty nice to just be able to see exactly what is happening here |
| # and hack on it. |
| # maybe we should just provide an example interpreter that people copy/paste |
| # then edit. |
| args_iter = iter(args) |
| env = {} |
| |
| def load_arg(a): |
| return map_arg(a, lambda node: env[node.name]) |
| |
| output_node : Optional[Node] = None |
| for node in self.graph.nodes: |
| if node.op == 'placeholder': |
| result = next(args_iter) |
| elif node.op == 'get_attr': |
| result = self.state_dict[node.target] |
| elif node.op == 'call_function': |
| result = node.target(*load_arg(node.args), **load_arg(node.kwargs)) |
| elif node.op == 'call_method': |
| self_obj, *args = load_arg(node.args) |
| kwargs = load_arg(node.kwargs) |
| result = getattr(self_obj, node.target)(*args, **kwargs) |
| elif node.op == 'call_module': |
| result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs)) |
| elif node.op == 'output': |
| return load_arg(node.args[0]) |
| |
| env[node.name] = result |
| root_node, obj = self.matches.get(node.name, (None, None)) |
| if root_node is node: |
| obj.observe(node, env) |
| if node.name in self.quants: |
| self.quants[node.name].observe(node, env) |
| |
| raise RuntimeError('Graph had no output node!') |
| |
| def quantize(self): |
| self.quantized_graph = Graph() |
| |
| env = {} |
| quant_env = {} |
| |
| def load_arg(n, quantized): |
| if not quantized: |
| if n.name not in env and n.name in quant_env: |
| env[n.name] = Proxy(quant_env[n.name]).dequantize().node |
| return env[n.name] |
| else: |
| if n.name not in quant_env and n.name in env: |
| quant_env[n.name] = self.quants[n.name].quantize(env[n.name]) |
| return quant_env[n.name] |
| |
| def copy_recursive(node): |
| def load_or_emit(n): |
| if n.name in env or e.name in quant_env: |
| return load_arg(n, quantized=False) |
| else: |
| return copy_recusive(n) |
| r = env[node.name] = self.quantized_graph.node_copy(node, lambda n: load_arg(n, quantized=False)) |
| return r |
| |
| for node in self.graph.nodes: |
| root_node, obj = self.matches.get(node.name, (None, None)) |
| if root_node is None: |
| # not quantized just copy it |
| env[node.name] = self.quantized_graph.node_copy(node, lambda n: load_arg(n, quantized=False)) |
| |
| elif root_node is node: |
| r = obj.quantize(self, node, lambda a: map_arg(a, lambda n: load_arg(n, quantized=True))) |
| if r is NotImplemented: |
| # quantizer choose to to quantize the node take the entire match, and just copy it over |
| env[node.name] = copy_recursive(node) |
| else: |
| quant_env[node.name] = r |
| |
| return GraphModule(self.root, self.quantized_graph) |
| |
| def _find_matches(self, patterns): |
| modules = dict(self.root.named_modules()) |
| match_map = {} # node name -> (root_node, match_value?) |
| |
| def apply_match(pattern, node, match): |
| if isinstance(pattern, tuple): |
| s, *args = pattern |
| apply_match(s, node, match) |
| for subpattern, arg in zip(args, node.args): |
| apply_match(subpattern, arg, match) |
| else: |
| match_map[node.name] = match |
| |
| for node in reversed(self.graph.nodes): |
| if node.name not in match_map: |
| for pattern, value in patterns.items(): |
| if matches(modules, node, pattern): |
| apply_match(pattern, node, (node, value(self, node))) |
| |
| return match_map |
| |
| def _find_quants(self, quant_ctor): |
| quants = {} |
| |
| def visit_arg(n): |
| # note: we have to measure quantization information |
| # even for nodes where we might not use it because it is already |
| # quantized. This is because each match has the option to |
| # say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate) |
| if n.name not in quants: |
| quants[n.name] = quant_ctor(self, n) |
| for node in self.graph.nodes: |
| if node.name in self.matches: |
| map_arg(node.args, visit_arg) |
| map_arg(node.kwargs, visit_arg) |
| return quants |