blob: 8f7198cd39497063d49664a4230bbcbdbe036727 [file] [log] [blame]
"""Functions to export models into the ONNX IR format.
These models can be loaded with the ONNX library and then
converted to models which run on other deep learning frameworks.
"""
from __future__ import annotations
import contextlib
import copy
import inspect
import itertools
import os
import re
import textwrap
import typing
import warnings
import zipfile
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch._C._onnx as _C_onnx
import torch.jit._trace
import torch.serialization
from torch import _C
from torch.onnx import ( # noqa: F401
_constants,
_exporter_states,
_patch_torch,
errors,
symbolic_caffe2,
symbolic_helper,
symbolic_registry,
)
from torch.onnx._globals import GLOBALS
__all__ = [
"is_in_onnx_export",
"select_model_mode_for_export",
"disable_apex_o2_state_dict_hook",
"setup_onnx_logging",
"exporter_context",
"export",
"warn_on_static_input_change",
"unpack_quantized_tensor",
"export_to_pretty_string",
"unconvertible_ops",
"get_ns_op_name_from_custom_op",
"register_custom_op_symbolic",
"unregister_custom_op_symbolic",
]
def is_in_onnx_export() -> bool:
"""Returns whether it is in the middle of ONNX export."""
return GLOBALS.in_onnx_export
# TODO(justinchuby): Remove dependency to this global variable from constant_fold.cpp
# Skip check due to cannot import IValue from torch._C
_params_dict = {} # type: ignore[var-annotated]
@contextlib.contextmanager
def select_model_mode_for_export(model, mode):
if not isinstance(model, torch.jit.ScriptFunction):
is_originally_training = model.training
if mode is None:
mode = _C_onnx.TrainingMode.EVAL
# if the model is in training mode but the user did not specify
# to export the model in training mode, export the model in inference
# mode (default) and warn them
if is_originally_training:
warnings.warn(
"You are exporting the model to ONNX while in training mode with "
"'train' parameter not specified. The model will default to inference mode export. "
"If you wish to export a training amenable ONNX model, specify training=TrainingMode.TRAINING or "
"training=TrainingMode.PRESERVE (to preserve the original model state) in torch.onnx.export()."
)
# if mode == TrainingMode.EVAL or (mode == TrainingMode.PRESERVE and not is_originally_training) => is_training = False
is_export_training = False
# ONNX opset 12 has better support for training amenable models, with updated
# versions of the dropout and batch_norm operators
if mode == _C_onnx.TrainingMode.TRAINING or (
mode == _C_onnx.TrainingMode.PRESERVE and is_originally_training
):
if GLOBALS.export_onnx_opset_version < 12:
warnings.warn(
"You are exporting the model in training mode with onnx opset "
f"version {GLOBALS.export_onnx_opset_version}. "
"Opset versions lower than opset 12 will not be able to export "
"nodes such as Dropout and BatchNorm correctly."
)
is_export_training = True
symbolic_helper._set_training_mode(is_export_training)
model.train(is_export_training)
try:
yield
finally:
if not isinstance(model, torch.jit.ScriptFunction):
# FIXME(justinchuby): is_originally_training is possibly unbound
model.train(is_originally_training)
@contextlib.contextmanager
def disable_apex_o2_state_dict_hook(model):
# Apex O2 hook state_dict to return fp16 weights as fp32.
# Exporter cannot identify them as same tensors.
# Since this hook is only used by optimizer, it is safe to
# remove this hook while exporting.
if not isinstance(model, torch.jit.ScriptFunction):
tmp_map = {} # type: ignore[var-annotated]
for module in model.modules():
for k, v in module._state_dict_hooks.items():
if type(v).__name__ == "O2StateDictHook":
if module not in tmp_map:
tmp_map[module] = {}
tmp_map[module][k] = v
if module in tmp_map:
for k in tmp_map[module].keys():
module._state_dict_hooks.pop(k)
try:
yield
finally:
if not isinstance(model, torch.jit.ScriptFunction):
# FIXME(justinchuby): tmp_map is possibly unbound
for module, m_map in tmp_map.items():
for k, v in m_map.items():
module._state_dict_hooks[k] = v
@contextlib.contextmanager
def setup_onnx_logging(verbose):
is_originally_enabled = torch.onnx.is_onnx_log_enabled()
if is_originally_enabled or verbose:
torch.onnx.enable_log()
try:
yield
finally:
if not is_originally_enabled:
torch.onnx.disable_log()
@contextlib.contextmanager
def exporter_context(model, mode, verbose):
with select_model_mode_for_export(
model, mode
) as mode_ctx, disable_apex_o2_state_dict_hook(
model
) as apex_ctx, setup_onnx_logging(
verbose
) as log_ctx:
yield (mode_ctx, apex_ctx, log_ctx)
def export(
model,
args,
f,
export_params=True,
verbose=False,
training=None,
input_names=None,
output_names=None,
operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
opset_version=None,
do_constant_folding=True,
dynamic_axes=None,
keep_initializers_as_inputs=None,
custom_opsets=None,
export_modules_as_functions=False,
):
_export(
model,
args,
f,
export_params,
verbose,
training,
input_names,
output_names,
operator_export_type=operator_export_type,
opset_version=opset_version,
do_constant_folding=do_constant_folding,
dynamic_axes=dynamic_axes,
keep_initializers_as_inputs=keep_initializers_as_inputs,
custom_opsets=custom_opsets,
export_modules_as_functions=export_modules_as_functions,
)
def _is_constant_tensor_list(node):
if node.kind() != "prim::Constant":
return False
output_type = node.output().type()
if output_type.isSubtypeOf(_C.ListType.ofTensors()):
return True
if output_type.isSubtypeOf(_C.ListType(_C.OptionalType.ofTensor())):
return True
# ONNX can't handle constants that are lists of tensors, which can
# get generated in constant prop. So we split them back into prim::ListConstructs
def _split_tensor_list_constants(g, block):
for node in block.nodes():
for subblock in node.blocks():
_split_tensor_list_constants(g, subblock)
if _is_constant_tensor_list(node):
inputs = []
for val in node.output().toIValue():
input = g.insertConstant(val)
input.node().moveBefore(node)
input.node().copyMetadata(node)
inputs.append(input)
lc = (
g.create("prim::ListConstruct", inputs)
.insertBefore(node)
.output()
.setType(_C.ListType.ofTensors())
)
lc.node().copyMetadata(node)
node.output().replaceAllUsesWith(lc)
def _optimize_graph(
graph: _C.Graph,
operator_export_type: _C_onnx.OperatorExportTypes,
_disable_torch_constant_prop: bool = False,
fixed_batch_size: bool = False,
params_dict=None,
dynamic_axes=None,
input_names=None,
module=None,
):
# Inline everything
_C._jit_pass_inline(graph)
# Remove fork/wait nodes
_C._jit_pass_inline_fork_wait(graph)
_C._jit_pass_lint(graph)
_C._jit_pass_lower_all_tuples(graph)
# we now record some ops like ones/zeros
# into a trace where we previously recorded constants.
# use constant prop to maintain our current level of onnx support
# without implementing symbolics for all of them
if _disable_torch_constant_prop is False:
_C._jit_pass_constant_propagation(graph)
_split_tensor_list_constants(graph, graph)
# run dce to eliminate dead parts of the graph that might have been
# left behind by things like symbolic_override
_C._jit_pass_dce(graph)
_C._jit_pass_lint(graph)
_C._jit_pass_canonicalize_graph_fuser_ops(graph)
_C._jit_pass_lint(graph)
_C._jit_pass_peephole(graph, True)
_C._jit_pass_fuse_addmm(graph)
_C._jit_pass_lint(graph)
_C._jit_pass_peephole(graph, True)
_C._jit_pass_lower_all_tuples(graph)
# in _jit_pass_onnx, symbolic functions are called for each node for conversion.
# However, there are nodes that cannot be converted without additional context.
# For example, the number of outputs from split (and whether it is static or dynamic) is unknown
# until the point where it is unpacked by listUnpack node.
# This pass does a preprocess, and prepares the nodes such that enough context can be received
# by the symbolic function.
_C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, module)
_C._jit_pass_onnx_preprocess(graph)
# onnx does not support tuples, so try to remove them
_C._jit_pass_lint(graph)
# onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0
_C._jit_pass_prepare_division_for_onnx(graph)
_C._jit_pass_onnx_remove_print(graph)
_C._jit_pass_onnx_preprocess_caffe2(graph)
symbolic_helper._quantized_ops.clear()
# Unpack quantized weights for conv and linear ops and insert into graph.
_C._jit_pass_onnx_unpack_quantized_weights(
graph, params_dict, symbolic_helper.is_caffe2_aten_fallback()
)
if symbolic_helper.is_caffe2_aten_fallback():
# Insert permutes before and after each conv op to ensure correct order.
_C._jit_pass_onnx_quantization_insert_permutes(graph, params_dict)
# Find consecutive permutes that are no-ops and remove them.
_C._jit_pass_custom_pattern_based_rewrite_graph(
textwrap.dedent(
"""\
graph(%Pi):
%Pq = quantized::nhwc2nchw(%Pi)
%Pr = quantized::nchw2nhwc(%Pq)
return (%Pr)"""
),
textwrap.dedent(
"""\
graph(%Ri):
return (%Ri)"""
),
graph,
)
# onnx only supports tensors, so we turn all out number types into tensors
_C._jit_pass_erase_number_types(graph)
if GLOBALS.onnx_shape_inference:
input_names = [] if input_names is None else input_names
dynamic_axes = {} if dynamic_axes is None else dynamic_axes
_C._jit_pass_onnx_set_dynamic_input_shape(graph, dynamic_axes, input_names)
_C._jit_pass_onnx_lint(graph)
graph = _C._jit_pass_onnx(graph, operator_export_type)
_C._jit_pass_onnx_lint(graph)
_C._jit_pass_lint(graph)
_C._jit_pass_onnx_scalar_type_analysis(
graph, True, GLOBALS.export_onnx_opset_version
)
_C._jit_pass_lint(graph)
_C._jit_pass_onnx_peephole(
graph, GLOBALS.export_onnx_opset_version, fixed_batch_size
)
_C._jit_pass_lint(graph)
# graph is not a valid jit graph anymore because types have been replaced
# (e.g. int with Tensor), so it now contains operators that don't actually
# exist. We can't run normal dead code elimination because it'd fail trying
# to look up if an operator has side effects, but we can run a dead code
# elimination variant that doesn't need to look up if an op has side effects.
_C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
_C._jit_pass_lint(graph)
graph = _C._jit_pass_canonicalize(graph)
_C._jit_pass_lint(graph)
if GLOBALS.onnx_shape_inference:
_C._jit_pass_onnx_graph_shape_type_inference(
graph, params_dict, GLOBALS.export_onnx_opset_version
)
return graph
def warn_on_static_input_change(input_states):
"""Warns that changes to input dictionaries and strings won't take effect in the traced ONNX graph.
We accept dictionaries and strings as ONNX inputs, but they should be only for
configuration use. we detect here if these inputs are modified, and if so we warn
the user that the changes won't take effect in the traced ONNX graph.
"""
for input, traced_input in zip(input_states[0], input_states[1]):
if isinstance(input, dict):
if list(input.keys()) != list(traced_input.keys()):
warning = (
"We detected that you are modifying a dictionary that is an input to your "
"model. "
"Note that dictionaries are allowed as inputs in ONNX but they should be "
"handled with care. "
"Usages of dictionaries is not recommended, and should not be used except "
"for configuration use. "
"Also note that the order and values of the keys must remain the same. "
)
warnings.warn(warning)
elif isinstance(input, str):
if input != traced_input:
warning = (
"The model seems to have string inputs/outputs. "
"Note that strings will not appear as inputs/outputs of the ONNX graph. "
)
warnings.warn(warning)
def _resolve_args_by_export_type(arg_name, arg_value, operator_export_type):
"""Resolves the arguments that are ignored when export_type != operator_export_type.ONNX."""
if (
operator_export_type is not operator_export_type.ONNX
and _C_onnx._CAFFE2_ATEN_FALLBACK
):
if arg_value is True:
warnings.warn(
f"'{arg_name}' can be set to True only when 'operator_export_type' is "
"`ONNX`. Since 'operator_export_type' is not set to 'ONNX', "
f"'{arg_name}' argument will be ignored."
)
arg_value = False
return arg_value
def _decide_keep_init_as_input(
keep_initializers_as_inputs: Optional[bool],
operator_export_type: _C_onnx.OperatorExportTypes,
opset_version: int,
):
"""Decides whether the initializers in the graph should be listed as ONNX graph inputs.
This method encapsulates the logic to decide whether the initializers in the graph
should be listed as ONNX graph inputs (i.e., whether to choose ONNX IR v3 or v4).
If keep_initializers_as_inputs is not specified (None), then we decide whether to keep
initializers as graph inputs (val_keep_init_as_ip) based on export type. If export type
is ONNX, then do not keep initializers as input (val_keep_init_as_ip=False). For all other
export types keep initializers as input (val_keep_init_as_ip=True).
If keep_initializers_as_inputs is specified, then respect it. Unless opset version <= 8,
in which case it must be ignored because for opset version <= 8, all initializers MUST be
part of graph input (only ONNX IR v3 is allowed), i.e. val_keep_init_as_ip=True.
Special handling is needed for opset version 8 or lower, because irrespective
of user input for keep_initializers_as_inputs, the graph must follow ONNX IR v3
semantics, i.e. all initializers must be listed as ONNX graph input.
"""
if opset_version < 9:
if keep_initializers_as_inputs is False:
warnings.warn(
"Setting 'keep_initializers_as_inputs=False' for opset version"
"8 or lower would lead to an invalid ONNX graph. Therefore, "
"'keep_initializers_as_inputs=False' is ignored during export."
"Exported model will have initializers as graph inputs (compliant "
" to ONNX IR v3)."
)
return True # i.e. True == initializers are part of graph input (ONNX IR v3)
val_keep_init_as_ip = (
True if keep_initializers_as_inputs is None else keep_initializers_as_inputs
)
if (
keep_initializers_as_inputs is None
and operator_export_type is _C_onnx.OperatorExportTypes.ONNX
):
val_keep_init_as_ip = False
return val_keep_init_as_ip
def _decide_add_node_names(add_node_names, operator_export_type):
return _resolve_args_by_export_type(
"add_node_names", add_node_names, operator_export_type
)
def _decide_constant_folding(do_constant_folding, operator_export_type, training):
do_constant_folding = _resolve_args_by_export_type(
"do_constant_folding", do_constant_folding, operator_export_type
)
if do_constant_folding and (
training is not None and training is not _C_onnx.TrainingMode.EVAL
):
warnings.warn(
"It is recommended that constant folding be turned off ('do_constant_folding=False') "
"when exporting the model in training-amenable mode, i.e. with 'training=TrainingMode.TRAIN' "
"or 'training=TrainingMode.PRESERVE' (when model is in training mode). Otherwise, some "
"learnable model parameters may not translate correctly in the exported ONNX model "
"because constant folding mutates model parameters. Please consider "
"turning off constant folding or setting the training=TrainingMode.EVAL."
)
return do_constant_folding
def _signature(model) -> inspect.Signature:
should_be_callable = getattr(model, "forward", model)
if callable(should_be_callable):
return inspect.signature(should_be_callable)
raise ValueError("model has no forward method and is not callable")
def _decide_input_format(model, args):
try:
sig = _signature(model)
except ValueError as e:
warnings.warn(f"{e}, skipping _decide_input_format")
return args
try:
ordered_list_keys = list(sig.parameters.keys())
if ordered_list_keys[0] == "self":
ordered_list_keys = ordered_list_keys[1:]
args_dict: Dict = {}
if isinstance(args, list):
args_list = args
elif isinstance(args, tuple):
args_list = list(args)
else:
args_list = [args]
if isinstance(args_list[-1], dict):
args_dict = args_list[-1]
args_list = args_list[:-1]
n_nonkeyword = len(args_list)
for optional_arg in ordered_list_keys[n_nonkeyword:]:
if optional_arg in args_dict:
args_list.append(args_dict[optional_arg])
# Check if this arg has a default value
else:
param = sig.parameters[optional_arg]
if param.default != param.empty:
args_list.append(param.default)
args = args_list if isinstance(args, list) else tuple(args_list)
# Cases of models with no input args
except IndexError:
warnings.warn("No input args, skipping _decide_input_format")
except Exception as e:
warnings.warn(f"Skipping _decide_input_format\n {e.args[0]}")
return args
def _trace(func, args, operator_export_type, return_outs=False):
# Special case for common case of passing a single Tensor
if isinstance(args, torch.Tensor):
args = (args,)
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
func, args, strict=False, _force_outplace=False, _return_inputs_states=True
)
warn_on_static_input_change(inputs_states)
trace_graph = _optimize_graph(trace_graph, operator_export_type, params_dict={})
if return_outs:
return trace_graph, torch_out
return trace_graph
def _trace_and_get_graph_from_model(model, args):
# A basic sanity check: make sure the state_dict keys are the same
# before and after running the model. Fail fast!
orig_state_dict_keys = torch.jit._unique_state_dict(model).keys()
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
model, args, strict=False, _force_outplace=False, _return_inputs_states=True
)
warn_on_static_input_change(inputs_states)
if orig_state_dict_keys != torch.jit._unique_state_dict(model).keys():
raise RuntimeError(
"state_dict changed after running the tracer; "
"something weird is happening in your model!"
)
return trace_graph, torch_out
def _get_param_count_list(method_graph, args_params):
param_count_list = []
for input_, arg_params_ in zip(method_graph.inputs(), args_params):
if "PackedParams" in str(input_.type()):
in_vars, _ = torch.jit._flatten(arg_params_)
param_count_list.append(len(in_vars))
else:
param_count_list.append(arg_params_ is not None)
return param_count_list
def _check_flatten_did_not_remove(original, jit_flattened):
"""torch.jit._flatten removes None. Check if it did so in this case."""
def flatten(x):
if isinstance(x, (list, tuple)):
for inner in x:
yield from flatten(inner)
elif isinstance(x, dict):
for inner in x.values():
yield from flatten(inner)
else:
yield x
flattened_with_none = list(flatten(original))
num_none = len(flattened_with_none) - len(jit_flattened)
assert num_none >= 0
if num_none:
raise ValueError(
f"args contained {num_none} None's after flattening. "
"When exporting a ScriptModule or ScriptFunction, no args may "
"be None because that breaks type propagation."
)
def _create_jit_graph(model, args):
torch_out = None
params: Union[List, Tuple]
if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)):
flattened_args = tuple(torch.jit._flatten(tuple(args))[0])
_check_flatten_did_not_remove(args, flattened_args)
if isinstance(model, torch.jit.ScriptModule):
try:
graph = model.forward.graph
except AttributeError as e:
raise RuntimeError("'forward' method must be a script method") from e
_C._jit_pass_onnx_function_substitution(graph)
freezed_m = _C._freeze_module(model._c, preserveParameters=True)
module, params = _C._jit_onnx_list_model_parameters(freezed_m)
method_graph = module._get_method("forward").graph
args_params = tuple(args) + tuple(params)
param_count_list = _get_param_count_list(method_graph, args_params)
in_vars, _ = torch.jit._flatten(args_params)
graph = _C._propagate_and_assign_input_shapes(
method_graph, tuple(in_vars), param_count_list, False, False
)
return graph, params, torch_out, module
elif isinstance(model, torch.jit.ScriptFunction):
params = ()
graph = model.graph
_C._jit_pass_onnx_function_substitution(graph)
param_count_list = _get_param_count_list(graph, args)
# FIXME(justinchuby): flattened_args is possibly unbound
graph = _C._propagate_and_assign_input_shapes(
graph, flattened_args, param_count_list, False, False
)
return graph, params, torch_out, None
else:
graph, torch_out = _trace_and_get_graph_from_model(model, args)
_C._jit_pass_onnx_lint(graph)
state_dict = torch.jit._unique_state_dict(model)
params = list(state_dict.values())
graph_inputs = list(graph.inputs())
user_input_num = len(graph_inputs) - len(state_dict)
param_names = list(state_dict.keys())
for i, inp in enumerate(graph_inputs):
if i >= user_input_num:
inp.setDebugName(param_names[i - user_input_num])
_C._jit_pass_onnx_function_substitution(graph)
return graph, params, torch_out, None
def _get_named_param_dict(graph, params):
input_and_param_names = [val.debugName() for val in graph.inputs()]
param_names = input_and_param_names[len(input_and_param_names) - len(params) :]
_params_dict = dict(zip(param_names, params))
return _params_dict
def _get_example_outputs(model, args):
input_args = copy.deepcopy(args)
input_kwargs = {}
if input_args and isinstance(input_args[-1], dict):
input_kwargs = input_args[-1]
input_args = input_args[:-1]
example_outputs = model(*input_args, **input_kwargs)
if isinstance(example_outputs, list):
example_outputs = [example_outputs]
elif not isinstance(example_outputs, tuple):
example_outputs = (example_outputs,)
return example_outputs
_qtype_vtype_map = {
torch.quint8: torch.uint8,
torch.qint8: torch.int8,
torch.qint32: torch.int32,
torch.quint4x2: torch.int8,
}
def unpack_quantized_tensor(value):
if isinstance(value, torch.Tensor) and value.dtype in _qtype_vtype_map:
q_value_dequantize = value.dequantize()
q_scale = torch.tensor(value.q_scale(), dtype=torch.double)
q_zero_point = torch.tensor(value.q_zero_point(), dtype=torch.int64)
q_value = q_value_dequantize / q_scale + q_zero_point
q_value = q_value.to(dtype=_qtype_vtype_map[value.dtype])
return q_value, q_scale, q_zero_point
else:
return (value,)
def _pre_trace_quant_model(model, args):
r"""Returns `torch.jit.trace(model, args)` if model is quantized. Otherwise do nothing and return
original model.
This is due to https://github.com/pytorch/pytorch/issues/75761.
"""
if any(
hasattr(m, "_packed_params") for m in getattr(model, "modules", lambda: [])()
) or any(getattr(arg, "is_quantized", False) for arg in args):
return torch.jit.trace(model, args)
return model
def _assign_onnx_node_name(graph, node_names):
"""Takes in ONNX graph, and mapping from _C.Node to node name in exported ONNX ModelProto.
Returns:
graph (_C.Graph): A TorchScript IR Graph with ONNX nodes, where each _C.Node gets its name
in exported ONNX ModelProto assigned as attribute ``onnx_name``.
"""
def n_fn(n, b_fn, node_names):
for b in n.blocks():
b_fn(b, node_names)
if n in node_names:
n.s_("onnx_name", node_names[n])
def b_fn(b, node_names):
for n in b.nodes():
n_fn(n, b_fn, node_names)
b_fn(graph, node_names)
return graph
def _model_to_graph(
model,
args,
verbose=False,
input_names=None,
output_names=None,
operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
do_constant_folding=True,
_disable_torch_constant_prop=False,
fixed_batch_size=False,
training=None,
dynamic_axes=None,
) -> Tuple[
_C.Graph,
Dict[str, torch.Tensor],
Optional[Union[torch.Tensor, Tuple[torch.Tensor], List[torch.Tensor]]],
]:
"""Converts model into an ONNX graph.
Returns:
graph: A TorchScript IR Graph with ONNX nodes.
params_dict: Dict from input param name to param value.
torch_out: The output tensors resulting from the trace of ``model``.
If ``model`` is a :class:`torch.jit.ScriptModule` or :class:`torch.jit.ScriptFunction`,
this will be None, since we are not doing any tracing.
"""
# TODO: can we simplify this to always return a tuple of Tensor or None?
# Special case for common case of passing a single Tensor
if isinstance(args, (torch.Tensor, int, float, bool)):
args = (args,)
model = _pre_trace_quant_model(model, args)
graph, params, torch_out, module = _create_jit_graph(model, args)
params_dict = _get_named_param_dict(graph, params)
try:
graph = _optimize_graph(
graph,
operator_export_type,
_disable_torch_constant_prop=_disable_torch_constant_prop,
fixed_batch_size=fixed_batch_size,
params_dict=params_dict,
dynamic_axes=dynamic_axes,
input_names=input_names,
module=module,
)
except Exception as e:
torch.onnx.log("Torch IR graph at exception: ", graph)
raise
is_script = isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule))
if is_script:
example_outputs = _get_example_outputs(model, args)
example_outputs_final = ()
for example_output in example_outputs:
example_outputs_final += unpack_quantized_tensor(example_output)
out_vars, desc = torch.jit._flatten(example_outputs_final)
_C._jit_pass_onnx_assign_output_shape(
graph, out_vars, desc, GLOBALS.onnx_shape_inference, is_script
)
# NB: ONNX requires complete information about output types, which might be
# erased by some optimizations, so we need to set it explicitly again.
else:
if not isinstance(torch_out, (list, tuple)):
output_wrapped = [torch_out]
else:
output_wrapped = torch_out # type: ignore[assignment]
output_tensors, out_desc = _C._jit_flatten(tuple(output_wrapped))
# assign_output_shape pass is not compatible with quantized outputs.
# Quantized outputs are flattened to 3 values in ONNX, while packed as
# single value in PyTorch.
if not any(getattr(out, "is_quantized", False) for out in output_tensors):
_C._jit_pass_onnx_assign_output_shape(
graph,
output_tensors,
out_desc,
GLOBALS.onnx_shape_inference,
is_script,
)
_set_input_and_output_names(graph, input_names, output_names)
params_dict = _get_named_param_dict(graph, params)
if training is None or training == _C_onnx.TrainingMode.EVAL:
params_dict = _C._jit_pass_onnx_eval_peephole(graph, params_dict)
if (
do_constant_folding
and GLOBALS.export_onnx_opset_version in _constants.onnx_constant_folding_opsets
):
params_dict = _C._jit_pass_onnx_constant_fold(
graph, params_dict, GLOBALS.export_onnx_opset_version
)
_C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
if GLOBALS.onnx_shape_inference:
_C._jit_pass_onnx_graph_shape_type_inference(
graph, params_dict, GLOBALS.export_onnx_opset_version
)
params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict)
# For ONNX opset < 9, constants only have three data types: float16, float, double.
# In this pass transform constants of other data types to float/double + cast operator.
if GLOBALS.export_onnx_opset_version < 9:
_C._jit_pass_onnx_cast_all_constant_to_floating(graph)
params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict)
_C._jit_decay_packed_param_input_types(graph)
# If output names lack a proper name and are identified only by their unique
# give them a legible name for debugging purposes
_apply_friendly_debug_names(graph, params_dict)
return graph, params_dict, torch_out
def export_to_pretty_string(
model,
args,
export_params=True,
verbose=False,
training=None,
input_names=None,
output_names=None,
operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
export_type=None,
google_printer=False,
opset_version=None,
keep_initializers_as_inputs=None,
custom_opsets=None,
add_node_names=True,
do_constant_folding=True,
dynamic_axes=None,
):
if opset_version is None:
opset_version = _constants.onnx_default_opset
if custom_opsets is None:
custom_opsets = {}
symbolic_helper._set_opset_version(opset_version)
symbolic_helper._set_operator_export_type(operator_export_type)
symbolic_helper._set_onnx_shape_inference(True)
with exporter_context(model, training, verbose):
val_keep_init_as_ip = _decide_keep_init_as_input(
keep_initializers_as_inputs, operator_export_type, opset_version
)
val_add_node_names = _decide_add_node_names(
add_node_names, operator_export_type
)
val_do_constant_folding = _decide_constant_folding(
do_constant_folding, operator_export_type, training
)
args = _decide_input_format(model, args)
graph, params_dict, torch_out = _model_to_graph(
model,
args,
verbose,
input_names,
output_names,
operator_export_type,
val_do_constant_folding,
training=training,
dynamic_axes=dynamic_axes,
)
return graph._pretty_print_onnx( # type: ignore[attr-defined]
params_dict,
opset_version,
False,
operator_export_type,
google_printer,
val_keep_init_as_ip,
custom_opsets,
val_add_node_names,
)
def unconvertible_ops(
model, args, training=_C_onnx.TrainingMode.EVAL, opset_version=None
):
r"""
Converts the model with operator_export_type set to
torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH once in order to get a list of
all the ops that are not supported/implemented by the exporter.
Args:
model: Same as corresponding arg to torch.onnx.export.
args: Same as corresponding arg to torch.onnx.export.
training: Same as corresponding arg to torch.onnx.export.
opset_version: Same as corresponding arg to torch.onnx.export.
Returns:
Tuple[torch._C.Graph, List[str]], where the list includes the names
of the unconvertible ops.
"""
opset_version = opset_version or _constants.onnx_default_opset
symbolic_helper._set_opset_version(opset_version)
# operator_export_type is set to ONNX_FALLTHROUGH by default so that if an op is not supported
# in ONNX, fall through will occur and export the operator as is, as a custom ONNX op.
with exporter_context(model, training, False):
args = _decide_input_format(model, args)
graph, params_dict, torch_out = _model_to_graph(
model,
args,
# So that if an op connot be converted to ONNX, it will be kept
# as-is rather than cause a failure.
operator_export_type=_C_onnx.OperatorExportTypes.ONNX_FALLTHROUGH,
)
unsupported_ops = list()
supported_namespaces = ("onnx", "prim", "quantized")
for node in graph.nodes():
if node.kind().split(":")[0] not in supported_namespaces:
unsupported_ops.append(node.kind())
return graph, unsupported_ops
def _setup_trace_module_map(model, export_modules_as_functions):
def __setup_trace_module_map():
trace_module_map = {_m: torch.typename(type(_m)) for _m in model.modules()}
torch.jit._trace._trace_module_map = trace_module_map
return trace_module_map
def __register_attribute_hook():
attr_name = "_onnx_attrs"
def _track_module_attributes_forward_pre_hook(module, input):
setattr(module, attr_name, _get_module_attributes(module))
def _track_module_attributes_forward_hook(module, input, output):
tracing_state = _C._get_tracing_state()
if not tracing_state:
return
graph = tracing_state.graph()
onnx_attrs = {}
if hasattr(module, attr_name):
onnx_attrs = getattr(module, attr_name)
delattr(module, attr_name)
_C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs)
for m in model.modules():
m.register_forward_hook(_track_module_attributes_forward_hook)
m.register_forward_pre_hook(_track_module_attributes_forward_pre_hook)
if isinstance(export_modules_as_functions, bool) and export_modules_as_functions:
trace_module_map = __setup_trace_module_map()
export_modules_as_functions = {v for k, v in trace_module_map.items()}
elif (
isinstance(export_modules_as_functions, set)
and len(export_modules_as_functions) > 0
):
def _find_typename(v):
if isinstance(v, type):
return torch.typename(v)
else:
raise RuntimeError(
"Only type of the `nn.Module` should be "
"passed in the set for argument `export_modules_as_functions`. "
"Got `%s`." % (type(v).__name__)
)
trace_module_map = __setup_trace_module_map()
module_typenames = {_find_typename(v) for v in export_modules_as_functions}
export_modules_as_functions = module_typenames
else:
export_modules_as_functions = None
if export_modules_as_functions:
__register_attribute_hook()
return export_modules_as_functions
def _reset_trace_module_map():
torch.jit._trace._trace_module_map = None
_C._jit_pass_onnx_clear_scope_records()
def _get_module_attributes(module):
annotations = typing.get_type_hints(type(module))
base_m_annotations = typing.get_type_hints(torch.nn.Module)
[annotations.pop(k, None) for k in base_m_annotations]
return {k: getattr(module, k) for k in annotations}
def _export(
model,
args,
f,
export_params=True,
verbose=False,
training=None,
input_names=None,
output_names=None,
operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
export_type=None,
opset_version=None,
do_constant_folding=True,
dynamic_axes=None,
keep_initializers_as_inputs=None,
fixed_batch_size=False,
custom_opsets=None,
add_node_names=True,
onnx_shape_inference=True,
export_modules_as_functions=False,
):
if export_type is None:
export_type = _exporter_states.ExportTypes.PROTOBUF_FILE
if isinstance(model, torch.nn.DataParallel):
raise ValueError(
"torch.nn.DataParallel is not supported by ONNX "
"exporter, please use 'attribute' module to "
"unwrap model from torch.nn.DataParallel. Try "
"torch.onnx.export(model.module, ...)"
)
assert GLOBALS.in_onnx_export is False
GLOBALS.in_onnx_export = True
try:
symbolic_helper._set_onnx_shape_inference(onnx_shape_inference)
if opset_version is None:
opset_version = _constants.onnx_default_opset
if export_modules_as_functions and opset_version < 15:
raise ValueError(
"`export_modules_as_functions` is not supported for `opset_version` < 15."
"This is because `opset_version` < 15 implies IR version < 8, which means "
"no local function support. "
)
export_modules_as_functions = _setup_trace_module_map(
model, export_modules_as_functions
)
if not operator_export_type:
if _C_onnx._CAFFE2_ATEN_FALLBACK:
operator_export_type = _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
else:
operator_export_type = _C_onnx.OperatorExportTypes.ONNX
# By default, training=None, (which defaults to TrainingMode.EVAL),
# which is good because running a model in training mode could result in
# internal buffers getting updated, dropout getting applied, etc.
# If you really know what you're doing, you can turn
# training=TrainingMode.TRAINING or training=TrainingMode.PRESERVE,
# (to preserve whatever the original training mode was.)
symbolic_helper._set_opset_version(opset_version)
symbolic_helper._set_operator_export_type(operator_export_type)
with exporter_context(model, training, verbose):
val_keep_init_as_ip = _decide_keep_init_as_input(
keep_initializers_as_inputs, operator_export_type, opset_version
)
val_add_node_names = _decide_add_node_names(
add_node_names, operator_export_type
)
val_do_constant_folding = _decide_constant_folding(
do_constant_folding, operator_export_type, training
)
# Normally f can be a file-like object, but for large models, the external data format requires a
# valid `model_file_location`. Code in export.cpp will enforce this.
if isinstance(f, str):
model_file_location = f
else:
model_file_location = ""
args = _decide_input_format(model, args)
if dynamic_axes is None:
dynamic_axes = {}
_validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
graph, params_dict, torch_out = _model_to_graph(
model,
args,
verbose,
input_names,
output_names,
operator_export_type,
val_do_constant_folding,
fixed_batch_size=fixed_batch_size,
training=training,
dynamic_axes=dynamic_axes,
)
# TODO: Don't allocate a in-memory string for the protobuf
defer_weight_export = (
export_type is not _exporter_states.ExportTypes.PROTOBUF_FILE
)
if custom_opsets is None:
custom_opsets = {}
_C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
node_attr_to_name = {} # type: ignore[var-annotated]
if export_modules_as_functions:
# NOTE: cannot call DCE after this pass. DCE will remove function definition nodes.
node_attr_to_name = _C._jit_pass_onnx_function_extraction(
graph, export_modules_as_functions, list(params_dict.keys())
)
params_dict = _C._jit_pass_onnx_deduplicate_initializers( # type: ignore[assignment]
graph, params_dict, getattr(model, "training", False) # type: ignore[arg-type]
)
if export_params:
(
proto,
export_map,
val_use_external_data_format,
node_names,
) = graph._export_onnx( # type: ignore[attr-defined]
params_dict,
opset_version,
dynamic_axes,
defer_weight_export,
operator_export_type,
not verbose,
val_keep_init_as_ip,
custom_opsets,
val_add_node_names,
model_file_location,
node_attr_to_name,
)
else:
(
proto,
export_map,
val_use_external_data_format,
node_names,
) = graph._export_onnx( # type: ignore[attr-defined]
{},
opset_version,
dynamic_axes,
False,
operator_export_type,
not verbose,
val_keep_init_as_ip,
custom_opsets,
val_add_node_names,
model_file_location,
node_attr_to_name,
)
if verbose:
torch.onnx.log(
"Exported graph: ", _assign_onnx_node_name(graph, node_names)
)
if export_type == _exporter_states.ExportTypes.PROTOBUF_FILE:
assert len(export_map) == 0
with torch.serialization._open_file_like(f, "wb") as opened_file:
opened_file.write(proto)
elif export_type in [
_exporter_states.ExportTypes.ZIP_ARCHIVE,
_exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE,
]:
compression = (
zipfile.ZIP_DEFLATED
if export_type
== _exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE
else zipfile.ZIP_STORED
)
with zipfile.ZipFile(f, "w", compression=compression) as z:
z.writestr(_constants.ONNX_ARCHIVE_MODEL_PROTO_NAME, proto)
for k, v in export_map.items():
z.writestr(k, v)
elif export_type == _exporter_states.ExportTypes.DIRECTORY:
if os.path.exists(f):
assert os.path.isdir(f)
else:
os.makedirs(f)
model_proto_file = os.path.join(
f, _constants.ONNX_ARCHIVE_MODEL_PROTO_NAME
)
with torch.serialization._open_file_like(
model_proto_file, "wb"
) as opened_file:
opened_file.write(proto)
for k, v in export_map.items():
weight_proto_file = os.path.join(f, k)
with torch.serialization._open_file_like(
weight_proto_file, "wb"
) as opened_file:
opened_file.write(v)
else:
raise RuntimeError("Unknown export type")
# The ONNX checker only works for ONNX graph. So if the operator_export_type is not ONNX,
# we can skip this check.
# If large model format export is enabled, proto will only contain data location instead of
# raw data and _check_onnx_proto() will fail because it can only handle the raw ONNX proto
# string in memory.
if (operator_export_type is _C_onnx.OperatorExportTypes.ONNX) and (
not val_use_external_data_format
):
try:
_C._check_onnx_proto(proto, full_check=True)
except RuntimeError as e:
raise errors.CheckerError(e)
finally:
assert GLOBALS.in_onnx_export
GLOBALS.in_onnx_export = False
_reset_trace_module_map()
return torch_out
def _apply_friendly_debug_names(graph, params):
for n in graph.nodes():
for v in n.inputs():
old_name = v.debugName()
if old_name != str(v.unique()):
continue
new_name = f"{n.kind()}_{v.unique()}"
v.setDebugName(new_name)
if old_name in params:
params[new_name] = params.pop(old_name)
def _set_input_and_output_names(graph, input_names, output_names):
def set_names(node_list, name_list, descriptor):
if name_list is None:
return
if len(name_list) > len(node_list):
raise RuntimeError(
"number of %s names provided (%d) exceeded number of %ss (%d)"
% (descriptor, len(name_list), descriptor, len(node_list))
)
# Mark if the output node DebugName is set before.
output_node_set = set()
for i, (name, node) in enumerate(zip(name_list, node_list)):
# Duplicated output node, insert onnx::Identity to avoid setting the same DebugName after setDebugName().
if descriptor == "output":
if node in output_node_set:
identity_node = graph.create("onnx::Identity")
identity_node.insertAfter(node.node())
identity_node.addInput(node)
identity_node.output().setType(node.type())
graph.return_node().replaceInput(i, identity_node.output())
node = identity_node.output()
output_node_set.add(node)
if node.debugName() != name:
node.setDebugName(name)
set_names(list(graph.inputs()), input_names, "input")
set_names(list(graph.outputs()), output_names, "output")
def _run_symbolic_method(g, op_name, symbolic_fn, args):
r"""
This trampoline function gets invoked for every symbolic method
call from C++.
"""
try:
return symbolic_fn(g, *args)
except TypeError as e:
# Handle the specific case where we didn't successfully dispatch
# to symbolic_fn. Otherwise, the backtrace will have the clues
# you need.
e.args = (f"{e.args[0]} (occurred when translating {op_name})",)
raise
def _add_block(node: _C.Node):
return node.addBlock() # type: ignore[attr-defined]
def _add_input_to_block(block: _C.Block):
return block.addInputToBlock() # type: ignore[attr-defined]
def _add_output_to_block(block: _C.Block, value: _C.Value):
new_output = block.registerOutput(value) # type: ignore[attr-defined]
return new_output
# Note [Export inplace]
# ~~~~~~~~~~~~~~~~~~~~~
# In abstract, it would be better for us to export inplace annotations,
# than to not export them, since it is useful information that can
# help the target of an ONNX export export more efficiently. However,
# ONNX doesn't currently formalize inplace. Fortunately, it's sound to drop
# inplace annotations, but we are losing information this way.
def _find_symbolic_in_registry(
domain: str,
op_name: str,
opset_version: int,
operator_export_type: _C_onnx.OperatorExportTypes,
) -> Optional[Callable]:
"""Looks up for the symbolic function in the registry.
Args:
domain: The domain of the symbolic function.
op_name: The name of the op.
opset_version: Currect opset used.
operator_export_type: An enum in _C_onnx.OperatorExportTypes.
Returns:
The symbolic function if found, None otherwise.
"""
if not symbolic_registry.is_registered_op(op_name, domain, opset_version):
if operator_export_type == _C_onnx.OperatorExportTypes.ONNX_FALLTHROUGH:
# Use the original node directly
return None
return symbolic_registry.get_registered_op(op_name, domain, opset_version)
def _should_aten_fallback(ns, op_name, opset_version, operator_export_type):
is_exportable_aten_op = symbolic_registry.is_registered_op(
op_name, "", opset_version
)
is_onnx_aten_export = operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN
is_aten_fallback_export = (
operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
)
return is_onnx_aten_export or (
not is_exportable_aten_op and is_aten_fallback_export
)
def _need_symbolic_context(symbolic_fn) -> bool:
"""Checks if the first argument to symbolic_fn is annotated as type `torch.onnx.SymbolicContext`."""
params = tuple(inspect.signature(symbolic_fn).parameters.values())
# When the annotation is postpone-evaluated, the annotation is a string
# and not a type. We need to use get_type_hints to get the real type.
if not params:
return False
first_param_name = params[0].name
type_hints = typing.get_type_hints(symbolic_fn)
if first_param_name not in type_hints:
return False
param_type = type_hints[first_param_name]
return issubclass(param_type, _exporter_states.SymbolicContext)
def _get_aten_op_overload_name(n: _C.Node) -> str:
# Returns `overload_name` attribute to ATen ops on non-Caffe2 builds
schema = n.schema()
if not schema.startswith("aten::") or symbolic_helper.is_caffe2_aten_fallback():
return ""
return _C.parse_schema(schema).overload_name
def _run_symbolic_function(
g: _C.Graph,
block: _C.Block,
n: _C.Node,
inputs: Any,
env: Dict[_C.Value, _C.Value],
operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
) -> Optional[Union[_C.Value, Tuple[_C.Value, ...]]]:
"""Runs a symbolic function.
The function is used in C++ to export the node to ONNX.
Returns:
A single or a tuple of Values.
None when the node gets cloned as is into the new graph.
"""
opset_version = GLOBALS.export_onnx_opset_version
symbolic_helper.is_caffe2_aten_fallback = symbolic_helper.is_caffe2_aten_fallback
# See Note [Export inplace]
# TODO(ezyang): I think this is not necessary anymore
if n.kind().endswith("_"):
ns_op_name = n.kind()[:-1]
else:
ns_op_name = n.kind()
ns, op_name = ns_op_name.split("::")
try:
symbolic_registry.register_version("", opset_version)
# Caffe2-specific: Quantized op symbolics are registered for opset 9 only.
if symbolic_helper.is_caffe2_aten_fallback() and opset_version == 9:
symbolic_caffe2.register_quantized_ops("caffe2", opset_version)
if ns == "aten":
domain = ""
elif ns == "quantized" and symbolic_helper.is_caffe2_aten_fallback():
domain = "caffe2"
else:
domain = ns
if symbolic_registry.is_registered_op(op_name, domain, opset_version):
symbolic_fn = _find_symbolic_in_registry(
domain, op_name, opset_version, operator_export_type
)
assert symbolic_fn is not None
attrs = {k: n[k] for k in n.attributeNames()} # type: ignore[attr-defined]
if _need_symbolic_context(symbolic_fn):
ctx = _exporter_states.SymbolicContext(_params_dict, env, n, block)
return symbolic_fn(ctx, g, *inputs, **attrs)
# PythonOp symbolic need access to the node to resolve the name conflict,
# this is inconsistent with regular op symbolic.
if op_name == "PythonOp":
inputs = (n, *inputs)
return symbolic_fn(g, *inputs, **attrs)
elif ns == "onnx":
# Clone node to trigger ONNX shape inference
attrs = {k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames()} # type: ignore[attr-defined]
return g.op(op_name, *inputs, **attrs, outputs=n.outputsSize()) # type: ignore[attr-defined]
elif _should_aten_fallback(ns, op_name, opset_version, operator_export_type):
# Direct ATen export requested
attrs = {k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames()} # type: ignore[attr-defined]
outputs = n.outputsSize()
attrs["outputs"] = outputs
# `overload_name` is set for non-Caffe2 builds only
return g.at( # type: ignore[attr-defined]
op_name, *inputs, overload_name=_get_aten_op_overload_name(n), **attrs
)
else:
raise errors.UnsupportedOperatorError(
domain,
op_name,
opset_version,
symbolic_registry.get_op_supported_version(
op_name, domain, opset_version
),
)
except RuntimeError:
if operator_export_type == _C_onnx.OperatorExportTypes.ONNX_FALLTHROUGH:
return None
elif (
operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
and not symbolic_helper.is_caffe2_aten_fallback()
):
# Emit ATen op for non-Caffe2 builds when `operator_export_type==ONNX_ATEN_FALLBACK`
attrs = {k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames()} # type: ignore[attr-defined]
return g.at( # type: ignore[attr-defined]
op_name, *inputs, overload_name=_get_aten_op_overload_name(n), **attrs
)
raise
except TypeError as e:
# Handle the specific case where we didn't successfully dispatch.
# Otherwise, the backtrace will have the clues you need.
e.args = (f"{e.args[0]} \n(Occurred when translating {op_name}).",)
raise
def get_ns_op_name_from_custom_op(symbolic_name):
if not bool(
re.match(r"^[a-zA-Z0-9-_]*::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name)
):
raise ValueError(
f"Failed to register operator {symbolic_name}."
"The symbolic name must match the format Domain::Name, "
"and should start with a letter and contain only "
"alphanumerical characters"
)
ns, op_name = symbolic_name.split("::")
if ns == "onnx":
raise ValueError(
f"Failed to register operator {symbolic_name}. {ns} domain cannot be modified."
)
if ns == "aten":
ns = ""
return ns, op_name
def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version):
"""Registers a symbolic function for a custom operator.
When the user registers symbolic for custom/contrib ops,
it is highly recommended to add shape inference for that operator via setType API,
otherwise the exported graph may have incorrect shape inference in some extreme cases.
An example of setType is `test_aten_embedding_2` in `test_operators.py`.
"""
ns, op_name = get_ns_op_name_from_custom_op(symbolic_name)
for version in itertools.chain(
_constants.onnx_stable_opsets, [_constants.onnx_main_opset]
):
if version >= opset_version:
symbolic_registry.register_op(op_name, symbolic_fn, ns, version)
def unregister_custom_op_symbolic(symbolic_name, opset_version):
ns, op_name = get_ns_op_name_from_custom_op(symbolic_name)
for version in itertools.chain(
_constants.onnx_stable_opsets, [_constants.onnx_main_opset]
):
if version >= opset_version:
symbolic_registry.unregister_op(op_name, ns, version)
def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names):
"""Ensures dynamic axes argument is follows the expected format."""
if len(dynamic_axes) == 0:
return
if hasattr(model, "graph"):
# Extracting set of valid input/output names that shall be used for dynamic_axes
if (input_names is None) or len(input_names) == 0:
input_names = [x.debugName() for x in model.graph.inputs()]
if (output_names is None) or len(output_names) == 0:
output_names = [y.debugName() for y in model.graph.outputs()]
valid_names = set((input_names or []) + (output_names or []))
# If dynamic axes are provided as a list rather than dictionary, they should
# first get converted to a dictionary in expected format. If desired axes names
# are not provided for dynamic axes, automatic names shall be generated for
# provided dynamic axes of specified input/output
for key, value in dynamic_axes.items():
if key not in valid_names:
warnings.warn(
f"Provided key {key} for dynamic axes is not a valid input/output name"
)
if isinstance(value, list):
warnings.warn(
"No names were found for specified dynamic axes of provided input."
f"Automatically generated names will be applied to each dynamic axes of input {key}"
)
value_dict = {}
for i, x in enumerate(value):
if not isinstance(x, int):
raise ValueError(
"The type of axis index is expected to be an integer"
)
if x in value_dict:
warnings.warn(
f"Duplicate dynamic axis index {x} was provided for input {key}."
)
else:
value_dict[x] = str(key) + "_dynamic_axes_" + str(i + 1)
dynamic_axes[key] = value_dict