| # Owner(s): ["oncall: fx"] |
| |
| from __future__ import annotations # type: ignore[attr-defined] |
| import torch |
| import typing |
| from torch.fx import symbolic_trace |
| |
| class A: |
| def __call__(self, x: torch.Tensor): |
| return torch.add(x, x) |
| |
| # No forward references |
| class M1(torch.nn.Module): |
| def forward(self, x: torch.Tensor, a: A) -> torch.Tensor: |
| return a(x) |
| |
| # Forward references |
| class M2(torch.nn.Module): |
| def forward(self, x: 'torch.Tensor', a: 'A') -> 'torch.Tensor': |
| return a(x) |
| |
| # Non-torch annotation with no internal forward references |
| class M3(torch.nn.Module): |
| def forward(self, x: typing.List[torch.Tensor], a: A) -> torch.Tensor: |
| return a(x[0]) |
| |
| # Non-torch annotation with internal forward references |
| class M4(torch.nn.Module): |
| def forward(self, x: typing.List['torch.Tensor'], a: A) -> 'torch.Tensor': |
| return a(x[0]) |
| |
| x = torch.rand(2, 3) |
| |
| ref = torch.add(x, x) |
| |
| traced1 = symbolic_trace(M1()) |
| res1 = traced1(x, A()) |
| assert torch.all(torch.eq(ref, res1)) |
| |
| traced2 = symbolic_trace(M2()) |
| res2 = traced2(x, A()) |
| assert torch.all(torch.eq(ref, res2)) |
| |
| traced3 = symbolic_trace(M3()) |
| res3 = traced3([x], A()) |
| assert torch.all(torch.eq(ref, res3)) |
| |
| traced4 = symbolic_trace(M4()) |
| res4 = traced4([x], A()) |
| assert torch.all(torch.eq(ref, res4)) |