blob: dbf10cddc059b18ad14e248128e3690d1983fab5 [file] [log] [blame]
# Owner(s): ["oncall: jit"]
import io
import os
import shutil
import sys
import tempfile
import torch
import torch.nn as nn
from torch.onnx import OperatorExportTypes
from torch.autograd import Variable
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.common_utils import skipIfNoLapack, skipIfCaffe2, skipIfNoCaffe2
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
# Smoke tests for export methods
class TestExportModes(JitTestCase):
class MyModel(nn.Module):
def __init__(self):
super(TestExportModes.MyModel, self).__init__()
def forward(self, x):
return x.transpose(0, 1)
def test_protobuf(self):
torch_model = TestExportModes.MyModel()
fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
f = io.BytesIO()
torch.onnx._export(torch_model, (fake_input), f, verbose=False,
export_type=torch.onnx.ExportTypes.PROTOBUF_FILE)
def test_zipfile(self):
torch_model = TestExportModes.MyModel()
fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
f = io.BytesIO()
torch.onnx._export(torch_model, (fake_input), f, verbose=False,
export_type=torch.onnx.ExportTypes.ZIP_ARCHIVE)
def test_compressed_zipfile(self):
torch_model = TestExportModes.MyModel()
fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
f = io.BytesIO()
torch.onnx._export(torch_model, (fake_input), f, verbose=False,
export_type=torch.onnx.ExportTypes.COMPRESSED_ZIP_ARCHIVE)
def test_directory(self):
torch_model = TestExportModes.MyModel()
fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
d = tempfile.mkdtemp()
torch.onnx._export(torch_model, (fake_input), d, verbose=False,
export_type=torch.onnx.ExportTypes.DIRECTORY)
shutil.rmtree(d)
def test_onnx_multiple_return(self):
@torch.jit.script
def foo(a):
return (a, a)
f = io.BytesIO()
x = torch.ones(3)
torch.onnx._export(foo, (x,), f)
@skipIfNoCaffe2
@skipIfNoLapack
def test_caffe2_aten_fallback(self):
class ModelWithAtenNotONNXOp(nn.Module):
def forward(self, x, y):
abcd = x + y
defg = torch.linalg.qr(abcd)
return defg
x = torch.rand(3, 4)
y = torch.rand(3, 4)
torch.onnx.export_to_pretty_string(
ModelWithAtenNotONNXOp(), (x, y),
add_node_names=False,
do_constant_folding=False,
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK)
@skipIfCaffe2
@skipIfNoLapack
def test_aten_fallback(self):
class ModelWithAtenNotONNXOp(nn.Module):
def forward(self, x, y):
abcd = x + y
defg = torch.linalg.qr(abcd)
return defg
x = torch.rand(3, 4)
y = torch.rand(3, 4)
torch.onnx.export_to_pretty_string(
ModelWithAtenNotONNXOp(), (x, y),
add_node_names=False,
do_constant_folding=False,
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
# support for linalg.qr was added in later op set versions.
opset_version=9)
# torch.fmod is using to test ONNX_ATEN.
# If you plan to remove fmod from aten, or found this test failed.
# please contact @Rui.
def test_onnx_aten(self):
class ModelWithAtenFmod(nn.Module):
def forward(self, x, y):
return torch.fmod(x, y)
x = torch.randn(3, 4, dtype=torch.float32)
y = torch.randn(3, 4, dtype=torch.float32)
torch.onnx.export_to_pretty_string(
ModelWithAtenFmod(), (x, y),
add_node_names=False,
do_constant_folding=False,
operator_export_type=OperatorExportTypes.ONNX_ATEN)