| # Owner(s): ["oncall: jit"] |
| |
| import unittest |
| import io |
| import os |
| import sys |
| import copy |
| |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.autograd import Variable, Function |
| from torch.testing import FileCheck |
| |
| # Make the helper files in test/ importable |
| pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) |
| sys.path.append(pytorch_test_dir) |
| from torch.testing._internal.common_utils import suppress_warnings, \ |
| skipIfCompiledWithoutNumpy, enable_profiling_mode_for_profiling_tests, \ |
| IS_SANDCASTLE, TemporaryFileName, skipIfCrossRef, skipIfTorchDynamo |
| from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, \ |
| _tmp_donotuse_dont_inline_everything, _trace, RUN_CUDA, \ |
| RUN_CUDA_MULTI_GPU, make_global |
| from torch.testing._internal.common_cuda import with_tf32_off |
| from torch import Tensor |
| |
| # Standard library |
| from collections import namedtuple |
| from itertools import chain |
| from typing import Dict, List, Optional, Tuple |
| import warnings |
| |
| if __name__ == '__main__': |
| raise RuntimeError("This test file is not meant to be run directly, use:\n\n" |
| "\tpython test/test_jit.py TESTNAME\n\n" |
| "instead.") |
| |
| class TestTracer(JitTestCase): |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| def test_large_nbr_kernel_args(self): |
| class Recurrence(nn.Module): |
| def __init__(self, seq_len): |
| super(Recurrence, self).__init__() |
| self.seq_len = seq_len |
| |
| def forward(self, input): |
| input = input.transpose(0, 1) |
| |
| # Main loop |
| output = [] |
| for i in range(self.seq_len): |
| b = input[i] * 2 |
| output.append(b) |
| |
| output = torch.cat(output, 0).view(input.size(0), *output[0].size()) |
| output = output.transpose(0, 1) |
| return output |
| |
| input_size = 8 |
| batch_size = 2 |
| seq_len = 130 |
| |
| rec = Recurrence(seq_len) |
| input = torch.rand(batch_size, seq_len, input_size) |
| |
| torch.cuda.set_device(0) |
| rec = rec.cuda() |
| input = input.cuda() |
| |
| traced_rec = torch.jit.trace(rec, (input)) |
| |
| def test_trace_legacy_ctor(self): |
| class MyModule(nn.Module): |
| def forward(self, x): |
| return (x + 1, torch.FloatTensor([0])) |
| |
| traced_rec = torch.jit.trace(MyModule(), torch.randn(2, 2)) |
| |
| def test_simple(self): |
| x = torch.tensor([0.4], requires_grad=True) |
| y = torch.tensor([0.7], requires_grad=True) |
| |
| def f(x, y): |
| return torch.sigmoid(torch.tanh(x * (x + y))) |
| |
| self.checkTrace(f, (x, y)) |
| |
| def test_trace_checking_with_global_name(self): |
| class MyClass(torch.nn.Module): |
| def __init__(self): |
| super(MyClass, self).__init__() |
| |
| def forward(self, xs: List[Tensor]): |
| y = torch.cat(xs, dim=0) |
| return y |
| |
| model = MyClass() |
| # Simulate these inputs being in the globals, like they would be if, |
| # e.g. they were defined outermost scope of a script |
| global input1, input2 |
| input1 = torch.ones(2, 2) |
| input2 = torch.ones(2, 2) |
| m2 = torch.jit.trace(model, ((input1, input2),)) |
| |
| def test_trace_aliased_parameter(self): |
| class M(nn.Module): |
| def __init__(self, x): |
| super(M, self).__init__() |
| self.x = nn.Parameter(x) |
| |
| def forward(self, y): |
| return self.x + y |
| |
| m = M(torch.rand(3, 4)) |
| r = torch.jit.trace(m, m.x) |
| t2 = torch.rand(3, 4) |
| self.assertEqual(r(t2), m.x + t2) |
| |
| def test_trace_nested_fn(self): |
| class TracedInlineDecision(torch.nn.Module): |
| def forward(self, x, flag): |
| @torch.jit.script |
| def make_decision(flag, x): |
| if flag: |
| return x |
| else: |
| return torch.zeros_like(x) |
| x = torch.neg(x) |
| return make_decision(flag, x) |
| |
| |
| decision = TracedInlineDecision() |
| torch.jit.trace(decision, (torch.rand(3, 4), torch.tensor([True], dtype=torch.bool)), check_trace=True) |
| |
| def test_trace_single_tuple(self): |
| x = torch.tensor(2.) |
| |
| def f2(x): |
| return (x,) |
| jit_f2 = torch.jit.trace(f2, x) |
| assert f2(x) == jit_f2(x) # fails |
| |
| def test_trace_namedtuple(self): |
| Point = namedtuple('point', ['x', 'y']) |
| |
| def f(p): |
| if type(p) is tuple: |
| p = Point(*p) |
| return p.x + p.y |
| |
| p = Point(torch.randn(1), torch.randn(1)) |
| traced = torch.jit.trace(f, (p,)) |
| self.assertEqual(f(p), traced(p)) |
| |
| def test_trace_topk(self): |
| class M(torch.nn.Module): |
| def forward(self, x, y): |
| return x.topk(y, dim=1)[1] |
| |
| mod = M() |
| inputs = (torch.randint(0, 10, (20, 20)), torch.tensor(17)) |
| traced_func = torch.jit.trace(mod, inputs) |
| |
| test_inputs = (torch.randint(0, 9, (9, 9)), torch.tensor(8)) |
| eager_out = mod(*test_inputs) |
| traced_out = traced_func(*test_inputs) |
| self.assertNotWarn(lambda: traced_func(*test_inputs), "Shouldn't throw slicing related warn here") |
| self.assertEqual(eager_out, traced_out) |
| |
| test_inputs = (torch.randint(0, 50, (50, 50)), torch.tensor(12)) |
| eager_out = mod(*test_inputs) |
| traced_out = traced_func(*test_inputs) |
| self.assertNotWarn(lambda: traced_func(*test_inputs), "Shouldn't throw slicing related warn here") |
| self.assertEqual(eager_out, traced_out) |
| |
| |
| def test_typeas_trace_check(self): |
| a = torch.tensor([0.4], requires_grad=True) |
| b = torch.tensor([0.7], requires_grad=True) |
| |
| def f(x, y): |
| return x.type_as(y) |
| |
| trace = torch.jit.trace(f, (a, b)) |
| |
| def test_trace_index(self): |
| x = torch.tensor([0.4], requires_grad=True) |
| y = torch.tensor([0], dtype=torch.int64) |
| |
| def fn(x, y): |
| return x[y] |
| |
| fn_traced = torch.jit.trace(fn, (x, y,)) |
| |
| self.assertEqual(fn(x, y), fn_traced(x, y)) |
| |
| # Backwards tracing was broken for indexing by a constant, |
| # because it's internally implemented using as_strided, |
| # and we attempted to trace its derivative (which is not |
| # currently supported.) It currently works because |
| # slice() is now not marked as traceable. |
| def test_trace_index_constant(self): |
| x = torch.tensor([0.4], requires_grad=True) |
| |
| def fn(x): |
| return x[0] |
| |
| def run(f): |
| y = f(x) |
| grad = torch.autograd.grad(y, x)[0].clone() |
| return y, grad |
| |
| traced_fn = torch.jit.trace(fn, torch.ones(1)) |
| self.assertEqual(run(fn), run(traced_fn)) |
| |
| def test_index_put(self): |
| ten = torch.zeros(3, 3) |
| mask = torch.tensor([[True, True, True], |
| [True, False, False], |
| [True, True, False]]) |
| |
| def test_fn(ten, mask): |
| ten[mask] = torch.ones(6) |
| return ten |
| |
| traced_test_fn = torch.jit.trace(test_fn, (ten, mask)) |
| |
| ten = torch.rand(3, 3) |
| self.assertEqual(test_fn(ten, mask), traced_test_fn(ten, mask)) |
| |
| def test_canonicalize_tensor_iterator(self): |
| x = torch.randn(4, 4) |
| |
| def f(x): |
| x = x + 2 |
| x = x - 4 |
| x = x * 6 |
| x = x / 8 |
| return x |
| |
| traced = torch.jit.trace(f, (x,)) |
| f(x) |
| graph = traced.graph_for(x) |
| # There should be 4 int constants for the right sides of operators, plus one |
| # for the alpha argument for add and sub |
| self.assertTrue(str(traced.graph_for(x)).count(': int = prim::Constant') == 5) |
| |
| @suppress_warnings |
| def test_constant(self): |
| x = torch.randn(2, 2, requires_grad=True) |
| |
| def f(x): |
| return x.matmul(torch.diag(torch.tensor([2., 2.]))) |
| |
| self.checkTrace(f, (x,), (torch.ones(2, 2, requires_grad=True),)) |
| |
| def test_wrapped_number(self): |
| # Scalar's get converted to 'wrapped' tensors of default tensor type. |
| # Wrapped tensors behave differently in certain promotion operations: |
| # float_tensor * double -> float but wrapped_float * double -> double. |
| # This can cause issues in check-trace if not handled correctly in |
| # `aten::isclose()`. |
| |
| def foobar(): |
| x = -10000.0 |
| result = x * torch.ones(1, dtype=torch.float) |
| return result |
| |
| scripted = torch.jit.trace(foobar, (), check_trace=True) |
| |
| |
| def test_inplace_transplant(self): |
| x = torch.tensor([0.], requires_grad=True) |
| |
| def fn(x): |
| y = x.clone() |
| y.add_(2) |
| y.add_(3) |
| return y |
| |
| g, _ = torch.jit._get_trace_graph(fn, (x,)) |
| self.run_pass('dce', g) |
| FileCheck().check_count("aten::clone", 1, exactly=True) \ |
| .check_count("aten::add_", 2, exactly=True) \ |
| .check_next("return").run(str(g)) |
| self.assertExportImport(g, (x,)) |
| |
| def test_inplace_flags(self): |
| class InplaceFn(Function): |
| @staticmethod |
| def forward(ctx, x): |
| ctx.mark_dirty(x) |
| return x.add_(1) |
| |
| @staticmethod |
| def backward(ctx, go): |
| return go |
| |
| class RegularFn(Function): |
| @staticmethod |
| def forward(ctx, x): |
| return x.add(1) |
| |
| @staticmethod |
| def backward(ctx, go): |
| return go |
| |
| x = torch.tensor([0.], requires_grad=True) |
| |
| def fn(x): |
| y = RegularFn.apply(x) |
| y = InplaceFn.apply(y) |
| y = InplaceFn.apply(y) |
| y = RegularFn.apply(y) |
| return y |
| |
| trace_graph, _ = torch.jit._get_trace_graph(fn, (x,), _force_outplace=True) |
| self.run_pass('dce', trace_graph) |
| ops = list(trace_graph.nodes()) |
| for op in ops: |
| self.assertTrue(op.hasAttribute('inplace')) |
| inplace_flags = [False, True, True, False] |
| for op, is_inplace in zip(ops, inplace_flags): |
| self.assertEqual(op.i('inplace'), is_inplace) |
| |
| def test_inplace_check(self): |
| class MyInplaceFn(Function): |
| @staticmethod |
| def forward(self, x): |
| x.add_(1) |
| self.mark_dirty(x) |
| return x |
| |
| @staticmethod |
| def backward(self, grad): |
| return grad |
| |
| def fn(x): |
| return MyInplaceFn.apply(x) |
| |
| x = torch.randn(5, 5) |
| ge = torch.jit.trace(fn, (x,), _force_outplace=True, check_trace=False) |
| with self.assertRaisesRegex(RuntimeError, 'inplace MyInplaceFn'): |
| ge(x) |
| |
| def test_force_outplace_check_fill(self): |
| def f(x): |
| return torch.empty(x.shape).fill_(7) |
| x = torch.randn(10, 15) |
| ft = torch.jit.trace(f, x, _force_outplace=True) |
| self.assertEqual(f(x), ft(x)) |
| |
| def test_force_outplace_check_zero(self): |
| def f(x): |
| return torch.empty(x.shape).zero_() |
| x = torch.randn(10, 15) |
| ft = torch.jit.trace(f, x, _force_outplace=True) |
| self.assertEqual(f(x), ft(x)) |
| |
| def do_trace_size(self, requires_grad): |
| def fn(x): |
| return x.view(x.shape[1] * 2, x.size(0), 2) |
| |
| x = torch.randn(5, 2, 4, requires_grad=requires_grad) |
| y = torch.randn(4, 8, 4, requires_grad=requires_grad) |
| |
| # Check that it behaves as expected |
| traced_fn = torch.jit.trace(fn, x) |
| self.assertEqual(traced_fn(y), fn(y)) |
| self.assertEqual(traced_fn(x), fn(x)) |
| |
| def test_trace_size(self): |
| self.do_trace_size(False) |
| |
| # test the different graph_executor path that happens when |
| # gradients are required and sizes are involved |
| def test_trace_size_with_grad(self): |
| self.do_trace_size(True) |
| |
| def test_trace_numel(self): |
| def fn(x): |
| return x.numel() |
| |
| x = torch.randn(2, 3, 4) |
| y = torch.randn(4, 5, 6) |
| |
| traced_fn = torch.jit.trace(fn, x) |
| self.assertEqual(traced_fn(y), fn(y)) |
| self.assertEqual(traced_fn(x), fn(x)) |
| |
| def do_trace_arange(self, requires_grad): |
| def arange(x): |
| return torch.arange(x.shape[0]) |
| |
| def arange_scalar(x): |
| return torch.arange(12) |
| |
| def arange_start_end(x): |
| return torch.arange(start=x.shape[0], end=x.shape[0] + 5) |
| |
| x = torch.randn(5, 3, 2, requires_grad=requires_grad) |
| y = torch.randn(8, 2, 4, requires_grad=requires_grad) |
| |
| # Check that it behaves as expected |
| traced_arange = torch.jit.trace(arange, x) |
| self.assertEqual(traced_arange(y), arange(y)) |
| self.assertEqual(traced_arange(x), arange(x)) |
| |
| traced_arange_scalar = torch.jit.trace(arange_scalar, x) |
| self.assertEqual(traced_arange_scalar(y), arange_scalar(y)) |
| self.assertEqual(traced_arange_scalar(x), arange_scalar(x)) |
| |
| traced_arange_start_end = torch.jit.trace(arange_start_end, x) |
| self.assertEqual(traced_arange_start_end(y), arange_start_end(y)) |
| self.assertEqual(traced_arange_start_end(x), arange_start_end(x)) |
| |
| def test_trace_arange(self): |
| self.do_trace_arange(False) |
| |
| # test the different graph_executor path that happens when |
| # gradients are required and sizes are involved |
| def test_trace_arange_with_grad(self): |
| self.do_trace_arange(True) |
| |
| # Test that a trace of torch.full(x.shape) doesn't store the shape as a constant |
| def test_trace_full_dynamic_shape(self): |
| def full_with_shape_like(x): |
| return torch.full(x.shape, 2.) |
| |
| x = torch.randn(3, 4) |
| ge = torch.jit.trace(full_with_shape_like, example_inputs=x) |
| y = torch.randn(2, 7) |
| self.assertEqual(ge(y).shape, y.shape) |
| self.assertEqual(ge(x).shape, x.shape) |
| |
| # Test that the trace of setitem doesn't store shapes as constants |
| # Fix https://github.com/pytorch/pytorch/issues/43548 |
| def test_trace_slice_setitem_dynamic_shape(self): |
| def slice_setitem(x, y): |
| x[:, 2] = y + 1 |
| return x |
| |
| x = torch.randn(3, 4) |
| traced = torch.jit.trace(slice_setitem, (x, x[:, 0])) |
| x = torch.randn(10, 5) |
| self.assertEqual(traced(x.clone(), x[:, 0]), slice_setitem(x.clone(), x[:, 0])) |
| |
| # Suppression: we are intentionally slicing a tensor, we don't care that it |
| # will be constantified |
| @suppress_warnings |
| def do_trace_slice(self, requires_grad): |
| def slice(x): |
| results = [] |
| for i in range(4): |
| results.append(x[:x.size(0) - i, i:x.size(2), i:3]) |
| return tuple(results) |
| |
| def slice_select(x): |
| results = [] |
| for i in range(4): |
| results.append(x[:, i:, x.size(2) - 5]) |
| return tuple(results) |
| |
| x = torch.randn(5, 6, 7, requires_grad=requires_grad) |
| y = torch.randn(7, 8, 9, requires_grad=requires_grad) |
| |
| # Check that it behaves as expected |
| traced_slice = torch.jit.trace(slice, x) |
| self.assertEqual(traced_slice(y), slice(y)) |
| self.assertEqual(traced_slice(x), slice(x)) |
| |
| traced_slice_select = torch.jit.trace(slice_select, x) |
| self.assertEqual(traced_slice_select(y), slice_select(y)) |
| self.assertEqual(traced_slice_select(x), slice_select(x)) |
| |
| def test_trace_slice(self): |
| self.do_trace_slice(False) |
| |
| # test the different graph_executor path that happens when |
| # gradients are required and sizes are involved |
| def test_trace_slice_with_grad(self): |
| self.do_trace_slice(True) |
| |
| |
| def test_trace_casts(self): |
| casts = [ |
| lambda x: x.byte(), |
| lambda x: x.float(), |
| lambda x: x.cpu(), |
| lambda x: x.to(device='cpu'), |
| lambda x: x.to(dtype=torch.int64), |
| lambda x: x.to(device='cpu', dtype=torch.float), |
| lambda x: x.to(x) |
| ] |
| |
| def assertContainsCast(trace): |
| self.assertEqual(sum(n.kind() == 'aten::to' for n in trace.graph.nodes()), 1) |
| |
| for cast in casts: |
| trace = torch.jit.trace(cast, torch.randn(2, 2)) |
| assertContainsCast(trace) |
| x = torch.randn(2, 2) |
| self.assertEqual(trace(x), cast(x)) |
| |
| def to_tensor(x, y): |
| return x.to(y) |
| |
| to_tensor_trace = torch.jit.trace(to_tensor, (torch.randn(2, 2), torch.randn(1, 8))) |
| assertContainsCast(to_tensor_trace) |
| x, y = torch.randn(2, 2), torch.randn(1, 10) |
| self.assertEqual(to_tensor_trace(x, y), to_tensor(x, y)) |
| |
| @skipIfCompiledWithoutNumpy |
| @skipIfCrossRef |
| def test_trace_warn(self): |
| def fn(x): |
| int(x) # Warning 1. |
| y = x * 1 |
| if y: # Warning 2. |
| pass |
| q = [x, x * 4] |
| z = q[y] |
| float(z) # Warning 3. |
| z.tolist() # Warning 4. |
| z.numpy() # Warning 5. |
| for _ in torch.ones(4, 4): # Warning 6. |
| pass |
| return z + 4 |
| |
| with warnings.catch_warnings(record=True) as warns: |
| traced_fn = torch.jit.trace(fn, torch.tensor([1])) |
| for warn in warns: |
| self.assertIs(warn.category, torch.jit.TracerWarning) |
| warns = [str(w.message) for w in warns] |
| self.assertIn('a Python integer', warns[0]) |
| self.assertIn('a Python boolean', warns[1]) |
| self.assertIn('a Python float', warns[2]) |
| self.assertIn('a Python list', warns[3]) |
| self.assertIn('a NumPy array', warns[4]) |
| self.assertIn('Iterating over', warns[5]) |
| |
| def test_trace_tuple(self): |
| def fn(x, y): |
| return x, (x * y[1], x * y[0]) |
| |
| x, y = torch.randn(2, 2), (torch.ones(2, 2), torch.randn(2, 2)) |
| traced_fn = torch.jit.trace(fn, (x, y)) |
| self.assertEqual(traced_fn(x, y), fn(x, y)) |
| # should be a tuple nested within another tuple |
| FileCheck().check_count("prim::TupleConstruct", 2, exactly=True).check_next("return") \ |
| .run(str(traced_fn.graph)) |
| self.assertExportImport(traced_fn.graph, (x, y)) |
| |
| def test_trace_random(self): |
| def f(mean, std): |
| return torch.normal(mean, std) |
| |
| traced = torch.jit.trace(f, (torch.zeros(2, 3), torch.ones(2, 3)), check_trace=False) |
| mean, std = torch.zeros(5, 5), torch.ones(5, 5) |
| with torch.random.fork_rng(devices=[]): |
| output = f(mean, std) |
| traced_output = traced(mean, std) |
| self.assertEqual(output, traced_output) |
| |
| def test_trace_tensor_factory(self): |
| def run(**kwargs): |
| inputs_require_grads = kwargs.pop('inputs_require_grads', True) |
| |
| def fn(x): |
| return x + torch.ones(2, 3, **kwargs) |
| |
| input_kwargs = kwargs.copy() |
| if 'out' in input_kwargs: |
| del input_kwargs['out'] |
| input = torch.ones(2, 3, **input_kwargs) |
| self.checkTrace(fn, (input,), inputs_require_grads=inputs_require_grads) |
| # check we recorded 'ones' and did not just record a constant |
| tfn = torch.jit.trace(fn, input) |
| self.assertTrue("ones" in str(tfn.graph)) |
| run() |
| run(dtype=torch.int, inputs_require_grads=False) |
| run(out=torch.tensor([])) |
| if RUN_CUDA: |
| run(device="cuda:0") |
| if RUN_CUDA_MULTI_GPU: |
| run(device="cuda:1") |
| |
| def test_trace_indexed_assignment(self): |
| def stuff(x, y): |
| x = x.clone() |
| x[0] = y |
| return x |
| example = torch.rand(3, 4) |
| self.checkTrace(stuff, (example, example[0] + 1)) |
| |
| # TODO: implement |
| @unittest.expectedFailure |
| def test_output_unflatten(self): |
| """Check that outputs of traced functions retain the original structure and nesting""" |
| def fn(x): |
| return (x * 2, (x ** 2, x + 4, (x + 2,), ), x * 4) |
| |
| self.checkTrace(fn, (torch.randn(2, 2),)) |
| |
| def test_input_flatten(self): |
| """Check that inputs to traced functions are flattened""" |
| |
| def fn(x, t): |
| y, z = t |
| return x * y * z |
| |
| inputs = (torch.randn(1), (torch.randn(1), torch.randn(1))) |
| self.checkTrace(fn, inputs) |
| |
| def test_input_dict_empty(self): |
| def test(d): |
| pass |
| |
| with self.assertRaises(RuntimeError): |
| self.checkTrace(test, {}) |
| |
| def test_input_dict_remembers_keys(self): |
| """Check that the trace remembers which keys were in a dict input""" |
| class TestModule(torch.nn.Module): |
| def __init__(self): |
| super(TestModule, self).__init__() |
| |
| def forward(self, dict_input): |
| return dict_input['x'] |
| |
| input_1 = {'x': torch.tensor(1)} |
| m = TestModule() |
| m_traced = torch.jit.trace(m, (input_1, )) |
| self.assertEqual(m_traced(input_1), torch.tensor(1)) |
| |
| # should work to change the values and not the keys |
| input_same_key_different_value = {'x': torch.tensor(2)} |
| self.assertEqual(m_traced(input_same_key_different_value), torch.tensor(2)) |
| |
| # error to use something that doesn't have `x` |
| input_different_key = {'y': torch.tensor(3)} |
| with self.assertRaises(RuntimeError): |
| m_traced(input_different_key) |
| |
| # it's okay to have additional elements in the dictionary, so long as 'x' is there |
| input_additional_key = {'x': torch.tensor(4), 'y': torch.tensor(3)} |
| self.assertEqual(m_traced(input_additional_key), torch.tensor(4)) |
| |
| def test_input_dict_insertion_order(self): |
| """Check that dictionary access doesn't care about insertion order""" |
| class TestModule(torch.nn.Module): |
| def __init__(self): |
| super(TestModule, self).__init__() |
| |
| def forward(self, dict_input): |
| return dict_input['x'], dict_input['y'] |
| input_x_then_y = {} |
| input_x_then_y['x'] = torch.tensor(1) |
| input_x_then_y['y'] = torch.tensor(2) |
| |
| m = TestModule() |
| m_traced = torch.jit.trace(m, (input_x_then_y, )) |
| |
| self.assertEqual(m_traced(input_x_then_y), (torch.tensor(1), torch.tensor(2))) |
| |
| input_y_then_x = {} |
| input_y_then_x['y'] = torch.tensor(4) |
| input_y_then_x['x'] = torch.tensor(3) |
| |
| self.assertEqual(m_traced(input_y_then_x), (torch.tensor(3), torch.tensor(4))) |
| |
| def test_input_dict_recursive(self): |
| class TestModule(torch.nn.Module): |
| def __init__(self): |
| super(TestModule, self).__init__() |
| |
| def forward(self, dict_input): |
| return dict_input['x'][1] |
| |
| input_1 = {'x': {1: torch.tensor(1)}} |
| m = TestModule() |
| m_traced = torch.jit.trace(m, (input_1, )) |
| |
| input_2 = {'x': {1: torch.tensor(2)}} |
| self.assertEqual(m_traced(input_2), torch.tensor(2)) |
| |
| def test_input_dict_checkTrace_mut(self): |
| def test(d): |
| d['x'].tanh_() |
| return d['x'] |
| inputs = {'x': torch.rand(3, 4), 'y': torch.rand(3, 4)} |
| self.checkTrace(test, (inputs,), inputs_require_grads=False) |
| |
| def test_input_dict_unify(self): |
| def test(d): |
| return d['int'], d['float'] |
| inputs = {'int': torch.ones((2, 2), dtype=torch.int32), |
| 'float': torch.ones((2, 2), dtype=torch.float32)} |
| self.checkTrace(test, (inputs,), inputs_require_grads=False) |
| |
| def test_input_tuple_of_dicts(self): |
| def test(t): |
| d = t[0] |
| return d['x']['y'] |
| inputs = {'x': {'y': torch.rand(2, 3)}} |
| self.checkTrace(test, ((inputs, inputs),), allow_unused=True) |
| |
| def test_input_dict_of_dicts(self): |
| def test(d): |
| return d['x']['y'] |
| nested_input = {'y': torch.rand(2, 3)} |
| unified_nested = {'y': torch.rand(3, 2)} |
| inputs = {'x': nested_input, 'force_unify': unified_nested} |
| self.checkTrace(test, (inputs,), allow_unused=True) |
| |
| def test_input_dict_of_lists(self): |
| def test(d): |
| return d['x'][0] |
| |
| inputs = {'x': [torch.rand(3, 2)]} |
| self.checkTrace(test, (inputs,)) |
| |
| def test_input_list_toplevel_flatten(self): |
| def test(t1, t2): |
| return torch.add(t1, t2) |
| |
| inputs = [torch.ones(2, 2), torch.rand(2, 2)] |
| self.checkTrace(test, inputs) |
| |
| def test_input_list_toplevel_flatten_direct(self): |
| class Test(torch.nn.Module): |
| def forward(self, t1, t2): |
| return torch.add(t1, t2) |
| inputs = [torch.ones(2, 2), torch.rand(2, 2)] |
| torch.jit.trace(Test(), inputs) |
| |
| def test_input_list_of_tuples(self): |
| def test(l): |
| return l[0][0] |
| inputs = [(torch.ones(2, 2),)] |
| self.checkTrace(test, (inputs,)) |
| |
| def test_input_dict_empty_list(self): |
| def test(d): |
| pass |
| inputs = {1: []} |
| with self.assertRaisesRegex(RuntimeError, 'List trace'): |
| self.checkTrace(test, (inputs,)) |
| |
| def test_input_list_mixed_type(self): |
| def test(d): |
| pass |
| inputs = [torch.rand(2, 3), (torch.ones(2), torch.ones(2))] |
| with self.assertRaisesRegex(RuntimeError, 'consistent'): |
| self.checkTrace(test, (inputs,)) |
| |
| def test_conv(self): |
| x = torch.ones(20, 16, 50, 40) |
| g, outputs, inputs = torch.jit._get_trace_graph(nn.Conv2d(16, 13, 3, bias=False), x, return_inputs=True) |
| m = self.createFunctionFromGraph(g) |
| self.assertEqual(outputs, m(*inputs)) |
| |
| def test_max_pool(self): |
| x = torch.rand(20, 16, 10, 10) |
| |
| def max_pool2d(x): |
| return F.max_pool2d(x, 2) + 2 |
| |
| trace = torch.jit.trace(max_pool2d, (x)) |
| graph = trace.graph_for(x) |
| FileCheck().check("aten::max_pool2d(").run(graph) |
| self.assertEqual(max_pool2d(x), trace(x)) |
| |
| def test_nested_inplace(self): |
| x = torch.randn(2, 2) |
| g, outputs, inputs = torch.jit._get_trace_graph( |
| lambda x: F.threshold(x, 0, 0, inplace=True), (x, ), return_inputs=True) |
| m = self.createFunctionFromGraph(g) |
| self.assertEqual(outputs, m(*inputs)) |
| FileCheck().check("threshold_").run(str(g)) |
| self.assertExportImport(g, (x,)) |
| |
| def test_repeated_input(self): |
| def fn(a, b): |
| return a + b |
| |
| ge = self.checkTrace(fn, [torch.randn(2, 2)] * 2) |
| inputs = set(ge.graph.inputs()) |
| # three instead of 2 because the export/import in checkTrace adds a |
| # `self` module argument |
| self.assertTrue(len(inputs) == 3) |
| |
| def test_repeated_output(self): |
| def fn(a, b): |
| z = a + b |
| return z, z |
| |
| ge = self.checkTrace(fn, [torch.randn(2, 2) for _ in range(2)]) |
| tuple_output = list(ge.graph.outputs())[0] |
| tuple_inputs = list(tuple_output.node().inputs()) |
| self.assertTrue(tuple_inputs[0] == tuple_inputs[1]) |
| |
| def test_inplace_copy(self): |
| x = torch.randn(4, 4, requires_grad=True) |
| |
| def f(x): |
| out = torch.zeros(x.size()) |
| out.copy_(x) |
| return out |
| |
| g, outputs, inputs = torch.jit._get_trace_graph(f, (x, ), return_inputs=True) |
| self.run_pass('dce', g) |
| m = self.createFunctionFromGraph(g) |
| self.assertEqual(outputs, m(*inputs)) |
| self.assertExportImport(g, (x,)) |
| |
| def test_inplace_copy_force_outplace(self): |
| x = torch.randn(4, 4, requires_grad=True) |
| |
| def f(x): |
| out = torch.zeros(x.size()) |
| out.copy_(x) |
| return out |
| |
| g, outputs, inputs = torch.jit._get_trace_graph( |
| f, (x, ), return_inputs=True, _force_outplace=True) |
| self.run_pass('dce', g) |
| m = self.createFunctionFromGraph(g) |
| self.assertEqual(outputs, m(*inputs)) |
| self.assertExportImport(g, (x,)) |
| FileCheck().check("expand_as").run(str(g)) |
| |
| def test_shared_param(self): |
| class MyModule(torch.nn.Module): |
| def __init__(self): |
| super(MyModule, self).__init__() |
| self.b = self.a = nn.Parameter(torch.randn(2, 2)) |
| |
| def forward(self, x): |
| return x * self.a + self.b |
| |
| m = MyModule() |
| g, _ = torch.jit._get_trace_graph(m, (torch.randn(2, 2),)) |
| self.run_pass('dce', g) |
| self.assertEqual(len(list(g.inputs())), 2) |
| FileCheck().check("mul").check("add").run(str(g)) |
| |
| def test_trace_c10_ops(self): |
| try: |
| _ = torch.ops._caffe2.GenerateProposals |
| except AttributeError: |
| self.skipTest("Skip the test since c2 ops are not registered.") |
| |
| class MyModel(torch.nn.Module): |
| def __init__(self): |
| super(MyModel, self).__init__() |
| |
| def forward(self, scores, bbox_deltas, im_info, anchors): |
| a, b = torch.ops._caffe2.GenerateProposals( |
| (scores), (bbox_deltas), (im_info), (anchors), |
| 2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0, True, |
| ) |
| return a, b |
| model = MyModel() |
| A = 4 |
| H = 10 |
| W = 8 |
| img_count = 3 |
| scores = torch.ones(img_count, A, H, W, dtype=torch.float32) |
| bbox_deltas = torch.linspace(0, 10, steps=img_count * 4 * A * H * W, |
| dtype=torch.float32) |
| bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W) |
| im_info = torch.ones(img_count, 3, dtype=torch.float32) |
| anchors = torch.ones(A, 4, dtype=torch.float32) |
| inputs = (scores, bbox_deltas, im_info, anchors) |
| traced_model = torch.jit.trace(model, inputs) |
| self.assertEqual(traced_model(*inputs), model(*inputs)) |
| self.assertExportImportModule(traced_model, (scores, bbox_deltas, im_info, anchors)) |
| |
| def run_ge_tests(self, optimize, use_cuda): |
| |
| with enable_profiling_mode_for_profiling_tests(): |
| with torch.jit.optimized_execution(optimize): |
| def rand(*args): |
| t = torch.rand(*args).float() |
| if use_cuda: |
| t = t.cuda() |
| return t |
| self.checkTrace(lambda a, b: a * b + b, |
| [rand(1), rand(1)], [rand(2, 3), rand(2, 3)]) |
| # trivial identity |
| self.checkTrace(lambda a, b: (b, a), [rand(1), rand(1)]) |
| |
| def foo(a): |
| t = a * a |
| return t * t, 4 * t |
| self.checkTrace(foo, [rand(1)]) |
| # unused input |
| self.checkTrace( |
| lambda a, b: a * a, [rand(1), rand(1)], allow_unused=True) |
| # test outputs that do not get used in grad |
| self.checkTrace(foo, [rand(1)], drop=1) |
| # test autograd fallback |
| self.checkTrace(lambda a, b: a * b / |
| (a - 2 * b) + b, [rand(1), rand(1)]) |
| |
| def test_ge_unoptimized(self): |
| self.run_ge_tests(False, False) |
| |
| @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle") |
| @enable_cpu_fuser |
| def test_ge_optimized(self): |
| with enable_profiling_mode_for_profiling_tests(): |
| self.run_ge_tests(True, False) |
| |
| @unittest.skipIf(not RUN_CUDA, "requires CUDA") |
| def test_ge_cuda(self): |
| self.run_ge_tests(True, True) |
| |
| # more manual test of graph executor that can be used as a scratchpad |
| def test_ge(self): |
| def foo(a, b): |
| return a * b / (a - b) + b |
| V = Variable |
| a, b = V(torch.rand(1)), V(torch.rand(1)) |
| ge = torch.jit.trace(foo, (a, b)) |
| a, b = V(torch.rand(1), requires_grad=True), V( |
| torch.rand(1), requires_grad=True) |
| r, = ge(a, b) |
| da, db = torch.autograd.grad(r + 3, [a, b], create_graph=True) |
| |
| l2 = (da * db + db * db) |
| g2result = torch.autograd.grad(l2, [da, db]) |
| |
| r = foo(a, b) |
| da2, db2 = torch.autograd.grad(r + 3, [a, b], create_graph=True) |
| self.assertEqual(da, da2) |
| self.assertEqual(db, db2) |
| l3 = (da2 * db2 + db2 * db2) |
| g2result2 = torch.autograd.grad(l3, [da2, db2]) |
| self.assertEqual(g2result, g2result2) |
| |
| def test_trace_annotation(self): |
| @_trace(torch.rand(1)) |
| def foo(a): |
| return a + a + a |
| |
| x = torch.randn(5, 5) |
| self.assertEqual(foo(x), x + x + x) |
| |
| @unittest.skipIf(not RUN_CUDA, "calls .cuda()") |
| # By default, on Ampere or later GPUs, nn.Linear computes float tensors at TF32 precision. |
| # We want float tensors to be computed at full precision in order to use the default precision |
| @with_tf32_off |
| def test_traced_module_cuda(self): |
| class Model(nn.Module): |
| def __init__(self, num_features, num_layers): |
| super(Model, self).__init__() |
| self.num_layers = num_layers |
| layers = [[nn.Linear(num_features, num_features), nn.Sigmoid()] |
| for _ in range(num_layers)] |
| self.submodule = nn.Sequential(*chain(*layers)) |
| |
| def forward(self, x): |
| for i in range(self.num_layers): |
| x = self.submodule[i](x) + x |
| return x |
| |
| model = Model(5, 3) |
| x = torch.randn(2, 5) |
| traced_model = torch.jit.trace(model, x) |
| |
| # We're missing some attributes these modules had initially. Make sure we can |
| # still get the __repr__() |
| model.__repr__() |
| |
| # XXX: indexing sequentials is broken |
| linear_submodule = next(iter(traced_model.submodule._modules.values())) |
| |
| # All attributes that aren't parameters should raise |
| with self.assertRaises(AttributeError): |
| linear_submodule.in_features |
| linear_submodule.weight |
| linear_submodule.weight = nn.Parameter(torch.randn(linear_submodule.weight.shape)) |
| with self.assertRaises(RuntimeError): |
| del linear_submodule.weight |
| |
| # Submodules can't be called |
| with self.assertRaises(RuntimeError): |
| linear_submodule(x) |
| |
| # Type casts |
| linear_submodule.cuda() |
| traced_model.float().cuda() |
| cuda_out = traced_model(x.float().cuda()) |
| traced_model.cpu() |
| cpu_out = traced_model(x.float()) |
| self.assertEqual(cpu_out, cuda_out) |
| traced_model.to('cuda') |
| cuda_out = traced_model(x.float().cuda()) |
| traced_model.to('cpu') |
| cpu_out = traced_model(x.float()) |
| self.assertEqual(cpu_out, cuda_out) |
| traced_model.double() |
| |
| # state_dict + load_state_dict |
| state = {k: v.clone() for k, v in traced_model.state_dict().items()} |
| new_state = {k: v.clone().fill_(1) for k, v in state.items()} |
| out = traced_model(x) |
| traced_model.load_state_dict(new_state) |
| out_ones = traced_model(x) |
| traced_model.load_state_dict(state) |
| out_state = traced_model(x) |
| self.assertEqual(out, out_state) |
| self.assertNotEqual(out, out_ones) |
| |
| def test_export_no_reorder(self): |
| def func(a, b): |
| return a * b / (a - 2 * b) + b |
| |
| recording_inputs = [torch.tensor([0.55619788169860839844], dtype=torch.float32, requires_grad=True), |
| torch.tensor([0.25947844982147216797], dtype=torch.float32, requires_grad=True)] |
| |
| ge1 = torch.jit.trace(func, recording_inputs) |
| ge2 = self.getExportImportCopy(ge1) |
| |
| outputs_ge1 = ge1(*recording_inputs) |
| outputs_ge2 = ge2(*recording_inputs) |
| |
| grad_ge1 = torch.autograd.grad(outputs_ge1, recording_inputs) |
| grad_ge2 = torch.autograd.grad(outputs_ge2, recording_inputs) |
| self.assertTrue(outputs_ge1 == outputs_ge2) |
| self.assertTrue(grad_ge1 == grad_ge2) |
| |
| def test_python_function(self): |
| class MyFn(Function): |
| @staticmethod |
| def forward(ctx, x): |
| return x + 1 |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| return grad_output |
| |
| @_trace(torch.zeros(2)) |
| def fn(x): |
| return MyFn.apply(x + 2) + 3 |
| |
| x = torch.tensor([1., 2., 3.]) |
| y = torch.randn(2, 2, requires_grad=True) |
| fn(x) |
| fn(y) |
| |
| def test_python_function_tup(self): |
| class MyFn(Function): |
| @staticmethod |
| def forward(ctx, x): |
| return x + 1, x - 1 |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| return grad_output, grad_output |
| |
| @_trace(torch.zeros(2)) |
| def fn(x): |
| a, b = MyFn.apply(x + 2) |
| return a + b + 3 |
| x = torch.tensor([1., 2., 3.]) |
| y = torch.randn(2, 2, requires_grad=True) |
| fn(x) |
| fn(y) |
| |
| def test_trace_detach(self): |
| def foo(x, w): |
| return torch.matmul(x, w).detach() |
| |
| traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5))) |
| |
| FileCheck().check("matmul").check("detach").run(str(traced.graph)) |
| x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True) |
| traced_result = traced(x, w) |
| self.assertEqual(foo(x, w), traced_result) |
| self.assertFalse(traced_result.requires_grad) |
| self.assertIsNone(traced_result.grad_fn) |
| |
| def test_trace_detach_redispatch(self): |
| def foo(x, w): |
| y = torch.matmul(x, w) |
| assert y.requires_grad |
| y = y.detach() |
| # Make sure trace kernel redispatches to the right lower kernel. |
| assert not y.requires_grad |
| return y |
| |
| x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True) |
| # With `check_trace=True` it will run with `@torch.no_grad()` and break assert. |
| torch.jit.trace(foo, (x, w), check_trace=False) |
| |
| def test_trace_detach_inplace(self): |
| def foo(x, w): |
| y = torch.matmul(x, w) |
| y.detach_() |
| return y |
| |
| traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5))) |
| |
| FileCheck().check("matmul").check("detach(").run(str(traced.graph)) |
| x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True) |
| traced_result = traced(x, w) |
| self.assertEqual(foo(x, w), traced_result) |
| self.assertFalse(traced_result.requires_grad) |
| self.assertIsNone(traced_result.grad_fn) |
| |
| def test_trace_detach_inplace_redispatch(self): |
| def foo(x, w): |
| y = torch.matmul(x, w) |
| assert y.requires_grad |
| y.detach_() |
| # Make sure trace kernel redispatches to the right lower kernel. |
| assert not y.requires_grad |
| return y |
| |
| x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True) |
| # With `check_trace=True` it will run with `@torch.no_grad()` and break assert. |
| torch.jit.trace(foo, (x, w), check_trace=False) |
| |
| def test_trace_detach_onnx_erase(self): |
| class Mod(torch.nn.Module): |
| def forward(self, x, w): |
| return torch.matmul(x, w).detach() |
| |
| torch.onnx.export_to_pretty_string( |
| Mod(), (torch.rand(3, 4), torch.rand(4, 5))) |
| |
| def test_trace_slice_full_dim(self): |
| def foo(x): |
| return x[0:5, 0] + 1.0 |
| |
| traced = torch.jit.trace(foo, (torch.rand(5, 4),)) |
| test_x = torch.rand(6, 3) |
| self.assertEqual(foo(test_x), traced(test_x)) |
| |
| def test_trace_dict_input(self): |
| class Bar(torch.nn.Module): |
| def __init__(self): |
| super(Bar, self).__init__() |
| self.foo = Foo() |
| |
| def forward(self, a, b): |
| return self.foo({'a': a, 'b': b})['a'] |
| |
| class Foo(torch.nn.Module): |
| def forward(self, x): |
| return {'a': x['a'] * x['b']} |
| |
| x = (torch.rand(3), torch.rand(3)) |
| model = Bar() |
| self.checkTrace(model, x) |
| |
| def test_trace_dict_output(self): |
| |
| class TraceDictStrTensor(torch.nn.Module): |
| def forward(self, a, b): |
| return {'a': a, 'b': b} |
| |
| class TraceDictTensorTensor(torch.nn.Module): |
| def forward(self, a, b): |
| return {a: b, b: a} |
| |
| x = (torch.rand(3), torch.rand(3)) |
| with self.assertRaisesRegex(RuntimeError, r"Encountering a dict at the output"): |
| torch.jit.trace(TraceDictStrTensor(), x) |
| |
| traced_dict_str_mod = torch.jit.trace(TraceDictStrTensor(), x, strict=False) |
| self.assertEqual(traced_dict_str_mod(*x), {'a': x[0], 'b': x[1]}) |
| |
| traced_dict_tensor_mod = torch.jit.trace(TraceDictTensorTensor(), x, strict=False) |
| self.assertEqual(traced_dict_tensor_mod(*x), {x[0]: x[1], x[1]: x[0]}) |
| |
| def test_trace_with_tensor_list_output(self): |
| def f(): |
| return [torch.zeros(1), torch.zeros(5)] |
| with self.assertWarnsRegex(torch.jit.TracerWarning, "cause the trace to be incorrect"): |
| torch.jit.trace(f, []) |
| traced_non_strict_f = torch.jit.trace(f, [], strict=False) |
| self.assertEqual(traced_non_strict_f(), f()) |
| |
| def test_trace_with_number_list_output(self): |
| def f(): |
| return [1, 5] |
| with self.assertRaisesRegex(RuntimeError, r"Only tensors.+can be output from traced functions"): |
| traced_f = torch.jit.trace(f, []) |
| |
| def test_trace_with_nested_tensor_list_output(self): |
| def f(): |
| return [[torch.zeros(1)], [torch.zeros(5)]] |
| with self.assertRaisesRegex(RuntimeError, r"Only tensors.+can be output from traced functions"): |
| traced_f = torch.jit.trace(f, []) |
| |
| def test_trace_variable_instantiation(self): |
| def random_foo(x): |
| return Variable(Variable(x) + 1.0) |
| |
| random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),)) |
| |
| x = torch.rand(5, 6) |
| self.assertEqual(random_foo(x), random_foo_traced(x)) |
| |
| def test_trace_slice_expr_complete_type(self): |
| def random_foo(x): |
| return x + 1.0 |
| |
| random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),)) |
| |
| @torch.jit.script |
| def random_bar(x): |
| return random_foo_traced(x)[0:1] |
| |
| x = torch.rand(3, 4) |
| self.assertEqual(random_bar(x), (x + 1)[0:1]) |
| |
| def test_trace_inline_shape(self): |
| # testing peephole optimization of size is turned into a constant |
| # in script fn |
| |
| @torch.jit.script |
| def tensor_size(x: torch.Tensor) -> torch.Tensor: |
| return torch.tensor([x.size()[0]]) |
| |
| self.assertEqual( |
| tensor_size(torch.rand(15,)), |
| torch.tensor([15]) |
| ) |
| |
| traced_tensor_size = torch.jit.trace(tensor_size, torch.rand(7,)) |
| |
| self.assertEqual( |
| traced_tensor_size(torch.rand(15,)), |
| torch.tensor([15]) |
| ) |
| |
| @torch.jit.script |
| def use_device(x): |
| return torch.zeros_like(x, device=x.device) |
| |
| def foo(x): |
| return use_device(x) |
| |
| traced_tensor_size = torch.jit.trace(foo, torch.rand(7,)) |
| self.run_pass('inline', traced_tensor_size.graph) |
| FileCheck().check("prim::device").run(traced_tensor_size.graph) |
| |
| def test_trace_save(self): |
| def fn(x): |
| return x + 2 |
| |
| def check(func): |
| with TemporaryFileName() as fname: |
| func.save(fname) |
| loaded = torch.jit.load(fname) |
| input = torch.randn(2, 2) |
| self.assertEqual(func(input), loaded(input)) |
| |
| out = torch.jit.trace(fn, (torch.ones(2, 2),)) |
| check(out) |
| |
| def test_trace_optioanl_dtype(self): |
| class Test(torch.nn.Module): |
| def forward(self): |
| return torch.arange(5) |
| |
| traced = torch.jit.trace(Test(), ()) |
| torch.allclose(traced(), Test()()) |
| |
| def test_trace_save_load_copy(self): |
| class Test(torch.nn.Module): |
| def __init__(self): |
| super(Test, self).__init__() |
| self.conv = torch.nn.Conv2d(3, 3, 3) |
| |
| def forward(self, x): |
| return self.conv(x) |
| |
| traced = torch.jit.trace(Test(), torch.rand(1, 3, 224, 224)) |
| buffer = io.BytesIO() |
| torch.jit.save(traced, buffer) |
| buffer.seek(0) |
| loaded = torch.jit.load(buffer) |
| # should work |
| copy.copy(loaded) |
| copy.deepcopy(loaded) |
| |
| def test_trace_export_fns(self): |
| class Foo(torch.nn.Module): |
| def __init__(self): |
| super(Foo, self).__init__() |
| self.a = 3 |
| |
| @torch.jit.export |
| def __getstate__(self): |
| return (3, self.training) |
| |
| @torch.jit.export |
| def __setstate__(self, state): |
| self.a = state[0] |
| self.training = state[1] |
| |
| def forward(self, x): |
| return x + self.a |
| |
| f = Foo() |
| |
| traced = torch.jit.trace(f, (torch.rand(3, 4),)) |
| expected_names = ['__getstate__', '__setstate__'] |
| |
| def check(mod): |
| self.assertTrue(all(name in mod._c._method_names() for name in expected_names)) |
| |
| check(traced) |
| |
| imported = self.getExportImportCopy(traced) |
| check(imported) |
| |
| def test_trace_export_fns_recursive(self): |
| class Foo(torch.nn.Module): |
| def __init__(self): |
| super(Foo, self).__init__() |
| self.a = 3 |
| |
| @torch.jit.export |
| def __getstate__(self): |
| return (3, self.training) |
| |
| @torch.jit.export |
| def __setstate__(self, state): |
| self.a = state[0] |
| self.training = state[1] |
| |
| def forward(self, x): |
| return x + self.a |
| |
| class Wrapper(torch.nn.Module): |
| def __init__(self): |
| super(Wrapper, self).__init__() |
| self.foo = Foo() |
| |
| def forward(self, x): |
| return self.foo(x) |
| |
| f = Wrapper() |
| |
| traced = torch.jit.trace(f, (torch.rand(3, 4),)) |
| expected_names = ['__getstate__', '__setstate__'] |
| |
| def check(mod): |
| self.assertTrue(all(name in mod._c._method_names() for name in expected_names)) |
| |
| check(traced.foo) |
| |
| imported = self.getExportImportCopy(traced) |
| check(imported.foo) |
| |
| # Note that Bar's forward can only be traced, but not scripted |
| class Bar(nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| @torch.jit.export |
| def addTwo(self, x): |
| return x + 2 |
| |
| def forward(self, input): |
| return (lambda a: a + 1)(input) |
| |
| # When tracing Bar as a submodule, we only want to script the |
| # exported methods, and we want to keep the forwards still |
| # being traced. |
| class WrapperExports(torch.nn.Module): |
| def __init__(self): |
| super(WrapperExports, self).__init__() |
| self.bar = Bar() |
| |
| @torch.jit.export |
| def addOne(self, x): |
| return x + 1 |
| |
| def forward(self, x): |
| return self.bar(x) |
| |
| f = WrapperExports() |
| |
| traced = torch.jit.trace(f, (torch.rand(3, 4),)) |
| expected_names = ['addOne'] |
| check(traced) |
| |
| def test_trace_autograd_function(self): |
| class TestFunc(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, input): |
| return torch.neg(input) |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| return torch.neg(grad_output) |
| |
| |
| class TracedModule(torch.nn.Module): |
| def forward(self, x): |
| return torch.relu(TestFunc.apply(x)) |
| |
| |
| class Wrapper(torch.nn.Module): |
| def __init__(self): |
| super(Wrapper, self).__init__() |
| self.tm = TracedModule() |
| |
| def forward(self, x): |
| return self.tm(x) |
| |
| traced = torch.jit.trace(Wrapper(), (torch.rand(3, 4),)) |
| |
| def test_trace_multi_output_function(self): |
| # An autograd.Function with two outputs. |
| # It swaps inputs so we can check if shape |
| # handling is correct in TorchScript. |
| class Foo(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x, y): |
| return y, x |
| |
| @staticmethod |
| def backward(ctx, du, dv): |
| return dv, du |
| |
| class Bar(torch.nn.Module): |
| def forward(self, x, y): |
| x = x.relu() |
| y = y.relu() |
| z = Foo.apply(x, y) |
| return z |
| |
| x = torch.rand(3, 2, dtype=torch.double) |
| y = torch.rand(1, 2, dtype=torch.double) |
| |
| # Generate JIT IR. |
| traced = torch.jit.trace(Bar(), (x, y)) |
| print(traced.graph) |
| |
| # Expected output schema of the custom autograd.Function. |
| schema = '(Double(1, 2, strides=[2, 1], requires_grad=0, device=cpu), '\ |
| 'Double(3, 2, strides=[2, 1], requires_grad=0, device=cpu)) '\ |
| '= ^Foo' |
| |
| # See if expected schema exists. |
| FileCheck().check(schema).run(traced.graph) |
| |
| # Also examine if the graph is runnable and produces |
| # the right result. |
| u, v = traced(x, y) |
| self.assertEqual(u, y) |
| self.assertEqual(v, x) |
| |
| def test_interpolate_trace(self): |
| class test(nn.Module): |
| def __init__(self): |
| super(test, self).__init__() |
| self.conv = nn.Conv2d(1, 32, kernel_size=3, padding=1) |
| |
| def forward(self, x): |
| y = self.conv(x) |
| w = nn.functional.interpolate(y, mode='bilinear', align_corners=False, scale_factor=3) |
| return w |
| |
| f = test() |
| # no failure |
| g = torch.jit.trace(f, (torch.zeros(1, 1, 28, 28),)) |
| x = torch.zeros(1, 1, 14, 14) |
| # constants not baked in |
| self.assertEqual(g(x), f(x)) |
| |
| @_tmp_donotuse_dont_inline_everything |
| def test_trace_optional(self): |
| @torch.jit.script |
| def test(x: Optional[Tensor]): |
| if x is None: |
| return torch.zeros(1) |
| else: |
| return x |
| |
| def test_none(): |
| return test(None) |
| |
| def test_tensor(): |
| return test(torch.zeros(2)) |
| |
| f_none = torch.jit.trace(test_none, ()) |
| self.assertEqual(f_none(), torch.zeros(1)) |
| |
| f_tensor = torch.jit.trace(test_tensor, ()) |
| self.assertEqual(f_tensor(), torch.zeros(2)) |
| |
| graph = f_tensor.graph |
| FileCheck().check('name="test"').check_next("prim::CallFunction").run(graph) |
| |
| def test_trace_nested_datatypes(self): |
| @torch.jit.script |
| def foo(x): |
| return [[x + 1, x - 1], [x + 2, x - 2]] |
| |
| def bar(x): |
| list_stuff = foo(x) |
| return list_stuff[0][0], list_stuff[1][1] |
| |
| traced = torch.jit.trace(bar, torch.rand(3, 4)) |
| x = torch.rand(5, 6) |
| self.assertEqual(bar(x), traced(x)) |
| |
| @_tmp_donotuse_dont_inline_everything |
| def test_call_traced_fn_from_traced_module(self): |
| @_trace(torch.rand(3, 4)) |
| def traced_fn(x): |
| return torch.neg(x) |
| |
| class TracedModule(torch.nn.Module): |
| def __init__(self): |
| super(TracedModule, self).__init__() |
| self.param = torch.nn.Parameter(torch.rand(4, 5)) |
| |
| def forward(self, x): |
| return traced_fn(torch.mm(x, self.param)) |
| |
| tm = torch.jit.trace(TracedModule(), torch.rand(3, 4)) |
| |
| # Note: neg op from the traced function should be properly inlined |
| FileCheck().check("aten::mm") \ |
| .check('name="traced_fn"') \ |
| .check_next("prim::CallFunction") \ |
| .run(str(tm.graph)) |
| |
| @_tmp_donotuse_dont_inline_everything |
| def test_call_traced_module_from_traced_module(self): |
| class TracedModule1(torch.nn.Module): |
| def __init__(self): |
| super(TracedModule1, self).__init__() |
| self.param = torch.nn.Parameter(torch.rand(5, 7)) |
| |
| def forward(self, x): |
| return torch.mm(x, self.param) |
| |
| class TracedModule(torch.nn.Module): |
| def __init__(self): |
| super(TracedModule, self).__init__() |
| self.param = torch.nn.Parameter(torch.rand(4, 5)) |
| self.mod = torch.jit.trace(TracedModule1(), torch.rand(3, 5)) |
| |
| def forward(self, x): |
| return self.mod(torch.mm(x, self.param)) + 1.0 |
| |
| tm = torch.jit.trace(TracedModule(), torch.rand(3, 4)) |
| |
| FileCheck().check("aten::mm").check("prim::CallMethod").check_same("forward").check("aten::add").run(str(tm.graph)) |
| |
| def test_index_put_trace_with_view(self): |
| @_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(1, 1, 1, 4)) |
| def test_index_put(target, indices, rhs): |
| target[indices] = rhs |
| return target |
| |
| FileCheck().check("aten::view").check("index_put_").run(str(test_index_put.graph)) |
| |
| def test_index_put_trace_without_view(self): |
| @_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(4)) |
| def test_index_put(target, indices, rhs): |
| target[indices] = rhs |
| return target |
| |
| FileCheck().check_not("aten::view").check("index_put_").run(str(test_index_put.graph)) |
| |
| @suppress_warnings |
| def test_trace_checker_dot_data(self): |
| with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Tensor-valued Constant nodes differed in value ' |
| r'across invocations'): |
| @_trace(torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)]) |
| def foo(x): |
| y = x.data |
| return x + y |
| |
| @suppress_warnings |
| def test_trace_checker_control_flow(self): |
| def foo(x): |
| for _ in range(x.size(0)): |
| x = torch.neg(x) |
| return x |
| |
| with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'): |
| torch.jit.trace(foo, torch.randn(3, 4), check_inputs=[torch.randn(4, 4)]) |
| |
| @suppress_warnings |
| def test_trace_checker_memoization(self): |
| with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'): |
| def foo(x): |
| if not hasattr(foo, 'cache'): |
| foo.cache = torch.neg(x) |
| return x + foo.cache |
| |
| traced = torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)]) |
| |
| def test_trace_checker_slice_lhs(self): |
| def foo(x): |
| for i in range(3): |
| x[i, :] = torch.zeros(4) |
| return x |
| |
| self.checkTrace(foo, (torch.rand(3, 4),), inputs_require_grads=False) |
| |
| def test_trace_checker_inplace_on_view(self): |
| def foo(x): |
| x.view(-1).add_(-x.view(-1)) |
| return x |
| |
| with self.assertWarnsRegex(torch.jit.TracerWarning, |
| 'Output nr 1. of the traced function does not match the ' |
| 'corresponding output of the Python function'): |
| torch.jit.trace(foo, |
| torch.rand(3, 4), |
| check_inputs=[torch.rand(5, 6)], |
| _force_outplace=True) |
| |
| def test_lhs_index_fails(self): |
| def foo(x): |
| x[0, 1] = 4 |
| return x |
| |
| with self.assertWarnsRegex(torch.jit.TracerWarning, "cause the trace to be incorrect"): |
| torch.jit.trace(foo, torch.rand(3, 4), _force_outplace=True) |
| |
| def test_lhs_index_trivial(self): |
| def foo(y, x): |
| y[...] = x |
| return y |
| self.checkTrace(foo, (torch.rand(3, 4), torch.rand(4)), inputs_require_grads=False) |
| |
| def test_inplace_warn(self): |
| def foo(x): |
| x.view(-1).add_(-x.view(-1)) |
| return x |
| |
| with self.assertWarnsRegex(torch.jit.TracerWarning, "cause the trace to be incorrect"): |
| torch.jit.trace(foo, torch.rand(3, 4), _force_outplace=True) |
| |
| @suppress_warnings |
| def test_trace_checker_dropout_train(self): |
| def foo(x): |
| return torch.dropout(x, p=0.5, train=True) |
| |
| with self.assertWarnsRegex(torch.jit.TracerWarning, |
| 'Output nr 1. of the traced function does not match the ' |
| 'corresponding output of the Python function'): |
| torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)]) |
| |
| with self.assertWarnsRegex(torch.jit.TracerWarning, |
| 'Trace had nondeterministic nodes'): |
| torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)]) |
| |
| def test_trace_checker_dropout_notrain(self): |
| input = torch.rand(3, 4) |
| |
| @_trace(input) |
| def foo(x): |
| return torch.dropout(x, p=0.5, train=False) |
| |
| self.assertEqual(foo(input), input) |
| |
| def test_trace_contiguous(self): |
| def foo(x): |
| return x[:, :, ::2].contiguous().view(12) |
| |
| x = torch.rand(2, 3, 4) |
| traced = torch.jit.trace(foo, (x,)) |
| y = traced(x) |
| self.assertNotEqual(x.storage().data_ptr(), y.storage().data_ptr()) |
| |
| # This tests the logic in THPVariable_contiguous. There is short-circuiting |
| # code that prevents us from even getting to VariableType::contiguous, since |
| # it is an optimization that prevents us from acquiring the GIL for touching |
| # the device. We needed to add the tracing logic directly into the |
| # THPVariable_contiguous function only for the path where we are skipping |
| # dispatch into contiguous. We should see an aten::contiguous in this trace! |
| def test_trace_contiguous_short_circuit(self): |
| def foo(x): |
| return x.contiguous() |
| |
| x = torch.rand(2, 3, 4) |
| traced = torch.jit.trace(foo, (x,)) |
| FileCheck().check("aten::contiguous").run(str(traced.graph)) |
| |
| def test_trace_inverse(self): |
| def foo(x): |
| return ~x |
| |
| foo_traced = torch.jit.trace(foo, torch.zeros(3, 4, dtype=torch.uint8)) |
| eg = torch.zeros(3, dtype=torch.uint8) |
| self.assertEqual(foo_traced(eg), foo(eg)) |
| |
| def test_trace_modulelist(self): |
| class MySubmod(torch.nn.Module): |
| def __init__(self): |
| super(MySubmod, self).__init__() |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| return self.relu(x) |
| |
| class MyMod(torch.nn.Module): |
| def __init__(self): |
| super(MyMod, self).__init__() |
| self.ml = torch.nn.ModuleList([ |
| MySubmod(), |
| MySubmod() |
| ]) |
| |
| def forward(self, x): |
| for mod in self.ml: |
| x = mod(x) |
| return x |
| |
| traced = torch.jit.trace(MyMod(), (torch.rand(3, 4),)) |
| |
| def test_trace_fork_join_and_module(self): |
| class MySubmod(torch.nn.Module): |
| def __init__(self): |
| super(MySubmod, self).__init__() |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| return self.relu(x), torch.neg(x) |
| |
| class Mod(torch.nn.Module): |
| def __init__(self): |
| super(Mod, self).__init__() |
| self.ml = torch.nn.ModuleList([ |
| MySubmod() for i in range(2) |
| ]) |
| |
| def forward(self, x): |
| futs = [] |
| for i in range(2): |
| futs.append(torch.jit._fork(self.ml[i], x)) |
| |
| results = [] |
| for i in range(2): |
| results.append(torch.jit._wait(futs[i])[0]) |
| |
| return torch.stack(results) |
| |
| m = Mod() |
| traced = torch.jit.trace(m, torch.rand(3, 4)) |
| |
| def test_trace_invert_module_hierarchy(self): |
| class MySubmod(torch.nn.Module): |
| def __init__(self): |
| super(MySubmod, self).__init__() |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| return self.relu(x), torch.neg(x) |
| |
| class MyFunctionalMod(torch.nn.Module): |
| def forward(self, x, submod): |
| return submod(x) |
| |
| class Mod(torch.nn.Module): |
| def __init__(self): |
| super(Mod, self).__init__() |
| self.sm = MySubmod() |
| self.fm = MyFunctionalMod() |
| |
| def forward(self, x): |
| return self.fm(x, self.sm) |
| |
| torch.jit.trace(Mod(), (torch.rand(3, 4),)) |
| |
| @skipIfCrossRef |
| def test_trace_records_names(self): |
| def foo(bar, baz): |
| baz = bar + 3 |
| quick_brown_fox = torch.neg(baz) |
| for _ in range(20): |
| yeet = quick_brown_fox - 3.14 |
| return yeet |
| |
| traced = torch.jit.trace(foo, (torch.rand(3, 3), torch.rand(3, 3))) |
| graph_str = str(traced.graph) |
| assert 'bar' in graph_str |
| assert 'baz' in graph_str |
| assert 'quick_brown_fox' in graph_str |
| |
| @skipIfTorchDynamo("Not a suitable test for TorchDynamo") |
| def test_tracing_hooks(self): |
| class Net(nn.Module): |
| def __init__(self): |
| super(Net, self).__init__() |
| |
| def forward(self, x): |
| return x + x |
| |
| def test_hook(is_post_hook, hook, fc): |
| n = Net() |
| if is_post_hook: |
| n.register_forward_hook(hook) |
| else: |
| n.register_forward_pre_hook(hook) |
| |
| module = torch.jit.trace(n, (torch.tensor(1.0),)) |
| |
| eager_input = torch.tensor(1.0) |
| eager_out = n(eager_input) |
| |
| fc.run(module.forward.graph) |
| input = torch.tensor(1.0) |
| output = module(input) |
| |
| self.assertEqual(input, eager_input) |
| self.assertEqual(output, eager_out) |
| |
| def hook_no_return(mod, input, output): |
| input[0].add_(1) |
| output.sub_(1) |
| |
| fc = FileCheck().check("add(").check("add_(").check("sub_(") |
| test_hook(True, hook_no_return, fc) |
| |
| def hook_return(mod, input, output): |
| input[0].add_(1) |
| return output - 3 |
| |
| fc = FileCheck().check("add(").check("add_(").check("sub(") |
| test_hook(True, hook_return, fc) |
| |
| b = torch.tensor(3.0) |
| |
| def captured_hook(mod, input, output): |
| return output - b |
| |
| fc = FileCheck().check("add(").check("sub(") |
| test_hook(True, captured_hook, fc) |
| |
| def pre_hook_no_ret(mod, input): |
| input[0].add_(3) |
| |
| fc = FileCheck().check("add_(").check("add(") |
| test_hook(False, pre_hook_no_ret, fc) |
| |
| def pre_hook_ret(mod, input): |
| return input[0] - 4 |
| |
| fc = FileCheck().check("sub(").check("add(") |
| test_hook(False, pre_hook_ret, fc) |
| |
| def test_tracing_backward_hook_error(self): |
| class Net(nn.Module): |
| def __init__(self): |
| super(Net, self).__init__() |
| |
| def forward(self, x): |
| return x + x |
| |
| n = Net() |
| |
| def backward_hook(module, grad_input, grad_output): |
| pass |
| |
| n.register_backward_hook(backward_hook) |
| with self.assertRaisesRegex(Exception, "backward hooks assigned"): |
| torch.jit.trace(n, (torch.tensor(1.0),)) |
| |
| def test_tracing_multiple_methods(self): |
| class Net(nn.Module): |
| def __init__(self): |
| super(Net, self).__init__() |
| self.conv = nn.Conv2d(1, 1, 3) |
| |
| def forward(self, x): |
| return self.conv(x) |
| |
| def weighted_kernel_sum(self, weight): |
| return weight * self.conv.weight |
| |
| example_weight = torch.rand(1, 1, 3, 3) |
| example_forward_input = torch.rand(1, 1, 3, 3) |
| inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight} |
| n = Net() |
| module = torch.jit.trace_module(n, inputs) |
| |
| check_inputs = [] |
| for i in range(2): |
| check_weight = torch.rand(1, 1, 3, 3) |
| check_forward_input = torch.rand(1, 1, 3, 3) |
| check_inputs.append({'forward' : check_forward_input, 'weighted_kernel_sum' : check_weight}) |
| module = torch.jit.trace_module(n, inputs, check_trace=True, check_inputs=check_inputs) |
| self.assertTrue(module._c._has_method("forward")) |
| self.assertTrue(module._c._has_method("weighted_kernel_sum")) |
| |
| module = torch.jit.trace(n.forward, example_forward_input) |
| module = torch.jit.trace(n.forward, example_forward_input, check_trace=True, check_inputs=[example_forward_input]) |
| with self.assertRaisesRegex(AttributeError, "trace doesn't support compiling individual module's functions"): |
| module = torch.jit.trace(n.weighted_kernel_sum, inputs) |
| |
| def test_tensor_with_grad_as_constant(self): |
| param = torch.randn(3).requires_grad_() |
| x = torch.randn(3) |
| |
| def f(x): |
| return x + param |
| with self.assertRaisesRegex(RuntimeError, "Cannot insert a Tensor that requires grad as a constant"): |
| torch.jit.trace(f, x) |
| |
| def test_non_tensor_tracing(self): |
| def f(x): |
| return x + param |
| with self.assertRaisesRegex(RuntimeError, r"Type 'Tuple\[int\]' cannot be traced"): |
| torch.jit.trace(f, (1,)) |
| |
| def test_trace_skip_none_submodule(self): |
| class TestModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.submod = torch.nn.Linear(3, 4) |
| self.submod = None |
| |
| def forward(self, inputs): |
| return inputs |
| |
| m = TestModule() |
| tm = torch.jit.trace(m, torch.tensor(1.)) |
| self.assertFalse(hasattr(tm, "submod")) |
| |
| def test_trace_with_conditional_property(self): |
| class Net(nn.Module): |
| def __init__(self, attr=None): |
| super(Net, self).__init__() |
| if attr is not None: |
| self._attr = attr |
| self.attr_name = '_attr' |
| |
| @property |
| def attr(self): |
| return getattr(self, self.attr_name) |
| |
| def forward(self, x): |
| return x |
| |
| x = torch.ones(1) |
| torch.jit.trace(Net(), x) |
| |
| def test_trace_func_argument_names_captured(self): |
| def fn(first_arg: torch.Tensor, second_arg: torch.Tensor) -> torch.Tensor: |
| return first_arg + second_arg |
| |
| traced_fn = torch.jit.trace(fn, (torch.ones(1), torch.ones(1))) |
| FileCheck().check("first_arg").check_next("second_arg") \ |
| .run(str(traced_fn.graph)) |
| |
| def test_trace_partial_func_argument_names_captured(self): |
| def fn(first_arg: torch.Tensor, second_arg=1) -> torch.Tensor: |
| return first_arg + second_arg |
| |
| traced_fn = torch.jit.trace(fn, (torch.ones(1),)) |
| FileCheck().check("first_arg").check_not("second_arg") \ |
| .run(str(traced_fn.graph)) |
| |
| def test_trace_module_argument_names_captured(self): |
| class TestModule(nn.Module): |
| def __init__(self): |
| super(TestModule, self).__init__() |
| self.conv = nn.Conv2d(1, 1, 3) |
| |
| def forward(self, first_arg: torch.Tensor, second_arg: torch.Tensor): |
| return self.conv(first_arg) + second_arg |
| |
| m = TestModule() |
| example_input = (torch.ones(1, 1, 3, 3), torch.ones(1, 1, 3, 3)) |
| |
| # Explicitly tracing module's forward method |
| traced_module_forward = torch.jit.trace(m.forward, example_input) |
| FileCheck().check("first_arg").check_next("second_arg") \ |
| .run(str(traced_module_forward.graph)) |
| |
| # Tracing module's directly |
| traced_module = torch.jit.trace(m, example_input) |
| FileCheck().check("first_arg").check_next("second_arg") \ |
| .run(str(traced_module.graph)) |
| |
| |
| class TestMixTracingScripting(JitTestCase): |
| def test_trace_script(self): |
| @torch.jit.script |
| def func1(x: Tuple[Tensor, Tensor]) -> Tensor: |
| return x[0] + x[1] |
| |
| @torch.jit.script |
| def func2(x: List[Tensor]) -> Tensor: |
| return x[0] + x[1] |
| |
| a = torch.randn(5) |
| b = torch.randn(5) |
| |
| self.checkTrace(func1, ((a, b),)) |
| self.checkTrace(func2, ((a, b),)) |
| |
| @torch.jit.script |
| def func3(x: Tensor, method: str = 'bilinear', align_corners: bool = True) -> Tensor: |
| hw = x.shape[2:4] |
| return F.interpolate(x, hw, mode=method, align_corners=align_corners) |
| |
| inp = torch.rand(1, 3, 6, 6) |
| self.checkTrace(func3, (inp,)) |
| |
| @torch.jit.script |
| def func4(x: Tensor, a: List[Optional[str]]) -> Tensor: |
| if len(a) == 2: |
| return x + 2 |
| else: |
| return x |
| |
| def test_trace_mixed_by_script_with_dict_output(self): |
| @torch.jit.script |
| def return_dict(input: torch.Tensor) -> Dict[str, torch.Tensor]: |
| return {"foo" : input + 1} |
| |
| class TraceModule(torch.nn.Module): |
| def forward(self, input): |
| dict = return_dict(input) |
| return dict["foo"] + dict["foo"] |
| |
| x = torch.ones(1) |
| tm = torch.jit.trace(TraceModule(), x) |
| self.assertEqual(tm(x), x + 1 + x + 1) |
| |
| def test_trace_of_script(self): |
| @torch.jit.script |
| def foo(a, c): |
| b = 0.0 |
| if bool(a == 0.0): |
| b = 1.0 |
| return b + c |
| |
| a = torch.ones(1, dtype=torch.float) |
| |
| @_trace(torch.zeros(1, dtype=torch.float)) |
| def use(b): |
| return foo(b - 1.0, a) + 1.0 |
| |
| # test we propagated shapes through the function |
| self.assertTrue("Dynamic" not in str(use.graph)) |
| |
| self.assertEqual(3, use(torch.ones(1, dtype=torch.float))) |
| self.assertEqual(2, use(torch.zeros(1, dtype=torch.float))) |
| |
| def test_trace_with_size(self): |
| @_trace(torch.zeros(1, 1)) |
| def foo(x): |
| return x + 1 |
| |
| @torch.jit.script |
| def bar(x): |
| y = int(foo(x)) |
| if 1 == 1: |
| y = 7 |
| return y + 1 |
| |
| self.assertEqual(8, bar(torch.ones(1, 1))) |
| |
| def test_tracing_slicing(self): |
| @_trace(torch.zeros(10)) |
| def foo_trace(x): |
| return x[-5:-3] |
| |
| @torch.jit.script |
| def foo_script(x): |
| return x[-5:-3] |
| |
| def foo(x): |
| return x[-5:-3] |
| |
| a = torch.arange(0, 8) |
| b = torch.arange(0, 20) |
| self.assertEqual(foo_trace(a), foo_script(a)) |
| self.assertEqual(foo_trace(a), foo(a)) |
| self.assertNotEqual(foo_trace(a), foo_trace(b)) |
| |
| def test_tracing_indexing(self): |
| @_trace(torch.zeros(10)) |
| def foo_trace(x): |
| return x[-2] |
| |
| @torch.jit.script |
| def foo_script(x): |
| return x[-2] |
| |
| def foo(x): |
| return x[-2] |
| |
| a = torch.arange(0, 8) |
| b = torch.arange(0, 20) |
| self.assertEqual(foo_script(a), foo_trace(a)) |
| self.assertEqual(foo_trace(a), foo(a)) |
| self.assertNotEqual(foo_trace(a), foo_trace(b)) |
| |
| def test_trace_hierarchy(self): |
| # Test that we preserve the module hierarchy for a ScriptModule |
| # submodule during tracing |
| |
| class AnotherScriptMod(torch.jit.ScriptModule): |
| def __init__(self): |
| super(AnotherScriptMod, self).__init__() |
| self.param = torch.nn.Parameter(torch.rand(1, 2, 3)) |
| |
| @torch.jit.script_method |
| def bar(self): |
| return torch.zeros(4, 5) |
| |
| class SomeScriptMod(torch.jit.ScriptModule): |
| def __init__(self): |
| super(SomeScriptMod, self).__init__() |
| self.asm = AnotherScriptMod() |
| |
| @torch.jit.script_method |
| def foo(self): |
| return torch.zeros(3, 4) |
| |
| @torch.jit.script_method |
| def bar(self): |
| return torch.zeros(4, 3) |
| |
| class TraceMe(torch.nn.Module): |
| def __init__(self): |
| super(TraceMe, self).__init__() |
| self.ssm = SomeScriptMod() |
| |
| def forward(self, x): |
| return self.ssm.bar() + x |
| |
| orig = TraceMe() |
| traced = torch.jit.trace(orig, (torch.rand(4, 3),)) |
| # for each of these checks, check that *BOTH* the underlying |
| # _C.ScriptModule object has the expected method/param, as well as the |
| # Python object that wraps it. |
| self.assertTrue(traced.ssm._c._has_method('foo')) |
| self.assertTrue(hasattr(traced.ssm, 'foo')) |
| |
| imported = self.getExportImportCopy(traced) |
| |
| self.assertTrue(imported.ssm._c._has_method('foo')) |
| self.assertTrue(hasattr(imported.ssm, 'foo')) |
| |
| self.assertTrue(imported.ssm.asm._c._has_method('bar')) |
| self.assertTrue(hasattr(imported.ssm.asm, 'bar')) |
| |
| self.assertTrue(hasattr(imported.ssm.asm, 'param')) |
| |
| def test_trace_parameter(self): |
| class Param(nn.Module): |
| def __init__(self): |
| super(Param, self).__init__() |
| self.register_parameter("bias", nn.Parameter(torch.empty(4, 4))) |
| |
| def forward(self, x): |
| return x |
| |
| class M3(torch.jit.ScriptModule): |
| def __init__(self, model): |
| super(M3, self).__init__() |
| self.traced = torch.jit.trace(model, (torch.rand(3, 3))) |
| |
| @torch.jit.script_method |
| def forward(self, x): |
| return self.traced(x) |
| |
| class M2(nn.Module): |
| def __init__(self, model): |
| super(M2, self).__init__() |
| self.module = M3(model) |
| |
| def forward(self, x): |
| return self.module(x) |
| |
| class M1(torch.jit.ScriptModule): |
| def __init__(self, model): |
| super(M1, self).__init__() |
| self.traced = torch.jit.trace(M2(model), (torch.rand(3, 3))) |
| |
| @torch.jit.script_method |
| def forward(self, x): |
| return self.traced(x) |
| |
| with torch.jit.optimized_execution(False): |
| module = M1(Param()) |
| f = io.BytesIO() |
| torch.jit.save(module, f) |
| |
| @_tmp_donotuse_dont_inline_everything |
| def test_call_script_fn_from_traced_module(self): |
| @torch.jit.script |
| def scripted_fn(x): |
| return torch.neg(x) |
| |
| class TracedModule(torch.nn.Module): |
| def __init__(self): |
| super(TracedModule, self).__init__() |
| self.param = torch.nn.Parameter(torch.rand(4, 5)) |
| |
| def forward(self, x): |
| return scripted_fn(torch.mm(x, self.param)) |
| |
| tm = torch.jit.trace(TracedModule(), torch.rand(3, 4)) |
| FileCheck().check("aten::mm").check("name=\"scripted_fn\"").check("prim::CallFunction").run(str(tm.graph)) |
| |
| @_tmp_donotuse_dont_inline_everything |
| def test_call_script_module_from_traced_module(self): |
| class ScriptMod(torch.jit.ScriptModule): |
| def __init__(self): |
| super(ScriptMod, self).__init__() |
| self.param_foo = torch.nn.Parameter(torch.rand(5, 7)) |
| |
| @torch.jit.script_method |
| def forward(self, x): |
| return torch.mm(x, self.param_foo) |
| |
| class TracedModule(torch.nn.Module): |
| def __init__(self): |
| super(TracedModule, self).__init__() |
| self.param = torch.nn.Parameter(torch.rand(4, 5)) |
| self.mod = ScriptMod() |
| |
| def forward(self, x): |
| return self.mod(torch.mm(x, self.param)) + 1.0 |
| |
| tm = torch.jit.trace(TracedModule(), torch.rand(3, 4)) |
| |
| FileCheck().check("aten::mm").check("prim::CallMethod").check_same("forward").check("aten::add").run(str(tm.graph)) |
| |
| @_tmp_donotuse_dont_inline_everything |
| def test_call_traced_fn_from_script_fn(self): |
| @_trace(torch.rand(3, 4)) |
| def traced_fn(x): |
| return torch.neg(x) |
| |
| @torch.jit.script |
| def script_fn(x): |
| return traced_fn(x) + 1 |
| |
| FileCheck().check("prim::CallFunction").check("aten::add").run(str(script_fn.graph)) |
| |
| def test_call_traced_mod_from_script_fn(self): |
| with self.assertRaisesRegex(RuntimeError, "Cannot call a ScriptModule that is not a submodule of the caller"): |
| class TracedModule(torch.nn.Module): |
| def __init__(self): |
| super(TracedModule, self).__init__() |
| |
| def forward(self, x): |
| return torch.mm(x, torch.zeros(4, 3)) |
| |
| tm = torch.jit.trace(TracedModule(), torch.rand(3, 4)) |
| |
| @torch.jit.script |
| def script_fn(x): |
| return tm(x) + 1 |
| |
| @_tmp_donotuse_dont_inline_everything |
| def test_call_tracing_fn_from_script_module(self): |
| @_trace(torch.rand(3, 3)) |
| def traced_fn(x): |
| return torch.neg(x) |
| |
| class ScriptMod(torch.jit.ScriptModule): |
| def __init__(self): |
| super(ScriptMod, self).__init__() |
| self.param = torch.nn.Parameter(torch.rand(4, 3)) |
| |
| @torch.jit.script_method |
| def forward(self, x): |
| return traced_fn(torch.mm(x, self.param)) |
| |
| sm = ScriptMod() |
| FileCheck().check("aten::mm").check("prim::CallFunction").run(str(sm.forward.graph)) |
| |
| @_tmp_donotuse_dont_inline_everything |
| def test_call_tracing_mod_from_script_module(self): |
| class TracedMod(torch.nn.Module): |
| def __init__(self): |
| super(TracedMod, self).__init__() |
| self.param = torch.nn.Parameter(torch.rand(3, 5)) |
| |
| def forward(self, x): |
| return torch.mm(x, self.param) |
| |
| class ScriptMod(torch.jit.ScriptModule): |
| def __init__(self): |
| super(ScriptMod, self).__init__() |
| self.param = torch.nn.Parameter(torch.rand(4, 3)) |
| self.tm = torch.jit.trace(TracedMod(), torch.rand(3, 3)) |
| |
| @torch.jit.script_method |
| def forward(self, x): |
| return self.tm(torch.mm(x, self.param)) |
| |
| sm = ScriptMod() |
| FileCheck().check("aten::mm").check("prim::CallMethod").run(str(sm.graph)) |
| |
| def test_script_inline_trace_multiple_args(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| |
| def forward(self, input, input2): |
| return input + input2 |
| |
| class M2(torch.jit.ScriptModule): |
| def __init__(self): |
| super(M2, self).__init__() |
| self.m = torch.jit.trace(M(), (torch.zeros(4, 3), torch.zeros(4, 3))) |
| |
| @torch.jit.script_method |
| def forward(self, inp): |
| return self.m(inp, inp) |
| |
| with torch.jit.optimized_execution(False): |
| m2 = M2() |
| m2(torch.zeros(4, 3)) |
| |
| def test_trace_dict_mix_script(self): |
| class testB(torch.nn.Module): |
| def __init__(self): |
| super(testB, self).__init__() |
| self.linear = torch.nn.Linear(2, 2) |
| |
| def forward(self, feature_map: Dict[str, List[Tensor]]) -> Tensor: |
| output = [] |
| for i, j in feature_map.items(): |
| output.append(self.linear(j[0])) |
| |
| return torch.stack(output) |
| |
| class testA(torch.nn.Module): |
| def __init__(self): |
| super(testA, self).__init__() |
| self.b = torch.jit.script(testB()) |
| |
| def forward(self, input_map: Dict[str, List[Tensor]]) -> Tensor: |
| feature_map = {} |
| for i, j in input_map.items(): |
| feature_map[i] = [j[0]] |
| |
| return self.b(feature_map) |
| |
| input_map = {"1" : [torch.rand(2, 2), torch.rand(2, 2)], "3" : [torch.rand(2, 2), torch.rand(2, 2)]} |
| model = testA() |
| traced_model = torch.jit.trace(model, input_map) |
| new_input_map = {"1" : [torch.rand(2, 2), torch.randn(2, 2)], "3" : [torch.rand(2, 2), torch.rand(2, 2)]} |
| self.assertEqual(model(new_input_map), traced_model(new_input_map)) |
| |
| def test_trace_script_returning_complex_dict(self): |
| """Tracing over a script function returning a dictionary should work. |
| The dictionary can should be able to contain other containers (like a tuple) recursively. |
| """ |
| class ReturnsDict(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward( |
| self, id_score_list: Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] |
| ) -> Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: |
| # do some random operations and then return a dict of the same structure |
| v = id_score_list["1000"] |
| idx_keys = v[1] - 1500000 |
| weights = v[2] |
| result = { |
| "1000": (v[0], idx_keys, weights) |
| } |
| return result |
| |
| class ChecksDict(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, input: Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]): |
| v = input["1000"] |
| return v[1] + 1 |
| |
| class TestModule(torch.nn.Module): |
| def __init__(self, checks_dict, returns_dict): |
| super().__init__() |
| self.checks_dict = checks_dict |
| self.returns_dict = returns_dict |
| |
| def forward(self, input: Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]): |
| foo = self.returns_dict(input) |
| return self.checks_dict(foo) |
| |
| input1 = { |
| "1000": ( |
| torch.tensor([0]), |
| torch.tensor([], dtype=torch.int64), |
| torch.tensor([]) |
| ) |
| } |
| |
| input2 = { |
| "1000": ( |
| torch.tensor([0]), |
| torch.tensor([1500000, 1500004], dtype=torch.int64), |
| torch.tensor([2.0, 3.0]) |
| ) |
| } |
| |
| checks_dict = torch.jit.script(ChecksDict()) |
| returns_dict = torch.jit.script(ReturnsDict()) |
| eager_module = TestModule(checks_dict, returns_dict) |
| traced_module = torch.jit.trace(eager_module, input1) |
| self.assertEqual(traced_module(input1), eager_module(input1)) |
| self.assertEqual(traced_module(input2), eager_module(input2)) |
| |
| def test_trace_returning_dict_with_tensor_tuples(self): |
| """Tracing over a module returning a dictionary whose values are tuples of tensors |
| should work. |
| """ |
| class ReturnsDict(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward( |
| self, k: torch.Tensor, v: torch.Tensor |
| ) -> Dict[str, Tuple[torch.Tensor, torch.Tensor]]: |
| x = 2 * k |
| y = 3 * v |
| result = { |
| "imakey": (x, y) |
| } |
| return result |
| |
| class ReturnsBadDict(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward( |
| self, k: torch.Tensor, v: torch.Tensor |
| ) -> Dict[str, Tuple[torch.Tensor, float]]: |
| x = 2 * k |
| result = { |
| "imakey": (x, 1) |
| } |
| return result |
| |
| mod = ReturnsDict() |
| traced_module = torch.jit.trace(mod, [torch.ones(1), torch.ones(1)], strict=False) |
| out = traced_module(torch.ones(1), torch.ones(1)) |
| expected = { |
| "imakey": (torch.tensor([2.]), torch.tensor([3.])) |
| } |
| self.assertEqual(out, expected) |
| |
| with self.assertRaisesRegex(RuntimeError, "cannot be understood by the tracer, only outputs matching"): |
| mod = ReturnsBadDict() |
| traced_module = torch.jit.trace(mod, [torch.ones(1), torch.ones(1)], strict=False) |
| |
| def test_trace_linear(self): |
| m = torch.nn.Linear(20, 20) |
| inp = torch.rand([20, 20]) |
| self.checkTrace(m, (inp,)) |
| g = torch.jit.trace(m, (inp,)).graph |
| FileCheck().check("aten::linear").run(g) |
| |
| def test_traced_module_implements_interface(self): |
| @torch.jit.interface |
| class TestModuleInterface(nn.Module): |
| def forward(self, first_arg: torch.Tensor, second_arg: torch.Tensor) -> torch.Tensor: |
| pass |
| |
| make_global(TestModuleInterface) |
| |
| class TestModule(nn.Module): |
| def __init__(self): |
| super(TestModule, self).__init__() |
| self.conv = nn.Conv2d(1, 1, 3) |
| |
| def forward(self, first_arg: torch.Tensor, second_arg: torch.Tensor) -> torch.Tensor: |
| return self.conv(first_arg) + second_arg |
| |
| def fn_takes_interface(x: TestModuleInterface): |
| ones = torch.ones(1, 1, 3, 3) |
| return x.forward(ones, ones) |
| |
| scripted_test_module = torch.jit.script(TestModule()) |
| self.checkScript(fn_takes_interface, (scripted_test_module,)) |
| |
| def test_traced_module_contains_scripted_interface_types(self): |
| |
| class LeafModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.weight = torch.nn.Parameter(torch.rand(19)) |
| |
| def forward(self, input: torch.Tensor): |
| return input + self.weight |
| |
| class LowerModuleImpl(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.leaf = LeafModule() |
| |
| def forward(self, input: torch.Tensor) -> torch.Tensor: |
| return self.leaf(input) |
| |
| @torch.jit.interface |
| class LowerModuleInterface(torch.nn.Module): |
| def forward(self, input: torch.Tensor) -> torch.Tensor: |
| pass |
| |
| class MiddleModule(torch.nn.Module): |
| lower: LowerModuleInterface |
| |
| def __init__(self, feature_processor_modules=None): |
| super().__init__() |
| self.lower = LowerModuleImpl() |
| |
| def forward(self, input): |
| return self.lower(input) |
| |
| class WrapperModule(torch.nn.Module): |
| def __init__(self, m): |
| super().__init__() |
| self.middle = m |
| |
| def forward(self, input): |
| return self.middle(input) |
| |
| class TopModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| m = MiddleModule() |
| m = torch.jit.script(m) |
| self.sub1 = m |
| self.sub2 = WrapperModule(m) |
| |
| def forward(self, input: torch.Tensor): |
| return self.sub1(input) + self.sub2(input) |
| |
| top = TopModule() |
| top_example_input = torch.ones(1) |
| torch.jit.trace(top, top_example_input) |