blob: 194e2abbbc2d329cb450d04d127d6525ec5e2dfd [file] [log] [blame]
# Owner(s): ["oncall: jit"]
from typing import List, Any
import torch
import torch.nn as nn
import os
import sys
from torch import Tensor
from torch.testing._internal.jit_utils import JitTestCase, make_global
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class OrigModule(nn.Module):
def __init__(self):
super(OrigModule, self).__init__()
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
return inp1 + inp2 + 1
def two(self, input: Tensor) -> Tensor:
return input + 2
def forward(self, input: Tensor) -> Tensor:
return input + self.one(input, input) + 1
class NewModule(nn.Module):
def __init__(self):
super(NewModule, self).__init__()
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
return inp1 * inp2 + 1
def forward(self, input: Tensor) -> Tensor:
return self.one(input, input + 1)
class TestModuleInterface(JitTestCase):
def test_not_submodule_interface_call(self):
@torch.jit.interface
class ModuleInterface(nn.Module):
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
pass
class TestNotModuleInterfaceCall(nn.Module):
proxy_mod : ModuleInterface
def __init__(self):
super(TestNotModuleInterfaceCall, self).__init__()
self.proxy_mod = OrigModule()
def forward(self, input: Tensor) -> Tensor:
return self.proxy_mod.two(input)
with self.assertRaisesRegexWithHighlight(RuntimeError, "object has no attribute or method", "self.proxy_mod.two"):
torch.jit.script(TestNotModuleInterfaceCall())
def test_module_interface(self):
@torch.jit.interface
class OneTwoModule(nn.Module):
def one(self, x: Tensor, y: Tensor) -> Tensor:
pass
def two(self, x: Tensor) -> Tensor:
pass
def forward(self, x: Tensor) -> Tensor:
pass
@torch.jit.interface
class OneTwoClass(object):
def one(self, x: Tensor, y: Tensor) -> Tensor:
pass
def two(self, x: Tensor) -> Tensor:
pass
class FooMod(nn.Module):
def one(self, x: Tensor, y: Tensor) -> Tensor:
return x + y
def two(self, x: Tensor) -> Tensor:
return 2 * x
def forward(self, x: Tensor) -> Tensor:
return self.one(self.two(x), x)
class BarMod(nn.Module):
def one(self, x: Tensor, y: Tensor) -> Tensor:
return x * y
def two(self, x: Tensor) -> Tensor:
return 2 / x
def forward(self, x: Tensor) -> Tensor:
return self.two(self.one(x, x))
@torch.jit.export
def forward2(self, x: Tensor) -> Tensor:
return self.two(self.one(x, x)) + 1
make_global(OneTwoModule, OneTwoClass)
def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor):
return mod_list[0].forward(x) + mod_list[1].forward(x)
def use_class_interface(mod_list: List[OneTwoClass], x: Tensor) -> Tensor:
return mod_list[0].two(x) + mod_list[1].one(x, x)
scripted_foo_mod = torch.jit.script(FooMod())
scripted_bar_mod = torch.jit.script(BarMod())
self.checkScript(use_module_interface,
([scripted_foo_mod, scripted_bar_mod], torch.rand(3, 4),))
self.checkScript(use_class_interface,
([scripted_foo_mod, scripted_bar_mod], torch.rand(3, 4),))
def call_module_interface_on_other_method(mod_interface: OneTwoModule, x: Tensor) -> Tensor:
return mod_interface.forward2(x)
# ensure error out when we call the module on the method other than the interface specified.
with self.assertRaisesRegexWithHighlight(RuntimeError, "object has no attribute or method", "mod_interface.forward2"):
self.checkScript(call_module_interface_on_other_method, (scripted_bar_mod, torch.rand(3, 4),))
def test_module_doc_string(self):
@torch.jit.interface
class TestInterface(nn.Module):
def one(self, inp1, inp2):
# type: (Tensor, Tensor) -> Tensor
pass
def forward(self, input):
# type: (Tensor) -> Tensor
r"""stuff 1"""
r"""stuff 2"""
pass
r"""stuff 3"""
class TestModule(nn.Module):
proxy_mod : TestInterface
def __init__(self):
super(TestModule, self).__init__()
self.proxy_mod = OrigModule()
def forward(self, input):
# type: (Tensor) -> Tensor
return self.proxy_mod.forward(input)
input = torch.randn(3, 4)
self.checkModule(TestModule(), (input,))
def test_module_interface_subtype(self):
@torch.jit.interface
class OneTwoModule(nn.Module):
def one(self, x: Tensor, y: Tensor) -> Tensor:
pass
def two(self, x: Tensor) -> Tensor:
pass
def forward(self, x: Tensor) -> Tensor:
pass
make_global(OneTwoModule)
@torch.jit.script
def as_module_interface(x: OneTwoModule) -> OneTwoModule:
return x
@torch.jit.script
class Foo(object):
def one(self, x: Tensor, y: Tensor) -> Tensor:
return x + y
def two(self, x: Tensor) -> Tensor:
return 2 * x
def forward(self, x: Tensor) -> Tensor:
return self.one(self.two(x), x)
# check class object is not a subtype of module interface
with self.assertRaisesRegex(RuntimeError, "ScriptModule class can be subtype of module interface"):
as_module_interface(Foo())
class WrongMod(nn.Module):
def two(self, x: int) -> int:
return 2 * x
def forward(self, x: Tensor) -> Tensor:
return x + torch.randn(3, self.two(3))
scripted_wrong_mod = torch.jit.script(WrongMod())
# wrong module that is not compatible with module interface
with self.assertRaisesRegex(RuntimeError, "is not compatible with interface"):
as_module_interface(scripted_wrong_mod)
# Check that interface implementations can be contravariant in argument types and covariant in return type.
@torch.jit.interface
class TensorToAny(nn.Module):
def forward(self, input: torch.Tensor) -> Any:
pass
make_global(TensorToAny)
@torch.jit.script
def as_tensor_to_any(x: TensorToAny) -> TensorToAny:
return x
@torch.jit.interface
class AnyToAny(nn.Module):
def forward(self, input: Any) -> Any:
pass
make_global(AnyToAny)
@torch.jit.script
def as_any_to_any(x: AnyToAny) -> AnyToAny:
return x
class TensorToAnyImplA(nn.Module):
def forward(self, input: Any) -> Any:
return input
class TensorToAnyImplB(nn.Module):
def forward(self, input: Any) -> torch.Tensor:
return torch.tensor([1])
class AnyToAnyImpl(nn.Module):
def forward(self, input: Any) -> torch.Tensor:
return torch.tensor([1])
as_tensor_to_any(torch.jit.script(TensorToAnyImplA()))
as_tensor_to_any(torch.jit.script(TensorToAnyImplB()))
as_any_to_any(torch.jit.script(AnyToAnyImpl()))
def test_module_interface_inheritance(self):
with self.assertRaisesRegex(RuntimeError, "does not support inheritance yet. Please directly"):
@torch.jit.interface
class InheritMod(nn.ReLU):
def three(self, x: Tensor) -> Tensor:
return 3 * x
def test_module_swap(self):
@torch.jit.interface
class ModuleInterface(nn.Module):
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
pass
def forward(self, input: Tensor) -> Tensor:
pass
class TestModule(nn.Module):
proxy_mod : ModuleInterface
def __init__(self):
super(TestModule, self).__init__()
self.proxy_mod = OrigModule()
def forward(self, input: Tensor) -> Tensor:
return self.proxy_mod.forward(input)
scripted_mod = torch.jit.script(TestModule())
input = torch.randn(3, 4)
self.assertEqual(scripted_mod(input), 3 * input + 2)
# module swap with module that have the same interface
scripted_mod.proxy_mod = torch.jit.script(NewModule())
self.assertEqual(scripted_mod(input), input * (input + 1) + 1)
# module swap with non-scripted module should throw error
with self.assertRaisesRegex(RuntimeError, "a ScriptModule with non-scripted module"):
scripted_mod.proxy_mod = NewModule()
def test_module_swap_wrong_module(self):
@torch.jit.interface
class ModuleInterface(nn.Module):
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
pass
def forward(self, input: Tensor) -> Tensor:
pass
class NewModuleWrong(nn.Module):
def __init__(self):
super(NewModuleWrong, self).__init__()
def forward(self, input: int) -> int:
return input + 1
class TestModule(nn.Module):
proxy_mod : ModuleInterface
def __init__(self):
super(TestModule, self).__init__()
self.proxy_mod = OrigModule()
def forward(self, input: Tensor) -> Tensor:
return self.proxy_mod.forward(input)
scripted_mod = torch.jit.script(TestModule())
# module swap with in-compatible interface
with self.assertRaisesRegex(RuntimeError, "is not compatible with interface"):
scripted_mod.proxy_mod = torch.jit.script(NewModuleWrong())
def test_module_swap_no_lazy_compile(self):
@torch.jit.interface
class ModuleInterface(nn.Module):
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
pass
def forward(self, input: Tensor) -> Tensor:
pass
class TestModule(nn.Module):
proxy_mod : ModuleInterface
def __init__(self):
super(TestModule, self).__init__()
self.proxy_mod = OrigModule()
def forward(self, input: Tensor) -> Tensor:
return self.proxy_mod.forward(input)
class NewModuleMethodNotLazyCompile(nn.Module):
def __init__(self):
super(NewModuleMethodNotLazyCompile, self).__init__()
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
return inp1 * inp2 + 1
def forward(self, input: Tensor) -> Tensor:
return input + 1
scripted_mod = torch.jit.script(TestModule())
# module swap with module that have the same interface, but the method not get
# lazily compiled from forward, user need to export it explicitly for swap to work
with self.assertRaisesRegex(RuntimeError, "is not compatible with interface"):
scripted_mod.proxy_mod = torch.jit.script(NewModuleMethodNotLazyCompile())
class NewModuleMethodManualExport(nn.Module):
def __init__(self):
super(NewModuleMethodManualExport, self).__init__()
@torch.jit.export
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
return inp1 * inp2 + 1
def forward(self, input: Tensor) -> Tensor:
return input + 1
scripted_mod.proxy_mod = torch.jit.script(NewModuleMethodManualExport())
input = torch.randn(3, 4)
self.assertEqual(scripted_mod(input), input + 1)
def test_module_swap_no_module_interface(self):
# test module swapping with no module interface
class TestNoModuleInterface(nn.Module):
def __init__(self):
super(TestNoModuleInterface, self).__init__()
self.proxy_mod = OrigModule()
def forward(self, input: Tensor) -> Tensor:
return self.proxy_mod(input)
scripted_no_module_interface = torch.jit.script(TestNoModuleInterface())
# proxy mod is swapped with the new ScriptModule that share the same JIT type, should succeed.
scripted_no_module_interface.proxy_mod = torch.jit.script(OrigModule())
# proxy_mod is neither a module interface or have the same JIT type, should fail
with self.assertRaisesRegex(RuntimeError,
r"Expected a value of type '__torch__.jit.test_module_interface.OrigModule \(.*\)' " +
r"for field 'proxy_mod', but found '__torch__.jit.test_module_interface.NewModule \(.*\)'"):
scripted_no_module_interface.proxy_mod = torch.jit.script(NewModule())
def test_script_module_as_interface_swap(self):
@torch.jit.interface
class ModuleInterface(nn.Module):
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
pass
def forward(self, input: Tensor) -> Tensor:
pass
class OrigScriptModule(torch.jit.ScriptModule):
def __init__(self):
super(OrigScriptModule, self).__init__()
@torch.jit.script_method
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
return inp1 + inp2 + 1
@torch.jit.script_method
def forward(self, input: Tensor) -> Tensor:
return input + self.one(input, input) + 1
class NewScriptModule(torch.jit.ScriptModule):
def __init__(self):
super(NewScriptModule, self).__init__()
@torch.jit.script_method
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
return inp1 * inp2 + 1
@torch.jit.script_method
def forward(self, input: Tensor) -> Tensor:
return self.one(input, input + 1)
class TestNNModuleWithScriptModule(nn.Module):
proxy_mod : ModuleInterface
def __init__(self):
super(TestNNModuleWithScriptModule, self).__init__()
self.proxy_mod = OrigScriptModule()
def forward(self, input: Tensor) -> Tensor:
return self.proxy_mod.forward(input)
input = torch.randn(3, 4)
scripted_mod = torch.jit.script(TestNNModuleWithScriptModule())
self.assertEqual(scripted_mod(input), 3 * input + 2)
scripted_mod.proxy_mod = NewScriptModule()
self.assertEqual(scripted_mod(input), input * (input + 1) + 1)
# The call to forward of proxy_mod cannot be inlined. Making sure
# Freezing is throwing an error for now.
def test_freeze_module_with_interface(self):
class SubModule(torch.nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.b = 20
def forward(self, x):
return self.b
class OrigMod(torch.nn.Module):
def __init__(self):
super(OrigMod, self).__init__()
self.a = 0
def forward(self, x):
return self.a
@torch.jit.interface
class ModInterface(torch.nn.Module):
def forward(self, x: Tensor) -> int:
pass
class TestModule(torch.nn.Module):
proxy_mod : ModInterface
def __init__(self):
super(TestModule, self).__init__()
self.proxy_mod = OrigMod()
self.sub = SubModule() # folded
def forward(self, x):
return self.proxy_mod(x) + self.sub(x)
m = torch.jit.script(TestModule())
m.eval()
mf = torch._C._freeze_module(m._c)
# Assume interface has no aliasing
mf = torch._C._freeze_module(m._c, freezeInterfaces=True)
input = torch.tensor([1])
out_s = m.forward(input)
out_f = mf.forward(input)
self.assertEqual(out_s, out_f)
def test_freeze_module_with_setattr_in_interface(self):
class SubModule(torch.nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.b = 20
def forward(self, x):
self.b += 2
return self.b
@torch.jit.export
def getb(self, x):
return self.b
class OrigMod(torch.nn.Module):
def __init__(self):
super(OrigMod, self).__init__()
self.a = 0
def forward(self, x):
return self.a
@torch.jit.interface
class ModInterface(torch.nn.Module):
def forward(self, x: Tensor) -> int:
pass
class TestModule(torch.nn.Module):
proxy_mod : ModInterface
def __init__(self):
super(TestModule, self).__init__()
self.proxy_mod = OrigMod()
self.sub = SubModule()
def forward(self, x):
return self.proxy_mod(x) + self.sub.getb(x)
m = torch.jit.script(TestModule())
m.proxy_mod = m.sub
m.eval()
mf = torch._C._freeze_module(m._c, freezeInterfaces=True)
def test_freeze_module_with_inplace_mutation_in_interface(self):
class SubModule(torch.nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.b = torch.tensor([1.5])
def forward(self, x):
self.b[0] += 2
return self.b
@torch.jit.export
def getb(self, x):
return self.b
class OrigMod(torch.nn.Module):
def __init__(self):
super(OrigMod, self).__init__()
self.a = torch.tensor([0.5])
def forward(self, x):
return self.a
@torch.jit.interface
class ModInterface(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
pass
class TestModule(torch.nn.Module):
proxy_mod : ModInterface
def __init__(self):
super(TestModule, self).__init__()
self.proxy_mod = OrigMod()
self.sub = SubModule()
def forward(self, x):
y = self.proxy_mod(x)
z = self.sub.getb(x)
return y[0] + z[0]
m = torch.jit.script(TestModule())
m.proxy_mod = m.sub
m.sub.b = m.proxy_mod.b
m.eval()
mf = torch._C._freeze_module(m._c, freezeInterfaces=True)
def test_freeze_module_with_mutated_interface(self):
class SubModule(torch.nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.b = torch.tensor([1.5])
def forward(self, x):
return self.b
@torch.jit.export
def getb(self, x):
return self.b
class OrigMod(torch.nn.Module):
def __init__(self):
super(OrigMod, self).__init__()
self.a = torch.tensor([0.5])
def forward(self, x):
return self.a
@torch.jit.interface
class ModInterface(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
pass
class TestModule(torch.nn.Module):
proxy_mod : ModInterface
def __init__(self):
super(TestModule, self).__init__()
self.proxy_mod = OrigMod()
self.sub = SubModule()
def forward(self, x):
self.proxy_mod = self.sub
y = self.proxy_mod(x)
z = self.sub.getb(x)
return y[0] + z[0]
m = torch.jit.script(TestModule())
m.eval()
with self.assertRaisesRegex(RuntimeError, "Freezing does not support SetAttr on an interface type."):
mf = torch._C._freeze_module(m._c, freezeInterfaces=True)
def test_freeze_module_with_interface_and_fork(self):
class SubModule(torch.nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.b = torch.tensor([1.5])
def forward(self, x):
self.b[0] += 3.2
return self.b
class OrigMod(torch.nn.Module):
def __init__(self):
super(OrigMod, self).__init__()
self.a = torch.tensor([0.5])
def forward(self, x):
return self.a
@torch.jit.interface
class ModInterface(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
pass
class TestModule(torch.nn.Module):
proxy_mod : ModInterface
def __init__(self):
super(TestModule, self).__init__()
self.proxy_mod = OrigMod()
self.sub = SubModule()
def forward(self, x):
y = self.proxy_mod(x)
z = self.sub(x)
return y + z
class MainModule(torch.nn.Module):
def __init__(self):
super(MainModule, self).__init__()
self.test = TestModule()
def forward(self, x):
fut = torch.jit._fork(self.test.forward, x)
y = self.test(x)
z = torch.jit._wait(fut)
return y + z
m = torch.jit.script(MainModule())
m.eval()
mf = torch._C._freeze_module(m._c, freezeInterfaces=True)
def test_module_apis_interface(self):
@torch.jit.interface
class ModuleInterface(nn.Module):
def one(self, inp1: Tensor, inp2: Tensor) -> Tensor:
pass
class TestModule(nn.Module):
proxy_mod : ModuleInterface
def __init__(self):
super(TestModule, self).__init__()
self.proxy_mod = OrigModule()
def forward(self, input):
return input * 2
@torch.jit.export
def method(self, input):
for module in self.modules():
input = module(input)
return input
with self.assertRaisesRegex(Exception, "Could not compile"):
scripted_mod = torch.jit.script(TestModule())