blob: 2a57f043c700f53024a515b9436bea01b7669974 [file] [log] [blame]
# Owner(s): ["module: onnx"]
import copy
import functools
import io
import re
import warnings
from typing import Callable
import onnx
import parameterized
import pytorch_test_common
import torch
import torch.onnx
import torch.utils.cpp_extension
import torchvision
from autograd_helper import CustomFunction as CustomFunction2
from pytorch_test_common import (
skipIfNoCuda,
skipIfUnsupportedMaxOpsetVersion,
skipIfUnsupportedMinOpsetVersion,
)
from torch.onnx import _constants, OperatorExportTypes, TrainingMode, utils
from torch.onnx._globals import GLOBALS
from torch.onnx.symbolic_helper import _unpack_list, parse_args
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import skipIfNoCaffe2, skipIfNoLapack
from verify import verify
def _remove_test_environment_prefix_from_scope_name(scope_name: str) -> str:
"""Remove test environment prefix added to module.
Remove prefix to normalize scope names, since different test environments add
prefixes with slight differences.
Example:
>>> _remove_test_environment_prefix_from_scope_name(
>>> "test_utility_funs.M"
>>> )
"M"
>>> _remove_test_environment_prefix_from_scope_name(
>>> "test_utility_funs.test_abc.<locals>.M"
>>> )
"M"
>>> _remove_test_environment_prefix_from_scope_name(
>>> "__main__.M"
>>> )
"M"
"""
prefixes_to_remove = ["test_utility_funs", "__main__"]
for prefix in prefixes_to_remove:
scope_name = re.sub(f"{prefix}\\.(.*?<locals>\\.)?", "", scope_name)
return scope_name
class _BaseTestCase(pytorch_test_common.ExportTestCase):
def _model_to_graph(
self,
model,
input,
do_constant_folding=True,
training=TrainingMode.EVAL,
operator_export_type=OperatorExportTypes.ONNX,
input_names=None,
dynamic_axes=None,
):
torch.onnx.utils._setup_trace_module_map(model, False)
if training == torch.onnx.TrainingMode.TRAINING:
model.train()
elif training == torch.onnx.TrainingMode.EVAL:
model.eval()
utils._validate_dynamic_axes(dynamic_axes, model, None, None)
graph, params_dict, torch_out = utils._model_to_graph(
model,
input,
do_constant_folding=do_constant_folding,
_disable_torch_constant_prop=True,
operator_export_type=operator_export_type,
training=training,
input_names=input_names,
dynamic_axes=dynamic_axes,
)
return graph, params_dict, torch_out
@common_utils.instantiate_parametrized_tests
class TestUnconvertibleOps(pytorch_test_common.ExportTestCase):
"""Unit tests for the `unconvertible_ops` function."""
def setUp(self):
class EinsumModule(torch.nn.Module):
def forward(self, x):
return torch.einsum("ii", x)
self.einsum_module = EinsumModule()
def test_it_returns_graph_and_unconvertible_ops_at_lower_opset_version(self):
x = torch.randn(4, 4)
# Einsum is supported since opset 12. It should be unconvertible at opset 9.
graph, unconvertible_ops = utils.unconvertible_ops(
self.einsum_module, (x,), opset_version=9
)
nodes = graph.nodes()
self.assertEqual(next(nodes).kind(), "prim::Constant")
self.assertEqual(next(nodes).kind(), "prim::ListConstruct")
self.assertEqual(next(nodes).kind(), "prim::Constant")
self.assertEqual(next(nodes).kind(), "aten::einsum")
self.assertEqual(unconvertible_ops, ["aten::einsum"])
@common_utils.parametrize(
"jit_function",
[
common_utils.subtest(
functools.partial(torch.jit.trace, example_inputs=torch.randn(4, 4)),
name="traced",
),
common_utils.subtest(torch.jit.script, name="scripted"),
],
)
def test_it_returns_unconvertible_ops_at_lower_opset_version_for_jit_module(
self, jit_function: Callable
):
module = jit_function(self.einsum_module)
x = torch.randn(4, 4)
# Einsum is supported since opset 12. It should be unconvertible at opset 9.
_, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=9)
self.assertEqual(unconvertible_ops, ["aten::einsum"])
@common_utils.parametrize(
"jit_function",
[
common_utils.subtest(lambda x: x, name="nn_module"),
common_utils.subtest(
functools.partial(torch.jit.trace, example_inputs=torch.randn(4, 4)),
name="traced",
),
common_utils.subtest(torch.jit.script, name="scripted"),
],
)
def test_it_returns_empty_list_when_all_ops_convertible(
self, jit_function: Callable
):
module = jit_function(self.einsum_module)
x = torch.randn(4, 4)
# Einsum is supported since opset 12
_, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=12)
self.assertEqual(unconvertible_ops, [])
def test_it_returns_empty_list_when_model_contains_supported_inplace_ops(self):
class SkipConnectionModule(torch.nn.Module):
def forward(self, x):
out = x
out += x
out = torch.nn.functional.relu(out, inplace=True)
return out
module = SkipConnectionModule()
x = torch.randn(4, 4)
_, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=13)
self.assertEqual(unconvertible_ops, [])
@parameterized.parameterized_class(
[
{"opset_version": opset}
for opset in range(
_constants.ONNX_BASE_OPSET,
_constants.ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET + 1,
)
],
class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_opset_{params_dict['opset_version']}",
)
class TestUtilityFuns(_BaseTestCase):
opset_version = None
def test_is_in_onnx_export(self):
test_self = self
class MyModule(torch.nn.Module):
def forward(self, x):
test_self.assertTrue(torch.onnx.is_in_onnx_export())
raise ValueError
return x + 1
x = torch.randn(3, 4)
f = io.BytesIO()
try:
torch.onnx.export(MyModule(), x, f, opset_version=self.opset_version)
except ValueError:
self.assertFalse(torch.onnx.is_in_onnx_export())
def test_validate_dynamic_axes_invalid_input_output_name(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
utils._validate_dynamic_axes(
{"input1": {}, "output": {}, "invalid_name1": {}, "invalid_name2": {}},
None,
["input1", "input2"],
["output"],
)
messages = [str(warning.message) for warning in w]
self.assertIn(
"Provided key invalid_name1 for dynamic axes is not a valid input/output name",
messages,
)
self.assertIn(
"Provided key invalid_name2 for dynamic axes is not a valid input/output name",
messages,
)
self.assertEqual(len(messages), 2)
@skipIfUnsupportedMinOpsetVersion(11)
def test_split_to_slice(self):
class SplitModule(torch.nn.Module):
def forward(self, x, y, t):
splits = (x.size(1), y.size(1))
out, out2 = torch.split(t, splits, dim=1)
return out, out2
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
x = torch.randn(2, 3)
y = torch.randn(2, 4)
t = torch.randn(2, 7)
graph, _, _ = self._model_to_graph(
SplitModule(),
(x, y, t),
input_names=["x", "y", "t"],
dynamic_axes={"x": [0, 1], "y": [0, 1], "t": [0, 1]},
)
for node in graph.nodes():
self.assertNotEqual(node.kind(), "onnx::SplitToSequence")
def test_constant_fold_transpose(self):
class TransposeModule(torch.nn.Module):
def forward(self, x):
a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
b = torch.transpose(a, 1, 0)
return b + x
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
x = torch.ones(3, 2)
graph, _, __ = self._model_to_graph(
TransposeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
)
for node in graph.nodes():
self.assertNotEqual(node.kind(), "onnx::Transpose")
self.assertNotEqual(node.kind(), "onnx::Cast")
self.assertEqual(len(list(graph.nodes())), 2)
def test_constant_fold_reduceL2(self):
class ReduceModule(torch.nn.Module):
def forward(self, x):
a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
b = torch.norm(a, p=2, dim=-2, keepdim=False)
return b + x
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
x = torch.ones(2, 3)
graph, _, __ = self._model_to_graph(
ReduceModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
)
for node in graph.nodes():
self.assertNotEqual(node.kind(), "onnx::ReduceL2")
def test_constant_fold_reduceL1(self):
class NormModule(torch.nn.Module):
def forward(self, x):
a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
b = torch.norm(a, p=1, dim=-2)
return b + x
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
x = torch.ones(2, 3)
graph, _, __ = self._model_to_graph(
NormModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
)
for node in graph.nodes():
self.assertNotEqual(node.kind(), "onnx::ReduceL1")
def test_constant_fold_slice(self):
class NarrowModule(torch.nn.Module):
def forward(self, x):
a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
b = torch.narrow(a, 0, 0, 1)
return b + x
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
x = torch.ones(1, 3)
graph, _, __ = self._model_to_graph(
NarrowModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
)
for node in graph.nodes():
self.assertNotEqual(node.kind(), "onnx::Slice")
self.assertNotEqual(node.kind(), "onnx::Cast")
self.assertEqual(len(list(graph.nodes())), 2)
def test_constant_fold_slice_index_exceeds_dim(self):
class SliceIndexExceedsDimModule(torch.nn.Module):
def forward(self, x):
a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
b = a[1:10] # index exceeds dimension
return b + x
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
x = torch.ones(1, 3)
graph, _, __ = self._model_to_graph(
SliceIndexExceedsDimModule(),
(x,),
input_names=["x"],
dynamic_axes={"x": [0, 1]},
)
for node in graph.nodes():
self.assertNotEqual(node.kind(), "onnx::Slice")
self.assertNotEqual(node.kind(), "onnx::Cast")
self.assertEqual(len(list(graph.nodes())), 2)
def test_constant_fold_slice_negative_index(self):
class SliceNegativeIndexModule(torch.nn.Module):
def forward(self, x):
a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
b = a[0:-1] # index relative to the end
c = torch.select(a, dim=-1, index=-2)
d = torch.select(a, dim=1, index=0)
return b + x, c + d
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
x = torch.ones(1, 3)
graph, _, __ = self._model_to_graph(
SliceNegativeIndexModule(),
(x,),
input_names=["x"],
dynamic_axes={"x": [0, 1]},
)
for node in graph.nodes():
self.assertNotEqual(node.kind(), "onnx::Slice")
self.assertNotEqual(node.kind(), "onnx::Cast")
def test_constant_fold_gather(self):
class GatherModule(torch.nn.Module):
def forward(self, x):
a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
b = torch.select(a, dim=1, index=-2)
c = torch.index_select(a, dim=-2, index=torch.tensor([0, 1]))
return b + 1, c + x
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
x = torch.ones(1, 3)
model = GatherModule()
model(x)
graph, _, __ = self._model_to_graph(
GatherModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
)
for node in graph.nodes():
self.assertNotEqual(node.kind(), "onnx::Gather")
def test_constant_fold_unsqueeze(self):
class UnsqueezeModule(torch.nn.Module):
def forward(self, x):
a = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
b = torch.unsqueeze(a, -2)
return b + x
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
x = torch.ones(1, 2, 3)
graph, _, __ = self._model_to_graph(
UnsqueezeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
)
for node in graph.nodes():
self.assertNotEqual(node.kind(), "onnx::Unsqueeze")
self.assertNotEqual(node.kind(), "onnx::Cast")
self.assertEqual(len(list(graph.nodes())), 2)
def test_constant_fold_unsqueeze_multi_axies(self):
class PReluModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.prelu = torch.nn.PReLU()
def forward(self, x):
a = torch.randn(2, 3, 4, 5, 8, 7)
return self.prelu(x) + a
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
x = torch.randn(2, 3, 4, 5, 8, 7)
graph, _, __ = self._model_to_graph(
PReluModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2, 3, 4, 5]}
)
for node in graph.nodes():
self.assertNotEqual(node.kind(), "onnx::Unsqueeze")
self.assertNotEqual(node.kind(), "onnx::Cast")
self.assertEqual(len(list(graph.nodes())), 5)
def test_constant_fold_squeeze_without_axes(self):
class SqueezeModule(torch.nn.Module):
def forward(self, x):
a = torch.tensor([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]])
return torch.squeeze(a) + x + torch.squeeze(a)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
x = torch.ones(2, 3)
graph, _, __ = self._model_to_graph(
SqueezeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
)
for node in graph.nodes():
self.assertNotEqual(node.kind(), "onnx::Squeeze")
self.assertNotEqual(node.kind(), "onnx::Cast")
self.assertEqual(len(list(graph.nodes())), 4)
def test_constant_fold_squeeze_with_axes(self):
class SqueezeAxesModule(torch.nn.Module):
def forward(self, x):
a = torch.tensor([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]])
return torch.squeeze(a, dim=-3) + x
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
x = torch.ones(2, 3)
graph, _, __ = self._model_to_graph(
SqueezeAxesModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
)
for node in graph.nodes():
self.assertNotEqual(node.kind(), "onnx::Squeeze")
self.assertNotEqual(node.kind(), "onnx::Cast")
self.assertEqual(len(list(graph.nodes())), 2)
def test_constant_fold_concat(self):
class ConcatModule(torch.nn.Module):
def forward(self, x):
# Why did I insert a Cast here? There appears to be intentional
# behavior in ONNX constant folding where constant tensors which
# are not attached to any known to be foldable onnx
# operations don't get extracted into the initializer graph. So
# without these casts, we will actually fail to pull out one of
# the constants, thus failing constant folding. I think the
# test is wrong but I don't have time to write a more correct
# test (I think the right way to go about the test is to setup
# a predicate for what invariant graphs should hold after
# constant folding, and then verify this predicate holds.
# I think the asserts below are an attempt at this predicate,
# but it is not right!)
#
# More commentary at
# https://github.com/pytorch/pytorch/pull/18698/files#r340107552
a = torch.tensor([[1.0, 2.0, 3.0]]).to(torch.float)
b = torch.tensor([[4.0, 5.0, 6.0]]).to(torch.float)
c = torch.cat((a, b), 0)
d = b + c
return x + d
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
x = torch.ones(2, 3)
graph, _, __ = self._model_to_graph(
ConcatModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
)
for node in graph.nodes():
self.assertNotEqual(node.kind(), "onnx::Concat")
self.assertNotEqual(node.kind(), "onnx::Cast")
self.assertEqual(len(list(graph.nodes())), 2)
def test_constant_fold_lstm(self):
class GruNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.mygru = torch.nn.GRU(7, 3, 1, bidirectional=False)
def forward(self, input, initial_state):
return self.mygru(input, initial_state)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
input = torch.randn(5, 3, 7)
h0 = torch.randn(1, 3, 3)
graph, _, __ = self._model_to_graph(
GruNet(),
(input, h0),
input_names=["input", "h0"],
dynamic_axes={"input": [0, 1, 2], "h0": [0, 1, 2]},
)
for node in graph.nodes():
self.assertNotEqual(node.kind(), "onnx::Slice")
self.assertNotEqual(node.kind(), "onnx::Concat")
self.assertNotEqual(node.kind(), "onnx::Unsqueeze")
if self.opset_version <= 12:
self.assertEqual(len(list(graph.nodes())), 3)
else:
# Unsqueeze op parameter "axes" as an input instead of as an attribute when opset version >= 13
self.assertEqual(len(list(graph.nodes())), 4)
def test_constant_fold_transpose_matmul(self):
class MatMulNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.B = torch.nn.Parameter(torch.ones(5, 3))
def forward(self, A):
return torch.matmul(A, torch.transpose(self.B, -1, -2))
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
A = torch.randn(2, 3)
graph, _, __ = self._model_to_graph(
MatMulNet(), (A,), input_names=["A"], dynamic_axes={"A": [0, 1]}
)
for node in graph.nodes():
self.assertNotEqual(node.kind(), "onnx::Transpose")
self.assertEqual(len(list(graph.nodes())), 1)
def test_constant_fold_reshape(self):
class ReshapeModule(torch.nn.Module):
def __init__(
self,
):
super().__init__()
self.register_buffer("weight", torch.ones(5))
def forward(self, x):
b = self.weight.reshape(1, -1, 1, 1)
return x * b
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
x = torch.randn(4, 5)
graph, _, __ = self._model_to_graph(
ReshapeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
)
for node in graph.nodes():
self.assertNotEqual(node.kind(), "onnx::Reshape")
self.assertEqual(len(list(graph.nodes())), 1)
def test_constant_fold_div(self):
class Module(torch.nn.Module):
def __init__(
self,
):
super().__init__()
self.register_buffer("weight", torch.ones(5))
def forward(self, x):
div = self.weight.div(torch.tensor([1, 2, 3, 4, 5]))
return div * x
x = torch.randn(2, 5)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
graph, _, __ = self._model_to_graph(
Module(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
)
for node in graph.nodes():
self.assertNotEqual(node.kind(), "onnx::Div")
self.assertEqual(len(list(graph.nodes())), 1)
def test_constant_fold_mul(self):
class Module(torch.nn.Module):
def __init__(
self,
):
super().__init__()
self.register_buffer("weight", torch.ones(5))
def forward(self, x):
mul = self.weight.mul(torch.tensor([1, 2, 3, 4, 5]))
return mul / x
x = torch.randn(2, 5)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
graph, _, __ = self._model_to_graph(
Module(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
)
for node in graph.nodes():
self.assertNotEqual(node.kind(), "onnx::Mul")
self.assertEqual(len(list(graph.nodes())), 1)
def test_constant_fold_add(self):
class Module(torch.nn.Module):
def __init__(
self,
):
super().__init__()
self.register_buffer("weight", torch.ones(5))
def forward(self, x):
add = self.weight + torch.tensor([1, 2, 3, 4, 5])
return add - x
x = torch.randn(2, 5)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
graph, params_dict, __ = self._model_to_graph(
Module(),
(x,),
do_constant_folding=True,
operator_export_type=OperatorExportTypes.ONNX,
input_names=["x"],
dynamic_axes={"x": [0, 1]},
)
for node in graph.nodes():
self.assertTrue(node.kind() != "onnx::Add")
self.assertEqual(len(list(graph.nodes())), 1)
params = list(params_dict.values())
self.assertEqual(len(params), 1)
weight = params[0]
self.assertEqual(weight, torch.tensor([2.0, 3.0, 4.0, 5.0, 6.0]))
def test_constant_fold_sub(self):
class Module(torch.nn.Module):
def __init__(
self,
):
super().__init__()
self.register_buffer("weight", torch.ones(5))
def forward(self, x):
sub = self.weight - torch.tensor([1, 2, 3, 4, 5])
return sub + x
x = torch.randn(2, 5)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
graph, params_dict, __ = self._model_to_graph(
Module(),
(x,),
do_constant_folding=True,
operator_export_type=OperatorExportTypes.ONNX,
input_names=["x"],
dynamic_axes={"x": [0, 1]},
)
for node in graph.nodes():
self.assertNotEqual(node.kind(), "onnx::Sub")
self.assertEqual(len(list(graph.nodes())), 1)
params = list(params_dict.values())
self.assertEqual(len(params), 1)
weight = params[0]
self.assertEqual(weight, torch.tensor([0.0, -1.0, -2.0, -3.0, -4.0]))
def test_constant_fold_sqrt(self):
class Module(torch.nn.Module):
def __init__(
self,
):
super().__init__()
self.register_buffer("weight", torch.ones(5))
def forward(self, x):
sqrt = torch.sqrt(self.weight)
return sqrt / x
x = torch.randn(2, 5)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
graph, _, __ = self._model_to_graph(
Module(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
)
for node in graph.nodes():
self.assertNotEqual(node.kind(), "onnx::Sqrt")
self.assertEqual(len(list(graph.nodes())), 1)
def test_constant_fold_shape(self):
class ShapeModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("weight", torch.ones(5))
def forward(self, x):
shape = self.weight.shape[0]
return x + shape
x = torch.randn(2, 5)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
graph, _, __ = self._model_to_graph(
ShapeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
)
for node in graph.nodes():
self.assertNotEqual(node.kind(), "onnx::Shape")
self.assertEqual(len(list(graph.nodes())), 2)
def test_constant_fold_upsample_scale_fold_as_constant(self):
# upsample scale is a constant, not a model parameter,
# therefore should not be added as initializer after constant folding.
model = torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
x = torch.randn(1, 32, 224, 224)
f = io.BytesIO()
torch.onnx.export(model, x, f)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertEqual(len(onnx_model.graph.initializer), 0)
def test_verbose(self):
class MyModule(torch.nn.Module):
def forward(self, input):
return torch.exp(input)
x = torch.randn(3, 4)
def is_model_stripped(f, verbose=None):
if verbose is None:
torch.onnx.export(MyModule(), x, f, opset_version=self.opset_version)
else:
torch.onnx.export(
MyModule(), x, f, verbose=verbose, opset_version=self.opset_version
)
model = onnx.load(io.BytesIO(f.getvalue()))
model_strip = copy.copy(model)
onnx.helper.strip_doc_string(model_strip)
return model == model_strip
# test verbose=False (default)
self.assertTrue(is_model_stripped(io.BytesIO()))
# test verbose=True
self.assertFalse(is_model_stripped(io.BytesIO(), True))
# NB: remove this test once DataParallel can be correctly handled
def test_error_on_data_parallel(self):
model = torch.nn.DataParallel(torch.nn.ReflectionPad2d((1, 2, 3, 4)))
x = torch.randn(1, 2, 3, 4)
f = io.BytesIO()
with self.assertRaisesRegex(
ValueError,
"torch.nn.DataParallel is not supported by ONNX "
"exporter, please use 'attribute' module to "
"unwrap model from torch.nn.DataParallel. Try ",
):
torch.onnx.export(model, x, f, opset_version=self.opset_version)
@skipIfUnsupportedMinOpsetVersion(11)
def test_sequence_dim(self):
class Module(torch.nn.Module):
def forward(self, x, y):
return [x, y]
model = Module()
# Export with scripting to keep output as Sequence type.
# Tracing unpacks the list.
script_model = torch.jit.script(model)
x = torch.randn(2, 3)
# Case 1: dynamic axis
f = io.BytesIO()
y = torch.randn(2, 3)
torch.onnx.export(
script_model,
(x, y),
f,
opset_version=self.opset_version,
input_names=["x", "y"],
dynamic_axes={"y": [1]},
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
loop_output_value_info_proto = onnx_model.graph.output[0]
ref_value_info_proto = onnx.helper.make_tensor_sequence_value_info(
loop_output_value_info_proto.name, 1, [2, None]
)
self.assertEqual(loop_output_value_info_proto, ref_value_info_proto)
# Case 2: no dynamic axes.
f = io.BytesIO()
y = torch.randn(2, 3)
torch.onnx.export(script_model, (x, y), f, opset_version=self.opset_version)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
loop_output_value_info_proto = onnx_model.graph.output[0]
ref_value_info_proto = onnx.helper.make_tensor_sequence_value_info(
loop_output_value_info_proto.name, 1, [2, 3]
)
self.assertEqual(loop_output_value_info_proto, ref_value_info_proto)
def test_export_mode(self):
class MyModule(torch.nn.Module):
def forward(self, x):
y = x + 1
return y
model = MyModule()
x = torch.randn(10, 3, 128, 128)
f = io.BytesIO()
# set mode to in inference mode and export in training mode
model.eval()
old_state = model.training
torch.onnx.export(
model,
(x,),
f,
opset_version=self.opset_version,
training=torch.onnx.TrainingMode.TRAINING,
)
# verify that the model state is preserved
self.assertEqual(model.training, old_state)
# set mode to training mode and export in inference mode
model.train()
old_state = model.training
torch.onnx.export(
model,
(x,),
f,
opset_version=self.opset_version,
training=torch.onnx.TrainingMode.EVAL,
)
# verify that the model state is preserved
self.assertEqual(model.training, old_state)
def test_export_does_not_fail_on_frozen_scripted_module(self):
class Inner(torch.nn.Module):
def forward(self, x):
if x > 0:
return x
else:
return x * x
class Outer(torch.nn.Module):
def __init__(self):
super().__init__()
self.inner = torch.jit.script(Inner())
def forward(self, x):
return self.inner(x)
x = torch.zeros(1)
# Freezing is only implemented in eval mode. So we need to call eval()
outer_module = Outer().eval()
module = torch.jit.trace_module(outer_module, {"forward": (x)})
# jit.freeze removes the training attribute in the module
module = torch.jit.freeze(module)
torch.onnx.export(module, (x,), io.BytesIO(), opset_version=self.opset_version)
@skipIfUnsupportedMinOpsetVersion(15)
def test_local_function(self):
class N(torch.nn.Module):
def __init__(self, prob):
super().__init__()
self.dropout = torch.nn.Dropout(prob)
def forward(self, x):
return self.dropout(x)
class M(torch.nn.Module):
def __init__(self, num_layers):
super().__init__()
self.num_layers = num_layers
self.lns = torch.nn.ModuleList(
[torch.nn.LayerNorm(3, eps=i) for i in range(num_layers)]
)
self.celu1 = torch.nn.CELU(1.0)
self.celu2 = torch.nn.CELU(2.0)
self.dropout = N(0.5)
def forward(self, x, y, z):
res1 = self.celu1(x)
res2 = self.celu2(y)
for ln in self.lns:
z = ln(z)
return res1 + res2, self.dropout(z)
x = torch.randn(2, 3)
y = torch.randn(2, 3)
z = torch.randn(2, 3)
# Export specified modules. Test against specifying modules that won't
# exist in the exported model.
# Model export in inference mode will remove dropout node,
# thus the dropout module no longer exist in graph.
f = io.BytesIO()
torch.onnx.export(
M(3),
(x, y, z),
f,
opset_version=self.opset_version,
export_modules_as_functions={
torch.nn.CELU,
torch.nn.Dropout,
torch.nn.LayerNorm,
},
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
# Check function definition
funcs = onnx_model.functions
celu_funcs = [f for f in funcs if f.name == "CELU"]
self.assertEqual(len(celu_funcs), 1)
self.assertEqual(celu_funcs[0].domain, "torch.nn.modules.activation")
self.assertEqual(len(celu_funcs[0].attribute), 3)
ln_funcs = [f for f in funcs if f.name == "LayerNorm"]
self.assertEqual(len(ln_funcs), 1)
self.assertEqual(ln_funcs[0].domain, "torch.nn.modules.normalization")
self.assertEqual(len(ln_funcs[0].attribute), 3)
# Check local function nodes
nodes = onnx_model.graph.node
celu_ns = [n for n in nodes if n.op_type == "CELU"]
ln_ns = [n for n in nodes if n.op_type == "LayerNorm"]
self.assertEqual(len(celu_ns), 2)
self.assertEqual(celu_ns[0].domain, "torch.nn.modules.activation")
self.assertEqual(len(celu_ns[0].attribute), 3)
self.assertEqual(len(ln_ns), 3)
self.assertEqual(ln_ns[0].domain, "torch.nn.modules.normalization")
self.assertEqual(len(ln_ns[0].attribute), 3)
# Export specified modules.
f = io.BytesIO()
torch.onnx.export(
M(3),
(x, y, z),
f,
opset_version=self.opset_version,
export_modules_as_functions={torch.nn.CELU},
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
funcs = onnx_model.functions
self.assertEqual(len(funcs), 1)
self.assertEqual(funcs[0].name, "CELU")
# Export with empty specified modules. Normal export.
f = io.BytesIO()
torch.onnx.export(
M(3),
(x, y, z),
f,
opset_version=self.opset_version,
export_modules_as_functions=set(),
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
funcs = onnx_model.functions
self.assertEqual(len(funcs), 0)
# Export all modules. Should contain {M, CELU, LayerNorm}.
f = io.BytesIO()
torch.onnx.export(
M(3),
(x, y, z),
f,
opset_version=self.opset_version,
export_modules_as_functions=True,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
funcs = onnx_model.functions
self.assertEqual(len(funcs), 3)
@skipIfUnsupportedMinOpsetVersion(15)
def test_local_function_overloads(self):
class NWithOverloads(torch.nn.Module):
def forward(self, x, y=None, z=None):
if y is None:
return x + 1
elif z is None:
return x + y
else:
return x + y, x + z
class M(torch.nn.Module):
def __init__(self, num_layers):
super().__init__()
self.n = NWithOverloads()
def forward(self, x, y, z):
return self.n(x), self.n(x, y), self.n(x, y, z)
x = torch.randn(2, 3)
y = torch.randn(2, 3)
z = torch.randn(2, 3)
f = io.BytesIO()
torch.onnx.export(
M(3),
(x, y, z),
f,
opset_version=self.opset_version,
export_modules_as_functions={NWithOverloads},
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
funcs = onnx_model.functions
self.assertEqual(len(funcs), 3)
func_names = [f.name for f in funcs]
self.assertIn("NWithOverloads", func_names)
self.assertIn("NWithOverloads.1", func_names)
self.assertIn("NWithOverloads.2", func_names)
# Failing after ONNX 1.13.0
@skipIfUnsupportedMaxOpsetVersion(1)
def test_local_function_infer_scopes(self):
class M(torch.nn.Module):
def forward(self, x):
# Concatenation of scalars inserts unscoped tensors in IR graph.
new_tensor_shape = x.size()[:-1] + (1, 1, -1)
tensor = x.view(*new_tensor_shape)
return tensor
x = torch.randn(4, 5)
f = io.BytesIO()
torch.onnx.export(
M(),
(x,),
f,
export_modules_as_functions=True,
opset_version=self.opset_version,
do_constant_folding=False,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
funcs = onnx_model.functions
self.assertIn("M", [f.name for f in funcs])
@skipIfUnsupportedMinOpsetVersion(15)
def test_local_function_predefined_attributes(self):
class M(torch.nn.Module):
num_layers: int
def __init__(self, num_layers):
super().__init__()
self.num_layers = num_layers
self.lns = torch.nn.ModuleList(
[torch.nn.LayerNorm(3, eps=1e-4) for _ in range(num_layers)]
)
def forward(self, x):
for ln in self.lns:
x = ln(x)
return x
x = torch.randn(2, 3)
f = io.BytesIO()
model = M(3)
torch.onnx.export(
model,
(x,),
f,
export_modules_as_functions=True,
opset_version=self.opset_version,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
funcs = onnx_model.functions
m_funcs = [fn for fn in funcs if fn.name == "M"]
self.assertEqual(m_funcs[0].attribute, ["num_layers"])
ln_funcs = [fn for fn in funcs if fn.name == "LayerNorm"]
self.assertEqual(ln_funcs[0].attribute, ["eps", "elementwise_affine"])
from onnx import helper
m_node = [n for n in onnx_model.graph.node if n.op_type == "M"]
self.assertEqual(
m_node[0].attribute[0],
helper.make_attribute("num_layers", model.num_layers),
)
ln_nodes = [n for n in m_funcs[0].node if n.op_type == "LayerNorm"]
expected_ln_attrs = [
helper.make_attribute(
"elementwise_affine", model.lns[0].elementwise_affine
),
helper.make_attribute("eps", model.lns[0].eps),
]
for ln_node in ln_nodes:
self.assertIn(ln_node.attribute[0], expected_ln_attrs)
self.assertIn(ln_node.attribute[1], expected_ln_attrs)
# This test cases checks the issue where an object does not have an attribute.
# When enabling `export_modules_as_functions = True`, the exporter could return an
# AttributeError. With this test case, we check that the export passes successfully
# without any AttributeError exceptions.
# See https://github.com/pytorch/pytorch/pull/109759 for an example. The exception that
# this test tries to avoid is `AttributeError: 'Embedding' object has no attribute 'freeze'`.
@skipIfUnsupportedMinOpsetVersion(15)
def test_local_function_subset_of_predefined_attributes(self):
class M(torch.nn.Module):
num_layers: int
def __init__(self, num_layers):
super().__init__()
self.embed_layer = torch.nn.Embedding.from_pretrained(
torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
)
self.num_layers = num_layers
self.lns = torch.nn.ModuleList(
[torch.nn.LayerNorm(3, eps=1e-4) for _ in range(num_layers)]
)
def forward(self, x):
e = self.embed_layer(torch.LongTensor([1]))
for ln in self.lns:
x = ln(x)
return x, e
x = torch.randn(2, 3)
f = io.BytesIO()
model = M(3)
torch.onnx.export(
model,
(x,),
f,
export_modules_as_functions=True,
opset_version=self.opset_version,
verbose=True, # Allows the test case to print `Skipping module attribute 'freeze'`
)
def test_node_scope(self):
class N(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(x)
class M(torch.nn.Module):
def __init__(self, num_layers):
super().__init__()
self.num_layers = num_layers
self.lns = torch.nn.ModuleList(
[torch.nn.LayerNorm(3, eps=float(i)) for i in range(num_layers)]
)
self.gelu1 = torch.nn.GELU()
self.gelu2 = torch.nn.GELU()
self.relu = N()
def forward(self, x, y, z):
res1 = self.gelu1(x)
res2 = self.gelu2(y)
for ln in self.lns:
z = ln(z)
return res1 + res2, self.relu(z)
x = torch.randn(2, 3)
y = torch.randn(2, 3)
z = torch.randn(2, 3)
model = M(3)
expected_scope_names = {
"M::/torch.nn.modules.activation.GELU::gelu1",
"M::/torch.nn.modules.activation.GELU::gelu2",
"M::/torch.nn.modules.normalization.LayerNorm::lns.0",
"M::/torch.nn.modules.normalization.LayerNorm::lns.1",
"M::/torch.nn.modules.normalization.LayerNorm::lns.2",
"M::/N::relu/torch.nn.modules.activation.ReLU::relu",
"M::",
}
graph, _, _ = self._model_to_graph(
model, (x, y, z), input_names=[], dynamic_axes={}
)
for node in graph.nodes():
self.assertIn(
_remove_test_environment_prefix_from_scope_name(node.scopeName()),
expected_scope_names,
)
graph, _, _ = self._model_to_graph(
torch.jit.script(model), (x, y, z), input_names=[], dynamic_axes={}
)
for node in graph.nodes():
self.assertIn(
_remove_test_environment_prefix_from_scope_name(node.scopeName()),
expected_scope_names,
)
def test_scope_of_constants_when_combined_by_cse_pass(self):
layer_num = 3
class M(torch.nn.Module):
def __init__(self, constant):
super().__init__()
self.constant = constant
def forward(self, x):
# 'self.constant' is designed to be the same for all layers,
# hence it is common sub expression.
return x + self.constant
class N(torch.nn.Module):
def __init__(self, layers: int = layer_num):
super().__init__()
self.layers = torch.nn.ModuleList(
[M(constant=torch.tensor(1.0)) for i in range(layers)]
)
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
graph, _, _ = self._model_to_graph(
N(), (torch.randn(2, 3)), input_names=[], dynamic_axes={}
)
# NOTE: Duplicated constants are populated due to implicit casting in scalar_type_analysis,
# so we expect 3 constants with different scopes. The 3 constants are for the 3 layers.
# If CSE in exporter is improved later, this test needs to be updated.
# It should expect 1 constant, with same scope as root.
expected_root_scope_name = "N::"
expected_layer_scope_name = "M::layers"
expected_constant_scope_name = [
f"{expected_root_scope_name}/{expected_layer_scope_name}.{i}"
for i in range(layer_num)
]
constant_scope_names = []
for node in graph.nodes():
if node.kind() == "onnx::Constant":
constant_scope_names.append(
_remove_test_environment_prefix_from_scope_name(node.scopeName())
)
self.assertEqual(constant_scope_names, expected_constant_scope_name)
def test_scope_of_nodes_when_combined_by_cse_pass(self):
layer_num = 3
class M(torch.nn.Module):
def __init__(self, constant, bias):
super().__init__()
self.constant = constant
self.bias = bias
def forward(self, x):
# 'constant' and 'x' is designed to be the same for all layers,
# hence `x + self.constant` is common sub expression.
# 'bias' is designed to be different for all layers,
# hence `* self.bias` is not common sub expression.
return (x + self.constant) * self.bias
class N(torch.nn.Module):
def __init__(self, layers: int = layer_num):
super().__init__()
self.layers = torch.nn.ModuleList(
[
M(constant=torch.tensor([1.0]), bias=torch.randn(1))
for i in range(layers)
]
)
def forward(self, x):
y = []
for layer in self.layers:
y.append(layer(x))
return y[0], y[1], y[2]
graph, _, _ = self._model_to_graph(
N(), (torch.randn(2, 3)), input_names=[], dynamic_axes={}
)
expected_root_scope_name = "N::"
expected_layer_scope_name = "M::layers"
expected_add_scope_names = [
f"{expected_root_scope_name}/{expected_layer_scope_name}.0"
]
expected_mul_scope_names = [
f"{expected_root_scope_name}/{expected_layer_scope_name}.{i}"
for i in range(layer_num)
]
add_scope_names = []
mul_scope_names = []
for node in graph.nodes():
if node.kind() == "onnx::Add":
add_scope_names.append(
_remove_test_environment_prefix_from_scope_name(node.scopeName())
)
elif node.kind() == "onnx::Mul":
mul_scope_names.append(
_remove_test_environment_prefix_from_scope_name(node.scopeName())
)
self.assertEqual(add_scope_names, expected_add_scope_names)
self.assertEqual(mul_scope_names, expected_mul_scope_names)
def test_aten_fallthrough(self):
# Test aten export of op with no symbolic
class Module(torch.nn.Module):
def forward(self, x):
return torch.erfc(x)
x = torch.randn(2, 3, 4)
GLOBALS.export_onnx_opset_version = self.opset_version
graph, _, __ = self._model_to_graph(
Module(),
(x,),
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
input_names=["x"],
dynamic_axes={"x": [0, 1, 2]},
)
iter = graph.nodes()
self.assertEqual(next(iter).kind(), "aten::erfc")
def test_custom_op_fallthrough(self):
# Test custom op
op_source = """
#include <torch/script.h>
torch::Tensor custom_add(torch::Tensor self, torch::Tensor other) {
return self + other;
}
static auto registry =
torch::RegisterOperators("custom_namespace::custom_op", &custom_add);
"""
torch.utils.cpp_extension.load_inline(
name="custom_add",
cpp_sources=op_source,
is_python_module=False,
verbose=True,
)
class FooModel(torch.nn.Module):
def forward(self, input, other):
# Calling custom op
return torch.ops.custom_namespace.custom_op(input, other)
x = torch.randn(2, 3, 4, requires_grad=False)
y = torch.randn(2, 3, 4, requires_grad=False)
model = FooModel()
graph, _, __ = self._model_to_graph(
model,
(x, y),
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH,
input_names=["x", "y"],
dynamic_axes={"x": [0, 1, 2], "y": [0, 1, 2]},
)
iter = graph.nodes()
self.assertEqual(next(iter).kind(), "custom_namespace::custom_op")
def test_custom_opsets_gelu(self):
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::gelu", 9)
def gelu(g, self, approximate):
return g.op("com.microsoft::Gelu", self).setType(self.type())
torch.onnx.register_custom_op_symbolic("::gelu", gelu, 9)
model = torch.nn.GELU(approximate="none")
x = torch.randn(3, 3)
f = io.BytesIO()
torch.onnx.export(
model,
(x,),
f,
opset_version=self.opset_version,
custom_opsets={"com.microsoft": 1},
)
graph = onnx.load(io.BytesIO(f.getvalue()))
self.assertEqual(graph.graph.node[0].op_type, "Gelu")
self.assertEqual(graph.opset_import[0].version, self.opset_version)
self.assertEqual(graph.opset_import[1].domain, "com.microsoft")
self.assertEqual(graph.opset_import[1].version, 1)
def test_register_aten_custom_op_symbolic(self):
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "aten::gelu", 9)
def gelu(g, self, approximate):
return g.op("com.microsoft::Gelu", self).setType(self.type())
torch.onnx.register_custom_op_symbolic("aten::gelu", gelu, 9)
model = torch.nn.GELU(approximate="none")
x = torch.randn(3, 3)
f = io.BytesIO()
torch.onnx.export(model, (x,), f, opset_version=self.opset_version)
graph = onnx.load(io.BytesIO(f.getvalue()))
self.assertEqual(graph.graph.node[0].op_type, "Gelu")
self.assertEqual(graph.opset_import[1].domain, "com.microsoft")
@skipIfNoLapack
def test_custom_opsets_inverse(self):
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9)
class CustomInverse(torch.nn.Module):
def forward(self, x):
return torch.inverse(x) + x
def linalg_inv(g, self):
return g.op("com.microsoft::Inverse", self).setType(self.type())
torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv, 9)
model = CustomInverse()
x = torch.randn(2, 3, 3)
f = io.BytesIO()
torch.onnx.export(
model,
(x,),
f,
opset_version=self.opset_version,
custom_opsets={"com.microsoft": 1},
)
graph = onnx.load(io.BytesIO(f.getvalue()))
self.assertEqual(graph.graph.node[0].op_type, "Inverse")
self.assertEqual(graph.opset_import[0].version, self.opset_version)
self.assertEqual(graph.opset_import[1].domain, "com.microsoft")
self.assertEqual(graph.opset_import[1].version, 1)
def test_onnx_fallthrough(self):
# Test aten export of op with symbolic for aten
class Module(torch.nn.Module):
def forward(self, x):
return torch.digamma(x)
x = torch.randn(100, 128)
graph, _, __ = self._model_to_graph(
Module(),
(x,),
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
input_names=["x"],
dynamic_axes={"x": [0, 1]},
)
iter = graph.nodes()
self.assertEqual(next(iter).kind(), "aten::digamma")
# prim::ListConstruct is exported as onnx::SequenceConstruct for opset >= 11
@skipIfUnsupportedMaxOpsetVersion(10)
def test_prim_fallthrough(self):
# Test prim op
class PrimModule(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
if isinstance(x, list):
y = x
else:
y = [x]
return y
x = torch.tensor([2])
model = PrimModule()
model.eval()
graph, _, __ = self._model_to_graph(
model,
(x,),
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
input_names=["x"],
dynamic_axes={"x": [0]},
)
iter = graph.nodes()
self.assertEqual(next(iter).kind(), "prim::ListConstruct")
def test_custom_layer_tuple(self):
class CustomFunction(torch.autograd.Function):
@staticmethod
def symbolic(g, input):
return g.op("CustomNamespace::Custom", input, outputs=2)
@staticmethod
def forward(ctx, input):
return input, input
class Custom(torch.nn.Module):
def forward(self, input):
return CustomFunction.apply(input)
model = Custom()
batch = torch.FloatTensor(1, 3)
graph, _, _ = self._model_to_graph(
model, batch, input_names=["batch"], dynamic_axes={"batch": [0, 1]}
)
iter = graph.nodes()
self.assertEqual(next(iter).kind(), "CustomNamespace::Custom")
def test_autograd_onnx_fallthrough(self):
class CustomFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def backward(ctx, grad_output):
(input,) = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input < 0] = 0
return grad_input
class Custom(torch.nn.Module):
def forward(self, input):
return CustomFunction.apply(input)
model = Custom()
batch = torch.FloatTensor(1, 3)
graph, _, _ = self._model_to_graph(
model,
batch,
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
input_names=["batch"],
dynamic_axes={"batch": [0, 1]},
)
iter = graph.nodes()
self.assertEqual(next(iter).kind(), "prim::PythonOp")
def test_autograd_module_name(self):
class CustomFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def backward(ctx, grad_output):
(input,) = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input < 0] = 0
return grad_input
class Custom(torch.nn.Module):
def forward(self, input):
return CustomFunction.apply(input) + CustomFunction2.apply(input)
model = Custom()
batch = torch.FloatTensor(1, 3)
graph, _, _ = self._model_to_graph(
model,
batch,
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
input_names=["batch"],
dynamic_axes={"batch": [0, 1]},
)
iter = graph.nodes()
autograd1 = next(iter)
autograd2 = next(iter)
self.assertEqual(autograd1.kind(), "prim::PythonOp")
self.assertEqual(autograd2.kind(), "prim::PythonOp")
self.assertNotEqual(autograd1.s("module"), autograd2.s("module"))
def test_unused_initializers(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv2 = torch.nn.ConvTranspose2d(
16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(1, 1)
)
self.k_proj = torch.nn.Linear(5, 5, bias=True)
def forward(self, x):
x = self.conv2(x)
return x
x = torch.randn(20, 16, 50, 100)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
_, params_dict, __ = self._model_to_graph(
Model(),
(x,),
do_constant_folding=False,
operator_export_type=OperatorExportTypes.ONNX,
input_names=["x"],
dynamic_axes={"x": [0, 1, 2, 3]},
)
self.assertEqual(len(params_dict), 2)
def test_scripting_param(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(
3, 16, kernel_size=1, stride=2, padding=3, bias=True
)
self.bn = torch.nn.BatchNorm2d(16, affine=True)
def forward(self, x):
x = self.conv(x)
bn = self.bn(x)
return bn
model = torch.jit.script(MyModule())
x = torch.randn(10, 3, 128, 128)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
graph, _, __ = self._model_to_graph(
model,
(x,),
do_constant_folding=True,
operator_export_type=OperatorExportTypes.ONNX,
training=torch.onnx.TrainingMode.TRAINING,
input_names=["x"],
dynamic_axes={"x": [0, 1, 2, 3]},
)
graph_input_params = [param.debugName() for param in graph.inputs()]
for item in dict(model.named_parameters()):
self.assertIn(
item,
graph_input_params,
"Graph parameter names does not match model parameters.",
)
@skipIfNoCaffe2
def test_modifying_params(self):
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.tensor([2.0]))
def forward(self, x):
y = x * x
self.param.data.add_(1.0)
return y
x = torch.tensor([1, 2])
# Move import to local as caffe2 backend requires additional build flag,
# and is only used in this test case.
import caffe2.python.onnx.backend as backend
verify(MyModel(), x, backend, do_constant_folding=False)
def test_fuse_conv_bn(self):
class Fuse(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(
3, 2, kernel_size=1, stride=2, padding=3, bias=True
)
self.bn = torch.nn.BatchNorm2d(2)
def forward(self, x):
out = self.conv(x)
return self.bn(out)
x = torch.randn(2, 3, 2, 2, requires_grad=True)
graph, _, __ = self._model_to_graph(
Fuse(),
(x,),
training=TrainingMode.EVAL,
input_names=["x"],
dynamic_axes={"x": [0, 1, 2, 3]},
)
for node in graph.nodes():
self.assertNotEqual(node.kind(), "onnx::BatchNormalization")
self.assertEqual(node.kind(), "onnx::Conv")
self.assertEqual(len(list(graph.nodes())), 1)
def test_fuse_resnet18(self):
model = torchvision.models.resnet18(weights=None)
x = torch.randn(2, 3, 224, 224, requires_grad=True)
graph, _, __ = self._model_to_graph(
model,
(x,),
training=TrainingMode.EVAL,
input_names=["x"],
dynamic_axes={"x": [0, 1, 2, 3]},
)
for node in graph.nodes():
self.assertNotEqual(node.kind(), "onnx::BatchNormalization")
def test_onnx_function_substitution_pass(self):
@torch.jit.script
def f(x: torch.Tensor, y: torch.Tensor):
z = x - y
return x + z
class MyModule(torch.nn.Module):
def forward(self, x, y):
return f(x, y)
input_1 = torch.tensor([11])
input_2 = torch.tensor([12])
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
graph, _, __ = self._model_to_graph(
MyModule(),
(input_1, input_2),
do_constant_folding=True,
operator_export_type=OperatorExportTypes.ONNX,
input_names=["input_1", "input_2"],
dynamic_axes={"input_1": [0], "input_2": [0]},
)
# Check that the prim::Constant node in the graph for representing the
# scripted function `f` is removed and the following prim::CallFunction
# is replced by inline graph, with onnx::Sub and onnx::Add nodes.
for node in graph.nodes():
self.assertNotEqual(node.kind(), "prim::Constant")
self.assertEqual(
len(list(graph.nodes())), 2
) # onnx::Sub and onnx::Add nodes only.
def test_onnx_value_name(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.in_weight = torch.nn.Parameter(torch.Tensor(3, 3))
self.in_bias = torch.nn.Parameter(torch.Tensor(3))
def forward(self, x):
start = 0
end = None
weight = self.in_weight
bias = self.in_bias
weight = weight[start:end, :]
if bias is not None:
bias = bias[start:end]
return torch.nn.functional.linear(x, weight, bias)
model = MyModule()
x = torch.randn(3, 3)
f = io.BytesIO()
model.eval()
torch.onnx.export(
model,
(x,),
f,
opset_version=self.opset_version,
keep_initializers_as_inputs=True,
)
graph = onnx.load(io.BytesIO(f.getvalue()))
self.assertEqual(graph.graph.input[1].name, "in_weight")
self.assertEqual(graph.graph.input[2].name, "in_bias")
def test_onnx_node_naming(self):
class MainModule(torch.nn.Module):
def __init__(self):
super().__init__()
self._module_1 = torch.nn.Linear(10, 10)
self._module_2 = torch.nn.Linear(10, 10)
self._module_3 = torch.nn.Linear(10, 10)
self._module_4 = torch.nn.Linear(10, 10)
def forward(self, x):
y = self._module_1(x)
z = self._module_2(y)
z = self._module_3(y * z)
z = self._module_4(y * z)
return z
module = MainModule()
ref_node_names = [
"/_module_1/Gemm",
"/_module_2/Gemm",
"/_module_3/Gemm",
"/_module_4/Gemm",
"/Mul",
"/Mul_1",
]
f = io.BytesIO()
torch.onnx.export(module, torch.ones(1, 10), f, output_names=["y"])
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
for n in onnx_model.graph.node:
self.assertIn(n.name, ref_node_names)
torch.onnx.export(
torch.jit.script(module), torch.ones(1, 10), f, output_names=["y"]
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
for n in onnx_model.graph.node:
self.assertIn(n.name, ref_node_names)
def _test_deduplicate_initializers(self, torchscript=False):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(3, 3)
self.layer2 = torch.nn.Linear(3, 3)
# Reusing layers.
self.layer3 = self.layer1
# Reusing parameters.
self.layer2.weight = self.layer1.weight
self.layer1.bias = self.layer2.bias
# Parameter with different tensors equal in value.
self.param1 = torch.nn.Parameter(torch.tensor([1.0, 2.0, 3.0]))
self.param2 = torch.nn.Parameter(torch.tensor([1.0, 2.0, 3.0]))
def forward(self, x):
return (
self.layer3(self.layer2(self.layer1(x))) + self.param1 + self.param2
)
model = torch.jit.script(MyModule()) if torchscript else MyModule()
x = torch.randn(3, 3)
param_name_set = {k for k, _ in model.named_parameters()}
# Test training mode.
model.train()
f = io.BytesIO()
torch.onnx.export(
model,
(x,),
f,
training=TrainingMode.TRAINING,
opset_version=self.opset_version,
)
graph = onnx.load(io.BytesIO(f.getvalue()))
self.assertSetEqual({i.name for i in graph.graph.initializer}, param_name_set)
model.train()
f = io.BytesIO()
torch.onnx.export(
model,
(x,),
f,
training=TrainingMode.PRESERVE,
opset_version=self.opset_version,
)
graph = onnx.load(io.BytesIO(f.getvalue()))
self.assertSetEqual({i.name for i in graph.graph.initializer}, param_name_set)
# Test eval mode.
model.eval()
f = io.BytesIO()
torch.onnx.export(model, (x,), f, opset_version=self.opset_version)
graph = onnx.load(io.BytesIO(f.getvalue()))
param_name_set.remove("param2")
self.assertSetEqual({i.name for i in graph.graph.initializer}, param_name_set)
def test_deduplicate_initializers(self):
self._test_deduplicate_initializers(torchscript=False)
def test_deduplicate_initializers_torchscript(self):
self._test_deduplicate_initializers(torchscript=True)
@skipIfNoCuda
def test_deduplicate_initializers_diff_devices(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.w_cpu = torch.nn.Parameter(
torch.ones(3, device=torch.device("cpu"))
)
self.w_cuda = torch.nn.Parameter(
torch.ones(3, device=torch.device("cuda"))
)
def forward(self, x, y):
return x + self.w_cpu, y + self.w_cuda
x = torch.randn(3, 3, device=torch.device("cpu"))
y = torch.randn(3, 3, device=torch.device("cuda"))
f = io.BytesIO()
torch.onnx.export(Model(), (x, y), f, opset_version=self.opset_version)
graph = onnx.load(io.BytesIO(f.getvalue()))
self.assertSetEqual({i.name for i in graph.graph.initializer}, {"w_cpu"})
def test_duplicated_output_node(self):
class DuplicatedOutputNet(torch.nn.Module):
def __init__(self, input_size, num_classes):
super().__init__()
self.fc1 = torch.nn.Linear(input_size, num_classes)
def forward(self, input0, input1):
out1 = self.fc1(input0)
out2 = self.fc1(input1)
return out1, out1, out2, out1, out2
N, D_in, H, D_out = 64, 784, 500, 10
pt_model = DuplicatedOutputNet(D_in, D_out)
f = io.BytesIO()
x = torch.randn(N, D_in)
dynamic_axes = {
"input0": {0: "input0_dim0", 1: "input0_dim1"},
"input1": {0: "input1_dim0", 1: "input1_dim1"},
"output-0": {0: "output-0_dim0", 1: "output-0_dim1"},
"output-1": {0: "output-1_dim0", 1: "output-1_dim1"},
"output-2": {0: "output-2_dim0", 1: "output-2_dim1"},
"output-3": {0: "output-3_dim0", 1: "output-3_dim1"},
"output-4": {0: "output-4_dim0", 1: "output-4_dim1"},
}
torch.onnx.export(
pt_model,
(x, x),
f,
input_names=["input0", "input1"],
output_names=["output-0", "output-1", "output-2", "output-3", "output-4"],
do_constant_folding=False,
training=torch.onnx.TrainingMode.TRAINING,
dynamic_axes=dynamic_axes,
verbose=True,
keep_initializers_as_inputs=True,
)
graph = onnx.load(io.BytesIO(f.getvalue()))
self.assertEqual(graph.graph.input[0].name, "input0")
self.assertEqual(graph.graph.input[1].name, "input1")
for i in range(5):
self.assertEqual(graph.graph.output[i].name, f"output-{i}")
self.assertEqual(graph.graph.node[0].op_type, "Gemm")
self.assertEqual(graph.graph.node[1].op_type, "Identity")
self.assertEqual(graph.graph.node[2].op_type, "Identity")
self.assertEqual(graph.graph.node[3].op_type, "Gemm")
self.assertEqual(graph.graph.node[4].op_type, "Identity")
def test_deduplicate_ignore_upsample_scale(self):
# upsample scale is a constant, not a model parameter,
# therefore should be ignored by shared weight deduplication.
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.upsample_1 = torch.nn.Upsample(scale_factor=2)
self.upsample_2 = torch.nn.Upsample(scale_factor=2)
def forward(self, x):
return self.upsample_1(x), self.upsample_2(x)
f = io.BytesIO()
x = torch.randn(1, 32, 224, 224)
torch.onnx.export(Model(), x, f)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
# aten::upsample converts to onnx::resize
resize_nodes = [n for n in onnx_model.graph.node if n.op_type == "Resize"]
self.assertEqual(len(resize_nodes), 2)
for resize_node in resize_nodes:
scale_node = [
n for n in onnx_model.graph.node if n.output[0] == resize_node.input[2]
]
self.assertEqual(len(scale_node), 1)
self.assertEqual(scale_node[0].op_type, "Constant")
def test_bad_symbolic_registration(self):
_onnx_opset_version = 9
@parse_args("v")
def cat(g, tensor_list, dim):
tensors = _unpack_list(tensor_list)
return g.op("Concat", *tensors, axis_i=dim)
torch.onnx.register_custom_op_symbolic("::cat", cat, _onnx_opset_version)
class CatModel(torch.nn.Module):
def forward(self, x):
return torch.cat((x, x, x), 0)
model = CatModel()
x = torch.randn(2, 3)
f = io.BytesIO()
self.assertExpectedRaisesInline(
AssertionError,
lambda: torch.onnx.export(
model, (x,), f, opset_version=_onnx_opset_version
),
(
"A mismatch between the number of arguments (2) and their descriptors (1) was found at symbolic function "
"'cat'. If you believe this is not due to custom symbolic implementation within your code or an external "
"library, please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml to "
"report this bug."
),
)
torch.onnx.unregister_custom_op_symbolic("::cat", _onnx_opset_version)
if __name__ == "__main__":
common_utils.run_tests()