| """ONNX exporter exceptions.""" |
| from __future__ import annotations |
| |
| import textwrap |
| from typing import Optional |
| |
| from torch import _C |
| from torch.onnx import _constants |
| from torch.onnx._internal import diagnostics |
| |
| __all__ = [ |
| "OnnxExporterError", |
| "OnnxExporterWarning", |
| "CallHintViolationWarning", |
| "CheckerError", |
| "UnsupportedOperatorError", |
| "SymbolicValueError", |
| ] |
| |
| |
| class OnnxExporterWarning(UserWarning): |
| """Base class for all warnings in the ONNX exporter.""" |
| |
| pass |
| |
| |
| class CallHintViolationWarning(OnnxExporterWarning): |
| """Warning raised when a type hint is violated during a function call.""" |
| |
| pass |
| |
| |
| class OnnxExporterError(RuntimeError): |
| """Errors raised by the ONNX exporter.""" |
| |
| pass |
| |
| |
| class CheckerError(OnnxExporterError): |
| """Raised when ONNX checker detects an invalid model.""" |
| |
| pass |
| |
| |
| class UnsupportedOperatorError(OnnxExporterError): |
| """Raised when an operator is unsupported by the exporter.""" |
| |
| def __init__(self, name: str, version: int, supported_version: Optional[int]): |
| if supported_version is not None: |
| diagnostic_rule: diagnostics.infra.Rule = ( |
| diagnostics.rules.operator_supported_in_newer_opset_version |
| ) |
| msg = diagnostic_rule.format_message(name, version, supported_version) |
| diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg) |
| else: |
| if ( |
| name.startswith("aten::") |
| or name.startswith("prim::") |
| or name.startswith("quantized::") |
| ): |
| diagnostic_rule = diagnostics.rules.missing_standard_symbolic_function |
| msg = diagnostic_rule.format_message( |
| name, version, _constants.PYTORCH_GITHUB_ISSUES_URL |
| ) |
| diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg) |
| else: |
| diagnostic_rule = diagnostics.rules.missing_custom_symbolic_function |
| msg = diagnostic_rule.format_message(name) |
| diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg) |
| super().__init__(msg) |
| |
| |
| class SymbolicValueError(OnnxExporterError): |
| """Errors around TorchScript values and nodes.""" |
| |
| def __init__(self, msg: str, value: _C.Value): |
| message = ( |
| f"{msg} [Caused by the value '{value}' (type '{value.type()}') in the " |
| f"TorchScript graph. The containing node has kind '{value.node().kind()}'.] " |
| ) |
| |
| code_location = value.node().sourceRange() |
| if code_location: |
| message += f"\n (node defined in {code_location})" |
| |
| try: |
| # Add its input and output to the message. |
| message += "\n\n" |
| message += textwrap.indent( |
| ( |
| "Inputs:\n" |
| + ( |
| "\n".join( |
| f" #{i}: {input_} (type '{input_.type()}')" |
| for i, input_ in enumerate(value.node().inputs()) |
| ) |
| or " Empty" |
| ) |
| + "\n" |
| + "Outputs:\n" |
| + ( |
| "\n".join( |
| f" #{i}: {output} (type '{output.type()}')" |
| for i, output in enumerate(value.node().outputs()) |
| ) |
| or " Empty" |
| ) |
| ), |
| " ", |
| ) |
| except AttributeError: |
| message += ( |
| " Failed to obtain its input and output for debugging. " |
| "Please refer to the TorchScript graph for debugging information." |
| ) |
| |
| super().__init__(message) |