| """ONNX exporter exceptions.""" |
| from __future__ import annotations |
| |
| import textwrap |
| from typing import Optional |
| |
| from torch import _C |
| from torch.onnx import _constants |
| from torch.onnx._internal import diagnostics |
| |
| __all__ = [ |
| "OnnxExporterError", |
| "OnnxExporterWarning", |
| "CallHintViolationWarning", |
| "CheckerError", |
| "UnsupportedOperatorError", |
| "SymbolicValueError", |
| ] |
| |
| |
| class OnnxExporterWarning(UserWarning): |
| """Base class for all warnings in the ONNX exporter.""" |
| |
| pass |
| |
| |
| class CallHintViolationWarning(OnnxExporterWarning): |
| """Warning raised when a type hint is violated during a function call.""" |
| |
| pass |
| |
| |
| class OnnxExporterError(RuntimeError): |
| """Errors raised by the ONNX exporter.""" |
| |
| pass |
| |
| |
| class CheckerError(OnnxExporterError): |
| """Raised when ONNX checker detects an invalid model.""" |
| |
| pass |
| |
| |
| class UnsupportedOperatorError(OnnxExporterError): |
| """Raised when an operator is unsupported by the exporter.""" |
| |
| def __init__( |
| self, |
| domain: str, |
| op_name: str, |
| version: int, |
| supported_version: Optional[int], |
| ): |
| if domain in {"", "aten", "prim", "quantized"}: |
| msg = f"Exporting the operator '{domain}::{op_name}' to ONNX opset version {version} is not supported. " |
| if supported_version is not None: |
| msg += ( |
| f"Support for this operator was added in version {supported_version}, " |
| "try exporting with this version." |
| ) |
| diagnostics.context.diagnose( |
| diagnostics.rules.operator_supported_in_newer_opset_version, |
| diagnostics.levels.ERROR, |
| message_args=( |
| f"{domain}::{op_name}", |
| version, |
| supported_version, |
| ), |
| ) |
| else: |
| msg += "Please feel free to request support or submit a pull request on PyTorch GitHub: " |
| msg += _constants.PYTORCH_GITHUB_ISSUES_URL |
| diagnostics.context.diagnose( |
| diagnostics.rules.missing_standard_symbolic_function, |
| diagnostics.levels.ERROR, |
| message_args=( |
| f"{domain}::{op_name}", |
| version, |
| _constants.PYTORCH_GITHUB_ISSUES_URL, |
| ), |
| ) |
| else: |
| msg = ( |
| f"ONNX export failed on an operator with unrecognized namespace '{domain}::{op_name}'. " |
| "If you are trying to export a custom operator, make sure you registered " |
| "it with the right domain and version." |
| ) |
| diagnostics.context.diagnose( |
| diagnostics.rules.missing_custom_symbolic_function, |
| diagnostics.levels.ERROR, |
| message_args=(f"{domain}::{op_name}",), |
| ) |
| super().__init__(msg) |
| |
| |
| class SymbolicValueError(OnnxExporterError): |
| """Errors around TorchScript values and nodes.""" |
| |
| def __init__(self, msg: str, value: _C.Value): |
| message = ( |
| f"{msg} [Caused by the value '{value}' (type '{value.type()}') in the " |
| f"TorchScript graph. The containing node has kind '{value.node().kind()}'.] " |
| ) |
| |
| code_location = value.node().sourceRange() |
| if code_location: |
| message += f"\n (node defined in {code_location})" |
| |
| try: |
| # Add its input and output to the message. |
| message += "\n\n" |
| message += textwrap.indent( |
| ( |
| "Inputs:\n" |
| + ( |
| "\n".join( |
| f" #{i}: {input_} (type '{input_.type()}')" |
| for i, input_ in enumerate(value.node().inputs()) |
| ) |
| or " Empty" |
| ) |
| + "\n" |
| + "Outputs:\n" |
| + ( |
| "\n".join( |
| f" #{i}: {output} (type '{output.type()}')" |
| for i, output in enumerate(value.node().outputs()) |
| ) |
| or " Empty" |
| ) |
| ), |
| " ", |
| ) |
| except AttributeError: |
| message += ( |
| " Failed to obtain its input and output for debugging. " |
| "Please refer to the TorchScript graph for debugging information." |
| ) |
| |
| super().__init__(message) |