| """ |
| Utils shared by different modes of quantization (eager/graph) |
| """ |
| import warnings |
| import functools |
| import torch |
| from torch.ao.quantization.quant_type import QuantType |
| from typing import Tuple, Any, Union, Callable, Dict, Optional |
| from torch.nn.utils.parametrize import is_parametrized |
| from collections import OrderedDict |
| from inspect import signature |
| from inspect import getfullargspec |
| |
| # Type for fusion patterns, it can be more complicated than the following actually, |
| # see pattern.md for docs |
| # TODO: not sure if typing supports recursive data types |
| Pattern = Union[Callable, Tuple[Callable, Callable], Tuple[Callable, Tuple[Callable, Callable]], Any] |
| Pattern.__module__ = "torch.ao.quantization.utils" |
| |
| # TODO: maybe rename this to MatchInputNode |
| class MatchAllNode: |
| """ A node pattern that matches all nodes, used in defining |
| fusion patterns in FX Graph Mode Quantization |
| """ |
| pass |
| |
| module_type_list = { |
| torch.nn.ReLU, |
| torch.nn.ReLU6, |
| torch.nn.AdaptiveAvgPool1d, |
| torch.nn.AdaptiveAvgPool2d, |
| torch.nn.AdaptiveAvgPool3d, |
| torch.nn.AvgPool1d, |
| torch.nn.AvgPool2d, |
| torch.nn.AvgPool3d, |
| torch.nn.MaxPool1d, |
| torch.nn.MaxPool2d, |
| torch.nn.MaxPool3d, |
| torch.nn.Identity, |
| torch.nn.Hardsigmoid, |
| torch.nn.Sigmoid, |
| torch.nn.Tanh, |
| } |
| func_list = { |
| torch.nn.functional.adaptive_avg_pool1d, |
| torch.nn.functional.adaptive_avg_pool2d, |
| torch.nn.functional.adaptive_avg_pool3d, |
| torch.nn.functional.elu, |
| torch.nn.functional.hardswish, |
| torch.nn.functional.instance_norm, |
| torch.nn.functional.layer_norm, |
| torch.nn.functional.leaky_relu, |
| torch.nn.functional.silu, |
| torch.nn.functional.mish, |
| torch.nn.functional.dropout, |
| torch.nn.functional.max_pool1d, |
| torch.nn.functional.max_pool2d, |
| torch.nn.functional.max_pool3d, |
| torch.nn.functional.relu, |
| torch.nn.functional.hardtanh, |
| torch.nn.functional.hardtanh_, |
| torch.nn.functional.hardsigmoid, |
| torch.nn.functional.sigmoid, |
| torch.transpose, |
| torch.repeat_interleave, |
| torch.sigmoid, |
| torch.squeeze, |
| torch.stack, |
| torch.sum, |
| torch.tanh, |
| torch.unsqueeze, |
| torch.cat, |
| } |
| method_list = { |
| torch.mean, |
| 'relu', |
| 'relu_', |
| 'contiguous', |
| 'detach', |
| 'detach_', |
| 'hardsigmoid', |
| 'hardsigmoid_', |
| 'permute', |
| 'repeat', |
| 'repeat_interleave', |
| 'reshape', |
| 'resize_', |
| 'shape', |
| 'sigmoid', |
| 'sigmoid_', |
| 'size', |
| 'squeeze', |
| 'squeeze_', |
| 'tanh', |
| 'tanh_', |
| 'transpose', |
| 'unsqueeze', |
| 'unsqueeze_', |
| 'view', |
| } |
| |
| # TODO: not used now, remove |
| def check_node(node, modules): |
| # TODO: reuse is_fixed_qparam_node after we move this function to _lower_to_native_backend.py |
| is_call_function = node.op == "call_function" and node.target in func_list |
| is_call_method = node.op == "call_method" and node.target in method_list |
| is_call_module = node.op == "call_module" and type(modules[str(node.target)]) in module_type_list |
| return is_call_function, is_call_method, is_call_module |
| |
| def get_combined_dict(default_dict, additional_dict): |
| d = default_dict.copy() |
| d.update(additional_dict) |
| return d |
| |
| def is_per_tensor(qscheme): |
| return qscheme == torch.per_tensor_affine or \ |
| qscheme == torch.per_tensor_symmetric |
| |
| def is_per_channel(qscheme): |
| return qscheme in [torch.per_channel_affine, |
| torch.per_channel_affine_float_qparams, |
| torch.per_channel_symmetric] |
| |
| def getattr_from_fqn(obj: Any, fqn: str) -> Any: |
| """ |
| Given an obj and a fqn such as "foo.bar.baz", returns gm.foo.bar.baz. |
| """ |
| return functools.reduce(getattr, fqn.split("."), obj) |
| |
| def get_qparam_dict(observer_or_fake_quant): |
| qscheme = observer_or_fake_quant.qscheme if hasattr(observer_or_fake_quant, "qscheme") else None |
| dtype = observer_or_fake_quant.dtype |
| qparams = {"qscheme": qscheme, "dtype": dtype} |
| |
| if not qscheme: |
| return qparams |
| |
| if is_per_tensor(qscheme): |
| qscheme = torch.per_tensor_affine |
| elif is_per_channel(qscheme): |
| # change symmetric to affine since we do not have symmetric |
| # quantized Tensor |
| if qscheme == torch.per_channel_symmetric: |
| qscheme = torch.per_channel_affine |
| qparams["axis"] = observer_or_fake_quant.ch_axis |
| else: |
| raise RuntimeError(f"Unrecognized qscheme: {qscheme}") |
| # update qscheme, since we don't have symmetric quant qscheme |
| # in quantized Tensor |
| qparams["qscheme"] = qscheme |
| |
| scale, zero_point = observer_or_fake_quant.calculate_qparams() |
| qparams["scale"] = scale |
| qparams["zero_point"] = zero_point |
| |
| return qparams |
| |
| |
| def get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig): |
| """ Get the observed/quantized custom module class that we need |
| to swap `custom_module` to |
| Input: |
| custom_module: input, can be an instance of either a float or observed custom module |
| custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping |
| qconfig: qconfig configured for the custom module |
| |
| Output: |
| corresponding observed/quantized custom module class for input custom module instance |
| """ |
| quant_type = get_quant_type(qconfig) |
| class_mapping = custom_module_class_mapping.get(quant_type, {}) |
| assert type(custom_module) in class_mapping, "did not find corresponding observed " \ |
| "module class for {} in mapping: {}".format(type(custom_module), class_mapping) |
| return class_mapping[type(custom_module)] |
| |
| def activation_dtype(qconfig): |
| assert qconfig is not None |
| activation = qconfig.activation() |
| return activation.dtype |
| |
| def weight_dtype(qconfig): |
| assert qconfig is not None |
| weight = qconfig.weight() |
| return weight.dtype |
| |
| def activation_is_statically_quantized(qconfig): |
| """ Given a qconfig, decide if the activation needs to be |
| quantized or not, this includes quantizing to quint8, qint8 and float16 |
| """ |
| return activation_dtype(qconfig) in [torch.quint8, torch.qint8, torch.float16] |
| |
| def activation_is_dynamically_quantized(qconfig): |
| """ Given a qconfig, decide if the activation needs to be |
| dynamically quantized or not, this includes dynamically quantizing to |
| quint8, qint8 and float16 |
| """ |
| activation_dtype, _, activation_compute_dtype = \ |
| get_qconfig_dtypes(qconfig) |
| return activation_dtype == torch.float and \ |
| activation_compute_dtype in [torch.quint8, torch.qint8, torch.float16] |
| |
| def activation_is_int8_quantized(qconfig): |
| """ Given a qconfig, decide if the activation needs to be |
| quantized to int8 or not, this includes quantizing to quint8, qint8 |
| """ |
| return activation_dtype(qconfig) in [torch.quint8, torch.qint8] |
| |
| def activation_is_int32_quantized(qconfig): |
| """ Given a qconfig, decide if the activation needs to be |
| quantized to int32 or not |
| """ |
| return activation_dtype(qconfig) == torch.qint32 |
| |
| def weight_is_quantized(qconfig): |
| """ Given a qconfig, decide if the weight needs to be |
| quantized or not |
| """ |
| return weight_dtype(qconfig) in [torch.quint8, torch.qint8, torch.float16, torch.quint4x2] |
| |
| def weight_is_statically_quantized(qconfig): |
| """ Given a qconfig, decide if the weight needs to be statically |
| quantized or not |
| """ |
| return weight_dtype(qconfig) in [torch.quint8, torch.qint8] |
| |
| def op_is_int8_dynamically_quantized(qconfig) -> bool: |
| """ Given a qconfig, returns True if this op is using int8 dynamic |
| quantization |
| """ |
| activation_dtype, weight_dtype, activation_compute_dtype = \ |
| get_qconfig_dtypes(qconfig) |
| return ( |
| activation_dtype is torch.float and |
| # for now, the lines below assume fbgemm or qnnpack |
| weight_dtype is torch.qint8 and |
| activation_compute_dtype is torch.quint8 |
| ) |
| |
| def get_qconfig_dtypes(qconfig): |
| r""" returns the qconfig tuple for qconfig: |
| (activation_dtype, weight_dtype, activation_compute_dtype) |
| """ |
| assert qconfig is not None |
| activation = qconfig.activation() |
| weight = qconfig.weight() |
| compute_dtype = activation.compute_dtype if hasattr(activation, 'compute_dtype') else None |
| return (activation.dtype, weight.dtype, compute_dtype) |
| |
| def get_quant_type(qconfig): |
| assert qconfig is not None |
| activation = qconfig.activation() |
| weight = qconfig.weight() |
| static_dtypes = [torch.quint8, torch.qint8, torch.quint4x2] |
| if weight.dtype in static_dtypes: |
| if activation.dtype in static_dtypes: |
| return QuantType.STATIC |
| elif hasattr(activation, 'compute_dtype') and activation.compute_dtype in static_dtypes: |
| return QuantType.DYNAMIC |
| else: |
| return QuantType.WEIGHT_ONLY |
| |
| if weight.dtype == torch.float16: |
| if activation.dtype == torch.float: |
| return QuantType.DYNAMIC |
| elif activation.dtype == torch.float16: |
| return QuantType.STATIC |
| |
| raise Exception("Unrecognized dtype combination in get_quant_type: activation({})," |
| "weight({})".format(activation.dtype, weight.dtype)) |
| |
| def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool: |
| """ Checks if the given minimum and maximum values are valid, meaning that |
| they exist and the min value is less than the max value. |
| """ |
| if min_val.numel() == 0 or max_val.numel() == 0: |
| warnings.warn( |
| "must run observer before calling calculate_qparams. " + |
| "Returning default values." |
| ) |
| return False |
| |
| if min_val.dim() == 0 or max_val.dim() == 0: |
| if min_val == float("inf") and max_val == float("-inf"): |
| warnings.warn( |
| "must run observer before calling calculate_qparams. " + |
| "Returning default values." |
| ) |
| |
| return False |
| |
| assert min_val <= max_val, "min {} should be less than max {}".format( |
| min_val, max_val |
| ) |
| else: |
| assert torch.all( |
| min_val <= max_val |
| ), "min {} should be less than max {}".format(min_val, max_val) |
| |
| return True |
| |
| |
| def calculate_qmin_qmax(quant_min: int, quant_max: int, has_customized_qrange: bool, dtype: torch.dtype, |
| reduce_range: bool) -> Tuple[int, int]: |
| r"""Calculates actual qmin and qmax based on the quantization range, |
| observer datatype and if range is reduced. |
| """ |
| # TODO(jerryzh): Figure out why custom quant_min/quant_max are still adjusted. |
| if has_customized_qrange: |
| # This initialization here is to be resolve TorchScript compilation issues and allow |
| # using of refinement to decouple initial_qmin and initial_qmax from quantization range. |
| # The actual values of initial_qmin and initial_qmax will be reset below. |
| if dtype == torch.qint32: |
| initial_quant_min, initial_quant_max = 0, 2**31 - 1 |
| else: |
| initial_quant_min, initial_quant_max = 0, 255 |
| # The following assignment of self.qmin and self.qmax to the local variables and the if check refine the |
| # attribute from Optional valid integers for use, based on TorchScript's requirements. |
| custom_quant_min, custom_quant_max = quant_min, quant_max |
| if custom_quant_min is not None and custom_quant_max is not None: |
| initial_quant_min, initial_quant_max = ( |
| custom_quant_min, |
| custom_quant_max, |
| ) |
| |
| qrange_len = initial_quant_max - initial_quant_min + 1 |
| if dtype == torch.qint8: |
| assert ( |
| 0 < qrange_len <= 256 |
| ), "quantization range should be positive and not exceed the maximum bit range (=256)." |
| elif dtype == torch.qint32: |
| assert ( |
| 0 < qrange_len <= 2**31 |
| ), "quantization range should be positive and not exceed the maximum bit range (=4294967296)." |
| if reduce_range: |
| quant_min, quant_max = quant_min // 2, quant_max // 2 |
| else: |
| # Fallback onto default 8-bit qmin and qmax calculation if dynamic range is not used. |
| if dtype == torch.qint8: |
| if reduce_range: |
| quant_min, quant_max = -64, 63 |
| else: |
| quant_min, quant_max = -128, 127 |
| elif dtype == torch.quint8: |
| if reduce_range: |
| quant_min, quant_max = 0, 127 |
| else: |
| quant_min, quant_max = 0, 255 |
| elif dtype == torch.qint32: |
| quant_min, quant_max = -1 * (2 ** 31), (2 ** 31) - 1 |
| else: |
| quant_min, quant_max = 0, 15 |
| return quant_min, quant_max |
| |
| |
| def _parent_name(target): |
| """ |
| Turn 'foo.bar' into ['foo', 'bar'] |
| """ |
| r = target.rsplit('.', 1) |
| if len(r) == 1: |
| return '', r[0] |
| else: |
| return r[0], r[1] |
| |
| def has_no_children_ignoring_parametrizations(module): |
| """ |
| Checks if module._modules is empty or |
| if module is a parametrization, checks that module._modules only has |
| the 'parametrizations' module |
| """ |
| if len(module._modules) == 0: |
| return True |
| elif is_parametrized(module): |
| return len(module._modules) == 1 and 'parametrizations' in module._modules |
| else: |
| return False |
| |
| def _get_path_of_module(root: torch.nn.Module, submodule: torch.nn.Module) -> Optional[str]: |
| """ Get the path (fully qualified name) of a submodule |
| |
| Example:: |
| |
| >> class M(torch.nn.Module): |
| def __init__(self): |
| self.linear = torch.nn.Linear(5, 5) |
| def forward(self, x): |
| return self.linear(x) |
| |
| >> m = M() |
| >> l = m.linear |
| >> _get_path_of_module(m, l) |
| "linear" |
| """ |
| for n, p in root.named_modules(): |
| if submodule is p: |
| return n |
| return None |
| |
| def _get_signature_locals(f: Callable, loc: Dict[str, Any]) -> Dict[str, Any]: |
| """ Get local keyword arguments |
| |
| Example:: |
| |
| >> def f(self, a, b=9): |
| pass |
| >> loc = {"a": 6, "c": 7} |
| >> _get_signature_locals(f, loc) |
| {"a": 6} |
| """ |
| return {k: v for k, v in loc.items() if k in signature(f).parameters} |
| |
| def _get_default_kwargs(f: Callable) -> "OrderedDict[str, Any]": |
| """ Get all default keyword arguments from function signature |
| |
| Example:: |
| |
| >> def f(self, a, b=9): |
| pass |
| >> _get_default_kwargs(f) |
| {"b": 9} |
| """ |
| kwargs = {} |
| for name, param in signature(f).parameters.items(): |
| if param.default is not param.empty: |
| kwargs[name] = param.default |
| elif param.kind is param.VAR_POSITIONAL: |
| kwargs[name] = () |
| elif param.kind is param.VAR_KEYWORD: |
| kwargs[name] = {} |
| return OrderedDict(kwargs) |
| |
| def _normalize_kwargs(func: Callable, loc: Dict[str, Any]) -> "OrderedDict[str, Any]": |
| """ Given a function and local function arguments, normalize the keyword |
| arguments by filling in default arguments from function signature |
| |
| Example:: |
| |
| >> def f(self, key1=3, key2=3): |
| pass |
| >> loc = {"key2": 6} |
| >> _normalize_kwargs(f, loc) |
| {"key1": 3, "key2": 6} |
| """ |
| default_kwargs = _get_default_kwargs(func) |
| local_kwargs = _get_signature_locals(func, loc) |
| normalized_kwargs = default_kwargs.copy() |
| for attr, val in local_kwargs.items(): |
| if attr in normalized_kwargs: |
| # override the default keyword arguments |
| normalized_kwargs[attr] = val |
| return normalized_kwargs |
| |
| def _get_num_pos_args(f: Callable) -> int: |
| """ Get number of positional args for a function |
| |
| Example:: |
| |
| >> def f(self, key1=3, key2=3): |
| pass |
| >> _get_num_pos_args(f) |
| 3 |
| """ |
| return len(getfullargspec(f).args) |
| |
| def get_fqn_to_example_inputs( |
| model: torch.nn.Module, |
| example_inputs: Tuple[Any, ...] |
| ) -> Dict[str, Tuple[Any, ...]]: |
| """ Given a model and its example inputs, return a dictionary from |
| fully qualified name of submodules to example_inputs for that submodule, |
| e.g. {"linear1": (tensor1,), "linear2": (tensor2,), "sub": (tensor3,), |
| "sub.linear1": (tensor4,), ...} |
| |
| Used to make quantizing submodules easier now that FX Graph Mode Quantization requries |
| example inputs. |
| |
| Also works for keyword arguments with default values, we would flatten keyword |
| arguments as positional arguments and fill in the missing keyword args with default |
| values, e.g. if we have a forward function: |
| def forward(self, x, key1=3, key2=3): |
| ... |
| |
| and we call it with self.submodule(x, key2=6) |
| we'll get example_inputs: (x, 3, 6) |
| |
| user can also override `key1` with positional arguments as well: |
| for self.submodule(x, 5, key2=6) |
| we'll get: (x, 5, 6) |
| |
| variable positional arguments and variable positional keyword arguments in forward |
| function are not supported currently, so please make sure no submodules is using |
| them. |
| """ |
| root = model |
| fqn_to_example_inputs = {} |
| |
| def _patched_module_call(self, *args, **kwargs): |
| submodule_example_inputs = list(args).copy() |
| normalized_kwargs = _normalize_kwargs(self.forward, kwargs) |
| # minus 1 to skipping counting `self` |
| num_args = _get_num_pos_args(self.forward) - 1 |
| num_to_pop = num_args - len(submodule_example_inputs) |
| while num_to_pop and normalized_kwargs: |
| normalized_kwargs.popitem(last=False) |
| num_to_pop -= 1 |
| submodule_example_inputs.extend(normalized_kwargs.values()) |
| submodule_example_inputs_tuple = tuple(submodule_example_inputs) |
| fqn = _get_path_of_module(root, self) |
| if fqn is not None: |
| fqn_to_example_inputs[fqn] = submodule_example_inputs_tuple |
| return orig_module_call(self, *args, **kwargs) |
| |
| orig_module_call = torch.nn.Module.__call__ |
| torch.nn.Module.__call__ = _patched_module_call |
| try: |
| model(*example_inputs) |
| finally: |
| # restore the module call even if there is an exception |
| torch.nn.Module.__call__ = orig_module_call |
| return fqn_to_example_inputs |
| |
| |
| __all__ = [ |
| "Pattern", |
| "MatchAllNode", |
| "check_node", |
| "get_combined_dict", |
| "is_per_tensor", |
| "is_per_channel", |
| "getattr_from_fqn", |
| "get_qparam_dict", |
| "get_swapped_custom_module_class", |
| "activation_dtype", |
| "weight_dtype", |
| "activation_is_statically_quantized", |
| "activation_is_dynamically_quantized", |
| "activation_is_int8_quantized", |
| "activation_is_int32_quantized", |
| "weight_is_quantized", |
| "weight_is_statically_quantized", |
| "op_is_int8_dynamically_quantized", |
| "get_qconfig_dtypes", |
| "get_quant_type", |
| "check_min_max_valid", |
| "calculate_qmin_qmax", |
| "has_no_children_ignoring_parametrizations", |
| "get_fqn_to_example_inputs", |
| ] |