blob: 109a5e3f1b7163e5238ada8be2f5d8d82a1bb882 [file] [log] [blame]
# Owner(s): ["oncall: jit"]
import os
import sys
import unittest
from typing import Tuple
import torch
from jit.test_hooks_modules import (
ModuleDirectforwardSubmodCall, ModuleForwardSingleInput,
ModuleForwardTupleInput, create_forward_tuple_input,
create_module_forward_multiple_inputs, create_module_forward_single_input,
create_module_hook_return_nothing,
create_module_multiple_hooks_multiple_inputs,
create_module_multiple_hooks_single_input, create_module_no_forward_input,
create_module_same_hook_repeated, create_submodule_forward_multiple_inputs,
create_submodule_forward_single_input,
create_submodule_forward_single_input_return_not_tupled,
create_submodule_hook_return_nothing,
create_submodule_multiple_hooks_multiple_inputs,
create_submodule_multiple_hooks_single_input,
create_submodule_no_forward_input, create_submodule_same_hook_repeated,
create_submodule_to_call_directly_with_hooks)
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
# Tests for JIT forward hooks and pre-hooks
class TestHooks(JitTestCase):
def test_module_no_forward_input(self):
self.checkModule(create_module_no_forward_input(), ())
def test_submodule_no_forward_input(self):
self.checkModule(create_submodule_no_forward_input(), ())
def test_module_forward_multiple_inputs(self):
self.checkModule(
create_module_forward_multiple_inputs(), (["a"], "no_pre_hook")
)
def test_module_multiple_hooks_multiple_inputs(self):
self.checkModule(
create_module_multiple_hooks_multiple_inputs(), (["a"], "no_pre_hook")
)
def test_module_forward_single_input(self):
self.checkModule(create_module_forward_single_input(), ("a",))
def test_module_same_hook_repeated(self):
self.checkModule(create_module_same_hook_repeated(), ("a",))
def test_module_hook_return_nothing(self):
self.checkModule(create_module_hook_return_nothing(), ("a",))
def test_module_multiple_hooks_single_input(self):
self.checkModule(create_module_multiple_hooks_single_input(), ("a",))
def test_submodule_forward_multiple_inputs(self):
self.checkModule(
create_submodule_forward_multiple_inputs(), (["a"], "no_pre_hook")
)
def test_submodule_multiple_hooks_multiple_inputs(self):
self.checkModule(
create_submodule_multiple_hooks_multiple_inputs(), (["a"], "no_pre_hook"),
)
def test_submodule_forward_single_input(self):
self.checkModule(create_submodule_forward_single_input(), ("a",))
def test_submodule_called_directly_with_hooks(self):
module = create_submodule_to_call_directly_with_hooks()
module_scripted = torch.jit.script(module)
submodule = module.submodule
scripted_submodule = module_scripted.submodule
self.assertEqual(submodule("a"), scripted_submodule("a"))
def test_submodule_same_hook_repeated(self):
self.checkModule(create_submodule_same_hook_repeated(), ("a",))
def test_submodule_hook_return_nothing(self):
self.checkModule(create_submodule_hook_return_nothing(), ("a",))
def test_submodule_multiple_hooks_single_input(self):
self.checkModule(create_submodule_multiple_hooks_single_input(), (["a"]))
def test_forward_tuple_input(self):
self.checkModule(create_forward_tuple_input(), ((3,),))
def test_submodule_forward_single_input_return_not_tupled(self):
self.checkModule(
create_submodule_forward_single_input_return_not_tupled(), ("a",)
)
def test_hook_method_name_collision(self):
# Hooks can't have the same name as methods.
m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
def foo(self, input: Tuple[str]) -> Tuple[str]:
assert self.name == "inner_mod_name"
assert input[0] == "a_outermod"
return ("pre_hook_override_name",)
m.submodule.register_forward_pre_hook(foo)
with self.assertRaisesRegex(
RuntimeError,
"Can't define hook: foo on class: .+ "
"because a method or hook with that name already exists.",
):
torch.jit.script(m)
def test_hook_hook_name_collision(self):
# Test edge case of two hooks sharing name but not python definition
m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
def prehook(self, input: Tuple[str]) -> Tuple[str]:
return "This is the first hook"
m.submodule.register_forward_pre_hook(prehook)
def prehook(self, input: Tuple[str]) -> Tuple[str]:
return "This is the second hook"
m.submodule.register_forward_pre_hook(prehook)
with self.assertRaisesRegex(
RuntimeError,
"Pre-hook '.+' on .+ has at least two different python "
"definitions. Please use unique names for all hooks.",
):
torch.jit.script(m)
m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
def hook(self, input: Tuple[str], output: str):
return "This is the first hook"
m.submodule.register_forward_hook(hook)
def hook(self, input: Tuple[str]):
return "This is the second hook"
m.submodule.register_forward_hook(hook)
with self.assertRaisesRegex(
RuntimeError,
"Hook '.+' on .+ has at least two different python "
"definitions. Please use unique names for all hooks.",
):
torch.jit.script(m)
def test_module_direct_forward_invocation(self):
# Test that hooks are only invoked when the module is
# called directly and not when forward is called.
m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
def pre_hook(self, input: Tuple[str]) -> Tuple[str]:
return ("pre_hook_override_name",)
def forward_hook(self, input: Tuple[str], output: str):
assert self.name == "outer_mod_name"
assert input == ("pre_hook_override_name",)
output = output + "_fh"
return output
m.register_forward_pre_hook(pre_hook)
m.register_forward_hook(forward_hook)
m_scripted = torch.jit.script(m)
self.assertEqual(m.forward("a"), m_scripted.forward("a"))
self.assertNotEqual(m_scripted("a"), m_scripted.forward("a"))
def test_submodule_direct_forward_invocation(self):
m_submod_forward_call = ModuleDirectforwardSubmodCall(
"outer_mod_name", "inner_mod_name"
)
m_submod_call = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
def pre_hook(self, input: Tuple[str]) -> Tuple[str]:
return ("pre_hook_override_name",)
def forward_hook(self, input: Tuple[str], output: str):
assert input == ("pre_hook_override_name",)
return output + "_fh"
m_submod_forward_call.submodule.register_forward_pre_hook(pre_hook)
m_submod_forward_call.submodule.register_forward_hook(forward_hook)
m_submod_call.submodule.register_forward_pre_hook(pre_hook)
m_submod_call.submodule.register_forward_hook(forward_hook)
m_submod_forward_call_scripted = torch.jit.script(m_submod_forward_call)
m_submod_call_scripted = torch.jit.script(m_submod_call)
self.assertEqual(
m_submod_forward_call_scripted("a"), m_submod_forward_call("a")
)
self.assertNotEqual(
m_submod_forward_call_scripted("a"), m_submod_call_scripted("a")
)
# TODO: add this test back once figured out how to print error msg
@unittest.skip
def test_hook_compilation_hint(self):
# Tests if hook error message is printed out if erroring after schema check.
# Useful for when user is scripting hooks while not aware of it.
m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
def pre_hook(self, input: Tuple[str]) -> Tuple[str]:
assert self.name == "outer_mod_name"
assert input[4] == "a" # out of bounds tuple range
return ("pre_hook_override_name",)
m.register_forward_pre_hook(pre_hook)
with self.assertRaisesRegex(
RuntimeError,
"This error occured while scripting the forward pre-hook 'pre_hook'",
):
torch.jit.script(m)
def test_wrong_pre_hook_signatures(self):
# correct signature: pre_hook_c(self, input: Tuple[str])
def pre_hook_wrong_input1(self, input: Tuple[None]) -> Tuple[str]:
return ("hello",)
m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
m.register_forward_pre_hook(pre_hook_wrong_input1)
with self.assertRaisesRegex(
RuntimeError, "has the wrong inner types for the input tuple argument",
):
torch.jit.script(m)
def pre_hook_wrong_input2(self, input: Tuple[str], input2: str) -> Tuple[str]:
return ("hello",)
m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
m.register_forward_pre_hook(pre_hook_wrong_input2)
with self.assertRaisesRegex(
RuntimeError,
"was expected to only have exactly 2 inputs but it had 3 inputs",
):
torch.jit.script(m)
def pre_hook_wrong_input3(self, input: int) -> Tuple[str]:
return ("hello",)
m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
m.register_forward_pre_hook(pre_hook_wrong_input3)
with self.assertRaisesRegex(
RuntimeError,
"expected the input argument to be typed as a Tuple but"
" found type: 'int' instead",
):
torch.jit.script(m)
def pre_hook_wrong_output(self, input: Tuple[str]) -> int:
return 1 # expecting Tuple[str], str, or None
m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
m.register_forward_pre_hook(pre_hook_wrong_output)
with self.assertRaisesRegex(
RuntimeError, "returned the wrong type of: 'int'",
):
torch.jit.script(m)
def pre_hook_no_output_annotation(self, input: Tuple[str]):
return 1 # expecting Tuple[str], str, or None
m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
m.register_forward_pre_hook(pre_hook_no_output_annotation)
with self.assertRaisesRegex(
RuntimeError,
"is missing a return annotation. Return annotations"
" are required, please add one.",
):
torch.jit.script(m)
def pre_hook_wrong_tuple_return(self, input: Tuple[Tuple[int]]) -> Tuple[int]:
return (11,) # doesn't work with eager, inner tuple lost
m = ModuleForwardTupleInput("outer_mod_name", "inner_mod_name")
m.register_forward_pre_hook(pre_hook_wrong_tuple_return)
with self.assertRaisesRegex(
RuntimeError,
"When forward has a single tuple input argument, "
"the return needs to be 'None' or a nested tuple containing "
r"forward's input tuple argument as in: 'Tuple\[Tuple\[int\]\]'",
):
torch.jit.script(m)
def test_wrong_hook_signatures(self):
# correct signature:
# def forward_hook(self, input: Tuple[str], output: str)
def forward_hook_wrong_input1(self, input: Tuple[str, str], output: str):
return output
m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
m.register_forward_hook(forward_hook_wrong_input1)
with self.assertRaisesRegex(
RuntimeError,
"has the wrong number of contained types for the "
r"input argument's Tuple. Received type: 'Tuple\[str, str\]'",
):
torch.jit.script(m)
def forward_hook_wrong_input2(self, input: str, output: str):
return output
m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
m.register_forward_hook(forward_hook_wrong_input2)
with self.assertRaisesRegex(
RuntimeError,
"expected the input argument to be typed as a Tuple "
"but found type: 'str' instead.",
):
torch.jit.script(m)
def forward_hook_wrong_input3(self, input: Tuple[None], output: str):
return output
m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
m.register_forward_hook(forward_hook_wrong_input3)
with self.assertRaisesRegex(
RuntimeError,
"has the wrong inner types for the input tuple"
r" argument. Received type: 'Tuple\[NoneType\]'",
):
torch.jit.script(m)
def forward_hook_wrong_output(self, input: Tuple[str], output: Tuple[str]):
return output
m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
m.register_forward_hook(forward_hook_wrong_output)
with self.assertRaisesRegex(
RuntimeError,
"has the wrong type for the output argument. Received"
r" type: 'Tuple\[str\]'. Expected type: 'str'",
):
torch.jit.script(m)
def forward_hook_correct(self, input: Tuple[str], output: str):
return (output,)
def forward_hook_wrong_output_from_prev_hook(
self, input: Tuple[str], output: str
):
return output
m = ModuleForwardSingleInput("outer_mod_name", "inner_mod_name")
m.register_forward_hook(forward_hook_correct)
m.register_forward_hook(forward_hook_wrong_output_from_prev_hook)
with self.assertRaisesRegex(
RuntimeError,
"has the wrong type for the output argument. "
r"Received type: 'str'. Expected type: 'Tuple\[str\]'",
):
torch.jit.script(m)