| # Owner(s): ["oncall: jit"] |
| |
| import io |
| import os |
| import sys |
| |
| import torch |
| import torch.nn as nn |
| |
| from typing import Any, Tuple |
| |
| # Make the helper files in test/ importable |
| pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) |
| sys.path.append(pytorch_test_dir) |
| from torch.testing._internal.jit_utils import JitTestCase, _inline_everything |
| from typing import List |
| from torch import Tensor |
| |
| class TestAsync(JitTestCase): |
| def test_async_python(self): |
| @torch.jit.script |
| def foo(x): |
| return torch.neg(x) |
| |
| x = torch.rand(3, 4) |
| fut = torch.jit.fork(foo, x) |
| y_hat = foo(x) |
| y = torch.jit.wait(fut) |
| # assert nothing; only to make sure the fake python path works |
| |
| def test_async_future_type_python(self): |
| def foo(inp): |
| futures = torch.jit.annotate(List[torch.jit.Future[torch.Tensor]], []) |
| for i in range(5): |
| futures.append(torch.jit.fork(lambda x: x, inp)) |
| all_outputs = [] |
| for future in futures: |
| all_outputs.append(torch.jit.wait(future)) |
| return all_outputs |
| |
| # assert nothing, just to make sure python type parsing works |
| foo(torch.randn(3, 4)) |
| |
| def test_async_parsing(self): |
| @torch.jit.script |
| def foo(x: Tensor) -> List[Tensor]: |
| return [torch.neg(x), x.t()] |
| |
| @torch.jit.script |
| def bar(x): |
| futures = torch.jit.annotate(List[Future[List[Tensor]]], []) |
| for _ in range(3): |
| future = torch.jit.annotate( |
| Future[List[Tensor]], |
| torch.jit.fork(foo, x) |
| ) |
| futures.append(future) |
| |
| output = torch.jit.annotate(List[List[Tensor]], []) |
| for i in range(3): |
| output.append(torch.jit.wait(futures[i])) |
| return output |
| |
| x = torch.rand(3, 3) |
| result = bar(x) |
| self.assertEqual(len(result), 3) |
| |
| def test_async_script(self): |
| @torch.jit.script |
| def foo(x): |
| return torch.neg(x), x |
| |
| x = torch.rand(3, 4) |
| |
| @torch.jit.script |
| def wait_script(x): |
| fut = torch.jit.fork(foo, x) |
| y_hat = foo(x) |
| y = torch.jit.wait(fut) |
| return y, y_hat |
| |
| y, y_hat = wait_script(x) |
| |
| self.assertEqual(y, y_hat) |
| |
| def test_async_script_capture(self): |
| class Mod(torch.jit.ScriptModule): |
| __constants__ = ['const'] |
| |
| def __init__(self): |
| super(Mod, self).__init__() |
| self.const = 42 |
| self.param = nn.Parameter(torch.randn(2, 2)) |
| |
| @torch.jit.script_method |
| def foo(self, x1, x2): |
| return torch.neg(x1), self.param, self.const, torch.neg(x2), self.param |
| |
| @torch.jit.script_method |
| def forward(self, x1, x2): |
| fut = torch.jit.fork(self.foo, x1, x2) |
| y_hat = self.foo(x1, x2) |
| y = torch.jit.wait(fut) |
| return y, y_hat |
| |
| x1 = torch.rand(3, 4) |
| x2 = torch.rand(5, 6) |
| |
| m = Mod() |
| |
| with torch.jit.optimized_execution(False): |
| y, y_hat = m.forward(x1, x2) |
| |
| self.assertEqual(y, y_hat) |
| |
| def test_async_script_nested(self): |
| @torch.jit.script |
| def foo(x): |
| return torch.neg(x), x |
| |
| x = torch.rand(3, 4) |
| |
| @torch.jit.script |
| def wait_script(x): |
| fut = torch.jit._fork(foo, x) |
| y_hat = foo(x) |
| y = torch.jit._wait(fut) |
| return y, y_hat |
| |
| @torch.jit.script |
| def wait_script_nest(x): |
| fut = torch.jit._fork(wait_script, x) |
| return torch.jit._wait(fut) |
| |
| y, y_hat = wait_script_nest(x) |
| |
| self.assertEqual(y, y_hat) |
| |
| def test_async_script_no_script_mod(self): |
| x = torch.rand(3, 4) |
| |
| with self.assertRaisesRegexWithHighlight(RuntimeError, 'cannot call a value', 'torch.jit._fork(x'): |
| @torch.jit.script |
| def wait_script(x): |
| fut = torch.jit._fork(x) |
| return fut |
| |
| def test_async_script_multi_waits(self): |
| @torch.jit.script |
| def foo(x): |
| return torch.neg(x).t() + x |
| |
| @torch.jit.script |
| def wait_script(x): |
| fut = torch.jit._fork(foo, x) |
| |
| # wait twice on the same future |
| y1 = torch.jit._wait(fut) |
| y2 = torch.jit._wait(fut) |
| return y1, y2 |
| |
| x = torch.rand(2, 2) |
| y1, y2 = wait_script(x) |
| self.assertEqual(y1, y2) |
| |
| def test_async_script_multi_forks(self): |
| @torch.jit.script |
| def foo1(x): |
| return torch.neg(x).t() + x |
| |
| @torch.jit.script |
| def foo2(x, y): |
| return torch.neg(x).t() + x + torch.neg(y).t() |
| |
| @torch.jit.script |
| def foo3(x, y, z): |
| return torch.neg(z).t() + y.t() + x |
| |
| x1 = torch.rand(10, 10) |
| x2 = torch.rand(10, 10) |
| x3 = torch.rand(10, 10) |
| |
| @torch.jit.script |
| def wait_script(x1, x2, x3): |
| f1 = torch.jit._fork(foo1, x1) |
| f2 = torch.jit._fork(foo2, x1, x2) |
| f3 = torch.jit._fork(foo3, x1, x2, x3) |
| f4 = torch.jit._fork(foo1, x2) |
| f5 = torch.jit._fork(foo2, x2, x3) |
| |
| # ignore some forks |
| y1 = torch.jit._wait(f1) |
| y2 = torch.jit._wait(f2) |
| y3 = torch.jit._wait(f3) |
| |
| return y1, y2, y3 |
| |
| y1, y2, y3 = wait_script(x1, x2, x3) |
| self.assertEqual(y1, foo1(x1)) |
| self.assertEqual(y2, foo2(x1, x2)) |
| self.assertEqual(y3, foo3(x1, x2, x3)) |
| |
| def test_async_kwargs(self): |
| def foo(x1, x2): |
| return 2 * x1 + x2 |
| |
| x1 = torch.rand(3, 4) |
| x2 = torch.rand(3, 4) |
| y_hat = foo(x1, x2) |
| |
| # Cover tracing and bare functions with permutations of args, kwargs |
| for func in [ |
| lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1, x2)), |
| lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1, x2=x2)), |
| lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1=x1, x2=x2)), |
| lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x2=x2, x1=x1)) |
| ]: |
| for wrapper in [ |
| func, |
| torch.jit.trace(func, (x1, x2)), |
| ]: |
| self.assertEqual(wrapper(x1, x2), y_hat) |
| self.assertEqual(wrapper(x1, x2=x2), y_hat) |
| self.assertEqual(wrapper(x1=x1, x2=x2), y_hat) |
| self.assertEqual(wrapper(x2=x2, x1=x1), y_hat) |
| |
| # Cover scripting |
| @torch.jit.script |
| def foo_script_args(x1, x2): |
| return torch.jit._wait(torch.jit._fork(foo, x1, x2)) |
| |
| @torch.jit.script |
| def foo_script_kwargs(x1, x2): |
| return torch.jit._wait(torch.jit._fork(foo, x1=x1, x2=x2)) |
| |
| for wrapper in [ |
| foo_script_args, |
| foo_script_kwargs, |
| ]: |
| self.assertEqual(wrapper(x1, x2), y_hat) |
| self.assertEqual(wrapper(x1, x2=x2), y_hat) |
| self.assertEqual(wrapper(x1=x1, x2=x2), y_hat) |
| self.assertEqual(wrapper(x2=x2, x1=x1), y_hat) |
| |
| @_inline_everything |
| def test_async_script_trace(self): |
| class Traced(nn.Module): |
| def __init__(self): |
| super(Traced, self).__init__() |
| |
| def forward(self, x): |
| return (torch.neg(x), x) |
| |
| class Mod(torch.jit.ScriptModule): |
| def __init__(self): |
| super(Mod, self).__init__() |
| x = torch.rand(3, 3) |
| self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True) |
| |
| @torch.jit.script_method |
| def forward(self, x: Tensor) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]: |
| future1 = torch.jit._fork(self.traced, x) |
| future2 = torch.jit._fork(torch.neg, x) |
| |
| tensor_tuple = torch.jit._wait(future1) |
| tensor_single = torch.jit._wait(future2) |
| |
| tensor_list = [] |
| tensor_list.append(tensor_tuple[0]) |
| tensor_list.append(tensor_single) |
| |
| # return a nested structure of tensors |
| return (tensor_list, tensor_tuple, tensor_tuple[1]) |
| |
| class TupleCl(nn.Module): |
| def __init__(self): |
| super(TupleCl, self).__init__() |
| self.module = Mod() |
| |
| def forward(self, x): |
| z = torch.neg(x) |
| y = self.module(x) |
| list = [z, y[0][0], y[0][1], y[1][0], y[1][1], y[2]] |
| return tuple(list) |
| |
| x = torch.rand(3, 3) |
| module = torch.jit.trace(TupleCl(), (x), _force_outplace=True) |
| |
| # Make sure we have forks |
| self.assertGraphContainsExactly(module.graph, kind='prim::fork', num_kind_nodes=2) |
| # Make sure 1 ::neg is in the root graph and 2 ::negs are in the subgraphs |
| self.assertGraphContainsExactly(module.graph, kind='aten::neg', num_kind_nodes=1) |
| self.assertGraphContainsExactly(module.graph, kind='aten::neg', num_kind_nodes=3, consider_subgraphs=True) |
| |
| y = torch.neg(x) |
| self.assertEqual(module(x), (y, y, y, y, x, x)) |
| |
| def test_async_script_error(self): |
| x = torch.rand(3, 4) |
| |
| @torch.jit.script |
| def foo(x): |
| # error here |
| return x.t() + x |
| |
| @torch.jit.script |
| def wait_script(x): |
| fut = torch.jit._fork(foo, x) |
| return torch.jit._wait(fut) |
| |
| @torch.jit.script |
| def wait_script_nest(x): |
| fut = torch.jit._fork(wait_script, x) |
| return torch.jit._wait(fut) |
| |
| # no future |
| error_msg = 'The size.*must match the size of tensor' |
| with self.assertRaisesRegexWithHighlight(Exception, error_msg, 'x.t() + x'): |
| foo(x) |
| |
| # one future |
| with self.assertRaisesRegexWithHighlight(Exception, error_msg, 'torch.jit._fork(foo, x'): |
| wait_script(x) |
| |
| # two futures with a different error |
| x = torch.rand(3, 4, 5) |
| with self.assertRaisesRegexWithHighlight(Exception, |
| 'expects a tensor with <= 2 dimensions', |
| 'torch.jit._fork(wait_script, x'): |
| wait_script_nest(x) |
| |
| def test_async_grad_guard_with_grad(self): |
| @torch.jit.script |
| def foo(x): |
| y = x * 2 |
| return y.requires_grad |
| |
| @torch.jit.script |
| def bar(x): |
| fut = torch.jit._fork(foo, x) |
| requires_grad_in_fork = torch.jit._wait(fut) |
| z = x * 2 |
| return (requires_grad_in_fork, z.requires_grad) |
| |
| x = torch.randn(3, requires_grad=True) |
| |
| with torch.enable_grad(): |
| (inside_fork, after_wait) = bar(x) |
| |
| self.assertEqual(inside_fork, True) |
| self.assertEqual(after_wait, True) |
| |
| def test_async_grad_guard_no_grad(self): |
| @torch.jit.script |
| def foo(x): |
| y = x * 2 |
| return y.requires_grad |
| |
| @torch.jit.script |
| def bar(x): |
| fut = torch.jit._fork(foo, x) |
| requires_grad_in_fork = torch.jit._wait(fut) |
| z = x * 2 |
| return (requires_grad_in_fork, z.requires_grad) |
| |
| x = torch.randn(3, requires_grad=True) |
| |
| with torch.no_grad(): |
| (inside_fork, after_wait) = bar(x) |
| |
| self.assertEqual(inside_fork, False) |
| self.assertEqual(after_wait, False) |
| |
| def test_trace_fork_wait(self): |
| def fork_body(x): |
| return x.neg(), x.neg() + 1 |
| |
| def fn(x): |
| fut = torch.jit._fork(fork_body, x) |
| vals = torch.jit._wait(fut) |
| return vals[0], vals[1], x - 1 |
| |
| traced = torch.jit.trace(fn, (torch.rand(3, 4),)) |
| x = torch.rand(3, 4) |
| self.assertEqual(fn(x), traced(x)) |
| |
| self.assertGraphContainsExactly(traced.graph, kind='prim::fork', num_kind_nodes=1) |
| self.assertGraphContainsExactly(traced.graph, kind='aten::wait', num_kind_nodes=1) |
| self.assertGraphContainsExactly(traced.graph, kind='aten::neg', num_kind_nodes=2, consider_subgraphs=True) |
| |
| def test_trace_fork_wait_leaking(self): |
| my_list = [] |
| |
| def fork_body(x): |
| my_list.append(x + 1) |
| return x + 1 |
| |
| def fn(x): |
| fut = torch.jit._fork(fork_body, x) |
| val = torch.jit._wait(fut) |
| return my_list[0] |
| |
| with self.assertRaisesRegexWithHighlight(RuntimeError, 'did not have observable data dependence with trace inputs; ' |
| 'this probably indicates your program cannot be understood ' |
| 'by the tracer.', ''): |
| traced = torch.jit.trace(fn, (torch.rand(3, 4),), check_trace=False) |
| |
| def test_trace_fork_wait_inline(self): |
| def fork_body(x): |
| return x + 1, x + 2 |
| |
| def fn(x): |
| fut = torch.jit._fork(fork_body, x) |
| val = torch.jit._wait(fut) |
| return val[1] |
| |
| traced = torch.jit.trace(fn, (torch.rand(3, 4),)) |
| torch._C._jit_pass_inline_fork_wait(traced.graph) |
| self.assertGraphContainsExactly(traced.graph, kind='prim::fork', num_kind_nodes=0) |
| self.assertGraphContainsExactly(traced.graph, kind='aten::wait', num_kind_nodes=0) |
| self.assertGraphContainsExactly(traced.graph, kind='aten::add', num_kind_nodes=2) |
| |
| def test_trace_fork_wait_inline_onnx(self): |
| def fork_body(x): |
| return torch.neg(x), torch.neg(x) |
| |
| class MyMod(torch.nn.Module): |
| def forward(self, x): |
| fut = torch.jit._fork(fork_body, x) |
| val = torch.jit._wait(fut) |
| return val[1] |
| |
| # smoke test for ONNX export |
| f = io.BytesIO() |
| torch.onnx.export(MyMod(), (torch.rand(3, 4),), f) |
| |
| def test_trace_fork_wait_list_modulecalls(self): |
| def add_one(input): |
| return input + torch.ones(input.size()) |
| |
| class TestListFutureModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, input): |
| input_list = [] |
| for i in range(3): |
| input_list.append(input) |
| |
| fut_list: List[Future[torch.Tensor]] = [] |
| for input_tensor in input_list: |
| fut_list.append(torch.jit._fork(add_one, input_tensor)) |
| # return list[future[tensor]] here to ensure tracing |
| # module calls return the correct types |
| return fut_list |
| |
| class TestModuleWrapper(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.list_fut_mod = TestListFutureModule() |
| |
| def forward(self, input): |
| fut_list = self.list_fut_mod(input) |
| res = input |
| for fut in fut_list: |
| res = res + fut.wait() |
| return res |
| |
| self.checkTrace(TestModuleWrapper(), (torch.randn(5, 5),)) |
| |
| def test_trace_modulecalls_with_different_output_types(self): |
| def add_one(input): |
| return input + torch.ones(input.size()) |
| |
| class DifferentOutputModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, input): |
| fut_res = torch.jit._fork(add_one, (input)) |
| |
| # return different types from module call |
| return input, fut_res |
| |
| class TestModule(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.gen_output = DifferentOutputModule() |
| |
| def forward(self, input): |
| res, fut_res = self.gen_output(input) |
| res = res + fut_res.wait() |
| return res |
| |
| self.checkTrace(TestModule(), (torch.randn(5, 5),)) |
| |
| def test_no_future_subtype_message(self): |
| with self.assertRaisesRegexWithHighlight(RuntimeError, 'Future without a contained type', ''): |
| @torch.jit.script |
| def forward(self, x): |
| futs = torch.jit.annotate(List[torch.jit.Future], []) |
| |
| def test_future_subtyping(self): |
| """ |
| Test that futures subtype each other properly. |
| """ |
| # Successful subtyping. |
| def returns_int(x: int) -> int: |
| return x + x + 1 |
| |
| def returns_future_any(x: int) -> torch.jit.Future[Any]: |
| return torch.jit._fork(returns_int, (x)) |
| |
| @torch.jit.script |
| def fn_int(x: int) -> Any: |
| fut = returns_future_any(x) |
| return fut.wait() |
| |
| # Unsuccessful subtyping. |
| with self.assertRaisesRegexWithHighlight( |
| RuntimeError, |
| r"was annotated as having type Future\[float\] but is actually of type Future\[int\]", |
| "fut = returns_future_float(x" |
| ): |
| def returns_future_float(x: int) -> torch.jit.Future[float]: |
| return torch.jit._fork(returns_int, (x)) |
| |
| @torch.jit.script |
| def fn_float(x: int) -> Any: |
| fut = returns_future_float(x) |
| return fut.wait() |
| |
| |
| |
| if __name__ == '__main__': |
| raise RuntimeError("This test file is not meant to be run directly, use:\n\n" |
| "\tpython test/test_jit.py TESTNAME\n\n" |
| "instead.") |