blob: c9fcce329ba0203bdea12206b3d8e0923aa855a8 [file] [log] [blame]
# Owner(s): ["oncall: jit"]
import os
import sys
import torch
from torch._C import parse_ir
from torch.testing import FileCheck
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
# Tests that Python slice class is supported in TorchScript
class TestIgnorableArgs(JitTestCase):
def test_slice_ignorable_args_for_slice(self):
graph_str = """graph():
%13 : int = prim::Constant[value=0]()
%10 : bool = prim::Constant[value=0]()
%8 : NoneType = prim::Constant()
%0 : int = prim::Constant[value=1]()
%1 : int = prim::Constant[value=2]()
%2 : int = prim::Constant[value=3]()
%3 : int = prim::Constant[value=4]()
%4 : int = prim::Constant[value=9]()
%5 : int[] = prim::ListConstruct(%0, %1, %2, %3, %4, %4)
%6 : int[] = prim::ListConstruct(%0, %1, %2, %3, %4, %4)
%7 : int[][] = prim::ListConstruct(%5, %6)
%val.1 : Tensor = aten::tensor(%7, %8, %8, %10)
%16 : Tensor = aten::slice(%val.1, %13, %1, %8, %0)
%20 : Tensor = aten::slice(%16, %0, %8, %0, %0)
return (%20)"""
graph = parse_ir(graph_str)
function = self.createFunctionFromGraph(graph)
function_copy = self.getExportImportCopy(function)
src = str(function.code)
# For a signature:
# aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor
# We ignore trailing arguments after start=2 for dim 0
# and after end=1 for dim 1
# because in %16, %15 and %0 are default values for the schema.
FileCheck().check("torch.slice(torch.slice(torch.tensor(_0), 0, 2), 1, None, 1)").run(src)
self.assertEqual(function(), function_copy())
def test_add_out_ignorable_args(self):
@torch.jit.script
def fn(x: torch.Tensor, y: torch.Tensor):
torch.add(x, y, out=y)
FileCheck().check("torch.add(x, y, out=y)").run(fn.code)