blob: db073d3274726a1a4f2bb9516f0b3978d5f9bddd [file] [log] [blame]
# Owner(s): ["oncall: jit"]
import os
import sys
import types
import typing
import typing_extensions
from typing import List, Dict, Optional, Tuple
import torch
import torch.nn as nn
from torch import Tensor
from torch.testing import FileCheck
from collections import OrderedDict
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, _tmp_donotuse_dont_inline_everything
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class TestRecursiveScript(JitTestCase):
def test_inferred_nonetype(self):
class M(nn.Module):
def __init__(self):
super(M, self).__init__()
self.x = None
def forward(self):
assert self.x is None
m = torch.jit.script(M())
self.checkModule(M(), ())
def test_script_function_attribute(self):
@torch.jit.script
def fn1(x):
return x + x
@torch.jit.script
def fn2(x):
return x - x
class M(torch.nn.Module):
def __init__(self, fn):
super(M, self).__init__()
self.fn = fn
def forward(self, x):
return self.fn(x)
fn1_mod = M(fn1)
fn2_mod = M(fn2)
self.checkModule(fn1_mod, (torch.randn(2, 2),))
self.checkModule(fn2_mod, (torch.randn(2, 2),))
def test_python_function_attribute(self):
class M(torch.nn.Module):
def __init__(self, fn):
super(M, self).__init__()
self.fn = fn
def forward(self, x):
return self.fn(x)
mod = M(torch.sigmoid)
self.checkModule(mod, (torch.randn(2, 2),))
def test_failed_function_compilation(self):
def fn(x):
return i_dont_exist
class M(torch.nn.Module):
def __init__(self, fn):
super(M, self).__init__()
self.fn = fn
def forward(self, x):
return self.fn(x)
m = M(fn)
with self.assertRaisesRegexWithHighlight(RuntimeError, "failed to compile", "i_dont_exist"):
torch.jit.script(m)
def test_init_error(self):
class M(nn.Module):
def __init__(self):
self.x = 2
def forward(self):
pass
with self.assertRaisesRegex(RuntimeError, "has not been initialized"):
torch.jit.script(M())
def test_script_after_eval(self):
class M(nn.Module):
def forward(self):
if self.training:
return 2
else:
return 0
m = M()
sm1 = torch.jit.script(m)
m.eval()
sm2 = torch.jit.script(m)
# m is in eval mode, training should be False
self.assertFalse(m.training)
# sm1 was created while m had training = True
self.assertTrue(sm1.training)
self.assertEqual(sm1.training, sm1._c.getattr('training'))
self.assertEqual(sm1(), 2)
# sm2 was created after m was eval'ed
self.assertFalse(sm2.training)
self.assertEqual(sm2.training, sm2._c.getattr('training'))
self.assertEqual(sm2(), 0)
def test_module_name(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.x = 2
def forward(self, t):
return t + self.x
m = torch.jit.script(MyModule())
FileCheck().check("MyModule").run(m.graph)
def test_repeated_error_stack(self):
def d(x):
return "a" - 2
def c(x):
return d(x)
def b(x):
return c(x)
def a(x):
return b(x)
try:
torch.jit.script(a)
except Exception as e:
FileCheck().check_count("is being compiled", 2).run(str(e))
try:
torch.jit.script(a)
except Exception as e:
# Make sure that no entries are left over from the previous failure
FileCheck().check_count("is being compiled", 2).run(str(e))
def test_constants_with_final(self):
class M1(torch.nn.Module):
x : torch.jit.Final[int]
def __init__(self):
super().__init__()
self.x = 2
def forward(self, t):
return t + self.x
self.checkModule(M1(), (torch.randn(2, 2),))
class M2(torch.nn.Module):
x : typing_extensions.Final[int]
def __init__(self):
super().__init__()
self.x = 2
def forward(self, t):
return t + self.x
self.checkModule(M2(), (torch.randn(2, 2),))
if sys.version_info[:2] >= (3, 8):
class M3(torch.nn.Module):
x : typing.Final[int]
def __init__(self):
super().__init__()
self.x = 2
def forward(self, t):
return t + self.x
self.checkModule(M3(), (torch.randn(2, 2),))
def test_ignore_class(self):
@torch.jit.ignore
class MyScriptClass(object):
def unscriptable(self):
return "a" + 200
class TestModule(torch.nn.Module):
def __init__(self):
super(TestModule, self).__init__()
def forward(self, x):
return MyScriptClass()
with self.assertRaisesRegexWithHighlight(torch.jit.frontend.FrontendError, "Cannot instantiate class", "MyScriptClass"):
t = torch.jit.script(TestModule())
def test_method_call(self):
class M(nn.Module):
def test(self, x):
return x
def forward(self, z):
y = self.test(z)
return z + 20 + y
self.checkModule(M(), (torch.randn(2, 2),))
def test_module_repr(self):
class Submodule(nn.Module):
def forward(self, x):
return x
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.conv = nn.Conv2d(10, 10, 3)
self.lin = nn.Linear(10, 10)
self.sub = Submodule()
def forward(self, x):
return self.lin(x) + self.sub(x) + self.conv(x)
m = torch.jit.script(MyModule())
with self.capture_stdout() as out:
print(m)
f = FileCheck()
f.check('MyModule')
f.check('Conv2d')
f.check('Linear')
f.check('Submodule')
f.run(out[0])
self.assertEqual(m.original_name, 'MyModule')
def test_dir(self):
def test_module_dir(mod):
dir_set = dir(mod)
scripted_mod = torch.jit.script(mod)
dir_scripted = set(dir(scripted_mod))
# set not currently copied over
ignore_set = ["training", "__delitem__", "__setitem__", "clear", "items",
"keys", "pop", "update", "values"]
for attr in dir_set:
if attr in ignore_set:
continue
self.assertTrue(attr in dir_scripted, attr)
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.conv = nn.Conv2d(10, 10, 3)
self.lin = nn.Linear(10, 10)
def forward(self, x):
return self.lin(x) + self.conv(x)
test_module_dir(MyModule())
# test custom __dir__ for containers
conv = nn.Conv2d(10, 10, 3)
linear = nn.Linear(10, 10)
test_module_dir(nn.Sequential(conv, linear))
test_module_dir(nn.ModuleDict(OrderedDict([("conv", conv), ("linear", linear)])))
def test_class_compile(self):
def other_fn(a: int, b: Tensor) -> Tensor:
return a * b
class B(object):
def __init__(self, x):
self.x = 2
def helper(self, a):
return self.x + a + other_fn(self.x, a)
class N(torch.nn.Module):
def __init__(self):
super(N, self).__init__()
def forward(self, x):
b = B(x)
return b.helper(x)
self.checkModule(N(), (torch.randn(2, 2),))
def test_error_stack(self):
def d(x: int) -> int:
return x + 10
def c(x):
return d("hello") + d(x)
def b(x):
return c(x)
def a(x):
return b(x)
try:
scripted = torch.jit.script(a)
except RuntimeError as e:
checker = FileCheck()
checker.check("Expected a value of type 'int'")
checker.check("def c(x)")
checker.check("def b(x)")
checker.check("def a(x)")
checker.run(str(e))
def test_error_stack_module(self):
def d(x: int) -> int:
return x + 10
def c(x):
return d("hello") + d(x)
def b(x):
return c(x)
class Submodule(torch.nn.Module):
def __init__(self):
super(Submodule, self).__init__()
def forward(self, x):
return b(x)
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.submodule = Submodule()
def some_method(self, y):
return y + self.submodule(y)
def forward(self, x):
return self.some_method(x)
try:
scripted = torch.jit.script(M())
except RuntimeError as e:
checker = FileCheck()
checker.check("Expected a value of type 'int'")
checker.check("'c' is being compiled since it was called from 'b'")
checker.check("'b' is being compiled since it was called from")
checker.run(str(e))
@_tmp_donotuse_dont_inline_everything
def test_script_basic(self):
def a_python_fn(a, b, c):
return a + b + c
@torch.jit.script
def a_script_fn(d, e, f):
return a_python_fn(d, e, f)
graph = str(a_script_fn.graph)
FileCheck().check("prim::CallFunction").run(graph)
FileCheck().check_not("^a_python_fn").run(graph)
t = torch.ones(2, 2)
self.assertEqual(a_script_fn(t, t, t), t + t + t)
def test_error_stack_class(self):
class X(object):
def bad_fn(self):
import pdb # noqa: F401
def fn(x) -> X:
return X(10)
try:
torch.jit.script(fn)
except Exception as e:
checker = FileCheck()
checker.check("import statements")
checker.check("is being compiled since it was called from")
checker.run(str(e))
def test_error_stack_annotation(self):
class X(object):
def bad_fn(self):
import pdb # noqa: F401
def fn(x) -> X:
return X(10)
try:
torch.jit.script(fn)
except Exception as e:
checker = FileCheck()
checker.check("import statements")
checker.check("is being compiled since it was called from")
checker.check("-> X")
checker.run(str(e))
def test_module_basic(self):
class Other(torch.nn.Module):
__constants__ = ['x']
def __init__(self, x):
super(Other, self).__init__()
self.x = x
self.param = torch.nn.Parameter(torch.ones(2, 2))
def some_unscriptable_method(self):
a = 2
a = [2]
return a
def forward(self, t):
return t + self.x + self.param
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.other = Other(200)
def forward(self, t):
return self.other(t) * 2
self.checkModule(M(), (torch.ones(2, 2),))
def test_module_function_export(self):
class Other(torch.nn.Module):
__constants__ = ['x']
def __init__(self, x):
super(Other, self).__init__()
self.x = x
self.param = torch.nn.Parameter(torch.ones(2, 2))
@torch.jit.export
def some_entry_point(self, y):
return y + 20
def forward(self, t):
return t + self.x + self.param
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.other = Other(200)
def forward(self, t):
return self.other(t) * 2
self.checkModule(M(), (torch.ones(2, 2),))
def test_iterable_modules(self):
class Inner(torch.nn.Module):
def forward(self, x):
return x + 10
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.sequential = nn.Sequential(
Inner(),
Inner(),
nn.Sequential(Inner(), Inner())
)
self.module_list = nn.ModuleList([Inner(), Inner()])
def forward(self, x):
for mod in self.module_list:
x += mod(x)
x += self.sequential(x)
return x
self.checkModule(M(), (torch.randn(5, 5),))
def test_prepare_scriptable_basic(self):
class SeluButReluWhenScripted(torch.nn.SELU):
def __prepare_scriptable__(self):
return nn.ReLU()
t = torch.randn(5, 5)
m = SeluButReluWhenScripted()
sm = torch.jit.script(m)
eager_out = m(t)
script_out = sm(t)
self.assertNotEqual(eager_out, script_out)
def test_prepare_scriptable_iterable_modules(self):
class SeluButReluWhenScripted(torch.nn.SELU):
def __prepare_scriptable__(self):
return nn.ReLU()
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
shared = SeluButReluWhenScripted()
self.sequential = nn.Sequential(
SeluButReluWhenScripted(),
SeluButReluWhenScripted(),
nn.Sequential(SeluButReluWhenScripted(), shared, SeluButReluWhenScripted()),
shared,
)
self.module_list = nn.ModuleList([SeluButReluWhenScripted(),
shared,
SeluButReluWhenScripted()])
def forward(self, x):
for mod in self.module_list:
x += mod(x)
x += self.sequential(x)
return x
t = torch.randn(5, 5)
m = M()
eager_out = m(t.clone())
sm = torch.jit.script(m)
script_out = sm(t.clone())
self.assertNotEqual(eager_out, script_out)
def test_prepare_scriptable_cycle(self):
t = torch.randn(5, 5)
c = torch.nn.Module()
p = torch.nn.Module()
c.__dict__["_p"] = p
p.__dict__["_c"] = c
sm = torch.jit.script(p)
def test_attributes(self):
@torch.jit.script
class Inner2(object):
def __init__(self):
self.b = "a string"
@torch.jit.script
class Foo(object):
def __init__(self):
self.a = 4
self.inner = Inner2()
@torch.jit.script
class SFoo(object):
def __init__(self):
self.a = 4
self.inner = Inner2()
def __setstate__(self, obj: Tuple[int, Inner2]) -> None:
a, inner = obj
self.a = a
self.inner = inner
def __getstate__(self):
return (self.a, self.inner)
untyped_values = (
('my_dict', {"I": "am", "a test": "test"}),
('my_float', 2.3),
('my_int', 99),
('my_bool', False),
('my_tuple', (1, 2, 3, 4)),
('my_list', [(1, 2), (3, 4)]),
# ('my_tensor', torch.randn(2, 2)),
('my_int_list', [1, 2, 3, 4]),
# ('my_tensor_list', [torch.ones(2, 2) + i for i in range(4)]),
('my_bool_list', [True, True, False, True]),
('my_float_list', [1., 2., 3., 4.]),
('my_str_list', ['hello', 'bye']),
)
typed_values = (
('my_empty_list', []),
('my_empty_dict', {}),
('my_none', None),
('my_object', Foo()),
('my_object2', SFoo()),
)
class M(torch.nn.Module):
# TODO: re-enable this once this test is in a Python 3-only syntax
# file
# my_empty_list : List[int]
# my_empty_dict : Dict[str, int]
# my_none : Optional[int]
def __init__(self):
super(M, self).__init__()
def forward(self, x):
return (
self.my_dict,
self.my_float,
self.my_int,
self.my_bool,
# self.my_tensor,
self.my_int_list,
# self.my_tensor_list,
self.my_bool_list,
self.my_float_list,
self.my_str_list,
self.my_empty_list,
self.my_empty_dict,
self.my_none,
self.my_object.a,
self.my_object.inner.b,
self.my_object.a,
self.my_object2.inner.b,
)
# TODO: as a followup, fix this test
# We can't define class attributes like we should be doing:
# class M(torch.nn.Module):
# my_empty_list : List[int]
# my_empty_dict : Dict[str, int]
# my_none : Optional[int]
# my_out_of_line_attribute: List[int] = [1, 2, 3]
# since there's no string frontend for Python classes (so the `define`)
# trick doesn't work.
M.__annotations__ = {
'my_empty_list': List[int],
'my_empty_dict': Dict[str, int],
'my_none': Optional[int],
'my_object': Foo,
'my_object2': SFoo,
}
m = M()
for name, value in untyped_values + typed_values:
setattr(m, name, value)
self.checkModule(m, (torch.randn(5, 5),))
def test_function_attribute_in_submodule(self):
class N(nn.Module):
def __init__(self, norm):
super(N, self).__init__()
self.activation = torch.nn.functional.relu
self.norm = norm
def forward(self, src):
output = src
output = self.norm(output)
return output
class M(nn.Module):
def __init__(self):
super(M, self).__init__()
encoder_norm = nn.ReLU()
self.encoder = N(encoder_norm)
def forward(self, x):
return self.encoder(x)
m = M()
self.checkModule(m, (torch.randn(5, 5), ))
def test_inner_traced_module(self):
class Dummy(nn.Module):
def forward(self, x):
return x
class Model(nn.Module):
def __init__(self, dummies):
super(Model, self).__init__()
self._dummies = dummies
def forward(self, x):
out = []
for dummy in self._dummies:
out.append(dummy(x))
return out
dummy = torch.jit.trace(Dummy(), torch.randn(1, 2))
dummies = nn.ModuleList([dummy])
model = Model(dummies)
self.checkModule(model, (torch.rand(5, 5), ))
def test_script_loaded_module(self):
"""
Test that we can hold a loaded ScriptModule as a submodule.
"""
class Dummy(nn.Module):
def forward(self, x):
return x
dummy = torch.jit.script(Dummy())
dummy = self.getExportImportCopy(dummy)
class ContainsLoaded(torch.nn.Module):
def __init__(self):
super(ContainsLoaded, self).__init__()
self.encoder = dummy
def forward(self, input):
return self.encoder(input)
self.checkModule(ContainsLoaded(), (torch.rand(2, 3), ))
def test_optional_module(self):
class Dummy(nn.Module):
def __init__(self):
super(Dummy, self).__init__()
self.foo = nn.Linear(2, 2)
def forward(self, x):
if self.foo is not None:
return self.foo(x)
return x
mod = Dummy()
self.checkModule(mod, (torch.rand(2, 2),))
mod.foo = None
self.checkModule(mod, (torch.rand(2, 2),))
def test_override_instance_method_ignore(self):
class M(torch.nn.Module):
@torch.jit.ignore
def i_am_ignored(self):
return "old"
m = M()
# Override the ignored method by binding a new method to this instance.
@torch.jit.ignore
def i_am_ignored(self):
return "new"
m.i_am_ignored = types.MethodType(i_am_ignored, m)
self.assertEqual(m.i_am_ignored(), "new")
# ScriptModule should correctly reflect the override.
s = torch.jit.script(m)
self.assertEqual(s.i_am_ignored(), "new")