| # Owner(s): ["oncall: jit"] |
| |
| import os |
| import sys |
| import torch |
| from torch._C import parse_ir |
| from torch.testing import FileCheck |
| |
| # Make the helper files in test/ importable |
| pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) |
| sys.path.append(pytorch_test_dir) |
| from torch.testing._internal.jit_utils import JitTestCase |
| |
| if __name__ == '__main__': |
| raise RuntimeError("This test file is not meant to be run directly, use:\n\n" |
| "\tpython test/test_jit.py TESTNAME\n\n" |
| "instead.") |
| |
| # Tests that Python slice class is supported in TorchScript |
| class TestIgnorableArgs(JitTestCase): |
| def test_slice_ignorable_args_for_slice(self): |
| graph_str = """graph(): |
| %13 : int = prim::Constant[value=0]() |
| %10 : bool = prim::Constant[value=0]() |
| %8 : NoneType = prim::Constant() |
| %0 : int = prim::Constant[value=1]() |
| %1 : int = prim::Constant[value=2]() |
| %2 : int = prim::Constant[value=3]() |
| %3 : int = prim::Constant[value=4]() |
| %4 : int = prim::Constant[value=9]() |
| %5 : int[] = prim::ListConstruct(%0, %1, %2, %3, %4, %4) |
| %6 : int[] = prim::ListConstruct(%0, %1, %2, %3, %4, %4) |
| %7 : int[][] = prim::ListConstruct(%5, %6) |
| %val.1 : Tensor = aten::tensor(%7, %8, %8, %10) |
| %16 : Tensor = aten::slice(%val.1, %13, %1, %8, %0) |
| %20 : Tensor = aten::slice(%16, %0, %8, %0, %0) |
| return (%20)""" |
| graph = parse_ir(graph_str) |
| function = self.createFunctionFromGraph(graph) |
| function_copy = self.getExportImportCopy(function) |
| src = str(function.code) |
| # For a signature: |
| # aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor |
| # We ignore trailing arguments after start=2 for dim 0 |
| # and after end=1 for dim 1 |
| # because in %16, %15 and %0 are default values for the schema. |
| FileCheck().check("torch.slice(torch.slice(torch.tensor(_0), 0, 2), 1, None, 1)").run(src) |
| self.assertEqual(function(), function_copy()) |
| |
| def test_add_out_ignorable_args(self): |
| @torch.jit.script |
| def fn(x: torch.Tensor, y: torch.Tensor): |
| torch.add(x, y, out=y) |
| FileCheck().check("torch.add(x, y, out=y)").run(fn.code) |