blob: 8b192b6ffd670806511b93bd32b159015306da58 [file] [log] [blame]
import copy
import torch
from torch import nn
import torch.nn.functional as F
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.quantized as nniq
import torch.nn.intrinsic.quantized.dynamic as nniqd
import torch.nn.intrinsic.qat as nniqat
import torch.nn.quantized as nnq
import torch.nn.quantized._reference as nnqr
import torch.nn.quantized.dynamic as nnqd
import torch.nn.qat as nnqat
import torch.nn.qat.dynamic as nnqatd
from typing import Optional, Union, Dict, Set, Callable, Any
import torch.ao.nn as ao_nn
from torch.ao.quantization.stubs import QuantStub, DeQuantStub
from torch.ao.quantization.fake_quantize import (
default_fixed_qparams_range_0to1_fake_quant,
default_fixed_qparams_range_neg1to1_fake_quant,
)
from torch.ao.quantization.utils import get_combined_dict
from torch.nn.utils.parametrize import type_before_parametrizations
# Default map for swapping float module to reference quantized modules
DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
QuantStub: nnq.Quantize,
DeQuantStub: nnq.DeQuantize,
nn.Linear: nnqr.Linear,
nn.Conv1d: nnqr.Conv1d,
nn.Conv2d: nnqr.Conv2d,
nn.Conv3d: nnqr.Conv3d,
nn.ConvTranspose1d: nnqr.ConvTranspose1d,
nn.ConvTranspose2d: nnqr.ConvTranspose2d,
nn.ConvTranspose3d: nnqr.ConvTranspose3d,
nn.Embedding: nnqr.Embedding,
nn.EmbeddingBag: nnqr.EmbeddingBag,
nn.GRUCell: nnqr.GRUCell,
nn.LSTMCell: nnqr.LSTMCell,
nn.RNNCell: nnqr.RNNCell,
nn.LSTM: nnqr.LSTM,
}
# Default map for swapping float module to quantized ones
DEFAULT_STATIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
QuantStub: nnq.Quantize,
DeQuantStub: nnq.DeQuantize,
nn.BatchNorm2d: nnq.BatchNorm2d,
nn.BatchNorm3d: nnq.BatchNorm3d,
nn.Dropout: nnq.Dropout,
nn.Conv1d: nnq.Conv1d,
nn.Conv2d: nnq.Conv2d,
nn.Conv3d: nnq.Conv3d,
nn.ConvTranspose1d: nnq.ConvTranspose1d,
nn.ConvTranspose2d: nnq.ConvTranspose2d,
nn.ConvTranspose3d: nnq.ConvTranspose3d,
nn.ELU: nnq.ELU,
nn.Embedding: nnq.Embedding,
nn.EmbeddingBag: nnq.EmbeddingBag,
nn.GroupNorm: nnq.GroupNorm,
nn.Hardswish: nnq.Hardswish,
nn.InstanceNorm1d: nnq.InstanceNorm1d,
nn.InstanceNorm2d: nnq.InstanceNorm2d,
nn.InstanceNorm3d: nnq.InstanceNorm3d,
nn.LayerNorm: nnq.LayerNorm,
nn.LeakyReLU: nnq.LeakyReLU,
nn.modules.linear.NonDynamicallyQuantizableLinear: nnq.Linear,
nn.Linear: nnq.Linear,
nn.ReLU6: nnq.ReLU6,
nn.Dropout: nnq.Dropout,
nn.PReLU: nnq.PReLU,
# Wrapper Modules:
nnq.FloatFunctional: nnq.QFunctional,
# Intrinsic modules:
nni.BNReLU2d: nniq.BNReLU2d,
nni.BNReLU3d: nniq.BNReLU3d,
nni.ConvReLU1d: nniq.ConvReLU1d,
nni.ConvReLU2d: nniq.ConvReLU2d,
nni.ConvReLU3d: nniq.ConvReLU3d,
nni.LinearReLU: nniq.LinearReLU,
nniqat.ConvBn1d: nnq.Conv1d,
nniqat.ConvBn2d: nnq.Conv2d,
nniqat.ConvBn3d: nnq.Conv3d,
nniqat.ConvBnReLU1d: nniq.ConvReLU1d,
nniqat.ConvBnReLU2d: nniq.ConvReLU2d,
nniqat.ConvBnReLU3d: nniq.ConvReLU3d,
nniqat.ConvReLU2d: nniq.ConvReLU2d,
nniqat.ConvReLU3d: nniq.ConvReLU3d,
nniqat.LinearReLU: nniq.LinearReLU,
nniqat.LinearBn1d: nnq.Linear,
# QAT modules:
nnqat.Linear: nnq.Linear,
nnqat.Conv2d: nnq.Conv2d,
nnqat.Conv3d: nnq.Conv3d,
}
# Default map for swapping float module to qat modules
DEFAULT_QAT_MODULE_MAPPINGS : Dict[Callable, Any] = {
nn.Conv2d: nnqat.Conv2d,
nn.Conv3d: nnqat.Conv3d,
nn.Linear: nnqat.Linear,
nn.modules.linear.NonDynamicallyQuantizableLinear: nnqat.Linear,
# Intrinsic modules:
nni.ConvBn1d: nniqat.ConvBn1d,
nni.ConvBn2d: nniqat.ConvBn2d,
nni.ConvBn3d: nniqat.ConvBn3d,
nni.ConvBnReLU1d: nniqat.ConvBnReLU1d,
nni.ConvBnReLU2d: nniqat.ConvBnReLU2d,
nni.ConvBnReLU3d: nniqat.ConvBnReLU3d,
nni.ConvReLU2d: nniqat.ConvReLU2d,
nni.ConvReLU3d: nniqat.ConvReLU3d,
nni.LinearReLU: nniqat.LinearReLU,
nni.LinearBn1d: nniqat.LinearBn1d,
}
# Default map for swapping dynamic modules
DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
nn.GRUCell: nnqd.GRUCell,
nn.Linear: nnqd.Linear,
nnqatd.Linear: nnqd.Linear,
nn.modules.linear.NonDynamicallyQuantizableLinear: nnqd.Linear,
nn.LSTM: nnqd.LSTM,
nn.GRU: nnqd.GRU,
nn.LSTMCell: nnqd.LSTMCell,
nn.RNNCell: nnqd.RNNCell,
nni.LinearReLU: nniqd.LinearReLU,
nn.EmbeddingBag: nnq.EmbeddingBag,
nn.Embedding: nnq.Embedding,
# Don't want to enable these by default because the numerical
# accuracy is poor compared to other dynamic ops
# nn.Conv1d: nnqd.Conv1d,
# nn.Conv2d: nnqd.Conv2d,
# nn.Conv3d: nnqd.Conv3d,
# nn.ConvTranspose1d: nnqd.ConvTranspose1d,
# nn.ConvTranspose2d: nnqd.ConvTranspose2d,
# nn.ConvTranspose3d: nnqd.ConvTranspose3d,
}
# Allowlist for propagating the qconfig
_INCLUDE_QCONFIG_PROPAGATE_LIST : Set[Callable] = {
nn.Sequential,
}
# Default mapping from floating point function or torch ops to quantized ops
# TODO: merge with default static mapping
DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS : Dict[Union[Callable, str], Callable] = {
F.elu: torch.ops.quantized.elu,
F.hardswish: torch.ops.quantized.hardswish,
F.instance_norm: torch.ops.quantized.instance_norm,
F.layer_norm: torch.ops.quantized.layer_norm,
F.leaky_relu: torch.ops.quantized.leaky_relu,
F.dropout: torch.ops.quantized.dropout,
}
# mapping from module to output activation post process class
DEFAULT_MODULE_TO_ACT_POST_PROCESS : Dict[Callable, Callable] = {
nn.Hardsigmoid: default_fixed_qparams_range_0to1_fake_quant,
nn.Sigmoid: default_fixed_qparams_range_0to1_fake_quant,
nn.Softmax: default_fixed_qparams_range_0to1_fake_quant,
nn.Tanh: default_fixed_qparams_range_neg1to1_fake_quant,
}
# Default map for swapping float module to static sparse quantized ones
DEFAULT_STATIC_SPARSE_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
nn.Linear: ao_nn.sparse.quantized.Linear
}
# Default map for swapping float module to dynamic sparse quantized ones
DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS : Dict[Callable, Any] = {
nn.Linear: ao_nn.sparse.quantized.dynamic.Linear
}
def no_observer_set() -> Set[Any]:
r"""These modules cannot have observers inserted by default."""
no_observers = set([
nn.quantizable.LSTM,
nn.quantizable.MultiheadAttention
])
return no_observers
def get_default_static_quant_module_mappings() -> Dict[Callable, Any]:
''' Get module mapping for post training static quantization
'''
return copy.deepcopy(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS)
def get_default_static_quant_reference_module_mappings() -> Dict[Callable, Any]:
''' Get reference module mapping for post training static quantization
'''
return copy.deepcopy(DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS)
def get_embedding_static_quant_module_mappings() -> Dict[Callable, Any]:
''' Get module mapping, including mapping for embedding QAT
'''
mapping = copy.deepcopy(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS)
mapping[nnqat.EmbeddingBag] = nnq.EmbeddingBag
mapping[nnqat.Embedding] = nnq.Embedding
return mapping
def get_default_static_sparse_quant_module_mappings() -> Dict[Callable, Any]:
''' Get module mapping for post training static sparse quantization
'''
return copy.deepcopy(DEFAULT_STATIC_SPARSE_QUANT_MODULE_MAPPINGS)
def get_static_quant_module_class(
float_module_class: Callable,
additional_static_quant_mapping: Optional[Dict[Callable, Any]] = None,
is_reference: bool = False) -> Any:
r"""n Get the statically quantized module class corresponding to
the floating point module class
"""
if additional_static_quant_mapping is None:
additional_static_quant_mapping = {}
all_mappings = get_combined_dict(
DEFAULT_REFERENCE_STATIC_QUANT_MODULE_MAPPINGS if is_reference
else DEFAULT_STATIC_QUANT_MODULE_MAPPINGS, additional_static_quant_mapping)
static_quant_module_class = all_mappings.get(float_module_class, None)
assert static_quant_module_class is not None, \
"Floating point module class {}".format(str(float_module_class)) + \
" does not have a corresponding quantized module class"
return copy.deepcopy(static_quant_module_class)
def get_dynamic_quant_module_class(
float_module_class: Callable,
additional_dynamic_quant_mapping: Optional[Dict[Callable, Any]] = None) -> Any:
r"""n Get the dynamically quantized module class corresponding to
the floating point module class
"""
if additional_dynamic_quant_mapping is None:
additional_dynamic_quant_mapping = {}
all_mappings = get_combined_dict(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, additional_dynamic_quant_mapping)
dynamic_quant_module_class = all_mappings.get(float_module_class, None)
assert dynamic_quant_module_class is not None, \
"Floating point module class {}".format(str(float_module_class)) + \
" does not have a corresponding quantized module class"
return copy.deepcopy(dynamic_quant_module_class)
def get_default_qat_module_mappings() -> Dict[Callable, Any]:
''' Get default module mapping for quantization aware training
'''
return copy.deepcopy(DEFAULT_QAT_MODULE_MAPPINGS)
def get_embedding_qat_module_mappings() -> Dict[Callable, Any]:
''' Get module mapping for quantization aware training
This is includes default values in addition to
enabling qat for embeddings.
'''
mapping = copy.deepcopy(DEFAULT_QAT_MODULE_MAPPINGS)
mapping[nn.EmbeddingBag] = nnqat.EmbeddingBag
mapping[nn.Embedding] = nnqat.Embedding
return mapping
def get_default_dynamic_quant_module_mappings() -> Dict[Callable, Any]:
''' Get module mapping for post training dynamic quantization
'''
return DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS
def get_default_dynamic_sparse_quant_module_mappings() -> Dict[Callable, Any]:
''' Get module mapping for post training dynamic sparse quantization
'''
return DEFAULT_DYNAMIC_SPARSE_QUANT_MODULE_MAPPINGS
def get_default_qconfig_propagation_list() -> Set[Callable]:
''' Get the default list of module types that we'll attach qconfig
attribute to in prepare
'''
QCONFIG_PROPAGATE_MODULE_CLASS_LIST = (
(set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys()) |
set(DEFAULT_QAT_MODULE_MAPPINGS.keys()) |
set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys()) |
_INCLUDE_QCONFIG_PROPAGATE_LIST)
)
return copy.deepcopy(QCONFIG_PROPAGATE_MODULE_CLASS_LIST)
def get_default_compare_output_module_list() -> Set[Callable]:
''' Get list of module class types that we will record output
in numeric suite
'''
NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST = (
set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.values())
| set(DEFAULT_QAT_MODULE_MAPPINGS.values())
| set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.values())
| set(DEFAULT_STATIC_QUANT_MODULE_MAPPINGS.keys())
| set(DEFAULT_QAT_MODULE_MAPPINGS.keys())
| set(DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS.keys())
| _INCLUDE_QCONFIG_PROPAGATE_LIST
)
return copy.deepcopy(NUMERIC_SUITE_COMPARE_MODEL_OUTPUT_MODULE_LIST)
def get_default_float_to_quantized_operator_mappings(
) -> Dict[Union[Callable, str], Callable]:
return copy.deepcopy(DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS)
# TODO: merge with get_static_quant_module_class
def get_quantized_operator(float_op: Union[Callable, str]) -> Callable:
''' Get the quantized operator corresponding to the float operator
'''
quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op, None)
assert quantized_op is not None, \
'Operator {} does not have corresponding quantized op'.format(str(float_op))
return quantized_op
def _get_special_act_post_process(module: torch.nn.Module) -> Optional[Callable]:
r""" Get the special activation post process for `module`, this has
higher priority than the activation post process in `qconfig`
e.g.
input: torch.nn.Sigmoid
output: default_affine_fixed_qparam_fake_quant
"""
return DEFAULT_MODULE_TO_ACT_POST_PROCESS.get(type_before_parametrizations(module), None)
def _has_special_act_post_process(module: torch.nn.Module) -> bool:
return module.training and type(module) in DEFAULT_MODULE_TO_ACT_POST_PROCESS