blob: 1500ed27b7f2b359c88987b8dfcff2f18e691636 [file] [log] [blame]
# Owner(s): ["oncall: jit"]
import io
import torch
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.jit_utils import make_global
from typing import List, Optional, Tuple
from torch import Tensor
from torch._awaits import _Await as Await
class TestAwait(JitTestCase):
def test_await_python(self):
def foo(x: int) -> int:
return x + 13
aw: Await[int] = torch.jit._awaitable(foo, 13)
self.assertTrue(aw.fn()(*aw.args()) == torch.jit._awaitable_wait(aw))
nw = torch.jit._awaitable_nowait(33)
self.assertTrue(nw.is_nowait())
self.assertTrue(nw.args() == (33,))
def test_await_type_python(self):
def foo() -> Tensor:
return torch.randn()
awaits = torch.jit.annotate(List[Await[Tensor]], [])
awaits.append(torch.jit._awaitable(foo))
def test_script(self):
def delayed(z: int) -> int:
return z + 3
def fn(x: Tensor):
aw: Await[int] = torch.jit._awaitable(delayed, 99)
a = torch.eye(2)
b = torch.jit._awaitable_wait(aw)
return a + b + x
inp = torch.zeros(2)
sm = torch.jit.script(fn)
out = fn(inp)
script_out = sm(inp)
self.assertTrue(torch.allclose(torch.eye(2) + 102, script_out))
self.assertTrue(torch.allclose(script_out, out))
def test_nowait(self):
def fn(x: Tensor):
aw = torch.jit._awaitable_nowait(13)
a = torch.eye(2)
b = torch.jit._awaitable_wait(aw)
return a + b + x
inp = torch.zeros(2)
sm = torch.jit.script(fn)
out = fn(inp)
script_out = sm(inp)
self.assertTrue(torch.allclose(torch.eye(2) + 13, script_out))
self.assertTrue(torch.allclose(script_out, out))
def test_nowait_class(self):
class C:
def __init__(self, a: Tensor, b: Tensor):
self._a = a
self._b = b
def a(self) -> Tensor:
return self._a
def fn(x: Tensor):
aw = torch.jit._awaitable_nowait(C(torch.zeros(2), torch.ones(2)))
_a = torch.eye(2)
c = torch.jit._awaitable_wait(aw)
return _a + c.a() + x
make_global(C)
inp = torch.zeros(2)
sm = torch.jit.script(fn)
out = fn(inp)
script_out = sm(inp)
self.assertTrue(torch.allclose(torch.eye(2), script_out))
self.assertTrue(torch.allclose(script_out, out))
def test_await_class_arg(self):
class C:
def __init__(self, a: Tensor, b: Tensor):
self.__a = a
self.__b = b
def a(self) -> Tensor:
return self.__a
make_global(C)
def delayed(c: C) -> Tensor:
return c.a()
def fn(x: Tensor):
c = C(torch.zeros(2), torch.ones(2))
aw = torch.jit._awaitable(delayed, c)
_a = torch.eye(2)
c2_t = torch.jit._awaitable_wait(aw)
return _a + c2_t + x
inp = torch.zeros(2)
sm = torch.jit.script(fn)
out = fn(inp)
script_out = sm(inp)
self.assertTrue(torch.allclose(torch.eye(2), script_out))
self.assertTrue(torch.allclose(script_out, out))
def test_awaitable_to_await(self):
class C:
__slots__ = ["_a", "_b"]
def __init__(self, a: Tensor, b: Tensor):
self._a = a
self._b = b
make_global(C)
# Can not stay in the class as Jit does not support Recursive annotations
# (self in wait_impl can not be annotated as C as C is not defined by this time)
def C_wait_impl(self: C):
return self._a + self._b
def fn(x: Tensor):
aw = torch.jit._awaitable(C_wait_impl, C(torch.zeros(2), torch.ones(2)))
_a = torch.eye(2)
c_wait_impl_res = torch.jit._awaitable_wait(aw)
return _a + c_wait_impl_res + x
inp = torch.ones(2)
sm = torch.jit.script(fn)
out = fn(inp)
script_out = sm(inp)
self.assertTrue(torch.allclose(torch.eye(2) + 2 * torch.ones(2), script_out))
self.assertTrue(torch.allclose(script_out, out))
def test_await_class_return(self):
class C:
__slots__ = ["a", "b"]
def __init__(self, a: Tensor, b: Tensor):
self.a = a
self.b = b
make_global(C)
# Can not stay in the class as Jit does not support Recursive annotations
# (self in wait_impl can not be annotated as C as C is not defined by this time)
def C_wait_impl(self: C) -> C:
return C(self.a * 2, self.b * 3)
def fn_arg_C(x: C) -> Tensor:
return x.a + x.b
def fn(x: Tensor):
aw: Await[C] = torch.jit._awaitable(C_wait_impl, C(x, x))
_a = torch.eye(2)
y = fn_arg_C(torch.jit._awaitable_wait(aw))
return _a + y + x
inp = torch.ones(2)
sm = torch.jit.script(fn)
out = fn(inp)
script_out = sm(inp)
self.assertTrue(torch.allclose(torch.eye(2) + 6 * torch.ones(2), script_out))
self.assertTrue(torch.allclose(script_out, out))
self.assertGraphContainsExactly(sm.graph, kind='prim::awaitable_wait', num_kind_nodes=1)
def test_await_getattr_implicit_convertion(self):
class C:
def __init__(self, a: Tensor, b: Tensor):
self._a = a
self._b = b
def b(self):
return self._b
make_global(C)
# Can not stay in the class as Jit does not support Recursive annotations
# (self in wait_impl can not be annotated as C as C is not defined by this time)
def C_wait_impl(self: C) -> C:
return C(self._a * 2, self._b * 3)
def fn_arg_C(x: C) -> Tensor:
return x._a + x._b
def fn(x: Tensor):
aw: Await[C] = torch.jit._awaitable(C_wait_impl, C(x, x))
_a = torch.eye(2)
ai = aw._a
awb = aw.b()
c = C(2 * x, 2 * x)
return _a + ai + x + c._a + c.b()
inp = torch.ones(2)
sm = torch.jit.script(fn)
out = fn(inp)
script_out = sm(inp)
self.assertTrue(torch.allclose(torch.eye(2) + 7 * torch.ones(2), script_out))
self.assertTrue(torch.allclose(script_out, out))
self.assertGraphContainsExactly(sm.graph, kind='prim::awaitable_wait', num_kind_nodes=2)
def test_await_nested(self):
class C:
def __init__(self, a: Tensor, b: Tensor):
self.__a = a
self.__b = b
def a(self) -> Tensor:
return self.__a
make_global(C)
def delayed(c: C) -> Await[Tensor]:
return torch.jit._awaitable_nowait(3 * c.a())
def fn(x: Tensor) -> Await[Await[Tensor]]:
return torch.jit._awaitable(delayed, C(2 * x, x))
def main(x: Tensor) -> Tensor:
awaw = fn(x)
return torch.jit._awaitable_wait(torch.jit._awaitable_wait(awaw))
inp = torch.eye(2)
sm = torch.jit.script(main)
out = main(inp)
script_out = sm(inp)
self.assertTrue(torch.allclose(6 * torch.eye(2), script_out))
self.assertTrue(torch.allclose(script_out, out))
def test_eager_await_non_scriptable(self):
# Tree type can not be compiled (Recursive type)
class Tree:
def __init__(self, v):
self.parent = torch.jit.annotate(Optional[Tree], None)
self.v = v
make_global(Tree)
def delayed(t: Tree):
t.v = t.v + 1
return t
aw = torch.jit._awaitable(delayed, Tree(2))
t = torch.jit._awaitable_wait(aw)
self.assertTrue(t.v == 3)
def test_await_isinstance(self):
def delayed(x: Tensor) -> Tensor:
return 2 * (x + 1)
def main(x: Tensor) -> Tensor:
aw = torch.jit._awaitable(delayed, x)
if torch.jit.is_scripting():
assert isinstance(aw, torch.jit._Await)
return torch.jit._awaitable_wait(aw)
inp = torch.eye(2)
sm = torch.jit.script(main)
out = main(inp)
script_out = sm(inp)
self.assertTrue(torch.allclose(2 * torch.eye(2) + 2 * torch.ones(2), script_out))
self.assertTrue(torch.allclose(script_out, out))
def test_await_eager_lazy(self):
def delayed(x: Tensor) -> Tensor:
return 2 * (x + 1)
t = torch.ones(2, dtype=torch.int64)
aw = torch.jit._awaitable(delayed, t)
self.assertTrue(isinstance(aw, torch._C._Await))
self.assertTrue(t.dtype == aw.dtype)
def test_await_out_of_interpreter(self):
def delayed(x: Tensor) -> Tensor:
return 2 * (x + 1)
def main(x: Tensor) -> Await[Tensor]:
aw = torch.jit._awaitable(delayed, x)
return aw
inp = torch.eye(2)
sm = torch.jit.script(main)
out_aw = main(inp)
out = torch.jit._awaitable_wait(out_aw)
script_out_aw = sm(inp)
script_out = torch.jit._awaitable_wait(script_out_aw)
self.assertTrue(torch.allclose(2 * torch.eye(2) + 2 * torch.ones(2), script_out))
self.assertTrue(torch.allclose(script_out, out))
def test_jit_trace(self):
def gap(x: Tensor):
return torch.relu(x) + torch.sin(x)
def delayed(x: Tensor) -> Tensor:
return 2 * (torch.cos(x) + 1)
def main(x: Tensor, y: Tensor) -> Tensor:
aw = torch.jit._awaitable(delayed, x)
z = gap(y)
k = torch.jit._awaitable_wait(aw)
return y + k
inp = torch.randn(2)
tm = torch.jit.trace(main, (inp, inp))
inp_check = torch.ones(2)
self.assertEqual(main(inp_check, inp_check), tm(inp_check, inp_check))
def test_await_multiout_save(self):
def gap(x: Tensor):
return torch.relu(x) + torch.sin(x)
def delayed(x: Tensor) -> Tuple[Tensor, List[Tensor]]:
l = [x * i for i in range(5)]
return (100 * x, l)
def main(x: Tensor) -> Tensor:
aw = torch.jit._awaitable(delayed, x)
z = gap(x)
(_, l) = torch.jit._awaitable_wait(aw)
return l[3] + z
inp = torch.eye(2)
sm = torch.jit.script(main)
out = main(inp)
script_out = sm(inp)
expected = 4.8415 * torch.eye(2)
self.assertTrue(torch.allclose(expected, script_out))
self.assertTrue(torch.allclose(script_out, out))
iofile = io.BytesIO()
torch.jit.save(sm, iofile)
iofile.seek(0)
sm = torch.jit.load(iofile)
script_out_load = sm(inp)
self.assertTrue(torch.allclose(expected, script_out_load))
def test_await_func_arg(self):
def gap(x: Tensor):
return torch.relu(x) + torch.sin(x)
def delayed(x: Tensor) -> Tensor:
return -1 * x
def fn(aw: Await[Tensor]) -> Tensor:
return 3 * torch.jit._awaitable_wait(aw)
def main(x: Tensor) -> Tensor:
aw = torch.jit._awaitable(delayed, x)
z = gap(x)
y = fn(aw)
return y + x
inp = torch.eye(2)
sm = torch.jit.script(main)
out = main(inp)
script_out = sm(inp)
expected = -2 * torch.eye(2)
self.assertTrue(torch.allclose(expected, script_out))
self.assertTrue(torch.allclose(script_out, out))
iofile = io.BytesIO()
torch.jit.save(sm, iofile)
iofile.seek(0)
sm = torch.jit.load(iofile)
script_out_load = sm(inp)
self.assertTrue(torch.allclose(expected, script_out_load))