blob: a03f31de7624e0b9839b88a0e00d005e8a440564 [file] [log] [blame]
from typing import Any, Dict, Optional, Set, Tuple, Union
import warnings
import torch
from torch.fx import GraphModule
from .fx.tracer import QuantizationTracer
from .fx import fuse # noqa: F401
from .fx import prepare # noqa: F401
from .fx.convert import convert
from .backend_config import get_tensorrt_backend_config_dict # noqa: F401
from .fx.graph_module import ObservedGraphModule
from .fx.custom_config import (
ConvertCustomConfig,
FuseCustomConfig,
PrepareCustomConfig,
)
from .fx.utils import graph_pretty_str # noqa: F401
from .fx.utils import get_custom_module_class_keys # noqa: F401
from .fx.utils import get_skipped_module_name_and_classes
from .qconfig_mapping import QConfigMapping
def _check_is_graph_module(model: torch.nn.Module) -> None:
if not isinstance(model, GraphModule):
raise ValueError(
"input model must be a GraphModule, "
+ "Got type:"
+ str(type(model))
+ " Please make "
+ "sure to follow the tutorials."
)
def _swap_ff_with_fxff(model: torch.nn.Module) -> None:
r""" Swap FloatFunctional with FXFloatFunctional
"""
modules_to_swap = []
for name, module in model.named_children():
if isinstance(module, torch.nn.quantized.FloatFunctional):
modules_to_swap.append(name)
else:
_swap_ff_with_fxff(module)
for name in modules_to_swap:
del model._modules[name]
model._modules[name] = torch.nn.quantized.FXFloatFunctional()
def _fuse_fx(
graph_module: GraphModule,
is_qat: bool,
fuse_custom_config: Union[FuseCustomConfig, Dict[str, Any], None] = None,
backend_config_dict: Optional[Dict[str, Any]] = None,
) -> GraphModule:
r""" Internal helper function to fuse modules in preparation for quantization
Args:
graph_module: GraphModule object from symbolic tracing (torch.fx.symbolic_trace)
"""
_check_is_graph_module(graph_module)
return fuse(
graph_module, is_qat, fuse_custom_config, backend_config_dict) # type: ignore[operator]
class Scope(object):
""" Scope object that records the module path and the module type
of a module. Scope is used to track the information of the module
that contains a Node in a Graph of GraphModule. For example::
class Sub(torch.nn.Module):
def forward(self, x):
# This will be a call_method Node in GraphModule,
# scope for this would be (module_path="sub", module_type=Sub)
return x.transpose(1, 2)
class M(torch.nn.Module):
def __init__(self):
self.sub = Sub()
def forward(self, x):
# This will be a call_method Node as well,
# scope for this would be (module_path="", None)
x = x.transpose(1, 2)
x = self.sub(x)
return x
"""
def __init__(self, module_path: str, module_type: Any):
super().__init__()
self.module_path = module_path
self.module_type = module_type
class ScopeContextManager(object):
""" A context manager to track the Scope of Node during symbolic tracing.
When entering a forward function of a Module, we'll update the scope information of
the current module, and when we exit, we'll restore the previous scope information.
"""
def __init__(
self, scope: Scope, current_module: torch.nn.Module, current_module_path: str
):
super().__init__()
self.prev_module_type = scope.module_type
self.prev_module_path = scope.module_path
self.scope = scope
self.scope.module_path = current_module_path
self.scope.module_type = type(current_module)
def __enter__(self):
return
def __exit__(self, *args):
self.scope.module_path = self.prev_module_path
self.scope.module_type = self.prev_module_type
return
def _prepare_fx(
model: torch.nn.Module,
qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
is_qat: bool,
example_inputs: Tuple[Any, ...],
prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
_equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None,
backend_config_dict: Optional[Dict[str, Any]] = None,
is_standalone_module: bool = False,
) -> ObservedGraphModule:
r""" Internal helper function for prepare_fx
Args:
`model`, `qconfig_mapping`, `prepare_custom_config`, `_equalization_config`:
see docs for :func:`~torch.ao.quantization.prepare_fx`
`is_standalone_module`: a boolean flag indicates whether we are
quantizing a standalone module or not, a standalone module
is a submodule of the parent module that is not inlined in the
forward graph of the parent module,
the way we quantize standalone module is described in:
:func:`~torch.ao.quantization._prepare_standalone_module_fx`
"""
if prepare_custom_config is None:
prepare_custom_config = PrepareCustomConfig()
if _equalization_config is None:
_equalization_config = QConfigMapping()
if isinstance(prepare_custom_config, Dict):
warnings.warn(
"Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported "
"in a future version. Please pass in a PrepareCustomConfig instead.")
prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config)
# swap FloatFunctional with FXFloatFunctional
_swap_ff_with_fxff(model)
skipped_module_names, skipped_module_classes = \
get_skipped_module_name_and_classes(prepare_custom_config, is_standalone_module)
preserved_attributes = prepare_custom_config.preserved_attributes
# symbolically trace the model
tracer = QuantizationTracer(skipped_module_names, skipped_module_classes) # type: ignore[arg-type]
graph_module = GraphModule(model, tracer.trace(model))
for attr_name in preserved_attributes:
setattr(graph_module, attr_name, getattr(model, attr_name))
fuse_custom_config = FuseCustomConfig().set_preserved_attributes(prepare_custom_config.preserved_attributes)
graph_module = _fuse_fx(
graph_module,
is_qat,
fuse_custom_config,
backend_config_dict)
prepared = prepare(
graph_module,
qconfig_mapping,
is_qat,
tracer.node_name_to_scope,
example_inputs=example_inputs,
prepare_custom_config=prepare_custom_config,
_equalization_config=_equalization_config,
backend_config_dict=backend_config_dict,
is_standalone_module=is_standalone_module,
) # type: ignore[operator]
for attr_name in preserved_attributes:
setattr(prepared, attr_name, getattr(model, attr_name))
return prepared
def _prepare_standalone_module_fx(
model: torch.nn.Module,
qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
is_qat: bool,
example_inputs: Tuple[Any, ...],
prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
backend_config_dict: Optional[Dict[str, Any]] = None,
) -> GraphModule:
r""" [Internal use only] Prepare a standalone module, so that it can be used when quantizing the
parent module.
standalone_module means it a submodule that is not inlined in parent module,
and will be quantized separately as one unit.
How the standalone module is observed is specified by `input_quantized_idxs` and
`output_quantized_idxs` in the prepare_custom_config for the standalone module
Returns:
* model(GraphModule): prepared standalone module. It has these attributes:
* `_standalone_module_input_quantized_idxs(List[Int])`: a list of
indexes for the graph input that is expected to be quantized,
same as input_quantized_idxs configuration provided
for the standalone module
* `_standalone_module_output_quantized_idxs(List[Int])`: a list of
indexs for the graph output that is quantized
same as input_quantized_idxs configuration provided
for the standalone module
"""
return _prepare_fx(
model,
qconfig_mapping,
is_qat,
example_inputs,
prepare_custom_config,
backend_config_dict=backend_config_dict,
is_standalone_module=True,
)
def fuse_fx(
model: torch.nn.Module,
fuse_custom_config: Union[FuseCustomConfig, Dict[str, Any], None] = None,
backend_config_dict: Optional[Dict[str, Any]] = None,
) -> GraphModule:
r""" Fuse modules like conv+bn, conv+bn+relu etc, model must be in eval mode.
Fusion rules are defined in torch.quantization.fx.fusion_pattern.py
Args:
* `model`: a torch.nn.Module model
* `fuse_custom_config`: custom configurations for fuse_fx.
See :class:`~torch.ao.quantization.fx.custom_config.FuseCustomConfig` for more detail::
from torch.ao.quantization.fx.custom_config import FuseCustomConfig
fuse_custom_config = FuseCustomConfig().set_preserved_attributes(["preserved_attr"])
Example::
from torch.ao.quantization import fuse_fx
m = Model().eval()
m = fuse_fx(m)
"""
if fuse_custom_config is None:
fuse_custom_config = FuseCustomConfig()
if isinstance(fuse_custom_config, Dict):
warnings.warn(
"Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported "
"in a future version. Please pass in a FuseCustomConfig instead.")
fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config)
torch._C._log_api_usage_once("quantization_api.quantize_fx.fuse_fx")
graph_module = torch.fx.symbolic_trace(model)
preserved_attributes: Set[str] = set()
if fuse_custom_config:
preserved_attributes = set(fuse_custom_config.preserved_attributes)
for attr_name in preserved_attributes:
setattr(graph_module, attr_name, getattr(model, attr_name))
return _fuse_fx(graph_module, False, fuse_custom_config, backend_config_dict)
def prepare_fx(
model: torch.nn.Module,
qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
example_inputs: Tuple[Any, ...],
prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
_equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None,
backend_config_dict: Optional[Dict[str, Any]] = None,
) -> ObservedGraphModule:
r""" Prepare a model for post training static quantization
Args:
* `model` (required): torch.nn.Module model, must be in eval mode
* `qconfig_mapping` (required): mapping from model ops to qconfigs::
from torch.quantization import QConfigMapping
qconfig_mapping = QConfigMapping() \
.set_global(global_qconfig) \
.set_object_type(torch.nn.Linear, qconfig1) \
.set_object_type(torch.nn.functional.linear, qconfig1) \
.set_module_name_regex("foo.*bar.*conv[0-9]+", qconfig1) \
.set_module_name_regex("foo.*bar.*", qconfig2) \
.set_module_name_regex("foo.*", qconfig3) \
.set_module_name("module1", qconfig1) \
.set_module_name("module2", qconfig2) \
.set_module_name_object_type_order("module3", torch.nn.functional.linear, 0, qconfig3)
* `example_inputs`: (required) Example inputs for forward function of the model
* `prepare_custom_config`: customization configuration for quantization tool.
See :class:`~torch.ao.quantization.fx.custom_config.PrepareCustomConfig` for more detail::
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
prepare_custom_config = PrepareCustomConfig() \
.set_standalone_module_name("module1", qconfig_mapping, example_inputs, \
child_prepare_custom_config, backend_config_dict) \
.set_standalone_module_class(MyStandaloneModule, qconfig_mapping, example_inputs, \
child_prepare_custom_config, backend_config_dict) \
.set_float_to_observed_mapping(FloatCustomModule, ObservedCustomModule) \
.set_non_traceable_module_names(["module2", "module3"]) \
.set_non_traceable_module_classes([NonTraceableModule1, NonTraceableModule2]) \
.set_input_quantized_indexes([0]) \
.set_output_quantized_indexes([0]) \
.set_preserved_attributes(["attr1", "attr2"])
* `_equalization_config`: config for specifying how to perform equalization on the model
* `backend_config_dict`: a dictionary that specifies how operators are quantized
in a backend, this includes how the operaetors are observed,
supported fusion patterns, how quantize/dequantize ops are
inserted, supported dtypes etc. The structure of the dictionary is still WIP
and will change in the future, please don't use right now.
Return:
A GraphModule with observer (configured by qconfig_mapping), ready for calibration
Example::
import torch
from torch.ao.quantization import get_default_qconfig
from torch.ao.quantization import prepare_fx
float_model.eval()
qconfig = get_default_qconfig('fbgemm')
def calibrate(model, data_loader):
model.eval()
with torch.no_grad():
for image, target in data_loader:
model(image)
qconfig_mapping = QConfigMapping().set_global(qconfig)
example_inputs = (torch.randn(1, 3, 224, 224),)
prepared_model = prepare_fx(float_model, qconfig_mapping, example_inputs)
# Run calibration
calibrate(prepared_model, sample_inference_data)
"""
torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_fx")
return _prepare_fx(
model,
qconfig_mapping,
False, # is_qat
example_inputs,
prepare_custom_config,
_equalization_config,
backend_config_dict,
)
def prepare_qat_fx(
model: torch.nn.Module,
qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
example_inputs: Tuple[Any, ...],
prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
backend_config_dict: Optional[Dict[str, Any]] = None,
) -> ObservedGraphModule:
r""" Prepare a model for quantization aware training
Args:
* `model`: torch.nn.Module model, must be in train mode
* `qconfig_mapping`: see :func:`~torch.ao.quantization.prepare_fx`
* `example_inputs`: see :func:`~torch.ao.quantization.prepare_fx`
* `prepare_custom_config`: see :func:`~torch.ao.quantization.prepare_fx`
* `backend_config_dict`: see :func:`~torch.ao.quantization.prepare_fx`
Return:
A GraphModule with fake quant modules (configured by qconfig_mapping), ready for
quantization aware training
Example::
import torch
from torch.ao.quantization import get_default_qat_qconfig
from torch.ao.quantization import prepare_fx
qconfig = get_default_qat_qconfig('fbgemm')
def train_loop(model, train_data):
model.train()
for image, target in data_loader:
...
float_model.train()
qconfig_mapping = QConfigMapping().set_global(qconfig)
prepared_model = prepare_fx(float_model, qconfig_mapping)
# Run calibration
train_loop(prepared_model, train_loop)
"""
torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_qat_fx")
return _prepare_fx(
model,
qconfig_mapping,
True, # is_qat
example_inputs,
prepare_custom_config,
backend_config_dict=backend_config_dict,
)
def _convert_fx(
graph_module: GraphModule,
is_reference: bool,
convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
is_standalone_module: bool = False,
_remove_qconfig: bool = True,
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
backend_config_dict: Dict[str, Any] = None,
) -> torch.nn.Module:
""" `is_standalone_module`: see docs in :func:`~torch.ao.quantization.prepare_standalone_module_fx`
"""
if convert_custom_config is None:
convert_custom_config = ConvertCustomConfig()
if isinstance(convert_custom_config, Dict):
warnings.warn(
"Passing a convert_custom_config_dict to convert is deprecated and will not be supported "
"in a future version. Please pass in a ConvertCustomConfig instead.")
convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config)
_check_is_graph_module(graph_module)
quantized = convert(
graph_module,
is_reference,
convert_custom_config,
is_standalone_module,
_remove_qconfig_flag=_remove_qconfig,
qconfig_mapping=qconfig_mapping,
backend_config_dict=backend_config_dict,
)
preserved_attributes = convert_custom_config.preserved_attributes
for attr_name in preserved_attributes:
setattr(quantized, attr_name, getattr(graph_module, attr_name))
return quantized
def convert_fx(
graph_module: GraphModule,
convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
_remove_qconfig: bool = True,
qconfig_mapping: Union[QConfigMapping, Dict[str, Any]] = None,
backend_config_dict: Dict[str, Any] = None,
) -> torch.nn.Module:
r""" Convert a calibrated or trained model to a quantized model
Args:
* `graph_module`: A prepared and calibrated/trained model (GraphModule)
* `is_reference`: flag for whether to produce a reference quantized model,
which will be a common interface between pytorch quantization with
other backends like accelerators
* `convert_custom_config`: custom configurations for convert function.
See :class:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig` for more detail::
from torch.ao.quantization.fx.custom_config import ConvertCustomConfig
convert_custom_config = ConvertCustomConfig() \
.set_observed_to_quantized_mapping(ObservedCustomModule, QuantizedCustomModule) \
.set_preserved_attributes(["attr1", "attr2"])
* `_remove_qconfig`: Option to remove the qconfig attributes in the model after convert.
* `qconfig_mapping`: config for specifying how to convert a model for quantization.
The keys must include the ones in the qconfig_mapping passed to `prepare_fx` or `prepare_qat_fx`,
with the same values or `None`. Additional keys can be specified with values set to `None`.
For each entry whose value is set to None, we skip quantizing that entry in the model::
qconfig_mapping = QConfigMapping
.set_global(qconfig_from_prepare)
.set_object_type(torch.nn.functional.add, None) # skip quantizing torch.nn.functional.add
.set_object_type(torch.nn.functional.linear, qconfig_from_prepare)
.set_module_name("foo.bar", None) # skip quantizing module "foo.bar"
* `backend_config_dict`: A configuration for the backend which describes how
operators should be quantized in the backend, this includes quantization
mode support (static/dynamic/weight_only), dtype support (quint8/qint8 etc.),
observer placement for each operators and fused operators. Detailed
documentation can be found in torch/ao/quantization/backend_config/README.md
Return:
A quantized model (GraphModule)
Example::
# prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training
quantized_model = convert_fx(prepared_model)
"""
torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_fx")
return _convert_fx(
graph_module,
is_reference=False,
convert_custom_config=convert_custom_config,
_remove_qconfig=_remove_qconfig,
qconfig_mapping=qconfig_mapping,
backend_config_dict=backend_config_dict,
)
def convert_to_reference(
graph_module: GraphModule,
convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
_remove_qconfig: bool = True,
qconfig_mapping: Union[QConfigMapping, Dict[str, Any]] = None,
backend_config_dict: Dict[str, Any] = None,
) -> torch.nn.Module:
r""" Convert a calibrated or trained model to a reference quantized model, a common interface
between PyTorch quantization with other backends like accelerators. Callers should additionally
lower the returned reference model to the target backend before using the model for inference.
Args:
* `graph_module`: A prepared and calibrated/trained model (GraphModule)
* `convert_custom_config`: custom configurations for convert function.
See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more detail.
* `_remove_qconfig`: Option to remove the qconfig attributes in the model after convert.
* `qconfig_mapping`: config for specifying how to convert a model for quantization.
See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more detail.
* `backend_config_dict`: A configuration for the backend which describes how
operators should be quantized in the backend. See
:func:`~torch.ao.quantization.quantize_fx.convert_fx` for more detail.
Return:
A reference quantized model (GraphModule)
Example::
# prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training
reference_model = convert_to_reference(prepared_model)
"""
torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_to_reference")
return _convert_fx(
graph_module,
is_reference=True,
convert_custom_config=convert_custom_config,
_remove_qconfig=_remove_qconfig,
qconfig_mapping=qconfig_mapping,
backend_config_dict=backend_config_dict,
)
def _convert_standalone_module_fx(
graph_module: GraphModule,
is_reference: bool = False,
convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
) -> torch.nn.Module:
r""" [Internal use only] Convert a model produced by :func:`~torch.ao.quantization.prepare_standalone_module_fx`
and convert it to a quantized model
Returns a quantized standalone module, whether input/output is quantized is
specified by prepare_custom_config, with
input_quantized_idxs, output_quantized_idxs, please
see docs for prepare_fx for details
"""
return _convert_fx(
graph_module,
is_reference,
convert_custom_config,
is_standalone_module=True,
)