| # Owner(s): ["oncall: jit"] |
| |
| import torch |
| from torch.testing import FileCheck |
| from torch.testing._internal.jit_utils import JitTestCase |
| |
| if __name__ == "__main__": |
| raise RuntimeError( |
| "This test file is not meant to be run directly, use:\n\n" |
| "\tpython test/test_jit.py TestPythonBindings\n\n" |
| "instead." |
| ) |
| |
| |
| class TestPythonBindings(JitTestCase): |
| def test_cu_get_functions(self): |
| @torch.jit.script |
| def test_get_python_cu_fn(x: torch.Tensor): |
| return 2 * x |
| |
| cu = torch.jit._state._python_cu |
| self.assertTrue( |
| "test_get_python_cu_fn" in (str(fn.name) for fn in cu.get_functions()) |
| ) |
| |
| def test_cu_create_function(self): |
| @torch.jit.script |
| def fn(x: torch.Tensor): |
| return 2 * x |
| |
| cu = torch._C.CompilationUnit() |
| cu.create_function("test_fn", fn.graph) |
| |
| inp = torch.randn(5) |
| |
| self.assertEqual(inp * 2, cu.find_function("test_fn")(inp)) |
| self.assertEqual(cu.find_function("doesnt_exist"), None) |
| self.assertEqual(inp * 2, cu.test_fn(inp)) |
| with self.assertRaises(AttributeError): |
| cu.doesnt_exist(inp) |
| |
| def test_invalidation(self): |
| @torch.jit.script |
| def test_invalidation_fn(x: torch.Tensor): |
| return 2 * x |
| |
| gr = test_invalidation_fn.graph.copy() |
| n = gr.insertNode(gr.create("prim::profile")) |
| v = n.output() |
| # check that they work |
| str((n, v)) |
| torch._C._jit_pass_dce(gr) |
| with self.assertRaisesRegex(RuntimeError, "invalidated"): |
| str(n) |
| with self.assertRaisesRegex(RuntimeError, "invalidated"): |
| str(v) |
| |
| def test_graph_iterator_keepalive(self): |
| @torch.jit.script |
| def test_iterator_keepalive_fn(x: torch.Tensor): |
| return 2 * x |
| |
| # the list would segfault before because inlined_graph |
| # is temporary and had been deleted (see issue #50454) |
| n = test_iterator_keepalive_fn.inlined_graph.nodes() |
| list(n) |
| i = test_iterator_keepalive_fn.inlined_graph.inputs() |
| list(i) |
| o = test_iterator_keepalive_fn.inlined_graph.outputs() |
| list(o) |
| |
| def test_aliasdb(self): |
| @torch.jit.script |
| def test_aliasdb_fn(x: torch.Tensor): |
| return 2 * x |
| |
| gr = test_aliasdb_fn.graph.copy() |
| alias_db = gr.alias_db() |
| self.assertTrue("WILDCARD" in str(alias_db)) |
| self.assertTrue("digraph alias_db" in alias_db.to_graphviz_str()) |
| |
| def test_graph_create(self): |
| gr = torch._C.Graph() |
| with self.assertRaises(ValueError): |
| gr.create("prim::Constant", [None]) |
| |
| def test_canonicalize(self): |
| ir = """ |
| graph(%p207 : Tensor, |
| %1 : Tensor, |
| %p407 : int): |
| %11 : Tensor = aten::view_expand_placeholder(%1) |
| %12 : Tensor = aten::pointwise_placeholder(%11, %p207, %p407) |
| %13 : Tensor = aten::view_expand_placeholder(%12) |
| %14 : Tensor = aten::pointwise_placeholder(%13) |
| return (%14) |
| """ |
| |
| graph1 = torch._C.parse_ir(ir) |
| graph1 = torch._C._jit_pass_canonicalize(graph1, True) |
| |
| graph2 = torch._C.parse_ir(ir) |
| graph2 = torch._C._jit_pass_canonicalize(graph2) |
| |
| self.assertEqual(str(graph1), str(graph2)) |
| FileCheck().check("%p207").check_not("%14").run(graph1) |
| |
| graph3 = torch._C.parse_ir(ir) |
| graph3 = torch._C._jit_pass_canonicalize(graph3, False) |
| FileCheck().check_not("%p207").run(graph3) |