| # mypy: allow-untyped-defs |
| """Functions to verify exported ONNX model is functionally equivalent to original PyTorch model. |
| |
| ONNX Runtime is required, and is used as the ONNX backend for export verification. |
| """ |
| |
| from __future__ import annotations |
| |
| import contextlib |
| import copy |
| import dataclasses |
| import datetime |
| import difflib |
| import enum |
| import functools |
| import io |
| import itertools |
| import os |
| import tempfile |
| import warnings |
| from typing import Any, Callable, Collection, Mapping, Sequence, Tuple, Union |
| |
| import numpy as np |
| |
| import torch |
| import torch._C._onnx as _C_onnx |
| from torch import _C |
| from torch.onnx import _constants, _experimental, _exporter_states, utils |
| from torch.onnx._globals import GLOBALS |
| from torch.onnx._internal import onnx_proto_utils |
| from torch.types import Number |
| |
| |
| _ORT_PROVIDERS = ("CPUExecutionProvider",) |
| |
| _NumericType = Union[Number, torch.Tensor, np.ndarray] |
| _ModelType = Union[torch.nn.Module, torch.jit.ScriptModule] |
| _InputArgsType = Union[torch.Tensor, Tuple[Any, ...]] |
| _InputKwargsType = Mapping[str, Any] |
| _OutputsType = Union[Sequence[_NumericType], Sequence] |
| |
| |
| class OnnxBackend(enum.Enum): |
| """Enum class for ONNX backend used for export verification.""" |
| |
| REFERENCE = "ONNXReferenceEvaluator" |
| ONNX_RUNTIME_CPU = "CPUExecutionProvider" |
| ONNX_RUNTIME_CUDA = "CUDAExecutionProvider" |
| |
| |
| @dataclasses.dataclass |
| class VerificationOptions: |
| """Options for ONNX export verification. |
| |
| Attributes: |
| flatten: If True, unpack nested list/tuple/dict inputs into a flattened list of |
| Tensors for ONNX. Set this to False if nested structures are to be preserved |
| for ONNX, which is usually the case with exporting ScriptModules. Default True. |
| ignore_none: Whether to ignore None type in torch output, which is usually the |
| case with tracing. Set this to False, if torch output should keep None type, |
| which is usually the case with exporting ScriptModules. Default to True. |
| check_shape: Whether to check the shapes between PyTorch and ONNX Runtime outputs |
| are exactly the same. Set this to False to allow output shape broadcasting. |
| Default to True. |
| check_dtype: Whether to check the dtypes between PyTorch and ONNX Runtime outputs |
| are consistent. Default to True. |
| backend: ONNX backend for verification. Default to OnnxBackend.ONNX_RUNTIME_CPU. |
| rtol: relative tolerance in comparison between ONNX and PyTorch outputs. |
| atol: absolute tolerance in comparison between ONNX and PyTorch outputs. |
| remained_onnx_input_idx: If provided, only the specified inputs will be passed |
| to the ONNX model. Supply a list when there are unused inputs in the model. |
| Since unused inputs will be removed in the exported ONNX model, supplying |
| all inputs will cause an error on unexpected inputs. This parameter tells |
| the verifier which inputs to pass into the ONNX model. |
| acceptable_error_percentage: acceptable percentage of element mismatches in comparison. |
| It should be a float of value between 0.0 and 1.0. |
| """ |
| |
| flatten: bool = True |
| ignore_none: bool = True |
| check_shape: bool = True |
| check_dtype: bool = True |
| backend: OnnxBackend = OnnxBackend.ONNX_RUNTIME_CPU |
| rtol: float = 1e-3 |
| atol: float = 1e-7 |
| remained_onnx_input_idx: Sequence[int] | None = None |
| acceptable_error_percentage: float | None = None |
| |
| |
| def _flatten_tuples(elem): |
| flattened = [] |
| for t in elem: |
| if isinstance(t, tuple): |
| flattened.extend(_flatten_tuples(t)) |
| else: |
| flattened.append(t) |
| return flattened |
| |
| |
| # TODO(justinchuby): Add type checking by narrowing down the return type when input is None |
| def _to_numpy(elem) -> list | np.ndarray: |
| if isinstance(elem, torch.Tensor): |
| if elem.requires_grad: |
| return elem.detach().cpu().numpy() |
| else: |
| return elem.cpu().numpy() |
| elif isinstance(elem, (list, tuple)): |
| return [_to_numpy(inp) for inp in elem] |
| elif isinstance(elem, (bool, int, float)): |
| return np.array(elem) |
| elif isinstance(elem, dict): |
| flattened = [] |
| for k in elem: |
| flattened.extend([_to_numpy(k), _to_numpy(elem[k])]) |
| return flattened |
| return elem |
| |
| |
| def _inline_flatten_list(inputs, res_list) -> list: |
| for i in inputs: |
| res_list.append(i) if not isinstance( |
| i, (list, tuple) |
| ) else _inline_flatten_list(i, res_list) |
| return res_list |
| |
| |
| def _unpack_to_numpy(values, cast_onnx_accepted=True) -> list: |
| value_unpacked = [] |
| for value in values: |
| value_unpacked.extend( |
| utils.unpack_quantized_tensor(value, cast_onnx_accepted=cast_onnx_accepted) |
| ) |
| return [_to_numpy(v) for v in value_unpacked] |
| |
| |
| def _run_onnx(onnx_session, inputs) -> _OutputsType: |
| kw_inputs = {} |
| if inputs and isinstance(inputs[-1], dict): |
| kw_inputs = inputs[-1] |
| inputs = inputs[:-1] |
| inputs = _unpack_to_numpy(_flatten_tuples(inputs)) |
| ort_inputs = {} |
| for input_name, input in kw_inputs.items(): |
| ort_inputs[input_name] = _to_numpy(input) |
| inputs = _to_numpy(inputs) |
| if hasattr(onnx_session, "get_inputs"): |
| # onnxruntime.InferenceSession |
| input_names = [i.name for i in onnx_session.get_inputs()] |
| elif hasattr(onnx_session, "input_names"): |
| # onnx.reference.ReferenceEvaluator |
| input_names = onnx_session.input_names |
| else: |
| raise ValueError(f"Unknown ONNX backend type: {type(onnx_session)}.") |
| |
| for i, input in enumerate(inputs): |
| if i == len(input_names) or input_names[i] in ort_inputs: |
| raise ValueError( |
| f"got too many positional inputs. inputs: {inputs}. kw_inputs: {kw_inputs}. " |
| f"input names: {input_names}." |
| ) |
| ort_inputs[input_names[i]] = input |
| onnx_outs = onnx_session.run(None, ort_inputs) |
| return onnx_outs |
| |
| |
| def _ort_session( |
| model: str | io.BytesIO, ort_providers: Sequence[str] = _ORT_PROVIDERS |
| ): |
| try: |
| import onnxruntime # type: ignore[import] |
| except ImportError as e: |
| raise ImportError("onnxruntime is required for export verification.") from e |
| |
| if ort_providers is None: |
| ort_providers = _ORT_PROVIDERS |
| |
| session_options = onnxruntime.SessionOptions() |
| # suppress ort warnings. |
| # 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2. |
| session_options.log_severity_level = 3 |
| ort_session = onnxruntime.InferenceSession( |
| model if isinstance(model, str) else model.getvalue(), |
| session_options, |
| providers=ort_providers, |
| ) |
| return ort_session |
| |
| |
| def _onnx_reference_evaluator_session(model: str | io.BytesIO): |
| try: |
| import onnx |
| from onnx import reference as onnx_reference # type: ignore[attr-defined] |
| except ImportError as exc: |
| raise ImportError("onnx >= 1.13 is required for reference evaluator.") from exc |
| |
| proto = ( |
| onnx.load(model) # type: ignore[attr-defined] |
| if isinstance(model, str) |
| else onnx.load_model_from_string(model.getvalue()) # type: ignore[attr-defined] |
| ) |
| onnx_session = onnx_reference.ReferenceEvaluator(proto) |
| return onnx_session |
| |
| |
| def _onnx_backend_session(model: str | io.BytesIO, backend: OnnxBackend): |
| if backend == OnnxBackend.REFERENCE: |
| onnx_session = _onnx_reference_evaluator_session(model) |
| elif backend in {OnnxBackend.ONNX_RUNTIME_CPU, OnnxBackend.ONNX_RUNTIME_CUDA}: |
| onnx_session = _ort_session(model, (backend.value,)) |
| else: |
| raise ValueError(f"Unsupported backend: {backend}") |
| return onnx_session |
| |
| |
| def _compare_onnx_pytorch_outputs_in_np( |
| onnx_outs: _OutputsType, |
| pt_outs: _OutputsType, |
| options: VerificationOptions, |
| ): |
| assert ( |
| len(onnx_outs) == len(pt_outs) |
| ), f"Number of outputs differ ONNX runtime: ({len(onnx_outs)}) PyTorch: ({len(pt_outs)})" |
| acceptable_error_percentage = options.acceptable_error_percentage |
| if acceptable_error_percentage and ( |
| acceptable_error_percentage > 1.0 or acceptable_error_percentage < 0.0 |
| ): |
| raise ValueError( |
| "If set, acceptable_error_percentage should be between 0.0 and 1.0" |
| ) |
| |
| for ort_out, pt_out in zip(onnx_outs, pt_outs): |
| try: |
| # TODO: Remove `check_shape` option once every shape inconsistent issue is addressed. |
| if not options.check_shape: |
| # Allow different but broadcastable output shapes. |
| ort_out, pt_out = np.broadcast_arrays(ort_out, pt_out) |
| torch.testing.assert_close( |
| ort_out, |
| pt_out, |
| rtol=options.rtol, |
| atol=options.atol, |
| check_dtype=options.check_dtype, |
| equal_nan=True, |
| ) |
| except AssertionError as e: |
| if acceptable_error_percentage: |
| error_percentage = 1 - np.sum( |
| np.isclose(ort_out, pt_out, rtol=options.rtol, atol=options.atol) |
| ) / np.prod(ort_out.shape) |
| if error_percentage <= acceptable_error_percentage: |
| warnings.warn( |
| f"Suppressed AssertionError:\n{e}.\n" |
| f"Error percentage {error_percentage} " |
| f"within acceptable range {acceptable_error_percentage}." |
| ) |
| continue |
| if ort_out.dtype == np.uint8 or ort_out.dtype == np.int8: |
| warnings.warn("ONNX output is quantized") |
| if pt_out.dtype == np.uint8 or pt_out.dtype == np.int8: |
| warnings.warn("PyTorch output is quantized") |
| raise |
| |
| |
| def _compare_onnx_pytorch_outputs( |
| onnx_outs: _OutputsType, |
| pt_outs: Any, |
| options: VerificationOptions, |
| ): |
| """ |
| Compare ONNX and PyTorch outputs. |
| |
| Args: |
| onnx_outs: outputs from ONNX backend. |
| pt_outs: outputs from PyTorch. |
| options: options for verification. |
| |
| Raises: |
| AssertionError: if outputs from ONNX model and PyTorch model are not |
| equal up to specified precision. |
| ValueError: if arguments provided are invalid. |
| """ |
| if options.ignore_none: |
| # torch.jit._flatten filters None type |
| pt_outs, _ = torch.jit._flatten(pt_outs) |
| else: |
| pt_outs = _inline_flatten_list([pt_outs], []) |
| pt_outs_np = _unpack_to_numpy(pt_outs, cast_onnx_accepted=False) |
| onnx_outs = _inline_flatten_list(onnx_outs, []) |
| _compare_onnx_pytorch_outputs_in_np(onnx_outs, pt_outs_np, options) |
| |
| |
| def _prepare_input_for_pytorch(args, kwargs): |
| """Prepare input for PyTorch model execution. |
| |
| Any future changes/formatting to the input before dispatching to the PyTorch |
| model should be made in this function. |
| |
| Args: |
| args: positional arguments for PyTorch model forward method. |
| kwargs: keyword arguments for PyTorch model forward method. |
| |
| Returns: |
| args: positional arguments for PyTorch model forward method. |
| kwargs: keyword arguments for PyTorch model forward method. |
| """ |
| if isinstance(args, (torch.Tensor, dict)): |
| args = (args,) |
| # In-place operators will update input tensor data as well. |
| # Thus inputs are replicated before every forward call. |
| args = copy.deepcopy(args) |
| if kwargs: |
| kwargs = copy.deepcopy(kwargs) |
| else: |
| kwargs = {} |
| return args, kwargs |
| |
| |
| def _prepare_input_for_export(args, kwargs): |
| """Prepare input for ONNX model export. |
| |
| Any future changes/formatting to the input before dispatching to the |
| :func:`torch.onnx.export` api should be made in this function. |
| |
| Args: |
| args: positional arguments for PyTorch model forward method. |
| kwargs: keyword arguments for PyTorch model forward method. |
| |
| Returns: |
| onnx_inputs: positional arguments for ONNX model export, as `args` in |
| :func:`torch.onnx.export`. |
| """ |
| args, kwargs = _prepare_input_for_pytorch(args, kwargs) |
| if not kwargs and len(args) > 0 and isinstance(args[-1], dict): |
| onnx_inputs = args + ({},) |
| elif kwargs: |
| onnx_inputs = args + (kwargs,) |
| else: |
| onnx_inputs = args |
| return onnx_inputs |
| |
| |
| def _prepare_input_for_onnx( |
| args, kwargs, remained_onnx_input_idx: Sequence[int] | None, flatten: bool |
| ): |
| """Prepare input for ONNX model execution in ONNX backend. |
| |
| Any future changes/formatting to the input before dispatching to the ONNX backend |
| run should be made in this function. |
| |
| Args: |
| args: positional arguments for PyTorch model forward method. |
| kwargs: keyword arguments for PyTorch model forward method. |
| remained_onnx_input_idx: indices of inputs to be used for ONNX model execution. |
| flatten: whether to flatten the input before dispatching to the ONNX model execution. |
| |
| Returns: |
| onnx_inputs: positional arguments for ONNX model execution in ONNX backend. |
| """ |
| onnx_inputs = _prepare_input_for_export(args, kwargs) |
| if flatten: |
| onnx_inputs, _ = torch.jit._flatten(onnx_inputs) |
| elif onnx_inputs and onnx_inputs[-1] == {}: |
| # Handle empty kwargs (normally removed by flatten). |
| onnx_inputs = onnx_inputs[:-1] |
| if remained_onnx_input_idx is not None: |
| return [onnx_inputs[i] for i in remained_onnx_input_idx] |
| else: |
| return onnx_inputs |
| |
| |
| def _try_clone_model(model): |
| """Used for preserving original model in case forward mutates model states.""" |
| try: |
| return copy.deepcopy(model) |
| except Exception: |
| warnings.warn( |
| "Failed to clone model. Model state might be mutated during verification." |
| ) |
| return model |
| |
| |
| def _compare_onnx_pytorch_model( |
| pt_model: _ModelType, |
| onnx_model_f: str | io.BytesIO, |
| input_args: _InputArgsType, |
| input_kwargs: _InputKwargsType | None, |
| additional_test_inputs: Sequence[_InputArgsType] | None, |
| options: VerificationOptions, |
| ): |
| """Compare outputs from ONNX model runs with outputs from PyTorch model runs. |
| |
| Args: |
| pt_model: PyTorch model. |
| onnx_model_f: ONNX model file path or file-like object. |
| input_args: positional arguments for PyTorch model forward method. |
| input_kwargs: keyword arguments for PyTorch model forward method. |
| additional_test_inputs: additional positional arguments for PyTorch model |
| forward method. |
| options: options for verification. |
| |
| Raises: |
| AssertionError: if outputs from ONNX model and PyTorch model are not |
| equal up to specified precision. |
| """ |
| onnx_session = _onnx_backend_session(onnx_model_f, options.backend) |
| |
| def compare_onnx_pytorch_model_with_input(input_args, input_kwargs): |
| pt_args, pt_kwargs = _prepare_input_for_pytorch(input_args, input_kwargs) |
| # TODO: remove this and treat mutating model separately. See #77679 |
| pt_model_copy = _try_clone_model(pt_model) |
| pt_outs = pt_model_copy(*pt_args, **pt_kwargs) |
| |
| onnx_inputs = _prepare_input_for_onnx( |
| input_args, input_kwargs, options.remained_onnx_input_idx, options.flatten |
| ) |
| |
| onnx_outs = _run_onnx(onnx_session, onnx_inputs) |
| |
| _compare_onnx_pytorch_outputs( |
| onnx_outs=onnx_outs, |
| pt_outs=pt_outs, |
| options=options, |
| ) |
| |
| compare_onnx_pytorch_model_with_input(input_args, input_kwargs) |
| |
| if additional_test_inputs: |
| for test_input_args in additional_test_inputs: |
| compare_onnx_pytorch_model_with_input(test_input_args, {}) |
| |
| |
| class _GraphDiff: |
| """A class to represent the difference between two graphs.""" |
| |
| def __init__(self, graph_a: _C.Graph, graph_b: _C.Graph): |
| """Construct a _GraphDiff object. |
| |
| Args: |
| graph_a (_C.Graph): First graph to compare. |
| graph_b (_C.Graph): Second graph to compare. |
| """ |
| self.graph_a = graph_a |
| self.graph_b = graph_b |
| |
| def __str__(self): |
| """See function :func:`diff_report`.""" |
| return self.diff_report() |
| |
| def _indent(self, lines: str) -> str: |
| return "\n".join(["\t" + line for line in lines.splitlines()]) |
| |
| def diff_report(self) -> str: |
| """Return a string representation of the graph difference. |
| |
| The report shows the first pair of nodes that diverges. It also shows the source |
| location of the pair of nodes. |
| |
| Returns: |
| graph_diff_report (str): A string representation of the graph difference. |
| """ |
| graph_a = self.graph_a |
| graph_b = self.graph_b |
| |
| graph_a_str = str(graph_a) |
| graph_b_str = str(graph_b) |
| |
| if graph_a_str == graph_b_str: |
| return "" |
| |
| graph_diff = difflib.ndiff( |
| graph_a_str.splitlines(True), graph_b_str.splitlines(True) |
| ) |
| graph_diff_report = ["Graph diff:", self._indent("".join(graph_diff))] |
| |
| for node_a, node_b in itertools.zip_longest(graph_a.nodes(), graph_b.nodes()): |
| if str(node_a) != str(node_b): |
| graph_diff_report.append("First diverging operator:") |
| node_diff = difflib.ndiff( |
| str(node_a).splitlines(True), str(node_b).splitlines(True) |
| ) |
| source_printout = ["node diff:", self._indent("".join(node_diff))] |
| |
| stack_a = node_a.sourceRange() if node_a else None |
| if stack_a: |
| source_printout.extend( |
| ["Former source location:", self._indent(str(stack_a))] |
| ) |
| stack_b = node_b.sourceRange() if node_b else None |
| if stack_b: |
| source_printout.extend( |
| ["Latter source location:", self._indent(str(stack_b))] |
| ) |
| |
| graph_diff_report.extend(source_printout) |
| |
| break |
| |
| return "\n".join(graph_diff_report) |
| |
| |
| def _check_graph_diff( |
| model: torch.nn.Module | torch.jit.ScriptModule, |
| test_input_groups: Sequence[tuple[tuple[Any, ...], Mapping[str, Any]]], |
| export_options: _experimental.ExportOptions, |
| model_to_graph_func: Callable[ |
| [ |
| torch.nn.Module, |
| tuple[Any, ...], |
| Mapping[str, Any], |
| _experimental.ExportOptions, |
| ], |
| _C.Graph, |
| ], |
| ) -> str: |
| """Check if graph produced by `model_to_graph_func` is the same across `test_input_groups`. |
| |
| Args: |
| model: See :func:`check_export_model_diff`. |
| test_input_groups: See :func:`check_export_model_diff`. |
| export_options: See :func:`check_export_model_diff`. |
| model_to_graph_func: A function to convert a PyTorch model to a JIT IR graph. |
| |
| Returns: |
| graph_diff_report (str): A string representation of the graph difference. |
| """ |
| if len(test_input_groups) < 2: |
| raise ValueError("Need at least two groups of test inputs to compare.") |
| |
| ref_jit_graph = None |
| for args, kwargs in test_input_groups: |
| jit_graph = model_to_graph_func(model, args, kwargs, export_options) |
| if ref_jit_graph is None: |
| ref_jit_graph = jit_graph |
| continue |
| |
| graph_diff_report = _GraphDiff(ref_jit_graph, jit_graph).diff_report() |
| if graph_diff_report: |
| return graph_diff_report |
| return "" |
| |
| |
| def _traced_graph_from_model( |
| model: torch.nn.Module | torch.jit.ScriptModule, |
| args: tuple[Any, ...], |
| kwargs: Mapping[str, Any], |
| export_options: _experimental.ExportOptions, |
| ) -> _C.Graph: |
| """As part of the ONNX export steps, create a traced JIT graph from a PyTorch model. |
| |
| Args: |
| model: See :func:`check_export_model_diff`. |
| args: See :func:`check_export_model_diff`. |
| kwargs: See :func:`check_export_model_diff`. |
| export_options: See :func:`check_export_model_diff`. |
| |
| Returns: |
| jit_graph (_C.Graph): A traced JIT graph. |
| """ |
| training = export_options.training |
| verbose = export_options.verbose |
| |
| with utils.exporter_context(model, training, verbose): |
| export_inputs = _prepare_input_for_export(args, kwargs) |
| model = utils._pre_trace_quant_model(model, export_inputs) |
| jit_graph, _, _, _ = utils._create_jit_graph(model, export_inputs) |
| return jit_graph |
| |
| |
| def _onnx_graph_from_model( |
| model: torch.nn.Module | torch.jit.ScriptModule, |
| args: tuple[Any, ...], |
| kwargs: Mapping[str, Any], |
| export_options: _experimental.ExportOptions, |
| ) -> _C.Graph: |
| """As part of the ONNX export steps, export an ONNX JIT graph from a PyTorch model. |
| |
| Args: |
| model: See :func:`check_export_model_diff`. |
| args: See :func:`check_export_model_diff`. |
| kwargs: See :func:`check_export_model_diff`. |
| export_options: See :func:`check_export_model_diff`. |
| |
| Returns: |
| onnx_graph (_C.Graph): An ONNX JIT graph. |
| """ |
| # TODO: refactor utils.py to remove duplicated code of context setup. See #78834 |
| opset_version = export_options.opset_version |
| operator_export_type = export_options.operator_export_type |
| export_modules_as_functions = export_options.export_modules_as_functions |
| training = export_options.training |
| verbose = export_options.verbose |
| dynamic_axes = export_options.dynamic_axes |
| input_names = export_options.input_names |
| output_names = export_options.output_names |
| |
| if opset_version is None: |
| opset_version = _constants.ONNX_DEFAULT_OPSET |
| |
| utils._setup_trace_module_map(model, export_modules_as_functions) |
| |
| if not operator_export_type: |
| operator_export_type = _C_onnx.OperatorExportTypes.ONNX |
| |
| GLOBALS.export_onnx_opset_version = opset_version |
| GLOBALS.operator_export_type = operator_export_type |
| |
| with utils.exporter_context(model, training, verbose): |
| do_constant_folding = utils._decide_constant_folding( |
| export_options.do_constant_folding, operator_export_type, training |
| ) |
| |
| if dynamic_axes is None: |
| dynamic_axes = {} |
| utils._validate_dynamic_axes(dynamic_axes, model, input_names, output_names) |
| |
| export_inputs = _prepare_input_for_export(args, kwargs) |
| export_inputs = utils._decide_input_format(model, export_inputs) |
| onnx_graph, _, _ = utils._model_to_graph( |
| model, |
| export_inputs, |
| verbose, |
| input_names, |
| output_names, |
| operator_export_type, |
| do_constant_folding, |
| training=training, |
| dynamic_axes=dynamic_axes, |
| ) |
| |
| return onnx_graph |
| |
| |
| def _onnx_graph_from_aten_graph( |
| graph: torch.Graph, |
| export_options: _experimental.ExportOptions, |
| params_dict: dict[str, Any] | None = None, |
| ) -> tuple[torch.Graph, dict[str, Any]]: |
| if params_dict is None: |
| params_dict = {} |
| operator_export_type = export_options.operator_export_type |
| dynamic_axes = export_options.dynamic_axes or {} |
| input_names = export_options.input_names |
| training = export_options.training |
| do_constant_folding = export_options.do_constant_folding |
| opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET |
| |
| GLOBALS.export_onnx_opset_version = opset_version |
| GLOBALS.operator_export_type = operator_export_type |
| |
| do_constant_folding = utils._decide_constant_folding( |
| do_constant_folding, operator_export_type, training |
| ) |
| |
| # TODO: Below is doing aten graph to onnx. It should be abstracted as a |
| # function in torch/onnx/utils.py. |
| graph = graph.copy() |
| graph = utils._optimize_graph( |
| graph, |
| operator_export_type, |
| params_dict=params_dict, |
| dynamic_axes=dynamic_axes, |
| input_names=input_names, |
| ) |
| |
| if training is None or training == _C_onnx.TrainingMode.EVAL: |
| params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict) |
| |
| if ( |
| do_constant_folding |
| and opset_version >= _constants.ONNX_CONSTANT_FOLDING_MIN_OPSET |
| ): |
| params_dict = _C._jit_pass_onnx_constant_fold(graph, params_dict, opset_version) |
| _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) |
| |
| if GLOBALS.onnx_shape_inference: |
| _C._jit_pass_onnx_graph_shape_type_inference(graph, params_dict, opset_version) |
| |
| params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict) |
| |
| # For ONNX opset < 9, constants only have three data types: float16, float, double. |
| # In this pass transform constants of other data types to float/double + cast operator. |
| if opset_version < 9: |
| _C._jit_pass_onnx_cast_all_constant_to_floating(graph) |
| |
| params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict) |
| _C._jit_decay_packed_param_input_types(graph) |
| |
| _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) |
| |
| if export_options.verbose: |
| print("ONNX graph: ", graph) |
| |
| return graph, params_dict |
| |
| |
| def _onnx_proto_from_onnx_graph( |
| onnx_graph: torch.Graph, |
| export_options: _experimental.ExportOptions, |
| params_dict: dict[str, Any], |
| ) -> tuple[bytes, Mapping[str, bytes]]: |
| opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET |
| dynamic_axes = export_options.dynamic_axes or {} |
| operator_export_type = export_options.operator_export_type |
| val_keep_init_as_ip = utils._decide_keep_init_as_input( |
| export_options.keep_initializers_as_inputs, |
| operator_export_type, |
| opset_version, |
| ) |
| val_add_node_names = utils._decide_add_node_names(True, operator_export_type) |
| custom_opsets = export_options.custom_opsets or {} |
| |
| proto, export_map, _, _ = onnx_graph._export_onnx( # type: ignore[attr-defined] |
| params_dict, |
| opset_version, |
| dynamic_axes, |
| False, |
| operator_export_type, |
| not export_options.verbose, |
| val_keep_init_as_ip, |
| custom_opsets, |
| val_add_node_names, |
| "", |
| {}, |
| ) |
| |
| return proto, export_map |
| |
| |
| def check_export_model_diff( |
| model: torch.nn.Module | torch.jit.ScriptModule, |
| test_input_groups: Sequence[tuple[tuple[Any, ...], Mapping[str, Any]]], |
| export_options: _experimental.ExportOptions | None = None, |
| ) -> str: |
| """Verify exported model discrepancy between different groups of inputs. |
| |
| A graph is exported for each group of inputs. The exported graphs are then compared |
| to each other, and discrepancies of first pair of nodes are reported. This function |
| first checks the jit graph. If no discrepancies were found, it then checks the onnx |
| graph. |
| |
| Unless otherwise specified, the jit/ONNX graph is expected to be the same, regardless |
| of the inputs used for exporting. A discrepancy implies the graph exported is |
| not accurate when run on other groups of inputs, which will typically results in |
| runtime errors or mismatching output. |
| |
| Args: |
| model (torch.nn.Module or torch.jit.ScriptModule): The model to be exported. |
| test_input_groups (Sequence[Tuple[Tuple[Any, ...], Mapping[str, Any]]]): A sequence |
| of input groups to be used to export the model. Each input group is a pair of |
| (args, kwargs). |
| export_options (_experimental.ExportOptions, optional): An _experimental.ExportOptions |
| object that controls the export behavior. |
| |
| Returns: |
| str: A string containing the diff of the exported models. |
| """ |
| export_options = ( |
| _experimental.ExportOptions() if export_options is None else export_options |
| ) |
| |
| jit_diff_report = _check_graph_diff( |
| model, test_input_groups, export_options, _traced_graph_from_model |
| ) |
| if jit_diff_report: |
| return jit_diff_report |
| |
| return _check_graph_diff( |
| model, test_input_groups, export_options, _onnx_graph_from_model |
| ) |
| |
| |
| def verify( |
| model: _ModelType, |
| input_args: _InputArgsType, |
| input_kwargs: _InputKwargsType | None = None, |
| do_constant_folding: bool = True, |
| dynamic_axes: Mapping[str, Mapping[int, str] | Mapping[str, Sequence[int]]] |
| | None = None, |
| input_names: Sequence[str] | None = None, |
| output_names: Sequence[str] | None = None, |
| training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, |
| opset_version: int | None = None, |
| keep_initializers_as_inputs: bool = True, |
| verbose: bool = False, |
| fixed_batch_size: bool = False, |
| use_external_data: bool = False, |
| additional_test_inputs: Sequence[_InputArgsType] | None = None, |
| options: VerificationOptions | None = None, |
| ): |
| """Verify model export to ONNX against original PyTorch model. |
| |
| Args: |
| model (torch.nn.Module or torch.jit.ScriptModule): See :func:`torch.onnx.export`. |
| input_args (tuple): See :func:`torch.onnx.export`. |
| input_kwargs (dict): See :func:`torch.onnx.export`. |
| do_constant_folding (bool, optional): See :func:`torch.onnx.export`. |
| dynamic_axes (dict, optional): See :func:`torch.onnx.export`. |
| input_names (list, optional): See :func:`torch.onnx.export`. |
| output_names (list, optional): See :func:`torch.onnx.export`. |
| training (torch.onnx.TrainingMode): See :func:`torch.onnx.export`. |
| opset_version (int, optional): See :func:`torch.onnx.export`. |
| keep_initializers_as_inputs (bool, optional): See :func:`torch.onnx.export`. |
| verbose (bool, optional): See :func:`torch.onnx.export`. |
| fixed_batch_size (bool, optional): Legacy argument, used only by rnn test cases. |
| use_external_data (bool, optional): Explicitly specify whether to export the |
| model with external data. |
| additional_test_inputs (list, optional): List of tuples. Each tuple is a group of |
| input arguments to test. Currently only *args are supported. |
| options (_VerificationOptions, optional): A _VerificationOptions object that |
| controls the verification behavior. |
| |
| Raises: |
| AssertionError: if outputs from ONNX model and PyTorch model are not |
| equal up to specified precision. |
| ValueError: if arguments provided are invalid. |
| """ |
| if options is None: |
| options = VerificationOptions() |
| |
| if training == torch.onnx.TrainingMode.TRAINING: |
| model.train() |
| elif training == torch.onnx.TrainingMode.EVAL: |
| model.eval() |
| with torch.no_grad(), contextlib.ExitStack() as stack: |
| model_f: str | io.BytesIO = io.BytesIO() |
| if use_external_data: |
| tmpdir_path = stack.enter_context(tempfile.TemporaryDirectory()) |
| model_f = os.path.join(tmpdir_path, "model.onnx") |
| |
| inputs_for_export = _prepare_input_for_export(input_args, input_kwargs) |
| |
| # TODO(#77679): remove this and treat mutating model separately. |
| model_copy = _try_clone_model(model) |
| utils._export( |
| model, |
| inputs_for_export, |
| model_f, |
| opset_version=opset_version, |
| do_constant_folding=do_constant_folding, |
| keep_initializers_as_inputs=keep_initializers_as_inputs, |
| dynamic_axes=dynamic_axes, |
| input_names=input_names, |
| output_names=output_names, |
| fixed_batch_size=fixed_batch_size, |
| training=training, |
| verbose=verbose, |
| ) |
| |
| _compare_onnx_pytorch_model( |
| pt_model=model_copy, |
| onnx_model_f=model_f, |
| input_args=input_args, |
| input_kwargs=input_kwargs, |
| additional_test_inputs=additional_test_inputs, |
| options=options, |
| ) |
| |
| |
| def verify_aten_graph( |
| graph: torch.Graph, |
| input_args: tuple[Any, ...], |
| export_options: _experimental.ExportOptions, |
| params_dict: dict[str, Any] | None = None, |
| verification_options: VerificationOptions | None = None, |
| ) -> tuple[AssertionError | None, torch.Graph, _OutputsType, _OutputsType]: |
| if verification_options is None: |
| verification_options = VerificationOptions() |
| if params_dict is None: |
| params_dict = {} |
| |
| original_jit_graph = graph |
| graph = graph.copy() |
| |
| # Execute aten graph and get reference torch jit outputs. |
| graph_inputs = list(graph.inputs()) |
| jit_inputs = tuple([arg for arg in input_args if arg is not None]) |
| weights = [params_dict[v.debugName()] for v in graph_inputs[len(jit_inputs) :]] |
| assert all(w is not None for w in weights) |
| # TODO: Only copy the argument if mutation is detected in Graph. |
| jit_inputs = copy.deepcopy(jit_inputs) |
| jit_input_and_parameters = jit_inputs + tuple(weights) |
| jit_outs = torch._C._jit_interpret_graph(graph, jit_input_and_parameters) # type: ignore[attr-defined] |
| if not isinstance(jit_outs, (list, tuple)): |
| jit_outs = [jit_outs] |
| |
| # Convert aten graph to onnx graph. |
| graph, onnx_params_dict = _onnx_graph_from_aten_graph( |
| graph, export_options, params_dict |
| ) |
| |
| proto, export_map = _onnx_proto_from_onnx_graph( |
| graph, export_options, onnx_params_dict |
| ) |
| model_f: str | io.BytesIO = io.BytesIO() |
| export_type = _exporter_states.ExportTypes.PROTOBUF_FILE |
| onnx_proto_utils._export_file(proto, model_f, export_type, export_map) |
| |
| # NOTE: Verification is unstable. Try catch to emit information for debugging. |
| try: |
| # NOTE: Input might be dce'ed, so we need to remove those from the input args. |
| new_input_names = {v.debugName() for v in graph.inputs()} |
| new_input_args = [] |
| for v, arg in zip(original_jit_graph.inputs(), input_args): |
| if v.debugName() in new_input_names: |
| new_input_args.append(arg) |
| input_args = tuple(new_input_args) |
| |
| onnx_inputs = _prepare_input_for_onnx( |
| input_args, |
| {}, |
| verification_options.remained_onnx_input_idx, |
| verification_options.flatten, |
| ) |
| |
| onnx_session = _onnx_backend_session(model_f, verification_options.backend) |
| onnx_outs = _run_onnx(onnx_session, onnx_inputs) |
| del onnx_session # To free device memory |
| |
| try: |
| _compare_onnx_pytorch_outputs( |
| onnx_outs=onnx_outs, |
| pt_outs=jit_outs, |
| options=verification_options, |
| ) |
| except AssertionError as e: |
| return e, graph, jit_outs, onnx_outs |
| |
| return None, graph, jit_outs, onnx_outs |
| |
| except Exception as e: |
| print("Unexpected error during verification.") |
| print("jit graph: ", original_jit_graph) |
| print("onnx graph: ", graph) |
| raise e |
| |
| |
| class GraphInfoPrettyPrinter: |
| graph_info: GraphInfo | None |
| upper_printer: GraphInfoPrettyPrinter | None |
| lower_printer: GraphInfoPrettyPrinter | None |
| |
| graph_str_lambdas: Mapping[int, str] |
| connector_str_lambdas: Mapping[int, str] |
| children_str_lambdas: Mapping[int, str] |
| |
| def __init__(self, graph_info: GraphInfo | None): |
| self.graph_info = graph_info |
| if ( |
| graph_info is not None |
| and graph_info.upper_graph_info is not None |
| and graph_info.lower_graph_info is not None |
| ): |
| self.upper_printer = GraphInfoPrettyPrinter(graph_info.upper_graph_info) |
| self.lower_printer = GraphInfoPrettyPrinter(graph_info.lower_graph_info) |
| else: |
| self.upper_printer = None |
| self.lower_printer = None |
| |
| def _total_rows(self) -> int: |
| if self.graph_info is None: |
| return 1 |
| if self.upper_printer and self.lower_printer: |
| return ( |
| self.upper_printer._total_rows() + self.lower_printer._total_rows() + 1 |
| ) |
| return 2 # Two lines: node count + id. |
| |
| def _node_count_segment_str(self) -> str: |
| if self.graph_info is None: |
| return "..." |
| node_count = self.graph_info.essential_node_count() |
| has_mismatch = self.graph_info.has_mismatch() |
| error_node_kind = ( |
| f"({self.graph_info.essential_node_kinds().pop()})" |
| if node_count == 1 and has_mismatch |
| else "" |
| ) |
| |
| return f"{node_count} {'X' if has_mismatch else chr(0x2713)} {error_node_kind}" |
| |
| def _graph_id_segment_str(self) -> str: |
| if self.graph_info is None: |
| return "" |
| return f"id: {self.graph_info.id}" |
| |
| def _max_segment_columns(self) -> int: |
| return max( |
| map(len, (self._node_count_segment_str(), self._graph_id_segment_str())) |
| ) |
| |
| def _graph_segment_str_at_line(self, line: int) -> str: |
| """Get the string representation of the graph segment at the given line.""" |
| if line == 0: |
| result_str = self._node_count_segment_str() |
| result_str += " " * (self._max_segment_columns() - len(result_str)) |
| return result_str |
| if line == 1: |
| result_str = self._graph_id_segment_str() |
| result_str += " " * (self._max_segment_columns() - len(result_str)) |
| return result_str |
| if 0 <= line < self._total_rows(): |
| return " " * self._max_segment_columns() |
| return "" |
| |
| def _connector_segment_str_at_line(self, line: int) -> str: |
| """Get the connector segment string at the given line.""" |
| if self.upper_printer is None and self.lower_printer is None: |
| return "" |
| upper_total_rows = self.upper_printer._total_rows() if self.upper_printer else 1 |
| lower_total_rows = self.lower_printer._total_rows() if self.lower_printer else 1 |
| if line == 0: |
| return " __" |
| elif line < upper_total_rows + 1: |
| return " | " |
| elif line == upper_total_rows + 1: |
| return " |__" |
| elif line < upper_total_rows + lower_total_rows + 1: |
| return " " |
| return "" |
| |
| def _children_str_at_line(self, line: int) -> str: |
| """Get the string representation of the children at the given line. |
| |
| Recursively calls `_str_at_line` on children nodes. |
| """ |
| if self.upper_printer is None and self.lower_printer is None: |
| return "" |
| upper_total_rows = self.upper_printer._total_rows() if self.upper_printer else 1 |
| lower_total_rows = self.lower_printer._total_rows() if self.lower_printer else 1 |
| if 0 <= line < upper_total_rows: |
| return ( |
| self.upper_printer._str_at_line(line) if self.upper_printer else "..." |
| ) |
| elif upper_total_rows < line < upper_total_rows + lower_total_rows + 1: |
| return ( |
| self.lower_printer._str_at_line(line - upper_total_rows - 1) |
| if self.lower_printer |
| else "..." |
| ) |
| return "" |
| |
| def _str_at_line(self, line: int) -> str: |
| """Get the string representation of the graph at the given line.""" |
| return ( |
| self._graph_segment_str_at_line(line) |
| + self._connector_segment_str_at_line(line) |
| + self._children_str_at_line(line) |
| ) |
| |
| def pretty_print(self): |
| if self.graph_info is None: |
| print(None) |
| return |
| # Print tree. |
| print(" Tree: ".center(80, "=")) |
| total_rows = self._total_rows() |
| for line in range(total_rows): |
| print(self._str_at_line(line).rstrip()) |
| if self.graph_info.has_mismatch(): |
| # Summarize leaf subgraphs with mismatch. |
| print(" Mismatch leaf subgraphs: ".center(80, "=")) |
| print( |
| [ |
| graph_info.id |
| for graph_info in self.graph_info.all_mismatch_leaf_graph_info() |
| ] |
| ) |
| # Summarize node kinds with mismatch. |
| mismatch_node_kinds: dict[str, int] = {} |
| for graph_info in self.graph_info.all_mismatch_leaf_graph_info(): |
| node_kinds = graph_info.essential_node_kinds() |
| if len(node_kinds) == 1: |
| node_kind = node_kinds.pop() |
| mismatch_node_kinds[node_kind] = ( |
| mismatch_node_kinds.get(node_kind, 0) + 1 |
| ) |
| print(" Mismatch node kinds: ".center(80, "=")) |
| print(mismatch_node_kinds) |
| else: |
| print(" No mismatch found. ".center(80, "=")) |
| |
| |
| class OnnxTestCaseRepro: |
| def __init__(self, repro_dir): |
| self.repro_dir = repro_dir |
| self.proto, self.inputs, self.outputs = onnx_proto_utils.load_test_case( |
| repro_dir |
| ) |
| |
| @classmethod |
| def create_test_case_repro( |
| cls, proto: bytes, inputs, outputs, dir: str, name: str | None = None |
| ): |
| """Create a repro under "{dir}/test_{name}" for an ONNX test case. |
| |
| The test case contains the model and the inputs/outputs data. The directory |
| structure is as follows: |
| |
| dir |
| \u251c\u2500\u2500 test_<name> |
| \u2502 \u251c\u2500\u2500 model.onnx |
| \u2502 \u2514\u2500\u2500 test_data_set_0 |
| \u2502 \u251c\u2500\u2500 input_0.pb |
| \u2502 \u251c\u2500\u2500 input_1.pb |
| \u2502 \u251c\u2500\u2500 output_0.pb |
| \u2502 \u2514\u2500\u2500 output_1.pb |
| |
| Args: |
| proto: ONNX model proto. |
| inputs: Inputs to the model. |
| outputs: Outputs of the model. |
| dir: Directory to save the repro. |
| name: Name of the test case. If not specified, a name based on current time |
| will be generated. |
| Returns: |
| Path to the repro. |
| """ |
| if name is None: |
| name = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") |
| return onnx_proto_utils.export_as_test_case( |
| proto, |
| _to_numpy(inputs), |
| _to_numpy(outputs), |
| name, |
| dir, |
| ) |
| |
| def validate(self, options: VerificationOptions): |
| """Run the ONNX test case with options.backend, and compare with the expected outputs. |
| |
| Args: |
| options: Options for validation. |
| |
| Raise: |
| AssertionError: if outputs from options.backend and expected outputs are not |
| equal up to specified precision. |
| """ |
| onnx_session = _onnx_backend_session(io.BytesIO(self.proto), options.backend) |
| run_outputs = onnx_session.run(None, self.inputs) |
| if hasattr(onnx_session, "get_outputs"): |
| output_names = [o.name for o in onnx_session.get_outputs()] |
| elif hasattr(onnx_session, "output_names"): |
| output_names = onnx_session.output_names |
| else: |
| raise ValueError(f"Unknown onnx session type: {type(onnx_session)}") |
| expected_outs = [self.outputs[name] for name in output_names] |
| _compare_onnx_pytorch_outputs_in_np(run_outputs, expected_outs, options) |
| |
| |
| @dataclasses.dataclass |
| class GraphInfo: |
| """GraphInfo contains validation information of a TorchScript graph and its converted ONNX graph.""" |
| |
| graph: torch.Graph |
| input_args: tuple[Any, ...] |
| params_dict: dict[str, Any] |
| export_options: _experimental.ExportOptions = dataclasses.field( |
| default_factory=_experimental.ExportOptions |
| ) |
| mismatch_error: AssertionError | None = dataclasses.field(default=None, init=False) |
| pt_outs: Sequence[_NumericType] | None = dataclasses.field(default=None, init=False) |
| upper_graph_info: GraphInfo | None = dataclasses.field(default=None, init=False) |
| lower_graph_info: GraphInfo | None = dataclasses.field(default=None, init=False) |
| id: str = dataclasses.field(default="") |
| _onnx_graph: torch.Graph | None = dataclasses.field(init=False, default=None) |
| |
| _EXCLUDED_NODE_KINDS: frozenset[str] = frozenset( |
| {"prim::Constant", "prim::ListConstruct", "aten::ScalarImplicit"} |
| ) |
| |
| def clear(self): |
| """Clear states and results of previous verification.""" |
| self.mismatch_error = None |
| self.pt_outs = None |
| self._onnx_graph = None |
| self.upper_graph_info = None |
| self.lower_graph_info = None |
| |
| def pretty_print_tree(self): |
| """Pretty print `GraphInfo` tree. |
| |
| Each node represents a subgraph, showing the number of nodes in the subgraph and |
| a check mark if the subgraph has output mismatch between torch and ONNX. |
| |
| The id of the subgraph is shown under the node. The `GraphInfo` object for any |
| subgraph can be retrieved by calling `graph_info.find_partition(id)`. |
| |
| Example:: |
| |
| ==================================== Tree: ===================================== |
| 5 X __2 X __1 \u2713 |
| id: | id: 0 | id: 00 |
| | | |
| | |__1 X (aten::relu) |
| | id: 01 |
| | |
| |__3 X __1 \u2713 |
| id: 1 | id: 10 |
| | |
| |__2 X __1 X (aten::relu) |
| id: 11 | id: 110 |
| | |
| |__1 \u2713 |
| id: 111 |
| =========================== Mismatch leaf subgraphs: =========================== |
| ['01', '110'] |
| ============================= Mismatch node kinds: ============================= |
| {'aten::relu': 2} |
| |
| """ |
| GraphInfoPrettyPrinter(self).pretty_print() |
| |
| def pretty_print_mismatch(self, graph: bool = False): |
| """Pretty print details of the mismatch between torch and ONNX. |
| |
| Args: |
| graph: If True, print the ATen JIT graph and ONNX graph. |
| """ |
| print(f" Mismatch info for graph partition {self.id}: ".center(80, "=")) |
| if graph: |
| print(" ATen JIT graph ".center(80, "=")) |
| # TODO: A more compact graph printer. |
| # * Drop stride, grad, device information. |
| # * Show source location on a separate line. |
| print(self.graph) |
| if self._onnx_graph is not None: |
| print(" ONNX graph ".center(80, "=")) |
| print(self._onnx_graph) |
| if self.has_mismatch(): |
| print(" Mismatch error ".center(80, "=")) |
| print(self.mismatch_error) |
| else: |
| print(" No mismatch ".center(80, "=")) |
| |
| def has_mismatch(self) -> bool: |
| """Return True if the subgraph has output mismatch between torch and ONNX.""" |
| return self.mismatch_error is not None |
| |
| def essential_node_count(self) -> int: |
| """Return the number of nodes in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`.""" |
| return sum( |
| 1 for n in self.graph.nodes() if n.kind() not in self._EXCLUDED_NODE_KINDS |
| ) |
| |
| def essential_node_kinds(self) -> set[str]: |
| """Return the set of node kinds in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`.""" |
| return { |
| n.kind() |
| for n in self.graph.nodes() |
| if n.kind() not in self._EXCLUDED_NODE_KINDS |
| } |
| |
| def all_mismatch_leaf_graph_info(self) -> list[GraphInfo]: |
| """Return a list of all leaf `GraphInfo` objects that have mismatch.""" |
| if not self.has_mismatch(): |
| return [] |
| |
| no_mismatch_children = ( |
| self.upper_graph_info is None or not self.upper_graph_info.has_mismatch() |
| ) and ( |
| self.lower_graph_info is None or not self.lower_graph_info.has_mismatch() |
| ) |
| |
| if no_mismatch_children: |
| return [self] |
| |
| results = [] |
| if self.upper_graph_info is not None: |
| results += self.upper_graph_info.all_mismatch_leaf_graph_info() |
| if self.lower_graph_info is not None: |
| results += self.lower_graph_info.all_mismatch_leaf_graph_info() |
| |
| return results |
| |
| def find_partition(self, id: str) -> GraphInfo | None: |
| """Find the `GraphInfo` object with the given id.""" |
| if id == self.id: |
| return self |
| current_length = len(self.id) |
| if len(id) > current_length: |
| if id[current_length] == "0" and self.upper_graph_info is not None: |
| return self.upper_graph_info.find_partition(id) |
| elif id[current_length] == "1" and self.lower_graph_info is not None: |
| return self.lower_graph_info.find_partition(id) |
| return None |
| |
| def export_repro( |
| self, repro_dir: str | None = None, name: str | None = None |
| ) -> str: |
| """Export the subgraph to ONNX along with the input/output data for repro. |
| |
| The repro directory will contain the following files:: |
| |
| dir |
| \u251c\u2500\u2500 test_<name> |
| \u2502 \u251c\u2500\u2500 model.onnx |
| \u2502 \u2514\u2500\u2500 test_data_set_0 |
| \u2502 \u251c\u2500\u2500 input_0.pb |
| \u2502 \u251c\u2500\u2500 input_1.pb |
| \u2502 \u251c\u2500\u2500 output_0.pb |
| \u2502 \u2514\u2500\u2500 output_1.pb |
| |
| Args: |
| repro_dir: The directory to export the repro files to. Defaults to current |
| working directory if None. |
| name: An optional name for the test case folder: "test_{name}". |
| |
| Returns: |
| The path to the exported repro directory. |
| """ |
| |
| if repro_dir is None: |
| repro_dir = os.getcwd() |
| repro_dir = os.path.join(repro_dir, "onnx_debug") |
| |
| onnx_graph, onnx_params_dict = _onnx_graph_from_aten_graph( |
| self.graph, self.export_options, self.params_dict |
| ) |
| |
| proto, _ = _onnx_proto_from_onnx_graph( |
| onnx_graph, self.export_options, onnx_params_dict |
| ) |
| return OnnxTestCaseRepro.create_test_case_repro( |
| proto, self.input_args, self.pt_outs, repro_dir, name |
| ) |
| |
| def _graph_partition_pivot(self) -> int: |
| """Find the pivot index to partition the graph. |
| |
| The pivot is the node that splits the graph into two parts. Each part should |
| have the similar amount of nodes, excluding non essential ops, defined in |
| `_EXCLUDED_NODE_KINDS`, such as `prim::Constant`. |
| If the graph has an odd number of nodes, the upper part will have one more node. |
| If the graph does not have any node that can be partitioned, return -1. |
| |
| Returns: |
| The index of the pivot node. |
| """ |
| included_node_indices = [ |
| i |
| for i, n in enumerate(self.graph.nodes()) |
| if n.kind() not in self._EXCLUDED_NODE_KINDS |
| ] |
| half_idx = len(included_node_indices) // 2 - 1 |
| if half_idx >= 0 and len(included_node_indices) > half_idx: |
| return included_node_indices[half_idx] + 1 |
| return -1 |
| |
| def _partition_upper_graph(self) -> torch.Graph: |
| pivot = self._graph_partition_pivot() |
| if pivot == -1: |
| return torch.Graph() |
| graph = self.graph.copy() # Copy to not mutate parent graph. |
| original_outputs = list(graph.outputs()) |
| |
| def _process_bridge_value_for_upper( |
| new_outputs: list[torch.Value], bridge_value: torch.Value |
| ) -> torch.Value: |
| # Add bridge values as upper graph outputs. |
| new_outputs.append(bridge_value) |
| return bridge_value |
| |
| new_outputs: list[torch.Value] = [] |
| process_bridge_value_for_upper = functools.partial( |
| _process_bridge_value_for_upper, new_outputs |
| ) |
| _, dropped_nodes, complete_upper_nodes_set, _ = self._partition_nodes( |
| graph, pivot, process_bridge_value_for_upper |
| ) |
| |
| for _ in enumerate(original_outputs): |
| graph.eraseOutput(0) |
| for output in new_outputs: |
| graph.registerOutput(output) |
| |
| for node in reversed(dropped_nodes): |
| node.destroy() |
| |
| for i, input in reversed(list(enumerate(list(graph.inputs())))): |
| if ( |
| not _has_uses_by_nodes(input, complete_upper_nodes_set) |
| and input not in new_outputs |
| ): |
| try: |
| graph.eraseInput(i) |
| except RuntimeError as e: |
| print(input, graph) |
| raise e |
| |
| return graph |
| |
| def _partition_lower_graph(self) -> torch.Graph: |
| pivot = self._graph_partition_pivot() |
| if pivot == -1: |
| return torch.Graph() |
| graph = self.graph.copy() # Copy to not mutate parent graph. |
| original_outputs = list(graph.outputs()) |
| original_inputs = list(graph.inputs()) |
| |
| new_outputs = [] |
| |
| def _process_bridge_value_for_lower( |
| graph: torch.Graph, bridge_value: torch.Value |
| ) -> torch.Value: |
| # Add bridge values as lower graph inputs. |
| new_input = graph.addInput() |
| bridge_value.replaceAllUsesWith(new_input) |
| new_input.copyMetadata(bridge_value) |
| return new_input |
| |
| process_bridge_value_for_lower = functools.partial( |
| _process_bridge_value_for_lower, graph |
| ) |
| |
| upper_nodes, lower_nodes, _, complete_lower_nodes_set = self._partition_nodes( |
| graph, pivot, process_bridge_value_for_lower |
| ) |
| |
| for output in original_outputs: |
| if _produced_by(output, lower_nodes): |
| new_outputs.append(output) |
| for _ in enumerate(original_outputs): |
| graph.eraseOutput(0) |
| for output in new_outputs: |
| graph.registerOutput(output) |
| |
| for input in original_inputs: |
| if _has_uses_by_nodes(input, complete_lower_nodes_set): |
| new_input = graph.addInput() |
| input.replaceAllUsesWith(new_input) |
| new_input.copyMetadata(input) |
| |
| for node in reversed(upper_nodes): |
| if node not in complete_lower_nodes_set: |
| try: |
| node.destroy() |
| except RuntimeError as e: |
| print(node, graph) |
| raise e |
| |
| for _ in original_inputs: |
| graph.eraseInput(0) |
| |
| return graph |
| |
| def _partition_node( |
| self, |
| node: torch.Node, |
| complete_upper_nodes_set: set[torch.Node], |
| complete_lower_nodes_set: set[torch.Node], |
| original_graph_outputs: set[torch.Value], |
| covered_bridge_values: set[torch.Value], |
| process_bridge_value: Callable[[torch.Value], torch.Value], |
| ): |
| if node in complete_lower_nodes_set: |
| return |
| |
| if ( |
| _node_has_uses_by(node, complete_lower_nodes_set) |
| and node.kind() in self._EXCLUDED_NODE_KINDS |
| ): |
| complete_lower_nodes_set.update(_all_nodes([node])) |
| for input in node.inputs(): |
| if input in covered_bridge_values: |
| continue |
| self._partition_node( |
| input.node(), |
| complete_upper_nodes_set, |
| complete_lower_nodes_set, |
| original_graph_outputs, |
| covered_bridge_values, |
| process_bridge_value, |
| ) |
| else: |
| for output in node.outputs(): |
| if output in covered_bridge_values: |
| continue |
| if ( |
| _has_uses_by_nodes(output, complete_lower_nodes_set) |
| or output in original_graph_outputs |
| ): |
| covered_bridge_values.add(process_bridge_value(output)) |
| |
| def _partition_nodes( |
| self, |
| graph: torch.Graph, |
| pivot: int, |
| process_bridge_value: Callable[[torch.Value], torch.Value], |
| ) -> tuple[list[torch.Node], list[torch.Node], set[torch.Node], set[torch.Node]]: |
| nodes = list(graph.nodes()) |
| upper_nodes = nodes[:pivot] |
| lower_nodes = nodes[pivot:] |
| # `upper_nodes` and `complete_upper_nodes_set` differs in that the latter |
| # recursively contains nodes in subblock of `upper_nodes`. |
| # The same applies for `lower_nodes` and `complete_lower_nodes_set`. |
| # With addition that `complete_lower_nodes_set` will include nodes that |
| # are determined to be copied from `upper_nodes` to `lower_nodes`. |
| complete_upper_nodes_set = _all_nodes(upper_nodes) |
| complete_lower_nodes_set = _all_nodes(lower_nodes) |
| original_graph_outputs = set(graph.outputs()) |
| # Bridge values are values produced from upper graph, and consumed |
| # by lower graph. These values need to be become upper graph outputs |
| # and lower graph inputs, to bridge the interaction. |
| # Start with all graph inputs marked as covered. If any graph input is |
| # needed by lower graph, just keep it in lower graph inputs later. |
| covered_bridge_values = set(graph.inputs()) |
| for node in upper_nodes: |
| self._partition_node( |
| node, |
| complete_upper_nodes_set, |
| complete_lower_nodes_set, |
| original_graph_outputs, |
| covered_bridge_values, |
| process_bridge_value, |
| ) |
| return ( |
| upper_nodes, |
| lower_nodes, |
| complete_upper_nodes_set, |
| complete_lower_nodes_set, |
| ) |
| |
| def _bridge_kwargs(self): |
| pt_outs = self.pt_outs |
| graph_outputs = list(self.graph.outputs()) |
| assert pt_outs is not None |
| assert len(graph_outputs) == len( |
| pt_outs |
| ), f"{len(graph_outputs)} vs {len(pt_outs)}\nGraph: {self.graph}" |
| return {v.debugName(): o for v, o in zip(graph_outputs, pt_outs)} |
| |
| def _args_and_params_for_partition_graph( |
| self, |
| graph: torch.Graph, |
| bridge_kwargs: Mapping[str, _NumericType | Sequence[_NumericType]], |
| full_kwargs: Mapping[str, torch.Tensor], |
| full_params: Mapping[str, torch.Tensor], |
| ): |
| input_names = [input.debugName() for input in graph.inputs()] |
| args = tuple(bridge_kwargs[k] for k in input_names if k in bridge_kwargs) |
| args += tuple(full_kwargs[k] for k in input_names if k in full_kwargs) |
| params = {k: full_params[k] for k in input_names if k in full_params} |
| assert len(args) + len(params) == len( |
| input_names |
| ), f"{len(args)} + {len(params)} vs {len(input_names)}: {input_names}" |
| return args, params |
| |
| def verify_export( |
| self, options: VerificationOptions |
| ) -> tuple[AssertionError | None, torch.Graph, _OutputsType, _OutputsType]: |
| """ |
| Verify the export from TorchScript IR graph to ONNX. |
| |
| Export the TorchScript IR graph to ONNX, with the inputs, parameters and export |
| options recorded in this object. Then verify the exported ONNX graph against |
| the original TorchScript IR graph under the provided verification options. |
| |
| Args: |
| options: The verification options. |
| |
| Returns: |
| error: The AssertionError raised during the verification. Returns None if no |
| error is raised. |
| onnx_graph: The exported ONNX graph in TorchScript IR format. |
| onnx_outs: The outputs from running exported ONNX model under the onnx |
| backend in `options`. |
| pt_outs: The outputs from running the TorchScript IR graph. |
| """ |
| return verify_aten_graph( |
| self.graph, |
| input_args=self.input_args, |
| params_dict=self.params_dict, |
| export_options=self.export_options, |
| verification_options=options, |
| ) |
| |
| def find_mismatch( |
| self, |
| options: VerificationOptions | None = None, |
| ): |
| """ |
| Find all mismatches between the TorchScript IR graph and the exported onnx model. |
| |
| Binary searches the model graph to find the minimal subgraph that exhibits the |
| mismatch. A `GraphInfo` object is created for each subgraph, recording the test |
| inputs and export options, as well as the validation results. |
| |
| Args: |
| options: The verification options. |
| """ |
| self.clear() |
| |
| if options is None: |
| options = VerificationOptions() |
| |
| if self.export_options.verbose: |
| print(self.graph) |
| |
| if len(list(self.graph.outputs())) == 0: |
| return |
| |
| assert len(self.input_args) + len(self.params_dict) == len( |
| list(self.graph.inputs()) |
| ), ( |
| f"Number of graph inputs({len(list(self.graph.inputs()))}) does not match " |
| f"the provided tensor arguments({len(self.input_args)} + {len(self.params_dict)})." |
| ) |
| |
| self.mismatch_error, self._onnx_graph, self.pt_outs, _ = self.verify_export( |
| options |
| ) |
| |
| if self.mismatch_error is None: |
| # No mismatch found in graph. |
| return |
| |
| if self.essential_node_count() <= 1: |
| # Reached leaf node, no more partitioning. |
| return |
| |
| full_kwargs = { |
| k.debugName(): v for k, v in zip(self.graph.inputs(), self.input_args) |
| } |
| full_params = self.params_dict |
| |
| upper_graph = self._partition_upper_graph() |
| upper_args, upper_params = self._args_and_params_for_partition_graph( |
| upper_graph, {}, full_kwargs, full_params |
| ) |
| self.upper_graph_info = GraphInfo( |
| upper_graph, |
| upper_args, |
| upper_params, |
| self.export_options, |
| id=self.id + "0", |
| ) |
| |
| self.upper_graph_info.find_mismatch(options) |
| |
| bridge_kwargs = self.upper_graph_info._bridge_kwargs() |
| lower_graph = self._partition_lower_graph() |
| lower_args, lower_params = self._args_and_params_for_partition_graph( |
| lower_graph, bridge_kwargs, full_kwargs, full_params |
| ) |
| self.lower_graph_info = GraphInfo( |
| lower_graph, |
| lower_args, |
| lower_params, |
| self.export_options, |
| id=self.id + "1", |
| ) |
| |
| self.lower_graph_info.find_mismatch(options) |
| |
| |
| def _all_nodes(nodes: Collection[torch.Node]) -> set[torch.Node]: |
| all_nodes = set(nodes) |
| for n in nodes: |
| for b in n.blocks(): |
| all_nodes.update(_all_nodes(list(b.nodes()))) |
| return all_nodes |
| |
| |
| def _has_uses_by_nodes(value: torch.Value, nodes: Collection[torch.Node]) -> bool: |
| return any(use.user in nodes for use in value.uses()) |
| |
| |
| def _node_has_uses_by(node: torch.Node, nodes: Collection[torch.Node]) -> bool: |
| for output in node.outputs(): |
| if _has_uses_by_nodes(output, nodes): |
| return True |
| return False |
| |
| |
| def _produced_by(value: torch.Value, nodes: Collection[torch.Node]) -> bool: |
| return value.node() in nodes |
| |
| |
| def find_mismatch( |
| model: torch.nn.Module | torch.jit.ScriptModule, |
| input_args: tuple[Any, ...], |
| do_constant_folding: bool = True, |
| training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL, |
| opset_version: int | None = None, |
| keep_initializers_as_inputs: bool = True, |
| verbose: bool = False, |
| options: VerificationOptions | None = None, |
| ) -> GraphInfo: |
| r"""Find all mismatches between the original model and the exported model. |
| |
| Experimental. The API is subject to change. |
| |
| This tool helps debug the mismatch between the original PyTorch model and exported |
| ONNX model. It binary searches the model graph to find the minimal subgraph that |
| exhibits the mismatch. |
| |
| Args: |
| model: The model to be exported. |
| input_args: The input arguments to the model. |
| do_constant_folding: Same as `do_constant_folding` in :func:`torch.onnx.export`. |
| training: Same as `training` in :func:`torch.onnx.export`. |
| opset_version: Same as `opset_version` in :func:`torch.onnx.export`. |
| keep_initializers_as_inputs: Same as `keep_initializers_as_inputs` in :func:`torch.onnx.export`. |
| verbose: Same as `verbose` in :func:`torch.onnx.export`. |
| options: The options for the mismatch verification. |
| |
| Returns: |
| A GraphInfo object that contains the mismatch information. |
| |
| Example:: |
| |
| >>> import torch |
| >>> import torch.onnx.verification |
| >>> torch.manual_seed(0) |
| >>> opset_version = 15 |
| >>> # Define a custom symbolic function for aten::relu. |
| >>> # The custom symbolic function is incorrect, which will result in mismatches. |
| >>> def incorrect_relu_symbolic_function(g, self): |
| ... return self |
| >>> torch.onnx.register_custom_op_symbolic( |
| ... "aten::relu", |
| ... incorrect_relu_symbolic_function, |
| ... opset_version=opset_version, |
| ... ) |
| >>> class Model(torch.nn.Module): |
| ... def __init__(self) -> None: |
| ... super().__init__() |
| ... self.layers = torch.nn.Sequential( |
| ... torch.nn.Linear(3, 4), |
| ... torch.nn.ReLU(), |
| ... torch.nn.Linear(4, 5), |
| ... torch.nn.ReLU(), |
| ... torch.nn.Linear(5, 6), |
| ... ) |
| ... def forward(self, x): |
| ... return self.layers(x) |
| >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) |
| >>> graph_info = torch.onnx.verification.find_mismatch( |
| ... Model(), |
| ... (torch.randn(2, 3),), |
| ... opset_version=opset_version, |
| ... ) |
| ===================== Mismatch info for graph partition : ====================== |
| ================================ Mismatch error ================================ |
| Tensor-likes are not close! |
| Mismatched elements: 12 / 12 (100.0%) |
| Greatest absolute difference: 0.2328854203224182 at index (1, 2) (up to 1e-07 allowed) |
| Greatest relative difference: 0.699536174352349 at index (1, 3) (up to 0.001 allowed) |
| ==================================== Tree: ===================================== |
| 5 X __2 X __1 \u2713 |
| id: | id: 0 | id: 00 |
| | | |
| | |__1 X (aten::relu) |
| | id: 01 |
| | |
| |__3 X __1 \u2713 |
| id: 1 | id: 10 |
| | |
| |__2 X __1 X (aten::relu) |
| id: 11 | id: 110 |
| | |
| |__1 \u2713 |
| id: 111 |
| =========================== Mismatch leaf subgraphs: =========================== |
| ['01', '110'] |
| ============================= Mismatch node kinds: ============================= |
| {'aten::relu': 2} |
| |
| """ |
| if options is None: |
| options = VerificationOptions() |
| if opset_version is None: |
| opset_version = _constants.ONNX_DEFAULT_OPSET |
| """From aten graph, do binary search on graph partition to find operator export discrepancy.""" |
| # TODO: Copied from utils.py `export` until `_optimize_graph`. |
| if training == torch.onnx.TrainingMode.TRAINING: |
| model.train() |
| elif training == torch.onnx.TrainingMode.EVAL: |
| model.eval() |
| with torch.no_grad(): |
| inputs_for_export = _prepare_input_for_export(input_args, {}) |
| args = utils._decide_input_format(model, inputs_for_export) |
| |
| model = utils._pre_trace_quant_model(model, args) |
| graph, params, torch_out, module = utils._create_jit_graph(model, args) |
| params_dict = utils._get_named_param_dict(graph, params) |
| |
| utils._apply_friendly_debug_names(graph, params_dict) |
| |
| graph_info = GraphInfo( |
| graph, |
| input_args, |
| params_dict, |
| _experimental.ExportOptions( |
| do_constant_folding=do_constant_folding, |
| training=training, |
| opset_version=opset_version, |
| keep_initializers_as_inputs=keep_initializers_as_inputs, |
| verbose=verbose, |
| ), |
| ) |
| graph_info.find_mismatch(options) |
| graph_info.pretty_print_mismatch() |
| graph_info.pretty_print_tree() |
| |
| return graph_info |