| import copy |
| import itertools |
| import warnings |
| |
| import torch |
| import torch.nn as nn |
| import torch.ao.nn.quantized as nnq |
| from torch.nn.intrinsic import _FusedModule |
| |
| from torch.ao.quantization.quantization_mappings import ( |
| get_default_dynamic_quant_module_mappings, |
| get_default_static_quant_module_mappings, |
| get_default_static_quant_reference_module_mappings, |
| get_default_qat_module_mappings, |
| get_default_qconfig_propagation_list, |
| no_observer_set, |
| _has_special_act_post_process, |
| _get_special_act_post_process, |
| ) |
| from .utils import get_qparam_dict, has_no_children_ignoring_parametrizations |
| from torch.ao.quantization.stubs import DeQuantStub, QuantWrapper |
| from torch.ao.quantization.qconfig import ( |
| add_module_to_qconfig_obs_ctr, |
| default_dynamic_qconfig, |
| float16_dynamic_qconfig, |
| float_qparams_weight_only_qconfig, |
| float_qparams_weight_only_qconfig_4bit, |
| activation_is_memoryless) |
| from torch.nn.utils.parametrize import type_before_parametrizations |
| |
| __all__ = [ |
| "get_default_custom_config_dict", |
| "is_activation_post_process", |
| "propagate_qconfig_", |
| "register_activation_post_process_hook", |
| "add_observer_", |
| "get_unique_devices_", |
| "add_quant_dequant", |
| "prepare", |
| "quantize", |
| "quantize_dynamic", |
| "prepare_qat", |
| "quantize_qat", |
| "convert", |
| "swap_module", |
| "get_observer_dict", |
| ] |
| |
| _DEFAULT_CUSTOM_CONFIG_DICT = { |
| 'float_to_observed_custom_module_class': { |
| nn.LSTM: nn.quantizable.LSTM, |
| nn.MultiheadAttention: nn.quantizable.MultiheadAttention, |
| }, |
| 'observed_to_quantized_custom_module_class': { |
| nn.quantizable.LSTM: nn.quantized.LSTM, |
| nn.quantizable.MultiheadAttention: nn.quantized.MultiheadAttention, |
| } |
| } |
| |
| def get_default_custom_config_dict(): |
| r"""Defines the default custom config dict. |
| """ |
| return _DEFAULT_CUSTOM_CONFIG_DICT |
| |
| def is_activation_post_process(module): |
| return (isinstance(module, torch.ao.quantization.ObserverBase) or |
| isinstance(module, torch.ao.quantization.FakeQuantizeBase)) |
| |
| |
| def _propagate_qconfig_helper(module, qconfig_dict, |
| qconfig_parent=None, prefix='', prepare_custom_config_dict=None): |
| r"""This is a helper function for `propagate_qconfig_` |
| |
| Args: |
| module: input module |
| qconfig_dict: dictionary that maps from name of submodule to quantization |
| configuration |
| qconfig_parent: quantization config of parent module, we will fallback to |
| this config when there is no specified config for current |
| module |
| prefix: corresponding prefix of the current module, used as key in |
| qconfig_dict |
| prepare_custom_config_dict: dictionary for custom handling of modules |
| see docs for :func:`~torch.ao.quantization.prepare_fx` |
| |
| Return: |
| None, module is modified inplace with qconfig attached |
| """ |
| |
| module_qconfig = qconfig_dict.get(type_before_parametrizations(module), qconfig_parent) |
| module_qconfig = qconfig_dict.get(prefix, module_qconfig) |
| module_qconfig = getattr(module, 'qconfig', module_qconfig) |
| |
| torch.ao.quantization.qconfig.assert_valid_qconfig(module_qconfig, module) |
| |
| qconfig_with_device_check = add_module_to_qconfig_obs_ctr(module_qconfig, module) |
| module.qconfig = qconfig_with_device_check |
| |
| for name, child in module.named_children(): |
| module_prefix = prefix + '.' + name if prefix else name |
| # do no not propagate qconfig to child if child is non traceable |
| if prepare_custom_config_dict is None or not ( |
| name in prepare_custom_config_dict.get("non_traceable_module_name", []) |
| or type(child) in prepare_custom_config_dict.get("non_traceable_module_class", []) |
| ): |
| _propagate_qconfig_helper( |
| child, qconfig_dict, qconfig_with_device_check, module_prefix |
| ) |
| |
| def propagate_qconfig_(module, qconfig_dict=None, prepare_custom_config_dict=None): |
| r"""Propagate qconfig through the module hierarchy and assign `qconfig` |
| attribute on each leaf module |
| |
| Args: |
| module: input module |
| qconfig_dict: dictionary that maps from name or type of submodule to |
| quantization configuration, qconfig applies to all submodules of a |
| given module unless qconfig for the submodules are specified (when |
| the submodule already has qconfig attribute) |
| prepare_custom_config_dict: dictionary for custom handling of modules |
| see docs for :func:`~torch.ao.quantization.prepare_fx` |
| |
| Return: |
| None, module is modified inplace with qconfig attached |
| """ |
| if qconfig_dict is None: |
| qconfig_dict = {} |
| if prepare_custom_config_dict is None: |
| prepare_custom_config_dict = {} |
| _propagate_qconfig_helper(module, qconfig_dict, prepare_custom_config_dict=prepare_custom_config_dict) |
| |
| def _observer_forward_hook(self, input, output): |
| r"""Forward hook that calls observer on the output |
| """ |
| return self.activation_post_process(output) |
| |
| def _observer_forward_pre_hook(self, input): |
| r"""Forward pre hook that calls observer on the output |
| """ |
| return self.activation_post_process(input[0]) |
| |
| def register_activation_post_process_hook(module, pre_hook=False): |
| assert hasattr(module, 'activation_post_process'), \ |
| 'Expect activation_post_process attribute already attached to the module' |
| if pre_hook: |
| handle = module.register_forward_pre_hook(_observer_forward_pre_hook) |
| module._forward_pre_hooks.move_to_end(handle.id, last=False) |
| else: |
| handle = module.register_forward_hook(_observer_forward_hook) |
| module._forward_hooks.move_to_end(handle.id, last=False) |
| |
| |
| def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None): |
| r"""Add observer for the leaf child of the module. |
| |
| This function insert observer module to all leaf child module that |
| has a valid qconfig attribute. |
| |
| Args: |
| module: input module with qconfig attributes for all the leaf modules that we want to quantize |
| qconfig_propagation_list: a list of quantizable modules that will have observers added to them |
| if they are leaf nodes |
| device: parent device, if any |
| non_leaf_module_list: list of non-leaf modules we want to add observer |
| |
| Return: |
| None, module is modified inplace with added observer modules and forward_hooks |
| """ |
| if qconfig_propagation_list is None: |
| qconfig_propagation_list = get_default_qconfig_propagation_list() |
| |
| if custom_module_class_mapping is None: |
| custom_module_class_mapping = {} |
| |
| # respect device affinity when adding observers |
| if device is None: |
| devices = get_unique_devices_(module) |
| assert len(devices) <= 1, ( |
| "add_observer_ only works with cpu or single-device CUDA modules, " |
| "but got devices {}".format(devices) |
| ) |
| device = next(iter(devices)) if len(devices) > 0 else None |
| |
| def get_activation_post_process(qconfig, device, special_act_post_process=None): |
| activation = qconfig.activation() if special_act_post_process is None else special_act_post_process() |
| if device is not None: |
| activation.to(device) |
| return activation |
| |
| def needs_observation(m): |
| return hasattr(m, 'qconfig') and m.qconfig is not None |
| |
| def insert_activation_post_process(m, special_act_post_process=None): |
| """ Adds an activation post process module and register |
| a pre or post hook that calls the module |
| """ |
| # We don't insert observer/fake_quantize for DeQuantStub |
| if needs_observation(m) and not isinstance(m, DeQuantStub): |
| # observer and hook will be gone after we swap the module |
| m.add_module('activation_post_process', get_activation_post_process( |
| m.qconfig, device, special_act_post_process)) |
| # Register observer as the first entry in the hook list |
| # All post forward hooks are preserved and will be executed after the observer before convert |
| register_activation_post_process_hook(m, pre_hook=activation_is_memoryless(m.qconfig)) |
| |
| for name, child in module.named_children(): |
| # TODO remove Dropout special after codebase stable |
| if type_before_parametrizations(child) in [nn.Dropout]: |
| continue |
| elif type_before_parametrizations(child) in [nnq.FloatFunctional, nnq.QFunctional]: |
| if needs_observation(child): |
| child.activation_post_process = get_activation_post_process(child.qconfig, device) |
| elif isinstance(child, _FusedModule): |
| # activation_post_process are now added directly to nn.Sequentail/_FusedModule |
| if needs_observation(child): |
| insert_activation_post_process(child) |
| elif _has_special_act_post_process(child): |
| special_act_post_process = _get_special_act_post_process(child) |
| insert_activation_post_process(child, special_act_post_process) |
| elif non_leaf_module_list is not None and type_before_parametrizations(child) in non_leaf_module_list: |
| if needs_observation(child): |
| insert_activation_post_process(child) |
| elif needs_observation(child) and type_before_parametrizations(child) in custom_module_class_mapping: |
| observed_child = custom_module_class_mapping[type_before_parametrizations(child)].from_float(child) |
| setattr(module, name, observed_child) |
| # TODO: These are the modules that cannot be observed |
| # Once there are more, we should move them to a separate list |
| if custom_module_class_mapping[type_before_parametrizations(child)] not in no_observer_set(): |
| insert_activation_post_process(observed_child) |
| else: |
| add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device, custom_module_class_mapping) |
| |
| # Insert observers only for leaf nodes, note that this observer is for |
| # the output of the module, for input QuantStub will observe them |
| if has_no_children_ignoring_parametrizations(module) and not isinstance(module, torch.nn.Sequential) \ |
| and type_before_parametrizations(module) in qconfig_propagation_list: |
| insert_activation_post_process(module) |
| |
| def get_unique_devices_(module): |
| return {p.device for p in module.parameters()} | \ |
| {p.device for p in module.buffers()} |
| |
| def add_quant_dequant(module): |
| r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig |
| Note that this function will modify the children of module inplace and it |
| can return a new module which wraps the input module as well. |
| |
| Args: |
| module: input module with qconfig attributes for all the leaf modules |
| that we want to quantize |
| |
| Return: |
| Either the inplace modified module with submodules wrapped in |
| `QuantWrapper` based on qconfig or a new `QuantWrapper` module which |
| wraps the input module, the latter case only happens when the input |
| module is a leaf module and we want to quantize it. |
| """ |
| if has_no_children_ignoring_parametrizations(module) and hasattr(module, 'qconfig') and module.qconfig: |
| return QuantWrapper(module) |
| |
| for name, child in module.named_children(): |
| module._modules[name] = add_quant_dequant(child) |
| return module |
| |
| def prepare(model, inplace=False, allow_list=None, |
| observer_non_leaf_module_list=None, |
| prepare_custom_config_dict=None): |
| r"""Prepares a copy of the model for quantization calibration or quantization-aware training. |
| |
| Quantization configuration should be assigned preemptively |
| to individual submodules in `.qconfig` attribute. |
| |
| The model will be attached with observer or fake quant modules, and qconfig |
| will be propagated. |
| |
| Args: |
| `model`: input model to be modified in-place |
| `inplace`: carry out model transformations in-place, the original module is mutated |
| `allow_list`: list of quantizable modules |
| `observer_non_leaf_module_list`: list of non-leaf modules we want to add observer |
| `prepare_custom_config_dict`: customization configuration dictionary for prepare function |
| |
| .. code-block:: python |
| |
| # Example of prepare_custom_config_dict: |
| prepare_custom_config_dict = { |
| # user will manually define the corresponding observed |
| # module class which has a from_float class method that converts |
| # float custom module to observed custom module |
| "float_to_observed_custom_module_class": { |
| CustomModule: ObservedCustomModule |
| } |
| } |
| |
| """ |
| torch._C._log_api_usage_once("quantization_api.quantize.prepare") |
| if prepare_custom_config_dict is None: |
| prepare_custom_config_dict = get_default_custom_config_dict() |
| custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {}) |
| |
| if not inplace: |
| model = copy.deepcopy(model) |
| |
| # TODO: remove allow_list |
| qconfig_propagation_list = allow_list |
| if allow_list is None: |
| qconfig_propagation_list = get_default_qconfig_propagation_list() |
| propagate_qconfig_(model, qconfig_dict=None) |
| |
| # sanity check common API misusage |
| if not any(hasattr(m, 'qconfig') and m.qconfig for m in model.modules()): |
| warnings.warn("None of the submodule got qconfig applied. Make sure you " |
| "passed correct configuration through `qconfig_dict` or " |
| "by assigning the `.qconfig` attribute directly on submodules") |
| |
| add_observer_( |
| model, qconfig_propagation_list, observer_non_leaf_module_list, |
| custom_module_class_mapping=custom_module_class_mapping) |
| return model |
| |
| def _remove_activation_post_process(module): |
| # TODO: maybe we should change activation_post_process to _activation_post_process |
| # to prevent it from being used by user |
| if hasattr(module, 'activation_post_process') and \ |
| is_activation_post_process(module.activation_post_process): |
| delattr(module, 'activation_post_process') |
| |
| # remove activation_post_proceess pre and post hooks |
| def remove_hooks(pre_hook=False): |
| hook_map = module._forward_pre_hooks if pre_hook else module._forward_hooks |
| observer_hook = _observer_forward_pre_hook if pre_hook else _observer_forward_hook |
| handle_ids_to_remove = set() |
| for handle_id, hook_fn in hook_map.items(): |
| if hook_fn is observer_hook: |
| handle_ids_to_remove.add(handle_id) |
| for handle_id in handle_ids_to_remove: |
| hook_map.pop(handle_id) |
| |
| remove_hooks(pre_hook=True) |
| remove_hooks(pre_hook=False) |
| |
| # TODO: rename to something more general |
| def _remove_qconfig(module): |
| r"""Clean up the qconfig left in the module so that new qconfig can be |
| propagated. |
| |
| Args: |
| module: module to be cleaned up |
| """ |
| for child in module.children(): |
| _remove_qconfig(child) |
| |
| if hasattr(module, "qconfig"): |
| del module.qconfig |
| |
| _remove_activation_post_process(module) |
| |
| def quantize(model, run_fn, run_args, mapping=None, inplace=False): |
| r"""Quantize the input float model with post training static quantization. |
| |
| First it will prepare the model for calibration, then it calls |
| `run_fn` which will run the calibration step, after that we will |
| convert the model to a quantized model. |
| |
| Args: |
| model: input float model |
| run_fn: a calibration function for calibrating the prepared model |
| run_args: positional arguments for `run_fn` |
| inplace: carry out model transformations in-place, the original module is mutated |
| mapping: correspondence between original module types and quantized counterparts |
| |
| Return: |
| Quantized model. |
| """ |
| torch._C._log_api_usage_once("quantization_api.quantize.quantize") |
| if mapping is None: |
| mapping = get_default_static_quant_module_mappings() |
| if not inplace: |
| model = copy.deepcopy(model) |
| model.eval() |
| prepare(model, inplace=True) |
| run_fn(model, *run_args) |
| convert(model, mapping, inplace=True) |
| return model |
| |
| def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8, |
| mapping=None, inplace=False): |
| r"""Converts a float model to dynamic (i.e. weights-only) quantized model. |
| |
| Replaces specified modules with dynamic weight-only quantized versions and output the quantized model. |
| |
| For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization |
| by default is performed for layers with large weights size - i.e. Linear and RNN variants. |
| |
| Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`. |
| If `qconfig` is provided, the `dtype` argument is ignored. |
| |
| Args: |
| model: input model |
| qconfig_spec: Either: |
| |
| - A dictionary that maps from name or type of submodule to quantization |
| configuration, qconfig applies to all submodules of a given |
| module unless qconfig for the submodules are specified (when the |
| submodule already has qconfig attribute). Entries in the dictionary |
| need to be QConfig instances. |
| |
| - A set of types and/or submodule names to apply dynamic quantization to, |
| in which case the `dtype` argument is used to specify the bit-width |
| |
| inplace: carry out model transformations in-place, the original module is mutated |
| mapping: maps type of a submodule to a type of corresponding dynamically quantized version |
| with which the submodule needs to be replaced |
| |
| """ |
| torch._C._log_api_usage_once("quantization_api.quantize.quantize_dynamic") |
| if qconfig_spec is None: |
| if dtype == torch.qint8: |
| qconfig_spec = { |
| nn.Linear : default_dynamic_qconfig, |
| nn.LSTM : default_dynamic_qconfig, |
| nn.GRU : default_dynamic_qconfig, |
| nn.LSTMCell : default_dynamic_qconfig, |
| nn.RNNCell : default_dynamic_qconfig, |
| nn.GRUCell : default_dynamic_qconfig, |
| } |
| elif dtype == torch.float16: |
| qconfig_spec = { |
| nn.Linear : float16_dynamic_qconfig, |
| nn.LSTM : float16_dynamic_qconfig, |
| nn.GRU : float16_dynamic_qconfig, |
| nn.LSTMCell : float16_dynamic_qconfig, |
| nn.RNNCell : float16_dynamic_qconfig, |
| nn.GRUCell : float16_dynamic_qconfig, |
| } |
| elif dtype == torch.quint8: |
| qconfig_spec = { |
| nn.EmbeddingBag : float_qparams_weight_only_qconfig, |
| nn.Embedding : float_qparams_weight_only_qconfig, |
| } |
| elif dtype == torch.quint4x2: |
| qconfig_spec = { |
| nn.EmbeddingBag : float_qparams_weight_only_qconfig_4bit, |
| } |
| else: |
| raise ValueError( |
| "Don't know how to quantize with default settings for {}. Provide full qconfig please".format(dtype)) |
| elif isinstance(qconfig_spec, set): |
| if dtype is torch.qint8: |
| default_qconfig = default_dynamic_qconfig |
| elif dtype is torch.float16: |
| default_qconfig = float16_dynamic_qconfig |
| elif dtype is torch.quint8: |
| default_qconfig = float_qparams_weight_only_qconfig |
| elif dtype is torch.quint4x2: |
| default_qconfig = float_qparams_weight_only_qconfig_4bit |
| else: |
| raise RuntimeError('Unknown dtype specified for quantize_dynamic: ', str(dtype)) |
| qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig))) |
| |
| if mapping is None: |
| mapping = get_default_dynamic_quant_module_mappings() |
| |
| if not inplace: |
| model = copy.deepcopy(model) |
| model.eval() |
| propagate_qconfig_(model, qconfig_spec) |
| convert(model, mapping, inplace=True) |
| return model |
| |
| def prepare_qat(model, mapping=None, inplace=False): |
| r""" |
| Prepares a copy of the model for quantization calibration or |
| quantization-aware training and converts it to quantized version. |
| |
| Quantization configuration should be assigned preemptively |
| to individual submodules in `.qconfig` attribute. |
| |
| Args: |
| model: input model to be modified in-place |
| mapping: dictionary that maps float modules to quantized modules to be |
| replaced. |
| inplace: carry out model transformations in-place, the original module |
| is mutated |
| """ |
| torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat") |
| assert model.training, "prepare_qat only works on models in training mode" |
| if mapping is None: |
| mapping = get_default_qat_module_mappings() |
| |
| if not inplace: |
| model = copy.deepcopy(model) |
| |
| propagate_qconfig_(model, qconfig_dict=None) |
| convert(model, mapping=mapping, inplace=True, remove_qconfig=False) |
| prepare(model, observer_non_leaf_module_list=set(mapping.values()), inplace=True) |
| return model |
| |
| def quantize_qat(model, run_fn, run_args, inplace=False): |
| r"""Do quantization aware training and output a quantized model |
| |
| Args: |
| model: input model |
| run_fn: a function for evaluating the prepared model, can be a |
| function that simply runs the prepared model or a training |
| loop |
| run_args: positional arguments for `run_fn` |
| |
| Return: |
| Quantized model. |
| """ |
| torch._C._log_api_usage_once("quantization_api.quantize.quantize_qat") |
| if not inplace: |
| model = copy.deepcopy(model) |
| model.train() |
| prepare_qat(model, inplace=True) |
| run_fn(model, *run_args) |
| convert(model, inplace=True) |
| return model |
| |
| def convert( |
| module, mapping=None, inplace=False, remove_qconfig=True, |
| is_reference=False, convert_custom_config_dict=None): |
| r"""Converts submodules in input module to a different module according to `mapping` |
| by calling `from_float` method on the target module class. And remove qconfig at the |
| end if remove_qconfig is set to True. |
| |
| Args: |
| `module`: prepared and calibrated module |
| `mapping`: a dictionary that maps from source module type to target |
| module type, can be overwritten to allow swapping user defined |
| Modules |
| `inplace`: carry out model transformations in-place, the original module |
| is mutated |
| `convert_custom_config_dict`: custom configuration dictionary for convert function |
| |
| .. code-block:: python |
| |
| # Example of convert_custom_config_dict: |
| convert_custom_config_dict = { |
| # user will manually define the corresponding quantized |
| # module class which has a from_observed class method that converts |
| # observed custom module to quantized custom module |
| "observed_to_quantized_custom_module_class": { |
| ObservedCustomModule: QuantizedCustomModule |
| } |
| } |
| |
| """ |
| torch._C._log_api_usage_once("quantization_api.quantize.convert") |
| if not inplace: |
| module = copy.deepcopy(module) |
| _convert( |
| module, mapping, inplace=True, is_reference=is_reference, |
| convert_custom_config_dict=convert_custom_config_dict) |
| if remove_qconfig: |
| _remove_qconfig(module) |
| return module |
| |
| def _convert( |
| module, mapping=None, inplace=False, |
| is_reference=False, convert_custom_config_dict=None): |
| r"""Converts submodules in input module to a different module according to `mapping` |
| by calling `from_float` method on the target module class |
| |
| Args: |
| module: input module |
| mapping: a dictionary that maps from source module type to target |
| module type, can be overwritten to allow swapping user defined |
| Modules |
| inplace: carry out model transformations in-place, the original module |
| is mutated |
| is_reference: a flag to enable quantized reference module |
| |
| """ |
| if mapping is None: |
| mapping = get_default_static_quant_reference_module_mappings() if is_reference \ |
| else get_default_static_quant_module_mappings() |
| if convert_custom_config_dict is None: |
| convert_custom_config_dict = get_default_custom_config_dict() |
| custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {}) |
| |
| if not inplace: |
| module = copy.deepcopy(module) |
| reassign = {} |
| for name, mod in module.named_children(): |
| # both fused modules and observed custom modules are |
| # swapped as one unit |
| if not isinstance(mod, _FusedModule) and \ |
| type_before_parametrizations(mod) not in custom_module_class_mapping: |
| _convert(mod, mapping, True, # inplace |
| is_reference, convert_custom_config_dict) |
| reassign[name] = swap_module(mod, mapping, custom_module_class_mapping) |
| |
| for key, value in reassign.items(): |
| module._modules[key] = value |
| |
| return module |
| |
| def swap_module(mod, mapping, custom_module_class_mapping): |
| r"""Swaps the module if it has a quantized counterpart and it has an |
| `observer` attached. |
| |
| Args: |
| mod: input module |
| mapping: a dictionary that maps from nn module to nnq module |
| |
| Return: |
| The corresponding quantized module of `mod` |
| """ |
| new_mod = mod |
| if hasattr(mod, 'qconfig') and mod.qconfig is not None: |
| swapped = False |
| if type_before_parametrizations(mod) in custom_module_class_mapping: |
| new_mod = custom_module_class_mapping[type_before_parametrizations(mod)].from_observed(mod) |
| swapped = True |
| elif type_before_parametrizations(mod) in mapping: |
| qmod = mapping[type_before_parametrizations(mod)] |
| if hasattr(qmod, '_IS_REFERENCE') and qmod._IS_REFERENCE: |
| assert mod.qconfig is not None |
| weight_post_process = mod.qconfig.weight() |
| weight_post_process(mod.weight) |
| weight_qparams = get_qparam_dict(weight_post_process) |
| new_mod = qmod.from_float(mod, weight_qparams) |
| else: |
| new_mod = qmod.from_float(mod) |
| swapped = True |
| |
| if swapped: |
| # Preserve module's pre forward hooks. They'll be called on quantized input |
| for pre_hook_fn in mod._forward_pre_hooks.values(): |
| new_mod.register_forward_pre_hook(pre_hook_fn) |
| # Preserve module's post forward hooks except _observer_forward_hook |
| # After convert they'll work with quantized output |
| for hook_fn in mod._forward_hooks.values(): |
| if hook_fn is not _observer_forward_hook: |
| new_mod.register_forward_hook(hook_fn) |
| |
| # respect device affinity when swapping modules |
| devices = get_unique_devices_(mod) |
| assert len(devices) <= 1, ( |
| "swap_module only works with cpu or single-device CUDA modules, " |
| "but got devices {}".format(devices) |
| ) |
| device = next(iter(devices)) if len(devices) > 0 else None |
| if device: |
| new_mod.to(device) |
| return new_mod |
| |
| def get_observer_dict(mod, target_dict, prefix=""): |
| r"""Traverse the modules and save all observers into dict. |
| This is mainly used for quantization accuracy debug |
| Args: |
| mod: the top module we want to save all observers |
| prefix: the prefix for the current module |
| target_dict: the dictionary used to save all the observers |
| """ |
| def get_prefix(prefix): |
| return prefix if prefix == "" else prefix + '.' |
| |
| if hasattr(mod, 'activation_post_process'): |
| target_dict[get_prefix(prefix) + 'activation_post_process'] = mod.activation_post_process |
| for name, child in mod.named_children(): |
| module_prefix = get_prefix(prefix) + name if prefix else name |
| get_observer_dict(child, target_dict, module_prefix) |