| # Owner(s): ["oncall: jit"] |
| |
| import os |
| import sys |
| |
| import torch |
| from torch.testing import FileCheck |
| |
| # Make the helper files in test/ importable |
| pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) |
| sys.path.append(pytorch_test_dir) |
| from torch.testing._internal.jit_utils import JitTestCase |
| |
| if __name__ == '__main__': |
| raise RuntimeError("This test file is not meant to be run directly, use:\n\n" |
| "\tpython test/test_jit.py TESTNAME\n\n" |
| "instead.") |
| |
| class TestFunctionalBlocks(JitTestCase): |
| def test_subgraph_creation(self): |
| def fn(x, y, z): |
| x = x + 1 |
| y = y + 1 |
| z = z + 1 |
| z.add_(2) |
| z = z * z |
| y = y * z |
| if y < 2: |
| y = y + 5 |
| return x + y + z |
| |
| graph = torch.jit.script(fn).graph |
| self.run_pass('create_functional_graphs', graph) |
| |
| # all uses of x and y should be sunk |
| FileCheck().check(r"%x").check_not(r"%x").check("FunctionalGraph").check(r"%x").run(graph) |
| FileCheck().check(r"%y").check_not(r"%y").check("FunctionalGraph").check(r"%y").run(graph) |
| |
| # Don't allow any outputs which escape scope, so there is one final addition in the graph |
| FileCheck().check("Tensor = prim::Functional").check_next("aten::add").run(graph) |
| |
| # z + 1, z.add_(2) considered non functional, z = z * z should be considered functional |
| FileCheck().check("add").check("add_").check_not("mul").check("FunctionalGraph").run(graph) |