blob: f253c2453b3bccf78d839808043c5bba71afe504 [file] [log] [blame]
# Owner(s): ["oncall: jit"]
import os
import sys
from typing import Any, List, Tuple
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.testing._internal.jit_utils import JitTestCase
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class TestModuleContainers(JitTestCase):
def test_sequential_intermediary_types(self):
class A(torch.nn.Module):
def __init__(self):
super(A, self).__init__()
def forward(self, x):
return x + 3
class B(torch.nn.Module):
def __init__(self):
super(B, self).__init__()
def forward(self, x):
return {"1": x}
class C(torch.nn.Module):
def __init__(self):
super(C, self).__init__()
self.foo = torch.nn.Sequential(A(), B())
def forward(self, x):
return self.foo(x)
self.checkModule(C(), (torch.tensor(1),))
def test_moduledict(self):
class Inner(torch.nn.Module):
def forward(self, x):
return x + 10
class Inner2(torch.nn.Module):
def forward(self, x):
return x * 2
class Inner3(torch.nn.Module):
def forward(self, x):
return (x - 4) * 3
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
modules = OrderedDict([
('one', Inner()),
('two', Inner2()),
('three', Inner3()),
])
self.moduledict = nn.ModuleDict(modules)
def forward(self, x, skip_name):
# type: (Tensor, str)
names = torch.jit.annotate(List[str], [])
values = []
for name in self.moduledict:
names.append(name)
for name, mod in self.moduledict.items():
if name != skip_name:
names.append(name)
x = mod(x)
values.append(x)
for mod in self.moduledict.values():
x = mod(x)
values.append(x)
for key in self.moduledict.keys():
names.append(key)
return x, names
class M2(M):
def __init__(self):
super(M2, self).__init__()
def forward(self, x, skip_name):
# type: (Tensor, str)
names = torch.jit.annotate(List[str], [])
values = []
x2 = x
iter = 0
for name in self.moduledict:
names.append(name)
for i, (name, mod) in enumerate(self.moduledict.items()):
iter += i
if name != skip_name:
names.append(name)
x = mod(x)
values.append(x)
for i, mod in enumerate(self.moduledict.values()):
iter += i
x = mod(x)
values.append(x)
for i, key in enumerate(self.moduledict.keys()):
iter += i
names.append(key)
for mod, mod in zip(self.moduledict.values(), self.moduledict.values()):
iter += i
x2 = mod(mod(x2))
return x, x2, names, iter
for name in ["", "one", "two", "three"]:
inp = torch.tensor(1)
self.checkModule(M(), (inp, name))
self.checkModule(M2(), (inp, name))
def test_custom_container_forward(self):
class Inner(torch.nn.Module):
def forward(self, x):
return x + 10
class CustomSequential(nn.Sequential):
def __init__(self):
super(CustomSequential, self).__init__(
nn.ReLU(), Inner())
def forward(self, x):
x = x + 3
for mod in self:
x = mod(x)
return x - 5
self.checkModule(CustomSequential(), (torch.tensor(.5),))
class CustomModuleList(nn.ModuleList):
def __init__(self):
super(CustomModuleList, self).__init__(
[nn.ReLU(), Inner()])
def forward(self, x):
x = x + 3
for mod in self:
x = mod(x)
return x - 5
self.checkModule(CustomModuleList(), (torch.tensor(.5),))
class CustomModuleDict(nn.ModuleDict):
def __init__(self):
super(CustomModuleDict, self).__init__(
OrderedDict([
('one', Inner()),
('two', nn.ReLU()),
('three', Inner()),
]))
def forward(self, x):
x = x + 3
names = torch.jit.annotate(List[str], [])
for name, mod in self.items():
x = mod(x)
names.append(name)
return names, x - 5
self.checkModule(CustomModuleDict(), (torch.tensor(.5),))
def test_script_module_list_sequential(self):
class M(torch.jit.ScriptModule):
def __init__(self, mod_list):
super(M, self).__init__()
self.mods = mod_list
@torch.jit.script_method
def forward(self, v):
for m in self.mods:
v = m(v)
return v
with torch.jit.optimized_execution(False):
m = M(nn.Sequential(nn.ReLU()))
self.assertExportImportModule(m, (torch.randn(2, 2),))
def test_script_modulelist_index(self):
class Sub(torch.nn.Module):
def __init__(self, i):
super(Sub, self).__init__()
self.i = i
def forward(self, thing):
return thing - self.i
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.mods = nn.ModuleList([Sub(i) for i in range(10)])
def forward(self, v):
v = self.mods[4].forward(v)
v = self.mods[-1].forward(v)
v = self.mods[-9].forward(v)
return v
x = torch.tensor(1)
self.checkModule(M(), (x,))
class MForward(torch.nn.Module):
def __init__(self):
super(MForward, self).__init__()
self.mods = nn.ModuleList([Sub(i) for i in range(10)])
def forward(self, v):
v = self.mods[4](v)
v = self.mods[-1](v)
v = self.mods[-9](v)
return v
self.checkModule(MForward(), (torch.tensor(1),))
class M2(M):
def __init__(self):
super(M2, self).__init__()
def forward(self, v):
return self.mods[-11].forward(v)
with self.assertRaisesRegexWithHighlight(Exception, "Index -11 out of range", "self.mods[-11]"):
torch.jit.script(M2())
class M3(M):
def __init__(self):
super(M3, self).__init__()
def forward(self, v):
i = 3
return self.mods[i].forward(v)
with self.assertRaisesRegexWithHighlight(Exception, "Enumeration is supported", "self.mods[i]"):
torch.jit.script(M3())
def test_module_interface_special_methods(self):
class CustomModuleInterface(torch.nn.Module):
def __init__(self):
super(CustomModuleInterface, self).__init__()
class CustomModuleList(CustomModuleInterface, torch.nn.ModuleList):
def __init__(self, modules=None):
CustomModuleInterface.__init__(self)
torch.nn.ModuleList.__init__(self, modules)
class CustomSequential(CustomModuleInterface, torch.nn.Sequential):
def __init__(self, modules=None):
CustomModuleInterface.__init__(self)
torch.nn.Sequential.__init__(self, modules)
class CustomModuleDict(CustomModuleInterface, torch.nn.ModuleDict):
def __init__(self, modules=None):
CustomModuleInterface.__init__(self)
torch.nn.ModuleDict.__init__(self, modules)
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
# work around aliasing issue for 'is' operator by scripting ReLU up front
self.submod = torch.jit.script(torch.nn.ReLU())
self.modulelist = CustomModuleList([self.submod])
self.sequential = CustomSequential(self.submod)
self.moduledict = CustomModuleDict({"submod": self.submod})
def forward(self, inputs):
assert self.modulelist[0] is self.submod, "__getitem__ failing for ModuleList"
assert len(self.modulelist) == 1, "__len__ failing for ModuleList"
for module in self.modulelist:
assert module is self.submod, "__iter__ failing for ModuleList"
assert self.sequential[0] is self.submod, "__getitem__ failing for Sequential"
assert len(self.sequential) == 1, "__len__ failing for Sequential"
for module in self.sequential:
assert module is self.submod, "__iter__ failing for Sequential"
assert self.moduledict["submod"] is self.submod, "__getitem__ failing for ModuleDict"
assert len(self.moduledict) == 1, "__len__ failing for ModuleDict"
# note: unable to index moduledict with a string variable currently
i = 0
for key in self.moduledict:
i += 1
assert i == len(self.moduledict), "iteration failing for ModuleDict"
assert "submod" in self.moduledict, "__contains__ fails for ModuleDict"
for key in self.moduledict.keys():
assert key == "submod", "keys() fails for ModuleDict"
for item in self.moduledict.items():
assert item[0] == "submod", "items() fails for ModuleDict"
assert item[1] is self.submod, "items() fails for ModuleDict"
for value in self.moduledict.values():
assert value is self.submod, "values() fails for ModuleDict"
return inputs
m = MyModule()
self.checkModule(m, [torch.randn(2, 2)])
def test_special_method_with_override(self):
class CustomModuleInterface(torch.nn.Module):
def __init__(self):
super(CustomModuleInterface, self).__init__()
class CustomModuleList(CustomModuleInterface, torch.nn.ModuleList):
def __init__(self, modules=None):
CustomModuleInterface.__init__(self)
torch.nn.ModuleList.__init__(self, modules)
def __len__(self):
# this is arbitrary, just to check that the overridden py __len__ from
# CustomModuleList takes precedence over the automatically generated
# __len__ added by the jit compiler
return 2
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
# work around aliasing issue for 'is' operator by scripting ReLU up front
self.submod = torch.jit.script(torch.nn.ReLU())
self.modulelist = CustomModuleList([self.submod])
def forward(self, inputs):
assert len(self.modulelist) == 2, "__len__ failing for ModuleList"
return inputs
m = MyModule()
self.checkModule(m, [torch.randn(2, 2)])
mm = torch.jit.script(m)
def test_moduledict_getitem(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.relu = torch.jit.script(torch.nn.ReLU())
self.tanh = torch.jit.script(torch.nn.Tanh())
self.moduledict = torch.nn.ModuleDict({"relu": self.relu,
"tanh": self.tanh})
def forward(self, input):
assert self.moduledict['relu'] is self.relu
assert self.moduledict['tanh'] is self.tanh
return input
m = MyModule()
self.checkModule(m, [torch.randn(2, 2)])
def test_moduledict_keyerror(self):
class BadModule(torch.nn.Module):
def __init__(self):
super(BadModule, self).__init__()
self.moduledict = torch.nn.ModuleDict({"foo": None,
"bar": None})
def forward(self, input):
assert self.moduledict['blah'] == "blah", "this is a keyerror"
with self.assertRaisesRegexWithHighlight(RuntimeError, "Key Error, blah", "self.moduledict['blah'"):
b = BadModule()
torch.jit.script(b)
class AnotherBadModule(torch.nn.Module):
def __init__(self):
super(AnotherBadModule, self).__init__()
self.moduledict = torch.nn.ModuleDict({"foo": None,
"bar": None})
def forward(self, input):
idx = 'blah'
assert self.moduledict[idx] == "blah", "this is a string literal error"
with self.assertRaisesRegexWithHighlight(RuntimeError, "Unable to extract string literal index. "
"ModuleDict indexing is only supported with string literals.",
"self.moduledict[idx]"):
b = AnotherBadModule()
torch.jit.script(b)
def test_normal_list_attribute_with_modules_error(self):
"""
Test that an attempt to script a module with a regular list attribute
containing other modules fails with a relevant error message.
"""
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = [torch.nn.ReLU(), torch.nn.ReLU()]
def forward(self):
return len(self.a)
error_msg = "Could not infer type of list element: Cannot infer concrete type of torch.nn.Module"
with self.assertRaisesRegexWithHighlight(RuntimeError, error_msg, "self.a"):
torch.jit.script(Mod())
def test_empty_dict_override_contains(self):
class CustomModuleInterface(torch.nn.Module):
def __init__(self):
super(CustomModuleInterface, self).__init__()
class CustomModuleDict(CustomModuleInterface, torch.nn.ModuleDict):
def __init__(self, modules=None):
CustomModuleInterface.__init__(self)
torch.nn.ModuleDict.__init__(self, modules)
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
# work around aliasing issue for 'is' operator by scripting ReLU up front
self.submod = torch.jit.script(torch.nn.ReLU())
self.moduledict = CustomModuleDict()
def forward(self, inputs):
assert "submod" not in self.moduledict, "__contains__ fails for ModuleDict"
return inputs
m = MyModule()
self.checkModule(m, [torch.randn(2, 2)])
def test_typed_module_dict(self):
"""
Test that a type annotation can be provided for a ModuleDict that allows
non-static indexing.
"""
@torch.jit.interface
class ModuleInterface(torch.nn.Module):
def forward(self, inp: Any) -> Any:
pass
class ImplementsInterface(torch.nn.Module):
def forward(self, inp: Any) -> Any:
if isinstance(inp, torch.Tensor):
return torch.max(inp, dim=0)
return inp
class DoesNotImplementInterface(torch.nn.Module):
def forward(self, inp: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.max(inp, dim=0)
# Test annotation of submodule.
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.d = torch.nn.ModuleDict({"module": ImplementsInterface()})
def forward(self, x: torch.Tensor, key: str) -> Any:
value: ModuleInterface = self.d[key]
return value.forward(x)
m = Mod()
self.checkModule(m, (torch.randn(2, 2), "module"))
# Test annotation of self.
class ModDict(torch.nn.ModuleDict):
def __init__(self):
super().__init__({"module": ImplementsInterface()})
def forward(self, x: torch.Tensor, key: str) -> Any:
submodule: ModuleInterface = self[key]
return submodule.forward(x)
m = ModDict()
self.checkModule(m, (torch.randn(2, 2), "module"))
# Test error message thrown when annotated attribute does not comply with the
# annotation.
class ModWithWrongAnnotation(torch.nn.ModuleDict):
def __init__(self):
super().__init__()
self.d = torch.nn.ModuleDict({"module": DoesNotImplementInterface()})
def forward(self, x: torch.Tensor, key: str) -> Any:
submodule: ModuleInterface = self.d[key]
return submodule.forward(x)
with self.assertRaisesRegexWithHighlight(RuntimeError, r"Attribute module is not of annotated type", "self.d[key]"):
torch.jit.script(ModWithWrongAnnotation())
def test_typed_module_list(self):
"""
Test that a type annotation can be provided for a ModuleList that allows
non-static indexing.
"""
@torch.jit.interface
class ModuleInterface(torch.nn.Module):
def forward(self, inp: Any) -> Any:
pass
class ImplementsInterface(torch.nn.Module):
def forward(self, inp: Any) -> Any:
if isinstance(inp, torch.Tensor):
return torch.max(inp, dim=0)
return inp
class DoesNotImplementInterface(torch.nn.Module):
def forward(self, inp: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.max(inp, dim=0)
# Test annotation of submodule.
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.l = torch.nn.ModuleList([ImplementsInterface()])
def forward(self, x: torch.Tensor, idx: int) -> Any:
value: ModuleInterface = self.l[idx]
return value.forward(x)
m = Mod()
self.checkModule(m, (torch.randn(2, 2), 0))
# Test annotation of self.
class ModList(torch.nn.ModuleList):
def __init__(self):
super().__init__([ImplementsInterface()])
def forward(self, x: torch.Tensor, idx: int) -> Any:
submodule: ModuleInterface = self[idx]
return submodule.forward(x)
m = ModList()
self.checkModule(m, (torch.randn(2, 2), 0))
# Test error message thrown when annotated attribute does not comply with the
# annotation.
class ModWithWrongAnnotation(torch.nn.ModuleList):
def __init__(self):
super().__init__()
self.l = torch.nn.ModuleList([DoesNotImplementInterface()])
def forward(self, x: torch.Tensor, idx: int) -> Any:
submodule: ModuleInterface = self.l[idx]
return submodule.forward(x)
with self.assertRaisesRegexWithHighlight(RuntimeError, r"Attribute 0 is not of annotated type", "self.l[idx]"):
torch.jit.script(ModWithWrongAnnotation())
def test_module_properties(self):
class ModuleWithProperties(torch.nn.Module):
__jit_unused_properties__ = ["ignored_attr"]
def __init__(self, a: int):
super().__init__()
self.a = a
def forward(self, a: int, b: int):
self.attr = a + b
return self.attr
@property
def attr(self):
return self.a
@property
def ignored_attr(self):
return sum([self.a])
@torch.jit.unused
@property
def ignored_attr_2(self):
return sum([self.a])
@ignored_attr_2.setter
def ignored_attr_2(self, value):
self.a = sum([self.a])
@attr.setter
def attr(self, a: int):
if a > 0:
self.a = a
else:
self.a = 0
class ModuleWithNoSetter(torch.nn.Module):
def __init__(self, a: int):
super().__init__()
self.a = a
def forward(self, a: int, b: int):
self.attr + a + b
@property
def attr(self):
return self.a + 1
self.checkModule(ModuleWithProperties(5), (5, 6,))
self.checkModule(ModuleWithProperties(5), (-5, -6,))
self.checkModule(ModuleWithNoSetter(5), (5, 6,))
self.checkModule(ModuleWithNoSetter(5), (-5, -6,))
mod = ModuleWithProperties(3)
scripted_mod = torch.jit.script(mod)
with self.assertRaisesRegex(AttributeError, "has no attribute"):
scripted_mod.ignored_attr
def test_module_inplace_construct(self):
class M(nn.Module):
def __init__(self, start: int):
super().__init__()
self.linear = nn.Linear(3, 3)
self.attribute = start
self.parameter = nn.Parameter(torch.tensor(3, dtype=torch.float))
def method(self) -> int:
return self.attribute
@torch.jit.unused
def unused_method(self):
return self.attribute + self.attribute
def forward(self, x):
return self.linear(self.linear(x))
class N(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(4, 4)
@torch.jit.ignore
def ignored_method(self, x):
return x
def forward(self, x):
return self.linear(x)
m = torch.jit.script(M(3))
n = torch.jit.script(N())
n._reconstruct(m._c)
inp = torch.rand((3))
# Check that both modules produce the same output.
with torch.no_grad():
m_out = m(inp)
n_out = n(inp)
self.assertEqual(m_out, n_out)
# Check that ignored method is still intact.
self.assertEqual(inp, n.ignored_method(inp))
def test_parameterlist_script_getitem(self):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.module_list = nn.ModuleList([nn.Linear(1, 1) for _ in range(10)])
self.parameter_list = nn.ParameterList([nn.Parameter(torch.zeros(1)) for _ in range(10)])
def forward(self, x):
self.module_list[0]
self.parameter_list[0]
return x
self.checkModule(MyModule(), (torch.zeros(1)))
def test_parameterlist_script_iter(self):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.module_list = nn.ModuleList([nn.Linear(1, 1) for _ in range(10)])
self.parameter_list = nn.ParameterList([nn.Parameter(torch.zeros(1)) for _ in range(10)])
def forward(self, x):
r = x
for i, p in enumerate(self.parameter_list):
r = r + p + i
return r
self.checkModule(MyModule(), (torch.zeros(1),))
def test_parameterdict_script_getitem(self):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.parameter_dict = nn.ParameterDict({k: nn.Parameter(torch.zeros(1)) for k in ['a', 'b', 'c']})
def forward(self, x):
return self.parameter_dict['a'] * x + self.parameter_dict['b'] * self.parameter_dict['c']
self.checkModule(MyModule(), (torch.ones(1),))