| # Owner(s): ["oncall: pt2"] |
| |
| # Copyright (c) Facebook, Inc. and its affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| from typing import Union, Callable, List, Any, Optional, Dict |
| from unittest.mock import patch |
| from torch.testing._internal.common_utils import ( |
| TestCase, |
| run_tests, |
| IS_ARM64, |
| IS_MACOS, |
| IS_WINDOWS, |
| IS_X86, |
| compare_equal_outs_and_grads, |
| outs_and_grads, |
| skipIfRocm, |
| ) |
| from torch.testing._internal.two_tensor import TwoTensor, TwoTensorMode |
| import copy |
| import torch |
| import torch.nn as nn |
| import torch.utils._pytree as pytree |
| import unittest |
| import warnings |
| import itertools |
| from contextlib import nullcontext |
| from functools import partial |
| from torch.nn.utils.rnn import PackedSequence |
| from torch.testing._internal.common_device_type import instantiate_device_type_tests, toleranceOverride, tol |
| from torch.testing._internal.common_methods_invocations import op_db |
| from torch.testing._internal.common_modules import module_db, modules |
| from torch.testing._internal.common_utils import parametrize, instantiate_parametrized_tests |
| from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db |
| from torch.testing._internal.optests import _test_aot_autograd_forwards_backwards_helper, aot_autograd_check |
| from torch._higher_order_ops.out_dtype import out_dtype |
| from functorch import ( |
| grad, vjp, vmap, jacrev, |
| make_fx |
| ) |
| from torch._functorch.aot_autograd import aot_module_simplified, aot_export_module, aot_export_joint_simple |
| from functorch.compile import ( |
| nnc_jit, compiled_function, compiled_module, |
| min_cut_rematerialization_partition, aot_function, aot_module, |
| nop, default_partition, default_decompositions, |
| memory_efficient_fusion, get_aot_compilation_context, make_boxed_compiler |
| ) |
| from torch._decomp import decomposition_table |
| |
| from torch.testing._internal.common_device_type import ops |
| from common_utils import ( |
| decorate, |
| xfail, |
| skip, |
| skipOps, |
| decorateForModules, |
| ) |
| from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode |
| from torch.fx.experimental.proxy_tensor import is_sym_node |
| from torch.fx.experimental.symbolic_shapes import ShapeEnv, GuardOnDataDependentSymNode |
| |
| USE_TORCHVISION = False |
| try: |
| import torchvision |
| USE_TORCHVISION = True |
| except ImportError: |
| warnings.warn("Couldn't import torchvision. Some of our tests use it, try " |
| "to install it with commands from pytorch.org, post-fixed with " |
| "`--no-deps` to avoid overwriting the pytorch installation", |
| UserWarning) |
| |
| USE_NETWORKX = False |
| try: |
| import networkx # noqa: F401 |
| USE_NETWORKX = True |
| except ImportError: |
| warnings.warn("Some tests use networkx but it was not installed", |
| UserWarning) |
| |
| # NB: numpy is a testing dependency! |
| |
| class AOTTestCase(TestCase): |
| def setUp(self): |
| super().setUp() |
| |
| class TestPythonKey(AOTTestCase): |
| def test_make_fx(self, device): |
| def f(x): |
| return torch.sin(x) |
| inp = torch.randn(3) |
| fx_f = make_fx(f)(inp) |
| |
| new_inp = torch.randn(3) |
| self.assertEqual(fx_f(new_inp), f(new_inp)) |
| |
| def test_make_fx_grad(self, device): |
| def f(x): |
| return torch.sin(x).sum() |
| inp = torch.randn(3) |
| f = grad(f) |
| fx_f = make_fx(f)(inp) |
| |
| new_inp = torch.randn(3) |
| self.assertEqual(fx_f(new_inp), f(new_inp)) |
| |
| def test_scalar_device(self, device): |
| def f(a, b): |
| return a + b |
| inps = [torch.randn(3, device=device), torch.tensor(5)] |
| fx_f = make_fx(f)(*inps) |
| self.assertEqual(fx_f(*inps), f(*inps)) |
| |
| def test_make_fx_vmap(self, device): |
| def f(x): |
| return torch.sin(x) |
| inp = torch.randn(5, 3) |
| f = vmap(f) |
| fx_f = make_fx(f)(inp) |
| new_inp = torch.randn(5, 3) |
| self.assertEqual(fx_f(new_inp), f(new_inp)) |
| |
| def test_make_fx_jacrev(self, device): |
| def f(x): |
| return x.sin().sum() |
| inp = torch.randn(3) |
| f = jacrev(jacrev(f)) |
| fx_f = make_fx(f)(inp) |
| new_inp = torch.randn(3) |
| self.assertEqual(fx_f(new_inp), f(new_inp)) |
| |
| def test_make_fx_vjp(self, device): |
| def f(x): |
| return torch.sin(x).sum() |
| |
| primals = torch.randn(3) |
| _, vjp_fn = vjp(f, primals) |
| cotangent = torch.randn(()) |
| fx_f = make_fx(vjp_fn)(cotangent, True, True) |
| new_cotangent = torch.randn(()) |
| self.assertEqual(fx_f(new_cotangent, True, True), vjp_fn(new_cotangent)) |
| |
| def test_make_fx_functionalize(self, device): |
| from functorch.experimental import functionalize |
| |
| def fn(a): |
| a = a * 2 |
| a.relu_() |
| return a |
| |
| a = torch.randn(3, device=device) |
| symbolic_gm = torch.fx.symbolic_trace(fn) |
| includes_method_relu_ = any( |
| str(n.target) == "relu_" for n in symbolic_gm.graph.nodes |
| ) |
| self.assertTrue(includes_method_relu_) |
| # Also verifies fix for https://github.com/pytorch/pytorch/issues/84570 |
| gm = make_fx(functionalize(symbolic_gm))(a) |
| includes_aten_relu = any( |
| n.target == torch.ops.aten.relu.default for n in gm.graph.nodes |
| ) |
| self.assertTrue(includes_aten_relu) |
| |
| def test_make_fx_no_decompose(self, device): |
| # FIXME |
| return self.skipTest("error: maximum recursion reached") |
| |
| def f(x): |
| return torch.tanh(x).sum() |
| |
| fx_f = make_fx(grad(f))(torch.randn(5)) |
| ops = {i.target for i in fx_f.graph.nodes} |
| |
| self.assertEqual(torch.ops.aten.tanh_backward in ops, True) |
| |
| fx_f = make_fx(grad(f), decomposition_table)(torch.randn(5)) |
| ops = {i.target for i in fx_f.graph.nodes} |
| self.assertEqual(torch.ops.aten.tanh_backward in ops, False) |
| |
| def test_nnc_jit(self, device): |
| def f(x): |
| return torch.sin(x) |
| |
| jit_f = nnc_jit(f) |
| |
| inp = torch.randn(3) |
| self.assertEqual(jit_f(inp), f(inp)) |
| |
| def test_nnc_scalar(self, device): |
| def f(x): |
| return torch.sin(x) |
| |
| jit_f = nnc_jit(f) |
| |
| inp = torch.randn(()) |
| self.assertEqual(jit_f(inp), f(inp)) |
| |
| def test_nnc_pytrees(self, device): |
| def f(x): |
| return [torch.sin(x[0])] |
| |
| jit_f = nnc_jit(f) |
| |
| inp = [torch.randn(3)] |
| self.assertEqual(jit_f(inp), f(inp)) |
| |
| def test_external_calls(self, device): |
| def f(a, b): |
| return torch.mv(a, b) |
| jit_f = nnc_jit(f) |
| inp = [torch.randn(3, 3), torch.randn(3)] |
| self.assertEqual(jit_f(*inp), f(*inp)) |
| |
| def test_nnc_passthrough(self, device): |
| def f(x, y): |
| return x + y, y |
| inp = (torch.randn(3), torch.randn(3)) |
| jit_f = nnc_jit(f) |
| self.assertEqual(jit_f(*inp), f(*inp)) |
| |
| def f(x): |
| x['a'] = x['a'] * 2 |
| return x |
| inp = ({'a': torch.randn(3), 'b': torch.randn(3)},) |
| jit_f = nnc_jit(f) |
| self.assertEqual(jit_f(*inp), f(*inp)) |
| |
| @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision") |
| def test_resnet18_backward_trace(self, device): |
| mod = torchvision.models.resnet18() |
| |
| def f(x): |
| out = mod(x) |
| out.sum().backward() |
| return [a.grad for a in mod.parameters()] |
| |
| inp = torch.randn(3, 3, 250, 250, requires_grad=True) |
| grads = f(inp) |
| |
| mod.zero_grad() |
| mod(inp).sum().backward() |
| grads2 = [a.grad for a in mod.parameters()] |
| self.assertEqual(grads, grads2) |
| |
| def get_base(t): |
| return t._base if t._is_view() else t |
| |
| def is_in_base(t, maybe_tensors): |
| t_base = get_base(t) |
| for maybe_tensor in maybe_tensors: |
| if isinstance(maybe_tensor, torch.Tensor): |
| if t_base is get_base(maybe_tensor): |
| return True |
| return False |
| |
| class TestAOTAutograd(AOTTestCase): |
| # test_mutation will: |
| # - Ensure that inputs are non-leaves, so our graphs can mutate them |
| # - try to mutate outputs of the graph (to ensure that autograd meta is set properly on outputs) |
| @patch("functorch.compile.config.debug_assert", True) |
| def verify_aot_autograd( |
| self, |
| f, |
| inp_: Union[Callable, List[Any]], |
| *, |
| test_mutation: bool = False, |
| only_keep_inference_mutations: bool = False, |
| decompositions: Optional[Dict] = None, |
| dynamic: bool = False, |
| # Only active when inp_ is Callable. |
| # TODO: probably consolidate all tests to make inp a Callable. |
| make_inputs_subclasses: bool = False, |
| ): |
| for keep_input_mutations in [True] if only_keep_inference_mutations else [True, False]: |
| # Some tests pass in a callable for inp, to generate the inputs |
| # (useful if we want to generate complicated aliasing inputs) |
| if isinstance(inp_, Callable): |
| inp_callable = inp_ |
| # The callable should return a tuple of f_inputs, f_graph_inputs |
| # (The idea is that we might want to compile a function with the graph inputs, |
| # but test autograd backprop all the way through the actual inputs) |
| with TwoTensorMode() if make_inputs_subclasses else nullcontext(): |
| inp_copy, graph_inps_copy = inp_callable() |
| inp, graph_inps = inp_callable() |
| else: |
| inp_copy = [] |
| inp = [] |
| # Our input clones need to mimic when inputs are duplicates of one another |
| dupes_map = {} |
| for i, x in enumerate(inp_): |
| if x in dupes_map: |
| x_dupe_idx = dupes_map[x] |
| inp_copy.append(inp_copy[x_dupe_idx]) |
| inp.append(inp[x_dupe_idx]) |
| else: |
| dupes_map[x] = i |
| if not isinstance(x, torch.Tensor): |
| x_copy = x |
| x_copy2 = x |
| else: |
| x_copy = x.clone().detach().requires_grad_(x.requires_grad) |
| x_copy2 = x.clone().detach().requires_grad_(x.requires_grad) |
| if x.requires_grad and not x.is_leaf: |
| x_copy = x_copy.clone() |
| x_copy2 = x_copy2.clone() |
| inp_copy.append(x_copy) |
| inp.append(x_copy2) |
| |
| if test_mutation: |
| # For graphs where we mutate inputs, need our test to make sure inputs aren't leaves |
| graph_inps = [x.add(1) for x in inp] |
| graph_inps_copy = [x.add(1) for x in inp_copy] |
| else: |
| graph_inps = inp |
| graph_inps_copy = inp_copy |
| fw_graph_cell = [None] |
| if isinstance(f, nn.Module): |
| compiled_f = aot_module( |
| f, |
| fw_compiler=make_boxed_compiler(partial(extract_graph, graph_cell=fw_graph_cell)), |
| bw_compiler=nop, |
| decompositions=decompositions, |
| keep_inference_input_mutations=keep_input_mutations, |
| dynamic=dynamic |
| ) |
| else: |
| compiled_f = aot_function( |
| f, |
| fw_compiler=make_boxed_compiler(partial(extract_graph, graph_cell=fw_graph_cell)), |
| bw_compiler=nop, |
| decompositions=decompositions, |
| keep_inference_input_mutations=keep_input_mutations, |
| dynamic=dynamic |
| ) |
| ref_out, ref_grad = outs_and_grads(f, graph_inps, inp) |
| test_out, test_grad = outs_and_grads(compiled_f, graph_inps_copy, inp_copy) |
| self.assertEqual(ref_grad, test_grad) |
| |
| if isinstance(ref_out, torch.Tensor): |
| self.assertTrue(isinstance(test_out, torch.Tensor)) |
| ref_out, test_out = [ref_out], [test_out] |
| for ref_o, test_o in zip(ref_out, test_out): |
| if isinstance(ref_o, torch.Tensor): |
| self.assertEqual(ref_o.requires_grad, test_o.requires_grad) |
| self.assertEqual(ref_o.is_leaf, test_o.is_leaf) |
| ref_is_view_of_non_interm = is_in_base(ref_o, graph_inps) or is_in_base(ref_o, ref_out) |
| test_is_view_of_non_interm = is_in_base(test_o, graph_inps_copy) or is_in_base(test_o, test_out) |
| self.assertEqual(ref_is_view_of_non_interm, test_is_view_of_non_interm) |
| self.assertEqual(ref_o, test_o) |
| if test_mutation: |
| # This tests that autograd meta is set properly on the output we can |
| # mutate it. |
| ref_o.mul_(2) |
| test_o.mul_(2) |
| self.assertEqual(ref_o, test_o) |
| for ref_i, test_i in zip(inp, inp_copy): |
| if isinstance(ref_i, torch.Tensor): |
| self.assertEqual(ref_i.requires_grad, test_i.requires_grad) |
| self.assertEqual(ref_i, test_i) |
| return fw_graph_cell[0] |
| |
| def test_non_tensor_and_none_inputs(self): |
| # int, None, Tensor |
| def f(a, b, c): |
| return a * c |
| inp = [2, None, torch.ones(3, 3, dtype=torch.float32, requires_grad=True)] |
| self.verify_aot_autograd(f, inp) |
| inp = [2, None, torch.ones(3, 3, dtype=torch.float32, requires_grad=False)] |
| self.verify_aot_autograd(f, inp) |
| |
| def test_single_output(self): |
| def f(a, b): |
| return a + b |
| inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)] |
| self.verify_aot_autograd(f, inp) |
| inp = [torch.randn(3, 3, requires_grad=False), torch.randn(3, 3)] |
| self.verify_aot_autograd(f, inp) |
| |
| def test_multi_output(self): |
| def f(a, b): |
| return a + b, a - b |
| inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)] |
| self.verify_aot_autograd(f, inp) |
| inp = [torch.randn(3, 3, requires_grad=False), torch.randn(3, 3)] |
| self.verify_aot_autograd(f, inp) |
| |
| def test_multi_output_list(self): |
| def f(a, b): |
| return [a + b, a - b] |
| inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)] |
| self.verify_aot_autograd(f, inp) |
| inp = [torch.randn(3, 3, requires_grad=False), torch.randn(3, 3)] |
| self.verify_aot_autograd(f, inp) |
| |
| # Test for bug occurring at the intersection of fake tensors & functionalization. |
| def test_squeeze_mutation(self): |
| def f(a): |
| b = a.clone().squeeze(-1) |
| b.add_(1.) |
| return a + b |
| |
| inp = [torch.randn(3, 1, requires_grad=True)] |
| self.verify_aot_autograd(f, inp, dynamic=True) |
| inp = [torch.randn(3, 1, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, dynamic=True) |
| |
| def test_complex_linear(self): |
| # https://github.com/pytorch/pytorch/issues/93424 |
| inp = [torch.randn(1, 10, 10, dtype=torch.complex64)] |
| |
| class F(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = nn.Linear(10, 10, dtype=torch.complex64) |
| |
| def forward(self, x): |
| return self.linear(x).sum().abs() |
| |
| self.verify_aot_autograd(F(), inp) |
| |
| def test_embedding_bag_view_dynamic(self): |
| # Backwards pass tries to wrap a sparse tensor in a FunctionalTensorWrapper; |
| # test that this works even though the sparse tensor has no storage. |
| |
| class F(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.emb = torch.nn.EmbeddingBag(100, 8, sparse=True) |
| |
| def forward(self, x, y): |
| return self.emb(x, y).view(-1) |
| |
| x = torch.arange(3) |
| y = torch.arange(3) |
| self.verify_aot_autograd(F(), [x, y], dynamic=False) |
| self.verify_aot_autograd(F(), [x, y], dynamic=True) |
| |
| |
| |
| def test_input_mutation_simple(self): |
| def f(a): |
| a.mul_(2) |
| return a * 3 |
| inp = [torch.ones(3, 3, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(3, 3, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| # Things to note: |
| # - the extra clone is because we need to pass the pre-mutated input to grad(), |
| # but autograd operates above functionalization so we need to manually clone. |
| # Hopefully backends can optimize this easily. |
| # - The extra return arg is because the compiled forward returns (mutated inputs + outputs) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| clone = torch.ops.aten.clone.default(primals_1); primals_1 = None |
| mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None |
| mul_1 = torch.ops.aten.mul.Tensor(mul, 3) |
| return [mul, mul_1]""") |
| |
| def test_input_mutation_set__input_mutation(self): |
| def f(a): |
| b = torch.arange(9, dtype=a.dtype).reshape(3, 3) |
| with torch.no_grad(): |
| a.set_(b) |
| return a * b |
| inp = [torch.ones(3, 3, requires_grad=True)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(3, 3, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| |
| def test_set__steals_view_chain(self): |
| def f(a, b): |
| a_ = a.mul(2) |
| b_ = b.mul(2) |
| b_slice = b_[1].view(3, 3) |
| # a_clone should inherit the view chain from b_slice |
| a_.set_(b_slice) |
| # Also mutates b_, |
| a_.view(-1).mul_(2) |
| return a_ * b_slice |
| inp = [torch.ones(3, 3, requires_grad=False), torch.zeros(3, 9, requires_grad=False)] |
| self.verify_aot_autograd(f, inp) |
| |
| def test_set__and_data_mutation_good(self): |
| def f(a, b): |
| # The data mutation happens *after* the set_(). This is ok (see the graph below) |
| with torch.no_grad(): |
| a.set_(b) |
| b.mul_(2) |
| return a + b |
| inp = [torch.ones(3, 3, requires_grad=True), torch.ones(3, 3, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(3, 3, requires_grad=False), torch.zeros(3, 3, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| # Important things to note: |
| # - "return a.set_(b)" desugars into "return b" |
| # - Both a and b are recorded as experiencing mutations, |
| # which is why we see "b_updated" (output of the mul) twice in the graph outputs. |
| # a is recorded as both a data mutation and a metadata mutation (due to set_ swapping its storage). |
| # - the runtime epilogue for a is "a.set_(mul)" |
| # - the runtime epilogue for b is "b.copy_(mul)" |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1, primals_2): |
| clone = torch.ops.aten.clone.default(primals_2); primals_2 = None |
| mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None |
| add = torch.ops.aten.add.Tensor(mul, mul) |
| return [mul, mul, add]""") |
| |
| # This is a (hopefully) extremely rare case that is difficult to handle, |
| # so we ban it. |
| def test_set__and_data_mutation_bad(self): |
| def f(a): |
| a_view = a.view(-1) |
| tmp = torch.ones(3, 3, requires_grad=True) |
| # Now, any mutations on either tmp |
| # will be tracked as graph input mutations. |
| with torch.no_grad(): |
| a.set_(tmp) |
| # BAD: a_view is now detached from every graph input, |
| # so we won't recognize that this caused an input mutation! |
| a_view.mul_(2) |
| return a + tmp |
| inp = [torch.ones(3, 3, requires_grad=True)] |
| with self.assertRaisesRegex(RuntimeError, "cannot mutate tensors with frozen storage"): |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| |
| def test_input_mutation_set__nop(self): |
| def f(a): |
| b = torch.arange(9, dtype=a.dtype) |
| a_old = torch.ops.aten.alias.default(a) |
| with torch.no_grad(): |
| a.set_(b) |
| a.set_(a_old) |
| return a + b.reshape(3, 3) |
| inp = [torch.ones(3, 3, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(3, 3, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| # Things to note: |
| # - There are no set_() calls in the graph (we functionalize a.set_(b) into "b") |
| # - There is only **1** graph output. We properly realized that the two set_() calls |
| # undo each other, and so effectively no inputs are mutated. |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| arange = torch.ops.aten.arange.default(9, dtype = torch.float32, device = device(type='cpu'), pin_memory = False) |
| alias = torch.ops.aten.alias.default(primals_1); primals_1 = None |
| view = torch.ops.aten.view.default(arange, [3, 3]); arange = None |
| add = torch.ops.aten.add.Tensor(alias, view); alias = view = None |
| return [add]""") |
| |
| def test_input_mutation_simple_with_none_and_nontensor(self): |
| # Tensor, None, int |
| def f(a, b, c): |
| return a * c |
| f_compiled = aot_function(f, nop) |
| for req_grad in [True, False]: |
| inp = [torch.ones(3, 3, requires_grad=req_grad), None, 3] |
| out_ref = f(*inp) |
| out_test = f_compiled(*inp) |
| self.assertEqual(out_ref, out_test) |
| |
| # https://github.com/pytorch/pytorch/issues/93363 |
| def test_mutates_input_noncontiguous(self): |
| def f(a): |
| a.add_(1) |
| return () |
| |
| f_compiled = aot_function(f, nop) |
| ref = torch.ones(4, requires_grad=True) + 0 |
| ref_view = ref[0::2] |
| |
| test = torch.ones(4, requires_grad=True) + 0 |
| test_view = test[0::2] |
| |
| out_ref = f(ref_view) |
| out_test = f_compiled(test_view) |
| print(ref) |
| print(test) |
| self.assertEqual(ref, test) |
| |
| def test_input_mutation_modifies_autograd_meta_of_aliases(self): |
| def f(a): |
| a.mul_(2) |
| out = a + 1 |
| return out.detach() |
| |
| x_ref = torch.ones(3, 3, requires_grad=True).clone() |
| x_ref_view = x_ref.view(3, 3) |
| |
| x_test = torch.ones(3, 3, requires_grad=True).clone() |
| x_test_view = x_test.view(3, 3) |
| |
| f_compiled = aot_function(f, nop, keep_inference_input_mutations=True) |
| f(x_ref) |
| f_compiled(x_test) |
| # f will mutate aliases of the input, including its autograd metadata! |
| # y.grad_fn is AsStridedBackward |
| self.assertEqual(x_ref_view, x_test_view) |
| self.assertEqual(x_ref_view._version, x_test_view._version) |
| self.assertEqual(x_ref_view.grad_fn.__class__, x_test_view.grad_fn.__class__) |
| # Test the actual gradients are correct |
| (x_ref * x_ref_view).sum().backward() |
| (x_test * x_test_view).sum().backward() |
| self.assertEqual(x_ref.grad, x_test.grad) |
| self.assertEqual(x_ref_view.grad, x_test_view.grad) |
| |
| |
| def test_outputs_are_aliased(self): |
| # Tensor, None, int |
| def f(a): |
| b = a.mul(2) |
| c = b.view(-1) |
| return b, c |
| f_compiled = aot_function(f, nop) |
| for req_grad in [True, False]: |
| inp = torch.ones(3, requires_grad=req_grad) |
| out_ref = f(inp) |
| out_test = f_compiled(inp) |
| self.assertEqual(out_ref[0], out_test[0]) |
| self.assertEqual(out_ref[1], out_test[1]) |
| # Try mutating one of the outputs, which is aliased. |
| out_ref[0].mul_(3) |
| out_test[0].mul_(3) |
| # Assert that the aliasing relationship was preserved |
| self.assertEqual(out_ref[0], out_test[0]) |
| self.assertEqual(out_ref[1], out_test[1]) |
| |
| def test_input_mutation_is_output(self): |
| def f(a): |
| a.mul_(2) |
| return a |
| inp = [torch.ones(3, 3, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(3, 3, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| clone = torch.ops.aten.clone.default(primals_1); primals_1 = None |
| mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None |
| return [mul, mul]""") |
| |
| def test_input_mutation_multiple(self): |
| def f(a, b, c): |
| a.mul_(2) |
| c.mul_(2) |
| return a + b + c |
| |
| def create_inp(req_grad): |
| return [ |
| torch.ones(3, 3, requires_grad=req_grad), |
| torch.ones(3, 3, requires_grad=req_grad), |
| torch.ones(3, 3, requires_grad=req_grad), |
| ] |
| |
| self.verify_aot_autograd(f, create_inp(False), test_mutation=True) |
| |
| fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1, primals_2, primals_3): |
| clone = torch.ops.aten.clone.default(primals_1); primals_1 = None |
| clone_1 = torch.ops.aten.clone.default(primals_3); primals_3 = None |
| mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None |
| mul_1 = torch.ops.aten.mul.Tensor(clone_1, 2); clone_1 = None |
| add = torch.ops.aten.add.Tensor(mul, primals_2); primals_2 = None |
| add_1 = torch.ops.aten.add.Tensor(add, mul_1); add = None |
| return [mul, mul_1, add_1]""") |
| |
| def test_input_mutation_metadata(self): |
| def f(a, b): |
| a.transpose_(1, 0) |
| return a + b |
| |
| def create_inp(req_grad): |
| return [ |
| torch.ones(3, 3, requires_grad=req_grad), |
| torch.ones(3, 3, requires_grad=req_grad), |
| ] |
| |
| self.verify_aot_autograd(f, create_inp(True), test_mutation=True) |
| self.verify_aot_autograd(f, create_inp(False), test_mutation=True) |
| |
| def test_input_output_aliase_custom_autograd_function(self): |
| |
| class Foo(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): |
| return x |
| |
| @staticmethod |
| def backward(ctx, gx): |
| return gx * 0.5 |
| |
| def f(x): |
| return Foo.apply(x) |
| |
| inp = [torch.ones(2, 2, requires_grad=True)] |
| self.verify_aot_autograd(f, inp, test_mutation=False) |
| |
| def test_input_mutation_requires_grad_detach(self): |
| # Here, "a" requires grad, and gets mutated, so we append a copy_() to the end of the graph. |
| # Its mutation doesn't take part in autograd though, because we mutated a detach'd view. |
| # Need to make sure that this copy_() doesn't error, and doesn't participate in autograd either. |
| def f(a): |
| a.detach().mul_(2) |
| return a + 3 |
| inp = [torch.ones(4, requires_grad=True)] |
| self.verify_aot_autograd(f, inp, test_mutation=False) |
| inp = [torch.ones(4, requires_grad=True)] |
| # test_mutation=True will first do some compute on inp, so it is no longer an autograd leaf |
| # by the time it becomes a graph input. Good to test both cases. |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| |
| def test_input_mutation_requires_grad_no_grad(self): |
| def f(a): |
| with torch.no_grad(): |
| a.mul_(2) |
| return a + 3 |
| inp = [torch.ones(4, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True, only_keep_inference_mutations=True) |
| # Even though the input requires_grad, we expect the keep the input mutation in the graph |
| # (Even though this is a training graph!) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| mul = torch.ops.aten.mul.Tensor(primals_1, 2) |
| add = torch.ops.aten.add.Tensor(mul, 3) |
| copy_ = torch.ops.aten.copy_.default(primals_1, mul); primals_1 = mul = None |
| return [add]""") |
| |
| def test_input_mutation_requires_grad_no_grad_inference_graph(self): |
| def f(a): |
| with torch.no_grad(): |
| a.mul_(2) |
| return a + 3 |
| inp = [torch.ones(4, requires_grad=True)] |
| # Even though the input requires_grad, we expect the keep the input mutation in the graph |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True, only_keep_inference_mutations=True) |
| |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, arg0_1): |
| mul = torch.ops.aten.mul.Tensor(arg0_1, 2) |
| add = torch.ops.aten.add.Tensor(mul, 3) |
| copy_ = torch.ops.aten.copy_.default(arg0_1, mul); arg0_1 = mul = None |
| return (add,)""") |
| |
| def test_input_mutation_requires_grad_no_grad_detach_mixed(self): |
| # Perform a mix of mutations on a: |
| # 1 normal, 1 in no_grad, 1 on a detach'd tensor. |
| # Only the first should participate in gradient computation. |
| def f(a): |
| a.detach().mul_(2) |
| a.mul_(3) |
| with torch.no_grad(): |
| a.mul_(4) |
| return a + 5 |
| inp = [torch.ones(4, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| |
| def test_input_mutation_metadata2(self): |
| def f(a): |
| a.transpose_(1, 0) |
| a.mul_(2) |
| return a + 1 |
| inp = [torch.ones(3, 3, requires_grad=True)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(3, 3, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| |
| def test_input_mutation_batchnorm(self): |
| def f(inpt, weight, bias, running_mean, running_var): |
| # This is additionally a good test, because the input tensors that we mutate |
| # are *also* saved for backwards. |
| # This tests that what we save for the backward is actually cloned inputs, |
| # and not the original inputs that got mutated. |
| return torch._native_batch_norm_legit(inpt, weight, bias, running_mean, running_var, True, 0.5, 1e-5) |
| |
| def create_inp(req_grad): |
| return [ |
| torch.ones(2, 5, 5, 5, requires_grad=req_grad), |
| torch.ones(5, requires_grad=req_grad), |
| torch.ones(5, requires_grad=req_grad), |
| torch.ones(5), |
| torch.ones(5), |
| ] |
| |
| from torch._decomp import get_decompositions |
| # This simulates what inductor does (running the fw + bw decompositions) |
| decompositions = get_decompositions([ |
| torch.ops.aten._native_batch_norm_legit_functional, |
| torch.ops.aten.native_batch_norm_backward, |
| ]) |
| self.verify_aot_autograd(f, create_inp(True), test_mutation=True, decompositions=decompositions) |
| self.verify_aot_autograd(f, create_inp(False), test_mutation=True, decompositions=decompositions) |
| |
| def test_batchnorm_inference(self): |
| inp = [ |
| torch.ones(2, 5, 5, 5, requires_grad=True), |
| torch.ones(5, requires_grad=True), |
| torch.ones(5, requires_grad=True), |
| torch.ones(5), |
| torch.ones(5), |
| ] |
| |
| m = torch.nn.BatchNorm2d(4, 4) |
| m.eval() |
| fw_graph_cell = [None] |
| inp = torch.ones(4, 4, 4, 4) |
| fw_graph_cell = [None] |
| compiled_m = aot_module( |
| m, |
| fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), |
| bw_compiler=nop, |
| keep_inference_input_mutations=True, |
| ) |
| inp = torch.ones(4, 4, 4, 4) |
| with torch.no_grad(): |
| out = compiled_m(inp) |
| # expectation: there are no copy_() calls in the decomposed batch norm when running under training=False (eval mode) |
| code = fw_graph_cell[0].code.strip() |
| self.assertTrue("copy_" not in str(code)) |
| |
| def test_input_output_view_simple(self): |
| def f(a): |
| return a.view(-1) |
| inp = [torch.ones(2, 2, requires_grad=False).add(1)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(2, 2, requires_grad=True).add(1)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| # Outputs that alias inputs are pulled out of the graph entirely, so we don't compile anything here |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| view = torch.ops.aten.view.default(primals_1, [-1]); primals_1 = None |
| return [view]""") |
| |
| def test_input_output_view_mutate_multiple(self): |
| def f(a, b, c): |
| a.mul_(2) |
| c.mul_(3) |
| return b.view(2, 2), c.view(2, 2) |
| |
| def create_inp(req_grad): |
| return [ |
| torch.ones(2, 2, requires_grad=req_grad).add(1), |
| torch.ones(2, 2, requires_grad=req_grad).add(1), |
| torch.ones(2, 2, requires_grad=req_grad).add(1), |
| ] |
| |
| self.verify_aot_autograd(f, create_inp(False), test_mutation=True) |
| fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True) |
| # The original function returned two outputs, both of which aliased inputs. |
| # We expect two outputs in the functional graph, a_updated and c_updated. |
| # The actual aliased outputs themselves aren't in the compiled forward graph; |
| # Instead, they're generated outside of the graph. |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1, primals_2, primals_3): |
| clone = torch.ops.aten.clone.default(primals_1); primals_1 = None |
| clone_1 = torch.ops.aten.clone.default(primals_3); primals_3 = None |
| mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None |
| mul_1 = torch.ops.aten.mul.Tensor(clone_1, 3); clone_1 = None |
| view = torch.ops.aten.view.default(primals_2, [2, 2]); primals_2 = None |
| view_2 = torch.ops.aten.view.default(mul_1, [2, 2]) |
| return [mul, mul_1, view, view_2]""") |
| |
| def test_input_output_view_metadata_mutate_multiple(self): |
| def f(a, b, c): |
| b.mul_(3) |
| c.t_() |
| return a.view(2, 2), b.view(2, 2), c.view(2, 2) |
| |
| def create_inp(req_grad): |
| return [ |
| torch.ones(2, 2, requires_grad=req_grad).add(1), |
| torch.ones(2, 2, requires_grad=req_grad).add(1), |
| torch.ones(2, 2, requires_grad=req_grad).add(1), |
| ] |
| |
| self.verify_aot_autograd(f, create_inp(False), test_mutation=True) |
| fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True) |
| # Important thing to check here: of the three inputs: |
| # Only the b.mul_(3) should show up in the graph (we functionalize it and return it). |
| # Everything else that does not show up in the graph includes: |
| # - The metadata mutation on c (we do it outside the graph) |
| # - All 3 original fw outputs, which are aliases of inputs (we regenerate them outside of the graph) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1, primals_2, primals_3): |
| clone = torch.ops.aten.clone.default(primals_2); primals_2 = None |
| view = torch.ops.aten.view.default(primals_3, [2, 2]); primals_3 = None |
| mul = torch.ops.aten.mul.Tensor(clone, 3); clone = None |
| t = torch.ops.aten.t.default(view); view = None |
| view_1 = torch.ops.aten.view.default(primals_1, [2, 2]); primals_1 = None |
| view_3 = torch.ops.aten.view.default(t, [2, 2]) |
| view_4 = torch.ops.aten.view.default(mul, [2, 2]) |
| return [mul, t, view_1, view_4, view_3]""") |
| |
| def test_input_mutation_and_output_view(self): |
| def f(a): |
| a.add_(1) |
| return a.view(-1) |
| inp = [torch.ones(2, 2, requires_grad=False).add(1)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(2, 2, requires_grad=True).add(1)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| # Here, total # of outputs is 1 because: |
| # - num_mutated_inps = 1 (a_updated) |
| # - num_fw_outputs = 0 (the output is an alias of the input, so we move it outside the compiled fw) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| clone = torch.ops.aten.clone.default(primals_1); primals_1 = None |
| add = torch.ops.aten.add.Tensor(clone, 1); clone = None |
| view_1 = torch.ops.aten.view.default(add, [-1]) |
| return [add, view_1]""") |
| |
| |
| def test_input_mutation_output_view_multiple(self): |
| def f(a, b, c, d): |
| b.transpose_(1, 0) |
| c.add_(1) |
| return d + 1, b.diagonal(), a + c |
| |
| def create_inp(req_grad): |
| return [ |
| torch.arange(4, requires_grad=req_grad, dtype=torch.float32).view(2, 2).add(1), |
| torch.arange(4, requires_grad=req_grad, dtype=torch.float32).view(2, 2).add(1), |
| torch.ones(2, 2, requires_grad=req_grad).add(1), |
| torch.ones(2, 2, requires_grad=req_grad).add(1), |
| ] |
| |
| self.verify_aot_autograd(f, create_inp(False), test_mutation=True) |
| fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1, primals_2, primals_3, primals_4): |
| view = torch.ops.aten.view.default(primals_2, [2, 2]); primals_2 = None |
| clone = torch.ops.aten.clone.default(primals_3); primals_3 = None |
| transpose = torch.ops.aten.transpose.int(view, 1, 0); view = None |
| add = torch.ops.aten.add.Tensor(clone, 1); clone = None |
| add_1 = torch.ops.aten.add.Tensor(primals_4, 1); primals_4 = None |
| diagonal = torch.ops.aten.diagonal.default(transpose) |
| add_2 = torch.ops.aten.add.Tensor(primals_1, add); primals_1 = None |
| return [transpose, add, add_1, diagonal, add_2]""") |
| |
| def test_output_aliases_intermediate_single(self): |
| def f(a): |
| out = torch.mul(a, 3) |
| return out.view(-1) |
| inp = [torch.ones(3, 3, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(3, 3, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| # In AOTAutograd, we are obligated to make the compiled forward directly return `out`, |
| # and reconstruct `out.view(-1)` as a fresh output. |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None |
| view = torch.ops.aten.view.default(mul, [-1]); mul = None |
| return [view]""") |
| |
| def test_output_aliases_input_multi_output_view_should_raise_autograd_error(self): |
| def f1(a): |
| return list(a.unbind(0)) |
| |
| f1_compiled = aot_function(f1, nop) |
| |
| inp1 = torch.ones(3, 3, requires_grad=True).clone() |
| inp2 = torch.ones(3, 3, requires_grad=True).clone() |
| inp3 = torch.ones(3, 3, requires_grad=True).clone() |
| |
| with self.assertRaisesRegex(RuntimeError, "Such functions do not allow the output views"): |
| out_test1 = f1_compiled(inp1) |
| # This raises a runtime error from autograd in eager mode |
| out_test1[0].mul_(2) |
| |
| with self.assertRaisesRegex(RuntimeError, "Such functions do not allow the output views"): |
| out_test2 = f1_compiled(inp2) |
| inp2.mul_(2) |
| # In eager mode, if we mutate a tensor, any multi-output-view aliases |
| # get their grad_fn replaced with error nodes, so accessing grad_fn should error |
| grad_fn = out_test2[0].grad_fn |
| |
| with self.assertRaisesRegex(RuntimeError, "Such functions do not allow the output views"): |
| out_test3 = f1_compiled(inp3) |
| out_test1[0].detach().mul_(2) |
| # The above case also applies to detached aliases (they turn the multi-output-view |
| # alias's grad_fns into error nodes) |
| grad_fn = out_test2[0].grad_fn |
| |
| |
| def test_output_aliases_input_multi_output_view(self): |
| # All aliased outs are from multi-output views, so AOTAutograd will hide the aliasing from autograd. |
| def f1(a): |
| return list(a.unbind(0)) |
| |
| inp = torch.ones(3, 3, requires_grad=True) |
| inp_ref = torch.ones(3, 3, requires_grad=True) |
| f1_compiled = aot_function(f1, nop) |
| |
| out_ref = f1(inp_ref) |
| out_test = f1_compiled(inp) |
| # Assert that we get CompiledFunctionBackward in the backward graph, |
| # and not AsStridedBackward. No view-regeneration necessary for this mult-output view case. |
| # See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call] |
| self.assertTrue(all('CompiledFunctionBackward' in str(o.grad_fn) for o in out_test)) |
| |
| sum(out_ref).sum().backward() |
| sum(out_test).sum().backward() |
| self.assertEqual(inp_ref.grad, inp.grad) |
| |
| # Several of the outputs are from multi-output views. |
| # However: they are part of the same alias set as "a", and "a.view(out.shape)", |
| # which are both user-visible. |
| # AOTAutograd will not try to be smart here and hide the aliasing relationships from autograd. |
| # Instead, it will perform its "output aliases input" logic, and regenerate all aliases. |
| def f3(a): |
| return *list(a.unbind(0)), a.view(a.shape) |
| |
| inp = torch.ones(3, 3, requires_grad=True) |
| inp_ref = torch.ones(3, 3, requires_grad=True) |
| f3_compiled = aot_function(f3, nop) |
| |
| inp_ref_clone = inp_ref.clone() |
| inp_clone = inp.clone() |
| out_ref = f3(inp_ref_clone) |
| out_test = f3_compiled(inp_clone) |
| self.assertTrue(all('UnbindBackward' in str(o.grad_fn) for o in out_test[:3])) |
| |
| # The last output is not from a multi-output view, so autograd will let us mutate it. |
| out_ref[-1].mul_(2) |
| out_test[-1].mul_(2) |
| # Also mutate the input, which should affect the aliased output. |
| inp_ref_clone.view(-1).mul_(3) |
| inp_clone.view(-1).mul_(3) |
| # Do backward |
| (inp_ref + out_ref[-1]).sum().backward() |
| (inp + out_test[-1]).sum().backward() |
| self.assertEqual(inp_ref.grad, inp.grad) |
| |
| |
| def test_output_aliases_intermediate_multi_output_view(self): |
| # All aliased outs are from multi-output views, so AOTAutograd will hide the aliasing from autograd. |
| def f1(a): |
| out = torch.mul(a, 3) |
| return list(out.unbind(0)) |
| |
| inp = torch.ones(3, 3, requires_grad=True) |
| inp_ref = torch.ones(3, 3, requires_grad=True) |
| f1_compiled = aot_function(f1, nop) |
| |
| out_ref = f1(inp_ref) |
| out_test = f1_compiled(inp) |
| # Assert that we get CompiledFunctionBackward in the backward graph, |
| # and not AsStridedBackward. No view-regeneration necessary for this mult-output view case. |
| # See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call] |
| self.assertTrue(all('CompiledFunctionBackward' in str(o.grad_fn) for o in out_test)) |
| |
| sum(out_ref).sum().backward() |
| sum(out_test).sum().backward() |
| self.assertEqual(inp_ref.grad, inp.grad) |
| |
| # All aliased outs but one are from multi-output views, so AOTAutograd will hide the aliasing from autograd. |
| def f2(a): |
| out = torch.mul(a, 3) |
| return *list(out.unbind(0)), out |
| |
| inp = torch.ones(3, 3, requires_grad=True) |
| inp_ref = torch.ones(3, 3, requires_grad=True) |
| f2_compiled = aot_function(f2, nop) |
| |
| out_ref = f2(inp_ref) |
| out_test = f2_compiled(inp) |
| # Assert that we get CompiledFunctionBackward in the backward graph, |
| # and not AsStridedBackward. No view-regeneration necessary for this mult-output view case. |
| # See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call] |
| self.assertTrue(all('CompiledFunctionBackward' in str(o.grad_fn) for o in out_test)) |
| |
| # The last output is not from a multi-output view, so autograd will let us mutate it. |
| out_ref[-1].mul_(2) |
| out_test[-1].mul_(2) |
| out_ref[-1].sum().backward() |
| out_test[-1].sum().backward() |
| self.assertEqual(inp_ref.grad, inp.grad) |
| |
| # All aliased outs but one are from multi-output views, so AOTAutograd will hide the aliasing from autograd. |
| def f3(a): |
| out = torch.mul(a, 3) |
| return *list(out.unbind(0)), out.view(out.shape) |
| |
| inp = torch.ones(3, 3, requires_grad=True) |
| inp_ref = torch.ones(3, 3, requires_grad=True) |
| f3_compiled = aot_function(f3, nop) |
| |
| out_ref = f3(inp_ref) |
| out_test = f3_compiled(inp) |
| # Assert that we get CompiledFunctionBackward in the backward graph, |
| # and not AsStridedBackward. No view-regeneration necessary for this mult-output view case. |
| # See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call] |
| self.assertTrue(all('CompiledFunctionBackward' in str(o.grad_fn) for o in out_test)) |
| |
| # The last output is not from a multi-output view, so autograd will let us mutate it. |
| out_ref[-1].mul_(2) |
| out_test[-1].mul_(2) |
| out_ref[-1].sum().backward() |
| out_test[-1].sum().backward() |
| self.assertEqual(inp_ref.grad, inp.grad) |
| |
| # There are 5 outputs that all alias each other. |
| # 3 of them come from multi-output views, but the other 3 are "ordinary" aliases. |
| # Therefore, AOTAutograd will not attempt the multi-output-view optimization, |
| # and apply the intermediate_base logic to all aliases. |
| # (In theory we could probably get AOTAutograd to only apply the intermediate base |
| # logic to the last 2 outputs and not the first 3. We should probably |
| # just do the graph partitioning defined in this doc instead though). |
| # https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit |
| def f4(a): |
| out = torch.mul(a, 3) |
| # also return the graph intermediate directly, |
| # which will force AOTAutograd to do the "intermediate base" logic. |
| # (Why? The user can mutate "out", which should change the autograd metadata |
| # of the other aliased outputs) |
| return *list(out.unbind(0)), out, out.view(out.shape) |
| |
| inp = torch.ones(3, 3, requires_grad=True) |
| inp_ref = torch.ones(3, 3, requires_grad=True) |
| f4_compiled = aot_function(f4, nop) |
| |
| out_ref = f4(inp_ref) |
| out_test = f4_compiled(inp) |
| # Mutate the last output of f4 (autograd will allow this, since it is not a multi-output view, |
| # as long as *only* the non-multi-output views participate in the backward) |
| # Note: We could probably try to hide **only** the multi-output views from autograd here |
| # and only do the intermediate base logic for the last two aliases. |
| # Longer term solution of graph partitioning is probably cleaner though (see the note). |
| out_ref[-1].mul_(2) |
| out_test[-1].mul_(2) |
| |
| out_ref_sum = out_ref[-1] + out_ref[-2] |
| out_test_sum = out_test[-1] + out_test[-2] |
| out_ref_sum.sum().backward() |
| out_test_sum.sum().backward() |
| self.assertEqual(inp_ref.grad, inp.grad) |
| |
| |
| def test_output_aliases_intermediate_mutation_linear(self): |
| def f(x): |
| return (x + 1).view(-1) |
| |
| inp = [torch.ones(3, 3, requires_grad=True)] |
| # use inductor's decomps (which will e.g. turn _unsafe_view() into view()) |
| from torch._inductor.decomposition import decompositions |
| f_compiled = aot_function(f, nop, decompositions=decompositions) |
| |
| out_ref = f(*inp) |
| out_test = f_compiled(*inp) |
| |
| out_ref.mul_(2) |
| out_test.mul_(2) |
| self.assertEqual(out_ref, out_test) |
| |
| def test_output_aliases_intermediate_no_grad(self): |
| def f(a, b): |
| out = torch.mul(a, 3) |
| # First output is an alias of an intermediate that doesn't require grad |
| return out.view(-1), b.add(1) |
| inp = [torch.ones(3, 3), torch.ones(3, 3, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(3, 3), torch.ones(3, 3, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| # important bit: we don't bother generating an intermediate base as an output in the graph, |
| # because the intermediate base itself didn't require gradients. |
| # (the only problematic case is when both the base and the aliasesed output require gradients). |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1, primals_2): |
| mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None |
| view = torch.ops.aten.view.default(mul, [-1]); mul = None |
| add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None |
| return [view, add]""") |
| |
| def test_output_aliases_intermediate_returned_multiple_times(self): |
| def f(a): |
| out = torch.mul(a, 3) |
| out_view = out.view(-1) |
| return out, out_view, out |
| inp = [torch.ones(3, 3, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(3, 3, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| |
| def test_output_aliases_intermediate_multiple(self): |
| def f(a): |
| out = torch.mul(a, 3) |
| # AOTAutograd should manually generate these two output views in the epilogue. |
| return out.view(-1), out.view(-1) |
| inp = [torch.ones(3, 3, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(3, 3, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None |
| view = torch.ops.aten.view.default(mul, [-1]) |
| view_1 = torch.ops.aten.view.default(mul, [-1]) |
| return [view, view_1, mul]""") |
| |
| def test_output_aliases_intermediate_and_returned(self): |
| def f(a): |
| out = torch.mul(a, 3) |
| # AOTAutograd should manually generate the first output (a view of an intermediate) |
| # but not the second (which is itself the intermediate for the first) |
| return out.view(-1), out |
| inp = [torch.ones(3, 3, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(3, 3, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None |
| view = torch.ops.aten.view.default(mul, [-1]) |
| return [view, mul]""") |
| |
| def test_output_aliases_intermediate_and_returned_flipped(self): |
| def f(a): |
| out = torch.mul(a, 3) |
| # AOTAutograd should manually generate the first output (a view of an intermediate) |
| # but not the second (which is itself the intermediate for the first) |
| return out, out.view(-1) |
| inp = [torch.ones(3, 3, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(3, 3, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None |
| view = torch.ops.aten.view.default(mul, [-1]) |
| return [mul, view]""") |
| |
| def test_output_aliases_intermediate_and_returned_different_grad(self): |
| def f(a): |
| out = torch.mul(a, 3) |
| # AOTAutograd should manually generate the first output (a view of an intermediate) |
| # but not the second (which is itself the intermediate for the first) |
| return out.view(-1), out, out[0].detach() |
| inp = [torch.ones(3, 3, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(3, 3, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None |
| view = torch.ops.aten.view.default(mul, [-1]) |
| select = torch.ops.aten.select.int(mul, 0, 0) |
| detach = torch.ops.aten.detach.default(select); select = None |
| detach_1 = torch.ops.aten.detach.default(detach); detach = None |
| detach_2 = torch.ops.aten.detach.default(detach_1); detach_1 = None |
| return [view, mul, detach_2]""") |
| |
| def test_output_aliases_intermediate_inplace_view(self): |
| def f(a): |
| out = torch.mul(a, 3) |
| out.t_() |
| return out |
| inp = [torch.ones(2, 4, requires_grad=True)] |
| |
| # TODO: fix this test. |
| # See https://github.com/pytorch/pytorch/issues/90507 |
| # self.verify_aot_autograd(f, inp, test_mutation=True) |
| |
| def test_output_aliases_intermediate_inplace_view_with_detach(self): |
| def f(a): |
| out = torch.mul(a, 3) |
| out.t_() |
| out.detach_() |
| # Thanks to the detach_() AOT Autograd doesn't need to do anything. |
| # `out` will show up as having OutputType.non_alias, |
| # and ._is_view() == False |
| return out, a + 1 |
| inp = [torch.ones(2, 4, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(2, 4, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| mul = torch.ops.aten.mul.Tensor(primals_1, 3) |
| t = torch.ops.aten.t.default(mul); mul = None |
| add = torch.ops.aten.add.Tensor(primals_1, 1); primals_1 = None |
| return [t, add]""") |
| |
| |
| def test_output_aliases_intermediate_inplace_view_and_view(self): |
| def f(a): |
| out = torch.mul(a, 3) |
| out_view = out.unsqueeze(0) |
| out.t_() |
| out_view2 = out.unsqueeze(0) |
| return out_view, out, out_view2 |
| inp = [torch.ones(2, 4, requires_grad=True)] |
| |
| # TODO: fix this test. |
| # See <github issue link> |
| # self.verify_aot_autograd(f, inp, test_mutation=True) |
| |
| def test_output_aliases_intermediate_multiple_mixed(self): |
| def f(a): |
| out1 = torch.mul(a, 3) |
| out2 = torch.mul(a, 4) |
| # AOTAutograd should manually generate these two output views in the epilogue. |
| return out1.view(-1), out2.transpose(1, 0), out1.transpose(1, 0) |
| inp = [torch.ones(3, 3, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(3, 3, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| mul = torch.ops.aten.mul.Tensor(primals_1, 3) |
| mul_1 = torch.ops.aten.mul.Tensor(primals_1, 4); primals_1 = None |
| view = torch.ops.aten.view.default(mul, [-1]) |
| transpose = torch.ops.aten.transpose.int(mul_1, 1, 0); mul_1 = None |
| transpose_1 = torch.ops.aten.transpose.int(mul, 1, 0) |
| return [view, transpose, transpose_1, mul]""") |
| |
| def test_output_all_alias_types(self): |
| # There are 3 types of aliasing that require us to return metadata in the compiled fw: |
| # (1) outputs that are views of inputs |
| # (2) outputs that are views of intermediates |
| # (3) inputs that get metadata mutations |
| # test all 3 of them here |
| def f(a): |
| a.transpose_(1, 0) |
| tmp = a.mul(2) |
| return tmp.squeeze(), tmp.transpose(1, 0), a.unsqueeze(0) |
| |
| def inp_callable(req_grad): |
| x = torch.ones(1, 2, 4, requires_grad=req_grad).clone() |
| return [(x,), (x,)] |
| |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True) |
| fw_graph = self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True) |
| # TODO: make this test run with dynamic shapes so it is more meaningful |
| # metadata output order: (a_updated_meta, out1_meta, out2_meta, out3_meta) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| view = torch.ops.aten.view.default(primals_1, [1, 2, 4]); primals_1 = None |
| transpose = torch.ops.aten.transpose.int(view, 1, 0); view = None |
| mul = torch.ops.aten.mul.Tensor(transpose, 2) |
| squeeze = torch.ops.aten.squeeze.default(mul) |
| transpose_1 = torch.ops.aten.transpose.int(mul, 1, 0) |
| unsqueeze = torch.ops.aten.unsqueeze.default(transpose, 0) |
| return [transpose, squeeze, transpose_1, unsqueeze, mul]""") |
| |
| @parametrize("req_grad", [False, True]) |
| def test_subclass_metadata_mutation(self, req_grad): |
| def f(a): |
| a.transpose_(1, 0) |
| tmp = a.mul(2) |
| return tmp.transpose(1, 0) |
| |
| def inp_callable(req_grad): |
| x = torch.ones(1, 2, 4, requires_grad=req_grad).clone() |
| return [(x,), (x,)] |
| |
| # See https://github.com/pytorch/pytorch/issues/114975 |
| with self.assertRaisesRegex(RuntimeError, "Metadata mutations are currently not allowed on tensor subclasses"): |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=req_grad), test_mutation=True, make_inputs_subclasses=True) |
| |
| def test_input_data_and_metadata_mutation(self): |
| def f(a): |
| a.t_() |
| a[0].mul_(2) |
| return a.view(a.shape) |
| inp = [torch.ones(3, 3, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(3, 3, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| clone = torch.ops.aten.clone.default(primals_1); primals_1 = None |
| t = torch.ops.aten.t.default(clone) |
| select = torch.ops.aten.select.int(t, 0, 0); t = None |
| mul = torch.ops.aten.mul.Tensor(select, 2); select = None |
| t_1 = torch.ops.aten.t.default(clone); clone = None |
| select_scatter = torch.ops.aten.select_scatter.default(t_1, mul, 0, 0); t_1 = mul = None |
| t_2 = torch.ops.aten.t.default(select_scatter); select_scatter = None |
| t_4 = torch.ops.aten.t.default(t_2) |
| t_6 = torch.ops.aten.t.default(t_2); t_2 = None |
| view_1 = torch.ops.aten.view.default(t_6, [3, 3]); t_6 = None |
| return [t_4, view_1]""") |
| |
| def test_view_and_inplace_view(self): |
| def f(a, b): |
| a.t_() |
| return b.view(b.shape), a.view(a.shape) |
| |
| def create_inp(req_grad): |
| return [ |
| torch.ones(3, 3, requires_grad=req_grad), |
| torch.ones(3, 3, requires_grad=req_grad) |
| ] |
| |
| self.verify_aot_autograd(f, create_inp(False), test_mutation=True) |
| fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1, primals_2): |
| view = torch.ops.aten.view.default(primals_1, [3, 3]); primals_1 = None |
| t = torch.ops.aten.t.default(view); view = None |
| view_1 = torch.ops.aten.view.default(primals_2, [3, 3]); primals_2 = None |
| view_2 = torch.ops.aten.view.default(t, [3, 3]) |
| return [t, view_1, view_2]""") |
| |
| def test_view_detach(self): |
| def f(a): |
| tmp = a.detach() |
| a.mul_(2) |
| return a, tmp |
| inp = [torch.ones(3, 3, requires_grad=True)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(3, 3, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| |
| def test_input_inplace_requires_grad_true(self): |
| def f(a, b): |
| a.requires_grad_(True) |
| return a.mul(3), b.mul(4) |
| inp = [ |
| # First inp doesnt require grad, but we switch it on |
| torch.ones(3, 3, requires_grad=False), |
| torch.ones(3, 3, requires_grad=True), |
| ] |
| |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1, primals_2): |
| mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None |
| mul_1 = torch.ops.aten.mul.Tensor(primals_2, 4); primals_2 = None |
| return [mul, mul_1]""") |
| |
| # This is a torture test: |
| # a and b get turned into a synthetic base in the compiled graph |
| # One gets a data mutation, the other gets a metadata mutation. |
| # We need to make sure that the metadata mutation gets propagated |
| # back to the original input. |
| def test_input_data_and_metadata_mutation_aliases_other_input(self): |
| # a and b are aliased |
| def f(a, b): |
| a.mul_(2) |
| b.t_() |
| return a.mul(b) |
| |
| def inp_callable(req_grad): |
| base = torch.ones(2, 2, requires_grad=req_grad) |
| # Note: in our test, the add() is important because we need the graph inputs to be non-leaves so we can mutate them. |
| x = base.add(1) |
| inp1 = x[0] |
| inp2 = x[0] |
| return [base], [inp1, inp2] |
| |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True) |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True) |
| with self.assertRaisesRegex(RuntimeError, "Encountered aliased inputs that are mutated in the graph, but"): |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True, make_inputs_subclasses=True) |
| with self.assertRaisesRegex(RuntimeError, "Encountered aliased inputs that are mutated in the graph, but"): |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True, make_inputs_subclasses=True) |
| |
| # https://github.com/pytorch/pytorch/issues/106456 |
| def test_input_mutation_noncontiguous(self): |
| def f(a): |
| a.mul_(2) |
| return a + 1 |
| |
| def inp_callable(req_grad): |
| base = torch.ones(2, 2, requires_grad=req_grad) |
| x = base.add(1) |
| # create a non-contiguous view to pass as an input to the compiler |
| inp = x[:, 0] |
| return [base], [inp] |
| |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True) |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True) |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True, make_inputs_subclasses=True) |
| with self.assertRaisesRegex(AssertionError, "attempted to compile the backward with incorrect subclass metadata"): |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True, make_inputs_subclasses=True) |
| |
| # Mutations in the backward are allowed as long as the mutated object does not require grad |
| def test_backward_mutation_data(self): |
| class BwMutation(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): |
| ctx.save_for_backward(x) |
| return x.clone() |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| x, = ctx.saved_tensors |
| # bw mutation |
| x.mul_(2) |
| return grad_output.clone() |
| |
| def f(a, b): |
| out = BwMutation.apply(b) |
| return a * out |
| |
| inp_no_grad = [ |
| torch.ones(3, 3, requires_grad=True), |
| torch.ones(3, 3, requires_grad=False), |
| ] |
| |
| # Mutation on buffer that does not require grad during the backward is allowed |
| self.verify_aot_autograd(f, inp_no_grad, test_mutation=True) |
| |
| inp_grad = [ |
| torch.ones(3, 3, requires_grad=True), |
| torch.ones(3, 3, requires_grad=True), |
| ] |
| with self.assertRaisesRegex(AssertionError, "input that requires_grad and was mutated in the backward"): |
| self.verify_aot_autograd(f, inp_grad, test_mutation=True) |
| |
| def test_backward_mutation_metadata(self): |
| class BwMutation(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, a, b): |
| ctx.save_for_backward(b) |
| return a.clone(), b.clone() |
| |
| @staticmethod |
| def backward(ctx, grad_a, grad_b): |
| b, = ctx.saved_tensors |
| # bw metadata mutation |
| b.transpose_(1, 0) |
| return grad_a.clone(), grad_b.clone() |
| |
| def f(a, b): |
| a_, b_ = BwMutation.apply(a, b) |
| out = a_ * b_ |
| return out |
| |
| inp_no_grad = [ |
| torch.ones(3, 3, requires_grad=True), |
| torch.ones(3, 3, requires_grad=False), |
| ] |
| |
| with self.assertRaisesRegex(AssertionError, "input that had its metadata mutated in the backward"): |
| self.verify_aot_autograd(f, inp_no_grad, test_mutation=True) |
| |
| def test_backward_mutation_on_grad_out(self): |
| class BwMutation(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): |
| return x.clone() |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| grad_output.mul_(2) |
| return grad_output.clone() |
| |
| def f(a, b): |
| tmp = a * b |
| out = BwMutation.apply(tmp) |
| return out |
| |
| inp_grad = [ |
| torch.ones(3, 3, requires_grad=True), |
| torch.ones(3, 3, requires_grad=True), |
| ] |
| f_compiled = aot_function(f, nop) |
| with self.assertRaisesRegex(AssertionError, "input to the backward that was mutated during the backward"): |
| out = f_compiled(*inp_grad) |
| |
| # Partially addresses https://github.com/pytorch/pytorch/issues/106457 |
| def test_input_mutation_false_aliasing(self): |
| def f(a, b): |
| a.mul_(3) |
| b.mul_(2) |
| return a + b |
| |
| # No overlap, contiguous |
| def inp_callable1(req_grad): |
| base = torch.ones(4, 4, requires_grad=req_grad) |
| x = base.add(1) |
| # create two non-contiguous views that share storage, but are actually non-overlapping |
| a = x[0:2] |
| b = x[2:4] |
| return [base], [a, b] |
| |
| fw_graph = self.verify_aot_autograd(f, partial(inp_callable1, req_grad=False), test_mutation=True) |
| self.verify_aot_autograd(f, partial(inp_callable1, req_grad=True), test_mutation=True) |
| self.verify_aot_autograd(f, partial(inp_callable1, req_grad=False), test_mutation=True, make_inputs_subclasses=True) |
| # Input mutations on subclasses with training graphs fail backward guards today. |
| with self.assertRaisesRegex(AssertionError, "attempted to compile the backward with incorrect subclass metadata"): |
| self.verify_aot_autograd(f, partial(inp_callable1, req_grad=True), test_mutation=True, make_inputs_subclasses=True) |
| |
| # Important characteristic: the graph takes in 2 inputs! |
| # That shows that we didn't try to run our complicated synthetic base logic, |
| # because we successfully detected false aliasing across the two inputs. |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, arg0_1, arg1_1): |
| mul = torch.ops.aten.mul.Tensor(arg0_1, 3); arg0_1 = None |
| mul_1 = torch.ops.aten.mul.Tensor(arg1_1, 2); arg1_1 = None |
| add = torch.ops.aten.add.Tensor(mul, mul_1) |
| return (mul, mul_1, add)""") |
| |
| # No overlap, non-contiguous: first tensor ends before second tensor start |
| def inp_callable2(req_grad): |
| base = torch.ones(256, requires_grad=req_grad) |
| x = base.add(1) |
| a = x.as_strided((4, 4), (8, 1), storage_offset=0) |
| b = x.as_strided((4, 4), (8, 1), storage_offset=28) |
| return [base], [a, b] |
| |
| # No overlap, non-contiguous: tensors are perfectly interleaved |
| def inp_callable3(req_grad): |
| base = torch.ones(4, 4, requires_grad=req_grad) |
| x = base.add(1) |
| a = x[:, 0:2] |
| b = x[:, 2:4] |
| return [base], [a, b] |
| |
| # No overlap, non-contiguous |
| def inp_callable4(req_grad): |
| base = torch.ones(256, requires_grad=req_grad) |
| x = base.add(1) |
| a = x.as_strided((4, 4), (9, 1), storage_offset=0) |
| b = x.as_strided((4, 4), (9, 1), storage_offset=22) |
| return [base], [a, b] |
| |
| # No overlap, non-contiguous |
| def inp_callable5(req_grad): |
| base = torch.ones(256, requires_grad=req_grad) |
| x = base.add(1) |
| a = x.as_strided((4, 4), (9, 1), storage_offset=0) |
| b = x.as_strided((4, 4), (9, 1), storage_offset=23) |
| return [base], [a, b] |
| |
| # overlap! non-contiguous |
| def inp_callable_overlap1(req_grad): |
| base = torch.ones(256, requires_grad=req_grad) |
| x = base.add(1) |
| a = x.as_strided((4, 4), (9, 1), storage_offset=0) |
| b = x.as_strided((4, 4), (9, 1), storage_offset=24) |
| return [base], [a, b] |
| |
| # overlap! non-contiguous |
| def inp_callable_overlap2(req_grad): |
| base = torch.ones(256, requires_grad=req_grad) |
| x = base.add(1) |
| a = x.as_strided((4, 4), (9, 1), storage_offset=0) |
| b = x.as_strided((4, 4), (9, 1), storage_offset=25) |
| return [base], [a, b] |
| |
| fw_graph2 = self.verify_aot_autograd(f, partial(inp_callable2, req_grad=False), test_mutation=True) |
| fw_graph3 = self.verify_aot_autograd(f, partial(inp_callable3, req_grad=False), test_mutation=True) |
| fw_graph4 = self.verify_aot_autograd(f, partial(inp_callable4, req_grad=False), test_mutation=True) |
| fw_graph5 = self.verify_aot_autograd(f, partial(inp_callable5, req_grad=False), test_mutation=True) |
| |
| fw_graph_overlap1 = self.verify_aot_autograd(f, partial(inp_callable_overlap2, req_grad=False), test_mutation=True) |
| fw_graph_overlap2 = self.verify_aot_autograd(f, partial(inp_callable_overlap1, req_grad=False), test_mutation=True) |
| |
| # All non-overlap graphs should be the same since we detected false aliasing |
| self.assertEqual(str(fw_graph.code), str(fw_graph2.code)) |
| self.assertEqual(str(fw_graph.code), str(fw_graph3.code)) |
| self.assertEqual(str(fw_graph.code), str(fw_graph4.code)) |
| self.assertEqual(str(fw_graph.code), str(fw_graph5.code)) |
| |
| # All overlap graphs should be the same since we detected real aliasing |
| self.assertNotEqual(str(fw_graph.code), str(fw_graph_overlap1.code)) |
| self.assertNotEqual(str(fw_graph.code), str(fw_graph_overlap2.code)) |
| self.assertTrue('as_strided_scatter' in str(fw_graph_overlap1.code)) |
| self.assertTrue('as_strided_scatter' in str(fw_graph_overlap2.code)) |
| |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") |
| def test_mem_leak_from_save_for_bw(self): |
| # See a full diagnosis at this issue: https://github.com/pytorch/pytorch/issues/94990 |
| # Note [Detaching saved tensors in AOTAutograd] |
| # This program creates a ref-cycle. Long term, we should fix this ref cycle |
| # (since it can arise, naturally albeit rarely, from uses of autograd.Function). |
| # But AOTAutograd makes it more likely to show up from tracing user programs, |
| # so we deal with it by manually detaching the tensors that we save for backward. |
| # This is completely wrong and would give wrong results if we were to do double backward. |
| # Fortunately today, double backward is explicitly banned in AOTAutograd. |
| def f(a, b): |
| add = a + a |
| split = torch.functional.split(add, [4, 4], dim=1) |
| getitem_2 = split[1] |
| unsqueeze = getitem_2.unsqueeze(-1) |
| mul = unsqueeze * b |
| return (getitem_2, mul) |
| |
| f_compiled = aot_function(f, nop) |
| inps = [ |
| torch.ones(8, 8, device='cuda', requires_grad=True), |
| torch.ones(1, 4, 1, device='cuda', requires_grad=True), |
| ] |
| mem_before = torch.cuda.memory_allocated() |
| f_compiled(*inps) |
| mem_after = torch.cuda.memory_allocated() |
| self.assertTrue(mem_after == mem_before) |
| |
| def test_output_aliases_multiple_inputs_get_correct_one(self): |
| # a and b are aliased, but have different shapes |
| # The first output should view off the first input, the 2nd output should view off the 2nd input |
| def f(a, b): |
| return a.view(a.shape), b.view(b.shape) |
| |
| def inp_callable(req_grad): |
| base = torch.ones(2, 2, requires_grad=req_grad) |
| # Note: in our test, the add() is important because we need the graph inputs to be non-leaves so we can mutate them. |
| x = base.mul(2) |
| inp1 = x.view(-1) |
| inp2 = x[0] |
| return [base], [inp1, inp2] |
| |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True) |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True) |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True, make_inputs_subclasses=True) |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True, make_inputs_subclasses=True) |
| |
| def test_input_mutation_aliases_other_input(self): |
| def f(a, b): |
| a.add_(1) |
| return a + b |
| |
| def inp_callable(req_grad): |
| base = torch.ones(4, 2, requires_grad=req_grad) |
| # Note: in our test, the add() is important because we need the graph inputs to be non-leaves so we can mutate them. |
| x = base.add(1) |
| inp1 = x[0] |
| inp2 = x[0] |
| return [base], [inp1, inp2] |
| |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True) |
| fw_graph = self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True) |
| # Important parts of the graph: |
| # - the compiled graph takes in a base, and we generate a and b (the views) off of the base |
| # - clone() is still in the graph, because we need to call grad() on the original (non-mutated) inputs |
| # - We re-generate the views *after* the clone, to preserve view relationships. |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| clone = torch.ops.aten.clone.default(primals_1); primals_1 = None |
| as_strided = torch.ops.aten.as_strided.default(clone, [2], [1], 0) |
| add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None |
| as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = add = None |
| as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0) |
| as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0) |
| add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_2 = as_strided_5 = None |
| return [as_strided_scatter, add_1]""") # noqa: B950 |
| |
| def test_input_mutation_aliases_other_input2(self): |
| def f(a, b): |
| a.add_(1) |
| return a + b |
| |
| def inp_callable(req_grad): |
| base = torch.ones(2, 2, requires_grad=req_grad) |
| x = base.add(1) |
| inp1 = x[0] |
| # Here, one of the aliased inputs is the base itself |
| inp2 = x |
| return [base], [inp1, inp2] |
| |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True) |
| fw_graph = self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| clone = torch.ops.aten.clone.default(primals_1); primals_1 = None |
| as_strided = torch.ops.aten.as_strided.default(clone, [2], [1], 0) |
| add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None |
| as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = add = None |
| as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0) |
| as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2, 2], [2, 1], 0) |
| add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_2 = as_strided_5 = None |
| return [as_strided_scatter, add_1]""") # noqa: B950 |
| |
| def test_input_mutation_aliases_and_output_alias(self): |
| def f(a, b): |
| # Here, we need to take care:that because and b are aliased |
| # since a and b are aliased, we generate a view off of "updated b" |
| a.add_(1) |
| return b.view(b.shape) |
| |
| def inp_callable(req_grad): |
| base = torch.ones(2, 2, requires_grad=req_grad) |
| x = base.add(1) |
| return [base], [x.view(-1), x.view(-1)] |
| |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True) |
| fw_graph = self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| clone = torch.ops.aten.clone.default(primals_1); primals_1 = None |
| as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0) |
| add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None |
| as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None |
| as_strided_8 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) |
| view_1 = torch.ops.aten.view.default(as_strided_8, [4]); as_strided_8 = None |
| return [as_strided_scatter, view_1]""") # noqa: B950 |
| |
| def test_input_aliased_with_mutation_output_alias(self): |
| def f(a, b, c): |
| # a and c alias |
| c.mul_(2) |
| # The main thing we're testing here is that |
| # (1) We need to reconstruct c.view(-1) from the 3rd input to the forward |
| # (2) But we need to be careful to do this *before* converting aliased inputs into synthetic bases. |
| # The original fw takes in 3 args, but the compiled fw takes in only 2 args. |
| return b.add(1), c.view(-1) |
| |
| def inp_callable(req_grad): |
| base1 = torch.ones(2, 2, requires_grad=req_grad) |
| base2 = torch.ones(2, 2, requires_grad=req_grad) |
| x = base1.add(1) |
| y = base2.add(1) |
| return [base1, base2], [x.view(-1), y, x.view(-1)] |
| |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True) |
| fw_graph = self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1, primals_2): |
| clone = torch.ops.aten.clone.default(primals_1); primals_1 = None |
| as_strided_1 = torch.ops.aten.as_strided.default(clone, [4], [1], 0) |
| mul = torch.ops.aten.mul.Tensor(as_strided_1, 2); as_strided_1 = None |
| as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None |
| add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None |
| as_strided_7 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) |
| view_1 = torch.ops.aten.view.default(as_strided_7, [-1]); as_strided_7 = None |
| return [as_strided_scatter, add, view_1]""") # noqa: B950 |
| |
| def test_input_metadata_mutation_aliases(self): |
| def f(a, b): |
| # a and b alias, and we do a metadata mutation on a |
| # Since we're not mutating data, then b isn't affected at all. |
| # We expect aot autograd to not bother with constructing a synthetic base. |
| a.t_() |
| return a + b |
| |
| def inp_callable(req_grad): |
| base = torch.ones(2, 2, requires_grad=req_grad) |
| x = base.add(1) |
| return [base], [x.view(-1), x.view(-1)] |
| |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True) |
| fw_graph = self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True) |
| # Expectation: fwd() takes in 2 args, and we don't construct a synthetic base. |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1, primals_2): |
| t = torch.ops.aten.t.default(primals_1); primals_1 = None |
| add = torch.ops.aten.add.Tensor(t, primals_2); t = primals_2 = None |
| return [add]""") |
| |
| def test_input_mutation_aliases_and_none_require_gradients(self): |
| def f(a, b, c): |
| # a and b alias, but neither require gradients (so they don't have a _base) |
| # aot autograd should construct the synthetic base from `torch.Tensor(a.storage())` |
| a.mul_(2) |
| return b + 1, c + 1 |
| |
| def inp_callable(req_grad): |
| base = torch.ones(2, 2) |
| c_arg = torch.ones(2, 2, requires_grad=req_grad) |
| x = base.add(1) |
| return [base, c_arg], [x.view(-1), x.view(-1), c_arg] |
| |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True) |
| |
| with self.assertRaisesRegex(RuntimeError, "is a tensor subclass. This is not supported today"): |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True, make_inputs_subclasses=True) |
| |
| fw_graph = self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1, primals_2): |
| as_strided = torch.ops.aten.as_strided.default(primals_1, [4], [1], 0) |
| mul = torch.ops.aten.mul.Tensor(as_strided, 2); as_strided = None |
| as_strided_scatter = torch.ops.aten.as_strided_scatter.default(primals_1, mul, [4], [1], 0); primals_1 = mul = None |
| as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) |
| add = torch.ops.aten.add.Tensor(as_strided_3, 1); as_strided_3 = None |
| add_1 = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None |
| return [as_strided_scatter, add, add_1]""") # noqa: B950 |
| |
| def test_input_mutation_aliases_bases_out_of_order(self): |
| # This tests our calling convention: if b and d are aliased, then the outer calling convention |
| # that we send to the compiled forward becomes: |
| # (b_d_base, a, c) |
| # Importantly, even though a and c alias in our test, neither inputs are mutated, |
| # So we don't need to do the base construction / deconstruction |
| def f(a, b, c, d): |
| b.add_(1) |
| d.unsqueeze_(0) |
| return a + c + d, b.view(-1) |
| |
| def inp_callable(req_grad): |
| base1 = torch.ones(2, 2, requires_grad=req_grad) |
| base2 = torch.ones(2, 2, requires_grad=req_grad) |
| x1 = base1.add(1) |
| x2 = base2.add(1) |
| # a and c alias, b and d alias |
| return [base1, base2], [x1.view(-1), x2.view(-1), x1.view(-1), x2.view(-1)] |
| |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True) |
| |
| with self.assertRaisesRegex(RuntimeError, "Metadata mutations are currently not allowed on tensor subclasses"): |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True, make_inputs_subclasses=True) |
| |
| fw_graph = self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True) |
| # 3 graph inputs: (b_d_base, a, c) |
| # 2 returns: (b_updated, a+c+d) |
| # (there are 2 original fw outs, but one is a view of b so it's not part of the graph) |
| # (there are also 2 input mutations, but one is a metadata-only mutation so the compiled forward doesn't return it) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1, primals_2, primals_3): |
| clone = torch.ops.aten.clone.default(primals_1); primals_1 = None |
| as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0) |
| add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None |
| as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None |
| add_1 = torch.ops.aten.add.Tensor(primals_2, primals_3); primals_2 = primals_3 = None |
| as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) |
| unsqueeze_1 = torch.ops.aten.unsqueeze.default(as_strided_5, 0); as_strided_5 = None |
| add_2 = torch.ops.aten.add.Tensor(add_1, unsqueeze_1); add_1 = None |
| as_strided_14 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) |
| view_2 = torch.ops.aten.view.default(as_strided_14, [-1]); as_strided_14 = None |
| return [as_strided_scatter, add_2, view_2, unsqueeze_1]""") # noqa: B950 |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") |
| def test_synthetic_base_base_attribute_is_none(self): |
| def f(a, b): |
| a.add_(1) |
| return a + b |
| |
| def inp_callable(): |
| base = torch.ones(4, 4, device='cuda') |
| # detach() so that none of the inputs have a ._base attribute. |
| a = base[0].detach() |
| b = base[1].detach() |
| base2 = torch.ones(2, 2, requires_grad=True) |
| return [base], [a, b] |
| |
| self.verify_aot_autograd(f, inp_callable, test_mutation=True) |
| |
| |
| def test_input_mutation_alias_everything(self): |
| # Mondo test that tests a combination of: |
| # input is mutated, that aliases another input (so we make a synthetic base) |
| # an output is an alias of another output |
| # an output is an alias of an intermediate |
| # a and c are aliased |
| def f(a, b, c): |
| c.mul_(2) # mutates c |
| b.t_() # metadata mutate b |
| tmp = a + c |
| out1 = tmp.view(-1) |
| out2 = b.t() |
| out3 = out1.unsqueeze(0) |
| # out1 and out3 are aliases of an intermediate, and alias each other! |
| # out2 aliases an input, so we don't return it |
| return out1, out2, out3 |
| |
| def inp_callable(req_grad): |
| base1 = torch.ones(2, 2, requires_grad=req_grad) |
| base2 = torch.ones(2, 2, requires_grad=req_grad) |
| # Note: in our test, the add() is important because we need the graph inputs to be non-leaves so we can mutate them. |
| base1_ = base1.add(1) |
| base2_ = base2.add(1) |
| a = base1_.view(-1) |
| b = base2_ |
| c = base1_.view(-1) |
| return [base1, base2], [a, b, c] |
| |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True) |
| fw_graph = self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True) |
| # Expected: |
| # - 2 inputs in the forward: synthetic_base_a_c, b |
| # - 1 output in the forward: "tmp" |
| # out2 is an alias of an input, and will be generated off of b outside of the compiled fn |
| # out1 and out3 are aliases of tmp, that we generate outside of the compiled function |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1, primals_2): |
| clone = torch.ops.aten.clone.default(primals_1); primals_1 = None |
| view = torch.ops.aten.view.default(primals_2, [2, 2]); primals_2 = None |
| as_strided_1 = torch.ops.aten.as_strided.default(clone, [4], [1], 0) |
| mul = torch.ops.aten.mul.Tensor(as_strided_1, 2); as_strided_1 = None |
| as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None |
| as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) |
| t = torch.ops.aten.t.default(view); view = None |
| as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) |
| add = torch.ops.aten.add.Tensor(as_strided_5, as_strided_2); as_strided_5 = as_strided_2 = None |
| view_1 = torch.ops.aten.view.default(add, [-1]) |
| t_1 = torch.ops.aten.t.default(t) |
| unsqueeze = torch.ops.aten.unsqueeze.default(view_1, 0) |
| return [as_strided_scatter, t, view_1, t_1, unsqueeze, add]""") # noqa: B950 |
| |
| def test_dynamic_shape_output_not_in_bw_graph(self): |
| def f(x): |
| return [x + 1, x.shape[0]] |
| inp = torch.ones(5, requires_grad=True) |
| bw_graph_cell = [None] |
| compiled_f = aot_function( |
| f, |
| fw_compiler=nop, |
| bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), |
| decompositions={}, |
| keep_inference_input_mutations=False, |
| dynamic=True, |
| ) |
| out = compiled_f(inp) |
| out[0].sum().backward() |
| # The important bit: the forward fn returns 2 outputs, |
| # but one of them is a symint so we should only see |
| # 1 grad_output as an input to the backward graph. |
| # (Otherwise, autograd will plumb a None as the value of the grad_output, |
| # which causes inductor to complain). |
| self.assertExpectedInline(bw_graph_cell[0].code.strip(), """\ |
| def forward(self, tangents_1): |
| return [tangents_1]""") |
| |
| def test_no_grad_input_output(self): |
| def f(a, b): |
| return a.cos(), b.cos(), a * b |
| |
| inp_thunks = [lambda: torch.randn(5, requires_grad=True), lambda: torch.randn(5, requires_grad=False)] |
| for inps in itertools.product(inp_thunks, repeat=2): |
| inps = [i() for i in inps] |
| self.verify_aot_autograd(f, inps) |
| |
| def test_some_output_requires_grad_input_doesnt(self): |
| def f(a, b): |
| a_view = a.view(-1) |
| a_view.requires_grad_(True) |
| return a_view |
| inp = [torch.randn(3, 3), torch.randn(3, 3, requires_grad=True)] |
| self.verify_aot_autograd(f, inp) |
| |
| def test_some_outputs_dont_require_grad_view(self): |
| def f(a, b): |
| return a.detach(), b |
| inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3, requires_grad=True)] |
| self.verify_aot_autograd(f, inp) |
| |
| def test_some_outputs_dont_require_grad_non_view(self): |
| def f(a, b): |
| return a.add(1).detach(), b |
| inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3, requires_grad=True)] |
| self.verify_aot_autograd(f, inp) |
| |
| def test_inner_grad(self): |
| def foo(x): |
| y = torch.exp(x) |
| z = torch.autograd.grad(y, x) |
| return z |
| inps = [torch.randn((), requires_grad=True)] |
| self.verify_aot_autograd(foo, inps) |
| |
| def test_grad_context(self): |
| def foo(x): |
| return x * 2 |
| inps = [torch.randn((), requires_grad=True)] |
| graph_size = None |
| |
| def get_graph_size(fx_g, _): |
| nonlocal graph_size |
| graph_size = len(fx_g.graph.nodes) |
| return fx_g |
| |
| f = aot_function(foo, nop, get_graph_size) |
| with torch.set_grad_enabled(False): |
| f(*inps) |
| self.assertIsNone(graph_size) |
| |
| f = aot_function(foo, nop, get_graph_size) |
| with torch.set_grad_enabled(True): |
| out = f(*inps) |
| self.assertIsNone(graph_size) |
| out.sum().backward() |
| self.assertTrue(graph_size > 2) |
| |
| def test_output_dict(self): |
| def f(x): |
| return {'a': x, 'b': x} |
| inp = [torch.randn(3, 3, requires_grad=True)] |
| self.verify_aot_autograd(f, inp) |
| |
| def f(x, y): |
| return {'a': x, 'b': y + x} |
| inp = [torch.randn(3, requires_grad=True), torch.randn(3)] |
| self.verify_aot_autograd(f, inp) |
| |
| def f(x): |
| new_d = {} |
| for k in x: |
| new_d[k] = x[k] * 2 |
| return new_d |
| |
| a = torch.randn(3, requires_grad=True) |
| b = torch.randn(3, requires_grad=True) |
| |
| def inp_callable(): |
| inps = [{'a': a, 'b': b}] |
| return inps, inps |
| |
| self.verify_aot_autograd(f, inp_callable) |
| |
| def test_module(self): |
| mod = nn.Sequential(nn.Linear(32, 32), nn.ReLU()) |
| compiled_mod = compiled_module(mod, nop, nop) |
| inp = torch.randn(32, 32) |
| ref_out = mod(inp) |
| ref_out.sum().backward() |
| ref_grads = sorted([(name, p.grad) for name, p in mod.named_parameters()]) |
| out = compiled_mod(inp) |
| out.sum().backward() |
| grads = sorted([(name, p.grad) for name, p in mod.named_parameters()]) |
| self.assertEqual((out, grads), (ref_out, ref_grads)) |
| |
| def test_batchnorm(self): |
| mod = compiled_module(nn.BatchNorm2d(4), nop, nop) |
| x = torch.ones(1, 4, 2, 2) |
| mod(x).sum().backward() |
| |
| def test_list_codegen(self): |
| def list_nop(f, _): |
| def g(inps): |
| return f(*inps) |
| g._boxed_call = True |
| return g |
| |
| def f(a, b, c): |
| return a.sin() * b.cos() * c.sin() |
| f = aot_function(f, list_nop) |
| inp = [torch.randn(5, requires_grad=True) for _ in range(3)] |
| f(*inp).sum().backward() |
| |
| @patch('torch._functorch.aot_autograd.AOT_COUNTER', new_callable=itertools.count) |
| def test_compilation_context(self, counter): |
| def f(x): |
| return x.sin().sin() |
| count = [] |
| |
| def compiler(fx_g, _): |
| context = get_aot_compilation_context() |
| count.append((context[0], len(fx_g.graph.nodes))) |
| return fx_g |
| |
| f = aot_function(f, compiler) |
| out = f(torch.randn(5, requires_grad=True)) |
| f = aot_function(f, compiler) |
| f(torch.randn(5)) |
| out.sum().backward() |
| self.assertExpectedInline(str(count), """[(['0_forward'], 4), (['1_inference'], 4), (['0_backward'], 8)]""") |
| |
| def test_dupe_arg(self): |
| def f(x, y): |
| return x + y |
| |
| x = torch.randn(3, 3, requires_grad=True) |
| self.verify_aot_autograd(f, [x, x]) |
| |
| def test_dupe_arg_torture(self): |
| def f(x, y): |
| x.t_() |
| y.unsqueeze_(0) |
| return x + y |
| |
| x = torch.randn(3, 3, requires_grad=True).clone() |
| self.verify_aot_autograd(f, [x, x]) |
| |
| # See https://github.com/pytorch/pytorch/issues/100224 |
| def test_dupe_arg_returned_as_output(self): |
| def f(a, b, a_): |
| a[0].add_(1) |
| return a_ |
| f_compiled = aot_function(f, nop) |
| a = torch.ones(2) |
| b = torch.ones(2) |
| out_ref = f(a, b, a) |
| |
| a2 = torch.ones(2) |
| b2 = torch.ones(2) |
| out_test = f_compiled(a2, b2, a2) |
| |
| self.assertEqual(out_ref, out_test) |
| self.assertEqual(a, a2) |
| |
| @patch('torch._functorch.aot_autograd.AOT_COUNTER', new_callable=itertools.count) |
| @patch("torch._functorch.config.debug_assert", True) |
| def test_invalid_dupe_left_bias(self, counter): |
| # This test checks that, just because only the first |
| # argument did a metadata mutation, we still correctly |
| # switch to strategy 2 (deduplicate) |
| # See: https://github.com/pytorch/pytorch/pull/89896#discussion_r1036224447 |
| class F(torch.nn.Module): |
| def forward(self, x, y): |
| x.t_() |
| return (x + y,) |
| |
| x = torch.randn(3, 3, requires_grad=True).clone() |
| y = torch.randn(3, 3, requires_grad=True) |
| self.verify_aot_autograd(F(), [x, x]) |
| |
| fxx = aot_module_simplified(F(), (x, x), nop) |
| self.assertExpectedRaisesInline( |
| AssertionError, lambda: fxx(x, y), |
| """At compilation time, graph 2 was compiled under the assumption that input 1 would be a duplicate of input 0, but at runtime this was not the case. This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""" # noqa: B950 |
| ) |
| |
| |
| @patch('torch._functorch.aot_autograd.AOT_COUNTER', new_callable=itertools.count) |
| @patch("torch._functorch.config.debug_assert", True) |
| def test_invalid_dupe(self, counter): |
| self._test_invalid_dupe(counter, fake=False) |
| |
| # See Note: Dynamo recompilation guarding invalid grad for why this test exists |
| @patch('torch._functorch.aot_autograd.AOT_COUNTER', new_callable=itertools.count) |
| @patch("torch._functorch.config.debug_assert", True) |
| def test_invalid_dupe_fake(self, counter): |
| self._test_invalid_dupe(counter, fake=True) |
| |
| |
| def _test_invalid_dupe(self, counter, fake): |
| class F(torch.nn.Module): |
| def forward(self, x, y): |
| x.unsqueeze_(0) |
| y.unsqueeze_(0) |
| return (x + y,) |
| |
| x = torch.randn(3, 3, requires_grad=True).clone() |
| y = torch.randn(3, 3, requires_grad=True).clone() |
| |
| if fake: |
| shape_env = ShapeEnv() |
| fake_mode = FakeTensorMode(shape_env=shape_env) |
| |
| fake_x = fake_mode.from_tensor(x) |
| fake_y = fake_mode.from_tensor(y) |
| |
| if fake: |
| fxy = aot_module_simplified(F(), (fake_x, fake_y), nop) |
| else: |
| fxy = aot_module_simplified(F(), (x, y), nop) |
| |
| fxy(x, y) |
| x = torch.randn(3, 3, requires_grad=True).clone() |
| y = torch.randn(3, 3, requires_grad=True).clone() |
| fxy(x, x) # is ok! |
| |
| if fake: |
| fxx = aot_module_simplified(F(), (fake_x, fake_x), nop) |
| else: |
| fxx = aot_module_simplified(F(), (x, x), nop) |
| |
| x = torch.randn(3, 3, requires_grad=True).clone() |
| y = torch.randn(3, 3, requires_grad=True).clone() |
| fxx(x, x) |
| # Note This should not raise! Once we have guards in place here, |
| # we will have this working correctly, as it should recompile. |
| x = torch.randn(3, 3, requires_grad=True).clone() |
| y = torch.randn(3, 3, requires_grad=True).clone() |
| self.assertExpectedRaisesInline( |
| AssertionError, lambda: fxx(x, y), |
| """At compilation time, graph 1 was compiled under the assumption that input 1 would be a duplicate of input 0, but at runtime this was not the case. This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""" # noqa: B950 |
| ) |
| |
| |
| @patch('torch._functorch.aot_autograd.AOT_COUNTER', new_callable=itertools.count) |
| @patch("torch._functorch.config.debug_assert", True) |
| def test_invalid_requires_grad(self, counter): |
| self._test_invalid_requires_grad(counter, fake=False) |
| |
| # See Note: Dynamo recompilation guarding invalid grad for why this test exists |
| @patch('torch._functorch.aot_autograd.AOT_COUNTER', new_callable=itertools.count) |
| @patch("torch._functorch.config.debug_assert", True) |
| def test_invalid_requires_grad_fake(self, counter): |
| self._test_invalid_requires_grad(counter, fake=True) |
| |
| def _test_invalid_requires_grad(self, counter, fake): |
| class F(torch.nn.Module): |
| def forward(self, x, y): |
| return (x + y,) |
| |
| x = torch.randn(3, 3, requires_grad=True) |
| y = torch.randn(3, 3, requires_grad=True) |
| z = torch.randn(3, 3, requires_grad=False) |
| |
| if fake: |
| shape_env = ShapeEnv() |
| fake_mode = FakeTensorMode(shape_env=shape_env) |
| |
| fake_x = fake_mode.from_tensor(x) |
| fake_y = fake_mode.from_tensor(y) |
| fake_z = fake_mode.from_tensor(z) |
| |
| if fake: |
| fxy = aot_module_simplified(F(), (fake_x, fake_y), nop) |
| else: |
| fxy = aot_module_simplified(F(), (x, y), nop) |
| |
| compare_equal_outs_and_grads(self, F(), fxy, (x, y)) |
| compare_equal_outs_and_grads(self, F(), fxy, (x, z)) |
| |
| if fake: |
| fxz = aot_module_simplified(F(), (fake_x, fake_z), nop) |
| else: |
| fxz = aot_module_simplified(F(), (x, z), nop) |
| |
| compare_equal_outs_and_grads(self, F(), fxz, (x, z)) |
| |
| self.assertExpectedRaisesInline( |
| AssertionError, lambda: fxz(x, y), |
| """At compilation time, graph 1 was compiled under the assumption that input 1 would not require grad, but at runtime this was not the case. This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""" # noqa: B950 |
| ) |
| |
| def test_custom_autograd(self): |
| class CustomFn(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): |
| return x.clone() |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| return grad_output + 1 |
| |
| def f(x): |
| return CustomFn.apply(x) |
| |
| self.verify_aot_autograd(f, [torch.randn(3)]) |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") |
| def test_autocast_disable_guard(self): |
| with torch._C._DisableAutocast(): |
| x = torch.rand([4, 4]).cuda() |
| y = x @ x |
| self.assertEqual(y.dtype, torch.float32) |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") |
| def test_nonidempotent_amp(self): |
| def f(self_s_emb, add_3): |
| einsum_2 = torch.functional.einsum('ah,th->t', self_s_emb, add_3) |
| log_softmax_2 = einsum_2.log_softmax(-1) |
| return (log_softmax_2,) |
| |
| args = [torch.rand((1, 256), dtype=torch.float32, device='cuda'), torch.rand((30, 256), dtype=torch.float16, device='cuda')] |
| with torch.cuda.amp.autocast(enabled=True): |
| self.verify_aot_autograd(f, args) |
| |
| args = [e.requires_grad_(True) for e in args] |
| with torch.cuda.amp.autocast(enabled=True): |
| self.verify_aot_autograd(f, args) |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") |
| @unittest.skipIf(not torch.backends.cudnn.is_available(), "CUDNN is unavailable") |
| @skipIfRocm # https://github.com/pytorch/pytorch/issues/96560 |
| def test_batch_norm_amp(self): |
| device = "cuda" |
| input_dtype = torch.float16 |
| param_dtype = torch.float32 |
| weight, bias = (torch.ones(64, device=device, dtype=param_dtype, requires_grad=True) for _ in range(2)) |
| running_mean, running_var = (torch.ones(64, device=device, dtype=param_dtype) for _ in range(2)) |
| |
| def bn(x): |
| return torch.ops.aten.cudnn_batch_norm( |
| x, |
| weight, |
| bias, |
| running_mean, |
| running_var, |
| False, |
| 0.1, |
| 1e-05, |
| ) |
| inp = torch.ones(torch.Size([16, 64, 112, 112]), dtype=input_dtype, device=device) |
| |
| ref = bn(inp) |
| cudnn_batch_norm_decomp = torch._decomp.get_decompositions({torch.ops.aten.cudnn_batch_norm}) |
| aot_fn = make_fx(bn, decomposition_table=cudnn_batch_norm_decomp)(inp) |
| res = aot_fn(inp) |
| for a, b in zip(ref, res): |
| assert torch.allclose(a, b) |
| |
| def test_output_op_depending_on_symint(self): |
| """ |
| It won't be obvious from reading this test what it's testing for. We should probably make it into a more |
| focused unit test. |
| |
| An issue with the following program was the expand op would end up depending on a symint whose proxy was |
| incorrectly associated with one of the grad tensors rather than input tensors. It broke partitioner logic |
| and the net result was aot_function failed to produce a function and threw an exception instead. |
| """ |
| inp = torch.randn(5, requires_grad=True) |
| |
| def f(x): |
| return x.expand(x.shape) |
| |
| # TODO(whc) make this work (test setup is wrong somehow) |
| # joint_forward_backward = create_joint_forward_backward(f) |
| # out = f(inp) |
| # joint_inputs = ([inp], [out.detach().contiguous()]) |
| # fx_g = make_fx(joint_forward_backward)(*joint_inputs) |
| # TODO: assert outputs of fwd graph trace to correct symint |
| |
| # e2e test that fails without symint clone fix |
| af = aot_function(f, nop, partition_fn=partial(min_cut_rematerialization_partition, compiler="inductor"), dynamic=True) |
| out = af(inp) |
| self.assertEqual(out, f(inp)) |
| |
| def test_inference_mode(self): |
| m = torch.nn.Linear(4, 4) |
| inp = torch.randn(4, 4) |
| |
| aot_mod = aot_module(m, fw_compiler=nop) |
| |
| with torch.inference_mode(): |
| out_ref = m(inp) |
| out_test = aot_mod(inp) |
| self.assertEqual(out_ref, out_test) |
| |
| def test_default_partitioner_saves_symints_not_tensors_for_bw(self): |
| """ |
| In this test, the important thing is that primals_1 is **only** needed in the backward |
| in order to grab its sizes. |
| We need to assert that what we save for the backward are the tensor's sizes, and not the tensor itself. |
| |
| The way this test is set up, it will actually fail if we try to save the input tensor for backward. |
| Why? |
| b.masked_fill_(c, 0) has a backward that requires knowing a's sizes |
| b.masked_fill_(c, 0) **also** mutates a (because b and a are aliased) |
| The autograd engine yells at us if we save "a" for backward, and then try to mutate it. |
| """ |
| inp = torch.randn(2, 2, requires_grad=True) |
| |
| def f(a): |
| b = a[0] |
| c = torch.ones_like(b, dtype=torch.bool) |
| d = b.masked_fill_(c, 0) |
| return d |
| |
| compiled_f = aot_function(f, nop, dynamic=True) |
| inp_ref = torch.ones(2, 2, requires_grad=True) |
| inp_test = torch.ones(2, 2, requires_grad=True) |
| |
| out_ref = f(inp_ref.clone()) |
| out_test = compiled_f(inp_test.clone()) |
| |
| self.assertEqual(out_ref, out_test) |
| |
| out_ref.sum().backward() |
| out_test.sum().backward() |
| |
| self.assertEqual(inp_ref.grad, inp_test.grad) |
| |
| def test_buffer_copied_in_graph(self): |
| class MyModel(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.register_buffer("buf", torch.zeros(1)) |
| self.w1 = torch.nn.Parameter(torch.zeros(1)) |
| self.w2 = torch.nn.Parameter(torch.zeros(1)) |
| |
| def forward(self, x): |
| self.buf.add_(1) |
| return (self.w1 * x * self.w2).sum() + self.buf.sum() |
| |
| model_for_eager = MyModel() |
| model_for_compile = copy.deepcopy(model_for_eager) |
| |
| fw_graph_cell = [None] |
| compiled_f = aot_module( |
| model_for_compile, |
| fw_compiler=make_boxed_compiler(partial(extract_graph, graph_cell=fw_graph_cell)), |
| bw_compiler=nop, |
| keep_inference_input_mutations=True, |
| ) |
| inp_ref = torch.ones(1, requires_grad=True) |
| inp_test = torch.ones(1, requires_grad=True) |
| |
| out_ref = model_for_eager(inp_ref.clone()) |
| out_test = compiled_f(inp_test.clone()) |
| |
| self.assertExpectedInline(fw_graph_cell[0].code.strip(), """\ |
| def forward(self, primals_1, primals_2, primals_3, primals_4): |
| add = torch.ops.aten.add.Tensor(primals_3, 1) |
| mul = torch.ops.aten.mul.Tensor(primals_1, primals_4) |
| mul_1 = torch.ops.aten.mul.Tensor(mul, primals_2) |
| sum_1 = torch.ops.aten.sum.default(mul_1); mul_1 = None |
| sum_2 = torch.ops.aten.sum.default(add) |
| add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None |
| copy_ = torch.ops.aten.copy_.default(primals_3, add); primals_3 = add = None |
| return [add_1, primals_1, primals_2, primals_4, mul]""") |
| |
| self.assertEqual(out_ref, out_test) |
| |
| out_ref.sum().backward() |
| out_test.sum().backward() |
| |
| eager_grads = [p.grad for _, p in model_for_eager.named_parameters()] |
| compile_grads = [p.grad for _, p in model_for_compile.named_parameters()] |
| |
| self.assertEqual(eager_grads, compile_grads) |
| self.assertEqual(inp_ref.grad, inp_test.grad) |
| |
| def test_buffer_copied_in_graph_with_different_shapes(self): |
| class MyModel(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.register_buffer("buf", torch.ones(4, 4)) |
| self.w = torch.nn.Parameter(torch.Tensor([[4, 5], [1, 2], [6, 7], [8, 9]])) |
| |
| def forward(self, x): |
| self.buf.add_(1) |
| return (self.w @ x).sum() + self.buf.sum() |
| |
| model_for_eager = MyModel() |
| model_for_compile = copy.deepcopy(model_for_eager) |
| |
| fw_graph_cell = [None] |
| compiled_f = aot_module( |
| model_for_compile, |
| fw_compiler=make_boxed_compiler(partial(extract_graph, graph_cell=fw_graph_cell)), |
| bw_compiler=nop, |
| keep_inference_input_mutations=True, |
| ) |
| inp_ref = torch.ones(2, 4, requires_grad=True) |
| inp_test = torch.ones(2, 4, requires_grad=True) |
| |
| out_ref = model_for_eager(inp_ref.clone()) |
| out_test = compiled_f(inp_test.clone()) |
| |
| self.assertExpectedInline(fw_graph_cell[0].code.strip(), """\ |
| def forward(self, primals_1, primals_2, primals_3): |
| add = torch.ops.aten.add.Tensor(primals_2, 1) |
| mm = torch.ops.aten.mm.default(primals_1, primals_3) |
| sum_1 = torch.ops.aten.sum.default(mm); mm = None |
| sum_2 = torch.ops.aten.sum.default(add) |
| add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None |
| copy_ = torch.ops.aten.copy_.default(primals_2, add); primals_2 = add = None |
| return [add_1, primals_1, primals_3]""") |
| self.assertEqual(out_ref, out_test) |
| |
| out_ref.sum().backward() |
| out_test.sum().backward() |
| |
| eager_grads = [p.grad for _, p in model_for_eager.named_parameters()] |
| compile_grads = [p.grad for _, p in model_for_compile.named_parameters()] |
| |
| self.assertEqual(eager_grads, compile_grads) |
| |
| self.assertEqual(inp_ref.grad, inp_test.grad) |
| |
| def test_buffer_batch_norm(self): |
| class MyModel(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.m = torch.nn.BatchNorm1d(100) |
| |
| def forward(self, x): |
| return self.m(x) |
| |
| model_for_eager = MyModel() |
| model_for_compile = copy.deepcopy(model_for_eager) |
| |
| fw_graph_cell = [None] |
| bw_graph_cell = [None] |
| compiled_f = aot_module( |
| model_for_compile, |
| fw_compiler=make_boxed_compiler(partial(extract_graph, graph_cell=fw_graph_cell)), |
| bw_compiler=make_boxed_compiler(partial(extract_graph, graph_cell=bw_graph_cell)), |
| keep_inference_input_mutations=True, |
| ) |
| inp_ref = torch.ones(20, 100, requires_grad=True) |
| inp_test = torch.ones(20, 100, requires_grad=True) |
| |
| out_ref = model_for_eager(inp_ref.clone()) |
| out_test = compiled_f(inp_test.clone()) |
| |
| self.assertExpectedInline(fw_graph_cell[0].code.strip(), """\ |
| def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6): |
| add = torch.ops.aten.add.Tensor(primals_5, 1) |
| _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(primals_6, primals_1, primals_2, primals_3, primals_4, True, 0.1, 1e-05); primals_2 = None |
| getitem = _native_batch_norm_legit_functional[0] |
| getitem_1 = _native_batch_norm_legit_functional[1] |
| getitem_2 = _native_batch_norm_legit_functional[2] |
| getitem_3 = _native_batch_norm_legit_functional[3] |
| getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None |
| copy_ = torch.ops.aten.copy_.default(primals_3, getitem_3); primals_3 = None |
| copy__1 = torch.ops.aten.copy_.default(primals_4, getitem_4); primals_4 = None |
| copy__2 = torch.ops.aten.copy_.default(primals_5, add); primals_5 = add = None |
| return [getitem, primals_1, primals_6, getitem_1, getitem_2, getitem_3, getitem_4]""") # noqa: B950 |
| |
| self.assertEqual(out_ref, out_test) |
| |
| out_ref.sum().backward() |
| out_test.sum().backward() |
| |
| eager_grads = [p.grad for _, p in model_for_eager.named_parameters()] |
| compile_grads = [p.grad for _, p in model_for_compile.named_parameters()] |
| self.assertEqual(eager_grads, compile_grads) |
| |
| self.assertExpectedInline(bw_graph_cell[0].code.strip(), """\ |
| def forward(self, primals_1, primals_6, getitem_1, getitem_2, getitem_3, getitem_4, tangents_1): |
| native_batch_norm_backward = torch.ops.aten.native_batch_norm_backward.default(tangents_1, primals_6, primals_1, getitem_3, getitem_4, getitem_1, getitem_2, True, 1e-05, [True, True, True]); tangents_1 = primals_6 = primals_1 = getitem_3 = getitem_4 = getitem_1 = getitem_2 = None |
| getitem_5 = native_batch_norm_backward[0] |
| getitem_6 = native_batch_norm_backward[1] |
| getitem_7 = native_batch_norm_backward[2]; native_batch_norm_backward = None |
| return [getitem_6, getitem_7, None, None, None, getitem_5]""") # noqa: B950 |
| |
| self.assertEqual(inp_ref.grad, inp_test.grad) |
| |
| def test_new_inp_requires_grad_now(self): |
| def f(x, y): |
| return x.add_(y) |
| |
| fw_graph_cell = [None] |
| bw_graph_cell = [None] |
| compiled_f = aot_function( |
| f, |
| fw_compiler=make_boxed_compiler(partial(extract_graph, graph_cell=fw_graph_cell)), |
| bw_compiler=make_boxed_compiler(partial(extract_graph, graph_cell=bw_graph_cell)), |
| keep_inference_input_mutations=True, |
| ) |
| |
| inp_ref = (torch.ones(20, 100, requires_grad=False), torch.ones(20, 100, requires_grad=True)) |
| inp_test = (torch.ones(20, 100, requires_grad=False), torch.ones(20, 100, requires_grad=True)) |
| |
| out_ref = f(*inp_ref) |
| out_test = compiled_f(*inp_test) |
| |
| # There is no copy_ method |
| self.assertExpectedInline(fw_graph_cell[0].code.strip(), """\ |
| def forward(self, primals_1, primals_2): |
| clone = torch.ops.aten.clone.default(primals_1); primals_1 = None |
| add = torch.ops.aten.add.Tensor(clone, primals_2); clone = primals_2 = None |
| return [add, add]""") # noqa: B950 |
| |
| self.assertEqual(out_ref, out_test) |
| |
| out_ref.sum().backward() |
| out_test.sum().backward() |
| |
| self.assertExpectedInline(bw_graph_cell[0].code.strip(), """\ |
| def forward(self, tangents_1): |
| return [None, tangents_1]""") # noqa: B950 |
| |
| def test_real_weights_in_symbolic_mode(self): |
| from functorch.experimental import functionalize |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(5, 5) |
| |
| def forward(self, x): |
| x = self.linear(x) |
| return x |
| |
| m = M().eval() |
| |
| inp = torch.randn(2, 5) |
| |
| gm = make_fx(m, tracing_mode="symbolic", _allow_non_fake_inputs=True)(inp) |
| self.assertEqual(gm(torch.ones(2, 5)), m(torch.ones(2, 5))) |
| |
| gm_functionalized = make_fx(functionalize(gm,), tracing_mode="symbolic", _allow_non_fake_inputs=True)(inp) |
| self.assertEqual(gm_functionalized(torch.ones(2, 5)), m(torch.ones(2, 5))) |
| |
| inp_count = 0 |
| for node in gm.graph.nodes: |
| if node.op == "placeholder": |
| inp_count += 1 |
| |
| # No more param lifting |
| self.assertEqual(inp_count, 1) |
| |
| inp_count = 0 |
| for node in gm_functionalized.graph.nodes: |
| if node.op == "placeholder": |
| inp_count += 1 |
| |
| # No more param lifting |
| self.assertEqual(inp_count, 1) |
| |
| with self.assertRaisesRegex(Exception, "Please convert all Tensors to FakeTensors"): |
| make_fx(m, tracing_mode="symbolic", _allow_non_fake_inputs=False)(torch.randn(2, 5)) |
| |
| def test_real_weights_in_symbolic_mode_with_inplace_ops(self): |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.register_buffer("buffer", torch.ones(4, 5)) |
| |
| def forward(self, x): |
| y = self.buffer.add_(3) |
| y.resize_([20]) |
| assert y.shape == self.buffer.shape |
| return x.sum() + self.buffer.sum() |
| |
| m = M().eval() |
| inp = torch.randn(2, 5) |
| # inplace mutation on attr is not allowed |
| with self.assertRaisesRegex(Exception, "Can't call metadata"): |
| make_fx(m, tracing_mode="symbolic", _allow_non_fake_inputs=True)(inp) |
| |
| |
| def extract_graph(fx_g, _, graph_cell): |
| graph_cell[0] = fx_g |
| return fx_g |
| |
| |
| def get_ins_outs(fx_g): |
| ins = [] |
| outs = [] |
| for n in fx_g.graph.nodes: |
| if n.op == 'placeholder': |
| ins.append(n) |
| elif n.op == 'output': |
| outs = tuple(n.args[0]) |
| return ins, outs |
| |
| |
| def get_num_ins_outs(fx_g): |
| return tuple(len(i) for i in get_ins_outs(fx_g)) |
| |
| |
| def get_fw_bw_graph(f, inps, partitioner=min_cut_rematerialization_partition, dynamic=False): |
| fw_graph_cell = [None] |
| bw_graph_cell = [None] |
| aot_function(f, |
| fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), |
| bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), |
| partition_fn=partitioner, |
| decompositions=default_decompositions, |
| dynamic=dynamic)(*inps).sum().backward() |
| return (fw_graph_cell[0], bw_graph_cell[0]) |
| |
| class TestMod(torch.nn.Module): |
| def __init__(self, fn): |
| super().__init__() |
| self.p = torch.nn.Parameter(torch.ones(2, requires_grad=True)) |
| self.fn = fn |
| |
| def forward(self, *args): |
| return self.fn(self.p, *args) |
| |
| class TestAOTExport(AOTTestCase): |
| |
| def test_aot_export_ban_dropout_mut_pre_dispatch(self): |
| def fn(p, x): |
| y = torch.ops.aten.dropout.default(x, 0.1, train=False) |
| y.add_(1) |
| return (y,) |
| |
| mod = TestMod(fn) |
| inp = torch.randn(2, 2) |
| |
| with self.assertRaisesRegex(RuntimeError, "cannot mutate tensors with frozen storage"): |
| aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) |
| |
| gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=False) |
| self.assertExpectedInline(str(gm.code).strip(), """\ |
| def forward(self, arg0_1, arg1_1): |
| clone = torch.ops.aten.clone.default(arg1_1); arg1_1 = None |
| add = torch.ops.aten.add.Tensor(clone, 1); clone = None |
| return (add,)""") |
| |
| fw_graph_cell = [None] |
| bw_graph_cell = [None] |
| |
| compiled_outs = aot_function( |
| fn, |
| fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), |
| bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), |
| partition_fn=default_partition, |
| decompositions=default_decompositions, |
| dynamic=True)(*inp) |
| fw_graph = fw_graph_cell[0] |
| bw_graph = bw_graph_cell[0] |
| |
| self.assertExpectedInline(str(fw_graph.code).strip(), """\ |
| def forward(self, arg0_1, arg1_1): |
| clone = torch.ops.aten.clone.default(arg1_1); arg1_1 = None |
| add = torch.ops.aten.add.Tensor(clone, 1); clone = None |
| return (add,)""") |
| |
| |
| def test_aot_export_predispatch_func_simple(self): |
| def fn(p, x): |
| y = x + 2 |
| with torch.no_grad(): |
| y.add_(2) |
| return (x * 2 + y,) |
| |
| mod = TestMod(fn) |
| inp = torch.randn(2, 2) |
| |
| gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) |
| self.assertExpectedInline(str(gm.code).strip(), """\ |
| def forward(self, arg0_1, arg1_1): |
| add = torch.ops.aten.add.Tensor(arg1_1, 2) |
| _set_grad_enabled = torch._C._set_grad_enabled(False) |
| add_1 = torch.ops.aten.add.Tensor(add, 2); add = None |
| _set_grad_enabled_1 = torch._C._set_grad_enabled(False) |
| mul = torch.ops.aten.mul.Tensor(arg1_1, 2); arg1_1 = None |
| add_2 = torch.ops.aten.add.Tensor(mul, add_1); mul = add_1 = None |
| return (add_2,)""") |
| |
| def test_aot_export_predispatch_func_composite_implicit(self): |
| def fn(p, x): |
| with torch.enable_grad(): |
| y = x @ x |
| y.add_(2) |
| return (x.sum() + y.sum(),) |
| |
| mod = TestMod(fn) |
| inp = torch.randn(2, 2) |
| |
| gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) |
| self.assertExpectedInline(str(gm.code).strip(), """\ |
| def forward(self, arg0_1, arg1_1): |
| _set_grad_enabled = torch._C._set_grad_enabled(True) |
| matmul = torch.ops.aten.matmul.default(arg1_1, arg1_1) |
| _set_grad_enabled_1 = torch._C._set_grad_enabled(False) |
| add = torch.ops.aten.add.Tensor(matmul, 2); matmul = None |
| sum_1 = torch.ops.aten.sum.default(arg1_1); arg1_1 = None |
| sum_2 = torch.ops.aten.sum.default(add); add = None |
| add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None |
| return (add_1,)""") |
| |
| def test_aot_export_predispatch_composite_implicit_inplace(self): |
| def fn(x, p): |
| return (torch.ops.aten.absolute_.default(x.clone()),) |
| |
| mod = TestMod(fn) |
| inp = torch.randn(2, 2) |
| |
| gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) |
| self.assertExpectedInline(str(gm.code).strip(), """\ |
| def forward(self, arg0_1, arg1_1): |
| clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None |
| abs_1 = torch.ops.aten.abs.default(clone); clone = None |
| return (abs_1,)""") |
| |
| def test_aot_export_predispatch_composite_implicit_linear(self): |
| class MM(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(2, 2) |
| |
| def forward(self, x): |
| return (self.linear(x),) |
| |
| mod = MM() |
| inp = torch.randn(2, 2) |
| |
| gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) |
| self.assertExpectedInline(str(gm.code).strip(), """\ |
| def forward(self, arg0_1, arg1_1, arg2_1): |
| linear = torch.ops.aten.linear.default(arg2_1, arg0_1, arg1_1); arg2_1 = arg0_1 = arg1_1 = None |
| return (linear,)""") |
| |
| @unittest.expectedFailure |
| def test_aot_export_predispatch_outdtype(self): |
| class M(torch.nn.Module): |
| def __init__(self, weight): |
| super().__init__() |
| self.weight = weight |
| |
| def forward(self, x): |
| y = x + 2 |
| y.add_(5) |
| return (out_dtype( |
| torch.ops.aten.mm.default, torch.int32, y, self.weight |
| ),) |
| |
| weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8) |
| mod = M(weight) |
| inp = torch.randint(-128, 127, (5, 5), dtype=torch.int8) |
| |
| gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) |
| self.assertExpectedInline(str(gm.code).strip(), """\ |
| def forward(self, arg0_1, arg1_1): |
| _set_grad_enabled = torch._C._set_grad_enabled(True) |
| mm = torch.ops.aten.mm.default(arg1_1, arg1_1) |
| _set_grad_enabled_1 = torch._C._set_grad_enabled(False) |
| add = torch.ops.aten.add.Tensor(mm, 2); mm = None |
| sum_1 = torch.ops.aten.sum.default(arg1_1); arg1_1 = None |
| sum_2 = torch.ops.aten.sum.default(add); add = None |
| add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None |
| return (add_1,)""") |
| |
| def test_aot_export_predispatch_func_view(self): |
| def fn(p, x): |
| y = x @ x |
| y.add_(2) |
| return (x.sum() + y.view(1, 4).sum(),) |
| |
| mod = TestMod(fn) |
| inp = torch.randn(2, 2) |
| |
| gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) |
| self.assertExpectedInline(str(gm.code).strip(), """\ |
| def forward(self, arg0_1, arg1_1): |
| matmul = torch.ops.aten.matmul.default(arg1_1, arg1_1) |
| add = torch.ops.aten.add.Tensor(matmul, 2); matmul = None |
| sum_1 = torch.ops.aten.sum.default(arg1_1); arg1_1 = None |
| view_1 = torch.ops.aten.view.default(add, [1, 4]); add = None |
| sum_2 = torch.ops.aten.sum.default(view_1); view_1 = None |
| add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None |
| return (add_1,)""") |
| |
| def test_aot_export_predispatch_buffer_mutation_metadata(self): |
| class Foo(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.register_buffer('foo', torch.zeros(2, 2)) |
| |
| def forward(self, x): |
| self.foo.add_(4) |
| return (x.sum() + self.foo.sum(),) |
| |
| inp = torch.randn(2, 2) |
| |
| gm, graph_sig = aot_export_module(Foo(), [inp], trace_joint=False, pre_dispatch=True) |
| self.assertExpectedInline(str(gm.code).strip(), """\ |
| def forward(self, arg0_1, arg1_1): |
| add = torch.ops.aten.add.Tensor(arg0_1, 4); arg0_1 = None |
| sum_1 = torch.ops.aten.sum.default(arg1_1); arg1_1 = None |
| sum_2 = torch.ops.aten.sum.default(add) |
| add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None |
| return (add, add_1)""") |
| eager_mod = Foo() |
| output_1, output_2 = gm(torch.zeros(2, 2), inp) |
| eager_output = eager_mod(inp) |
| self.assertTrue(torch.allclose(output_2, eager_output[0])) |
| |
| _, output_2 = gm(output_1, inp) |
| eager_output = eager_mod(inp) |
| self.assertTrue(torch.allclose(output_2, eager_output[0])) |
| self.assertTrue("foo" in graph_sig.buffers) |
| self.assertTrue(graph_sig.inputs_to_buffers["arg0_1"] == "foo") |
| |
| def test_aot_export_predispatch_with_autograd_op(self): |
| def foo(p, x): |
| with torch.enable_grad(): |
| y = x + 5 |
| y.add_(5) |
| y.add_(7) |
| return (x.cos() + y.sin(),) |
| |
| inp = torch.randn(2, 2) |
| mod = TestMod(foo) |
| |
| gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) |
| self.assertExpectedInline(str(gm.code).strip(), """\ |
| def forward(self, arg0_1, arg1_1): |
| _set_grad_enabled = torch._C._set_grad_enabled(True) |
| add = torch.ops.aten.add.Tensor(arg1_1, 5) |
| add_1 = torch.ops.aten.add.Tensor(add, 5); add = None |
| add_2 = torch.ops.aten.add.Tensor(add_1, 7); add_1 = None |
| cos = torch.ops.aten.cos.default(arg1_1); arg1_1 = None |
| sin = torch.ops.aten.sin.default(add_2); add_2 = None |
| add_3 = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None |
| _set_grad_enabled_1 = torch._C._set_grad_enabled(False) |
| return (add_3,)""") |
| |
| # TODO(tmanlaibaatar) properly support functionalizing HOO in |
| # predispatch tracing mode |
| @unittest.expectedFailure |
| def test_aot_export_predispatch_with_cond(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.register_buffer("buffer", torch.randn(4, 4)) |
| |
| def forward(self, x): |
| def true_fn(x): |
| self.buffer.add_(5) |
| return x.cos() + self.buffer.sum() |
| |
| def false_fn(x): |
| self.buffer.add_(6) |
| return x.sin() + self.buffer.sum() |
| |
| a = torch.cond(x.shape[0] > 4, true_fn, false_fn, [x]) |
| return (a + 3, a + 4) |
| |
| inp = torch.randn(2, 2) |
| gm, _ = aot_export_module(M(), [inp], trace_joint=False, pre_dispatch=True) |
| |
| def test_aot_export_predispatch_conv_and_bn(self): |
| class ConvBatchnorm(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(1, 3, 1, 1) |
| self.bn = torch.nn.BatchNorm2d(3) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = self.bn(x) |
| return (x,) |
| |
| mod = ConvBatchnorm() |
| mod.train() |
| inp = torch.randn(1, 1, 3, 3) |
| |
| gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) |
| self.assertExpectedInline(str(gm.code).strip(), """\ |
| def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1): |
| conv2d = torch.ops.aten.conv2d.default(arg7_1, arg0_1, arg1_1); arg7_1 = arg0_1 = arg1_1 = None |
| add = torch.ops.aten.add.Tensor(arg6_1, 1); arg6_1 = None |
| _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, arg2_1, arg3_1, arg4_1, arg5_1, True, 0.1, 1e-05); conv2d = arg2_1 = arg3_1 = arg4_1 = arg5_1 = None |
| getitem = _native_batch_norm_legit_functional[0] |
| getitem_3 = _native_batch_norm_legit_functional[3] |
| getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None |
| return (getitem_3, getitem_4, add, getitem)""") # noqa: B950 |
| |
| def test_aot_export_predispatch_reshape(self): |
| class Reshape(torch.nn.Module): |
| def forward(self, x): |
| y = x.reshape(4, 4) |
| return (y.sum(),) |
| |
| mod = Reshape() |
| inp = torch.randn(2, 8) |
| |
| gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) |
| self.assertExpectedInline(str(gm.code).strip(), """\ |
| def forward(self, arg0_1): |
| view = torch.ops.aten.view.default(arg0_1, [4, 4]); arg0_1 = None |
| sum_1 = torch.ops.aten.sum.default(view); view = None |
| return (sum_1,)""") # noqa: B950 |
| |
| def test_aot_export_predispatch_contiguous(self): |
| class Cont(torch.nn.Module): |
| def forward(self, x): |
| y = torch.ops.aten.contiguous.default(x) |
| return (y.sum(),) |
| |
| mod = Cont() |
| inp = torch.randn(2, 8) |
| |
| gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) |
| self.assertExpectedInline(str(gm.code).strip(), """\ |
| def forward(self, arg0_1): |
| sum_1 = torch.ops.aten.sum.default(arg0_1); arg0_1 = None |
| return (sum_1,)""") # noqa: B950 |
| |
| def test_aot_export_module_joint(self): |
| class ConvBatchnormRelu(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(1, 3, 1, 1) |
| self.bn = torch.nn.BatchNorm2d(3) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = self.bn(x) |
| user_out = torch.nn.functional.relu(x) |
| loss = user_out.sum() |
| return loss, user_out.detach() |
| |
| mod = ConvBatchnormRelu() |
| mod.train() |
| inp = torch.randn(1, 1, 3, 3) |
| o_ref = mod(inp) |
| fx_g, signature = aot_export_module(mod, [inp], trace_joint=True, output_loss_index=0) |
| # Some important characteristics of the exported graph below: |
| # 8 arguments: 2 params from conv, 2 params from batchnorm, 2 buffers from 1 batchnorm, 1 user input |
| # 9 outputs: 3 mutated buffers (from batchnorm), 2 user outputs and 4 gradients (since there were 4 parameters) |
| self.assertExpectedInline(fx_g.print_readable(print_output=False), """\ |
| class <lambda>(torch.nn.Module): |
| def forward(self, arg0_1: "f32[3, 1, 1, 1]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]", arg4_1: "f32[3]", arg5_1: "f32[3]", arg6_1: "i64[]", arg7_1: "f32[1, 1, 3, 3]"): |
| # No stacktrace found for following nodes |
| convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(arg7_1, arg0_1, arg1_1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); arg1_1 = None |
| add: "i64[]" = torch.ops.aten.add.Tensor(arg6_1, 1); arg6_1 = None |
| _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(convolution, arg2_1, arg3_1, arg4_1, arg5_1, True, 0.1, 1e-05); arg3_1 = arg4_1 = arg5_1 = None |
| getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0] |
| getitem_1: "f32[3]" = _native_batch_norm_legit_functional[1] |
| getitem_2: "f32[3]" = _native_batch_norm_legit_functional[2] |
| getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3] |
| getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None |
| relu: "f32[1, 3, 3, 3]" = torch.ops.aten.relu.default(getitem); getitem = None |
| detach: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu) |
| detach_1: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu) |
| detach_2: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_1); detach_1 = None |
| detach_3: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_2); detach_2 = None |
| detach_4: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_3); detach_3 = None |
| sum_1: "f32[]" = torch.ops.aten.sum.default(relu) |
| detach_5: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); relu = None |
| detach_6: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_5); detach_5 = None |
| detach_7: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_6); detach_6 = None |
| detach_8: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_7); detach_7 = None |
| detach_9: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_8); detach_8 = None |
| detach_10: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_9); detach_9 = None |
| ones_like: "f32[]" = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format) |
| expand: "f32[1, 3, 3, 3]" = torch.ops.aten.expand.default(ones_like, [1, 3, 3, 3]); ones_like = None |
| detach_11: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_4); detach_4 = None |
| detach_12: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_11); detach_11 = None |
| detach_13: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_12); detach_12 = None |
| detach_14: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_13); detach_13 = None |
| threshold_backward: "f32[1, 3, 3, 3]" = torch.ops.aten.threshold_backward.default(expand, detach_14, 0); expand = detach_14 = None |
| native_batch_norm_backward = torch.ops.aten.native_batch_norm_backward.default(threshold_backward, convolution, arg2_1, getitem_3, getitem_4, getitem_1, getitem_2, True, 1e-05, [True, True, True]); threshold_backward = convolution = arg2_1 = getitem_1 = getitem_2 = None |
| getitem_5: "f32[1, 3, 3, 3]" = native_batch_norm_backward[0] |
| getitem_6: "f32[3]" = native_batch_norm_backward[1] |
| getitem_7: "f32[3]" = native_batch_norm_backward[2]; native_batch_norm_backward = None |
| convolution_backward = torch.ops.aten.convolution_backward.default(getitem_5, arg7_1, arg0_1, [3], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]); getitem_5 = arg7_1 = arg0_1 = None |
| getitem_8 = convolution_backward[0] |
| getitem_9: "f32[3, 1, 1, 1]" = convolution_backward[1] |
| getitem_10: "f32[3]" = convolution_backward[2]; convolution_backward = None |
| return (getitem_3, getitem_4, add, sum_1, detach_10, getitem_9, getitem_10, getitem_6, getitem_7) |
| """) # noqa: B950 |
| |
| |
| self.assertExpectedInline(str(signature.parameters), """['conv.weight', 'conv.bias', 'bn.weight', 'bn.bias']""") |
| self.assertExpectedInline(str(signature.buffers), """['bn.running_mean', 'bn.running_var', 'bn.num_batches_tracked']""") |
| self.assertExpectedInline(str(signature.user_inputs), """['arg7_1']""") |
| self.assertExpectedInline(str(signature.inputs_to_parameters), """{'arg0_1': 'conv.weight', 'arg1_1': 'conv.bias', 'arg2_1': 'bn.weight', 'arg3_1': 'bn.bias'}""") # noqa: B950 |
| self.assertExpectedInline(str(signature.inputs_to_buffers), """{'arg4_1': 'bn.running_mean', 'arg5_1': 'bn.running_var', 'arg6_1': 'bn.num_batches_tracked'}""") # noqa: B950 |
| self.assertExpectedInline(str(signature.buffers_to_mutate), """{'getitem_3': 'bn.running_mean', 'getitem_4': 'bn.running_var', 'add': 'bn.num_batches_tracked'}""") # noqa: B950 |
| self.assertExpectedInline(str(signature.backward_signature.gradients_to_parameters), """{'getitem_9': 'conv.weight', 'getitem_10': 'conv.bias', 'getitem_6': 'bn.weight', 'getitem_7': 'bn.bias'}""") # noqa: B950 |
| self.assertExpectedInline(str(signature.backward_signature.gradients_to_user_inputs), """{}""") |
| self.assertExpectedInline(str(signature.backward_signature.loss_output), """getitem_3""") |
| |
| # Also check the inference graph |
| # Main important thing here is that there are 5 total outputs: 3 total mutated buffers (from batchnorm), 2 user outputs. |
| fx_g_inference, signature_inference = aot_export_module(mod, [inp], trace_joint=False) |
| self.assertExpectedInline(fx_g_inference.print_readable(print_output=False), """\ |
| class <lambda>(torch.nn.Module): |
| def forward(self, arg0_1: "f32[3, 1, 1, 1]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]", arg4_1: "f32[3]", arg5_1: "f32[3]", arg6_1: "i64[]", arg7_1: "f32[1, 1, 3, 3]"): |
| # No stacktrace found for following nodes |
| convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(arg7_1, arg0_1, arg1_1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); arg7_1 = arg0_1 = arg1_1 = None |
| add: "i64[]" = torch.ops.aten.add.Tensor(arg6_1, 1); arg6_1 = None |
| _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(convolution, arg2_1, arg3_1, arg4_1, arg5_1, True, 0.1, 1e-05); convolution = arg2_1 = arg3_1 = arg4_1 = arg5_1 = None |
| getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0] |
| getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3] |
| getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None |
| relu: "f32[1, 3, 3, 3]" = torch.ops.aten.relu.default(getitem); getitem = None |
| sum_1: "f32[]" = torch.ops.aten.sum.default(relu) |
| detach: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); relu = None |
| detach_1: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach); detach = None |
| detach_2: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_1); detach_1 = None |
| return (getitem_3, getitem_4, add, sum_1, detach_2) |
| """) # noqa: B950 |
| # Some important characteristics of the exported graph below: |
| # 8 arguments: 2 params from conv, 2 params from batchnorm, 2 buffers from 1 batchnorm, 1 user input |
| # 9 outputs: 2 mutated buffers (from batchnorm), 2 user outputs and 4 gradients (since there were 4 parameters) |
| |
| def test_aot_export_simplified_basic(self): |
| def f(x, y): |
| return x * y, y * y.detach() |
| |
| x = torch.randn(2, requires_grad=True) |
| y = torch.randn(2, requires_grad=True) |
| |
| f_graph_fw = aot_export_joint_simple(f, [x, y], trace_joint=False) |
| out_ref = f(x, y) |
| # No calling convention changes necessary to invoke the traced graph |
| out_test = f_graph_fw(x, y) |
| self.assertEqual(out_ref, out_test) |
| |
| # Now test the backward |
| x = torch.randn(2, requires_grad=True) |
| y = torch.randn(2, requires_grad=True) |
| x2 = x.clone().detach().requires_grad_(True) |
| y2 = y.clone().detach().requires_grad_(True) |
| x3 = x.clone().detach().requires_grad_(True) |
| y3 = y.clone().detach().requires_grad_(True) |
| f_graph_joint = aot_export_joint_simple(f, [x, y], trace_joint=True) |
| num_fw_outputs = 2 |
| fw_g, bw_g = default_partition(f_graph_joint, [x, y], num_fwd_outputs=num_fw_outputs) |
| out_ref2 = f(x2, y2) |
| fw_outs = fw_g(x3, y3) |
| out_test2, activations = fw_outs[:num_fw_outputs], fw_outs[num_fw_outputs:] |
| self.assertEqual(out_ref2, out_test2) |
| |
| # Test running the traced backward graph with a mocked-up grad_output |
| grad_outs = [torch.ones_like(x) for x in out_ref2] |
| grads_ref = torch.autograd.grad(out_ref2, [x2, y2], grad_outputs=grad_outs) |
| grads_test = bw_g(*activations, *grad_outs) |
| for g_ref, g_test in zip(grads_ref, grads_test): |
| self.assertEqual(g_ref, g_test) |
| |
| def test_aot_export_metadata_mutation_banned(self): |
| def fn(p, x): |
| x.t_() |
| return (x * 2,) |
| mod = TestMod(fn) |
| inp = torch.randn(2, 4) |
| with self.assertRaisesRegex( |
| RuntimeError, "Found an input that received a metadata mutation" |
| ): |
| aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False) |
| aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True) |
| aot_export_module(mod, [inp], trace_joint=False) |
| |
| def test_aot_export_forward_mutation_no_buffer_mut(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.register_buffer("buffer1", torch.ones(6, 4)) |
| |
| def forward(self, x): |
| x.add_(4) |
| return (x.cos().sum() + self.buffer1.sum(),) |
| |
| mod = M() |
| inp = torch.ones(6, 4) |
| gm, sig = aot_export_module(mod, [inp], trace_joint=False) |
| self.assertExpectedInline(str(gm.code).strip(), """\ |
| def forward(self, arg0_1, arg1_1): |
| add = torch.ops.aten.add.Tensor(arg1_1, 4); arg1_1 = None |
| cos = torch.ops.aten.cos.default(add) |
| sum_1 = torch.ops.aten.sum.default(cos); cos = None |
| sum_2 = torch.ops.aten.sum.default(arg0_1); arg0_1 = None |
| add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None |
| return (add, add_1)""") # noqa: B950 |
| self.assertEqual(sig.user_inputs_to_mutate, {"add": "arg1_1"}) |
| |
| def test_aot_export_forward_mutation_multiple_mut(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.register_buffer("buffer1", torch.ones(6, 4)) |
| |
| def forward(self, x, y): |
| y.add_(4) |
| self.buffer1.add_(5) |
| return (x.cos().sum() + y.sin().sum(), self.buffer1.sum(),) |
| |
| mod = M() |
| inp = [torch.ones(6, 4), torch.zeros(6, 4)] |
| gm, sig = aot_export_module(mod, inp, trace_joint=False) |
| self.assertExpectedInline(str(gm.code).strip(), """\ |
| def forward(self, arg0_1, arg1_1, arg2_1): |
| add = torch.ops.aten.add.Tensor(arg2_1, 4); arg2_1 = None |
| add_1 = torch.ops.aten.add.Tensor(arg0_1, 5); arg0_1 = None |
| cos = torch.ops.aten.cos.default(arg1_1); arg1_1 = None |
| sum_1 = torch.ops.aten.sum.default(cos); cos = None |
| sin = torch.ops.aten.sin.default(add) |
| sum_2 = torch.ops.aten.sum.default(sin); sin = None |
| add_2 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None |
| sum_3 = torch.ops.aten.sum.default(add_1) |
| return (add_1, add, add_2, sum_3)""") # noqa: B950 |
| self.assertEqual(sig.user_inputs_to_mutate, {"add": "arg2_1"}) |
| self.assertEqual(sig.buffers_to_mutate, {"add_1": "buffer1"}) |
| |
| def test_aot_export_input_mutation_on_input_requiring_grad_banned(self): |
| class M(torch.nn.Module): |
| def forward(self, x): |
| x.add_(4) |
| return (x,) |
| |
| mod = M() |
| inp = torch.randn(2, requires_grad=True) |
| with self.assertRaisesRegex( |
| RuntimeError, "Found a graph input that requires gradients, and received a mutation" |
| ): |
| aot_export_module(mod, [inp], trace_joint=False) |
| |
| def test_aot_export_input_mutation_on_parameter_banned(self): |
| def fn(p, x): |
| p.mul_(2) |
| return (p + x,) |
| mod = TestMod(fn) |
| inp = torch.randn(2) |
| with self.assertRaisesRegex( |
| RuntimeError, "Found a graph input that requires gradients, and received a mutation" |
| ): |
| aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False) |
| aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True) |
| aot_export_module(mod, [inp], trace_joint=False) |
| |
| def test_aot_export_synthetic_bases_banned(self): |
| def fn(p, x, y): |
| x.mul_(2) |
| return (x + y,) |
| mod = TestMod(fn) |
| inp = torch.randn(2) |
| inp2 = inp.view(-1) |
| with self.assertRaisesRegex( |
| RuntimeError, "Encountered aliased inputs that are mutated" |
| ): |
| aot_export_joint_simple(fn, [mod.p, inp, inp2], trace_joint=False) |
| aot_export_joint_simple(fn, [mod.p, inp, inp2], trace_joint=True) |
| aot_export_module(mod, [inp, inp2], trace_joint=False) |
| |
| def test_aot_export_input_dupes_banned(self): |
| def fn(p, x, y): |
| x.mul_(2) |
| return (x + y,) |
| mod = TestMod(fn) |
| inp = torch.randn(2) |
| with self.assertRaisesRegex( |
| RuntimeError, "Encountered duplicated inputs that are mutated in the graph" |
| ): |
| aot_export_joint_simple(fn, [mod.p, inp, inp], trace_joint=False) |
| aot_export_joint_simple(fn, [mod.p, inp, inp], trace_joint=True) |
| aot_export_module(mod, [inp, inp], trace_joint=False) |
| |
| def test_aot_export_multiple_outputs_require_grad_banned(self): |
| def fn(p, x): |
| out = p * x |
| return out, out.sum() |
| mod = TestMod(fn) |
| inp = torch.randn(2) |
| with self.assertRaisesRegex( |
| RuntimeError, "Found an output of the forward that requires gradients, that was not" |
| ): |
| aot_export_module(mod, [inp], trace_joint=True, output_loss_index=1) |
| |
| @unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case") |
| @unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "Cond needs dynamo to run") |
| def test_aot_export_with_torch_cond(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, x): |
| def true_fn(x): |
| y = x + 4 |
| y.add_(5) |
| return x.cos() |
| |
| def false_fn(x): |
| y = x + 5 |
| y.add_(6) |
| return x.sin() |
| |
| a = torch.cond(x.shape[0] > 4, true_fn, false_fn, [x]) |
| return (a + 3, a + 4) |
| |
| inp = torch.randn(3, 4) |
| gm, _ = aot_export_module(M(), (inp,), trace_joint=False) |
| self.assertExpectedInline(gm.code.strip(), """\ |
| def forward(self, arg0_1): |
| true_graph_0 = self.true_graph_0 |
| false_graph_0 = self.false_graph_0 |
| conditional = torch.ops.higher_order.cond(False, true_graph_0, false_graph_0, [arg0_1]); true_graph_0 = false_graph_0 = arg0_1 = None |
| getitem = conditional[0]; conditional = None |
| add = torch.ops.aten.add.Tensor(getitem, 3) |
| add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None |
| return (add, add_1)""") # noqa: B950 |
| |
| self.assertExpectedInline(gm.true_graph_0.code.strip(), """\ |
| def forward(self, arg0_1): |
| add = torch.ops.aten.add.Tensor(arg0_1, 4) |
| add_1 = torch.ops.aten.add.Tensor(add, 5); add = None |
| cos = torch.ops.aten.cos.default(arg0_1); arg0_1 = None |
| return (cos,)""") |
| |
| self.assertExpectedInline(gm.false_graph_0.code.strip(), """\ |
| def forward(self, arg0_1): |
| add = torch.ops.aten.add.Tensor(arg0_1, 5) |
| add_1 = torch.ops.aten.add.Tensor(add, 6); add = None |
| sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None |
| return (sin,)""") |
| |
| def test_aot_export_simplified_pytrees_banned(self): |
| def fn(inps): |
| return (inps[0] + inps[1],) |
| inp1 = torch.randn(2) |
| inp2 = torch.randn(2) |
| inps = [inp1, inp2] |
| with self.assertRaisesRegex( |
| RuntimeError, "aot_export_joint_simple requires individual inputs not to be pytrees" |
| ): |
| aot_export_joint_simple(fn, [inps], trace_joint=False) |
| aot_export_joint_simple(fn, [inps], trace_joint=True) |
| |
| def test_aot_export_functionalized_rng_banned(self): |
| def fn(p, x): |
| return (p + x,) |
| mod = TestMod(fn) |
| inp = torch.randn(2) |
| with patch("functorch.compile.config.functionalize_rng_ops", True), self.assertRaisesRegex( |
| RuntimeError, "Functionalized RNG is not currently supported in the aot_export" |
| ): |
| aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False) |
| aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True) |
| aot_export_module(mod, [inp], trace_joint=False) |
| |
| |
| class TestPartitioning(AOTTestCase): |
| @unittest.skipIf(not USE_NETWORKX, "networkx not available") |
| def test_recompute_partitioning(self): |
| def fn(a, b): |
| return torch.sin(torch.sin(a)) + b |
| |
| # Reference calculation |
| ref_a = torch.rand(10, 10, requires_grad=True) |
| ref_b = torch.rand(10, 10, requires_grad=True) |
| ref = fn(ref_a, ref_b) |
| ref.sum().backward() |
| |
| # Compiled function calculation |
| res_a = ref_a.clone().detach().requires_grad_(True) |
| res_b = ref_b.clone().detach().requires_grad_(True) |
| |
| def compile_fn(x, _): |
| return x |
| |
| compiled_fn = compiled_function(fn, compile_fn, compile_fn, min_cut_rematerialization_partition) |
| res = compiled_fn(res_a, res_b) |
| res.sum().backward() |
| assert torch.allclose(ref, res, atol=1e-3, rtol=1e-3) |
| assert torch.allclose(ref_a.grad, res_a.grad, atol=1e-3, rtol=1e-3) |
| assert torch.allclose(ref_b.grad, res_b.grad, atol=1e-3, rtol=1e-3) |
| |
| def test_meta_tensor_inplace_op(self): |
| # Following module results in inplace ops while tracing. The test checks |
| # that the meta tensor information is stored for inplace ops. |
| class MockModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.weight = torch.nn.Parameter(torch.randn(3072, 768, requires_grad=True)) |
| self.bias = torch.nn.Parameter(torch.randn(3072, requires_grad=True)) |
| |
| def forward(self, add_4): |
| linear_4 = torch.nn.functional.linear(add_4, self.weight, bias=self.bias) |
| gelu = torch.nn.functional.gelu(linear_4) |
| return gelu |
| |
| def check_meta_tensor(fx_g, _): |
| for node in fx_g.graph.nodes: |
| if node.op != 'output': |
| assert 'tensor_meta' in node.meta |
| return fx_g |
| |
| inp0 = torch.randn(16, 128, 768, requires_grad=True) |
| inputs = [inp0, ] |
| mod = MockModule().to(device="cpu") |
| aot_mod = aot_module(mod, fw_compiler=check_meta_tensor) |
| aot_mod(*inputs) |
| |
| def test_default_partitioner_getitem(self): |
| mod = nn.LayerNorm([10]) |
| |
| def f(x, mod_weight, mod_bias): |
| return torch.nn.functional.layer_norm(x, [10], mod_weight, mod_bias, eps=1e-6) |
| |
| fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(3, 10, requires_grad=True), mod.weight, mod.bias], |
| partitioner=default_partition) |
| self.assertEqual(get_num_ins_outs(fw_graph), (3, 6)) |
| self.assertEqual(get_num_ins_outs(bw_graph), (6, 3)) |
| |
| @unittest.skipIf(not USE_NETWORKX, "networkx not available") |
| def test_min_cut_partitioner_save_shape(self): |
| |
| def f(x): |
| s = x.sum(dim=1) |
| return s |
| |
| inp = [torch.ones([10, 10], requires_grad=True)] |
| fw_graph, bw_graph = get_fw_bw_graph(f, inp, dynamic=True) |
| _, fw_output = get_ins_outs(fw_graph) |
| self.assertEqual(get_num_ins_outs(fw_graph), (1, 3)) |
| self.assertEqual(get_num_ins_outs(bw_graph), (3, 1)) |
| self.assertEqual(str(fw_output[0]), "sum_1") |
| # make sure we don't do the suboptimal thing of saving the bigger primals input to sum, |
| # rather than saving the sizes of the primals input for use in backward expand |
| self.assertEqual(str(fw_output[1]), "sym_size_int") |
| self.assertEqual(str(fw_output[2]), "sym_size_int_1") |
| |
| inp = [ |
| torch.randn(10, requires_grad=True), |
| torch.randn((3, 10), requires_grad=True), |
| torch.randn((2, 10), requires_grad=True), |
| ] |
| |
| def f(a, b, c): |
| # tried to test what happens if we save a size tuple in the graph; |
| # turns out we never will due to how we trace, but this is probably |
| # still a good test case for various size manipulations |
| sb = torch.ops.aten.sym_size(b) |
| sc = c.size() |
| x = sb[0] + sc[0] |
| a_sz = (x, a.size(0)) |
| return torch.cat([a.expand(a_sz), b, c]) |
| fw_graph, bw_graph = get_fw_bw_graph(f, inp, dynamic=True) |
| self.assertEqual(get_num_ins_outs(fw_graph), (3, 4)) |
| self.assertEqual(get_num_ins_outs(bw_graph), (4, 3)) |
| _, outs = get_ins_outs(fw_graph) |
| self.assertTrue(all(is_sym_node(n) for n in outs[1:])) |
| |
| def test_default_partitioner_output_tensor_shape_tensor(self): |
| |
| inp = [ |
| torch.randn(10, requires_grad=True), |
| torch.randn((3, 10), requires_grad=True), |
| torch.randn((2, 10), requires_grad=True), |
| torch.randn((10, 1), requires_grad=True), |
| ] |
| |
| def f(a, b, c, d): |
| # Try to force symints intermixed with outputs in the function's returns |
| sb = b.size() |
| sc = c.size() |
| x = sb[0] + sc[0] |
| a_sz = (x, a.size(0)) |
| cat = torch.cat([a.expand(a_sz), b, c]) |
| mm = torch.mm(cat, d) |
| mm2 = torch.mm(mm, a.view(mm.size(1), a.size(0))) # this saves 4 new ints for backward. why? |
| # and what do i have to do to make it save a tensor for backward? |
| return cat, sb, c, mm2 |
| |
| fw_graph_cell = [None] |
| bw_graph_cell = [None] |
| compiled_outs = aot_function( |
| f, |
| fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), |
| bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), |
| partition_fn=default_partition, |
| decompositions=default_decompositions, |
| dynamic=True)(*inp) |
| fw_graph = fw_graph_cell[0] |
| (compiled_outs[0].sum() + compiled_outs[2].sum()).backward() |
| bw_graph = bw_graph_cell[0] |
| |
| # in the fwd graph, 13 outs because: |
| # - 5 original outputs (sb is a tuple, gets expanded to 2 symints) |
| # - 8 saved outputs for backward: 5 tensors, 3 symints |
| self.assertEqual(get_num_ins_outs(fw_graph), (4, 13)) |
| # in the bwd graph, 10 inputs (grad outs) because: |
| # - The fwd graph had 13 outputs |
| # - 1 was a view of an input, which gets regenerated outside of the graph |
| # and doesn't participate in the backward |
| # - 2 user outs were symints (b.size()), which don't get tangents in the backward |
| self.assertEqual(get_num_ins_outs(bw_graph), (10, 4)) |
| _, fw_graph_out_nodes = get_ins_outs(fw_graph) |
| self.assertEqual( |
| # fw outputs include b.size() which expands to 2 symints, |
| # |
| # TODO(whc)- are the saved-tensors/saved-symints correct here? |
| # i just made the test pass based on what default partition did |
| # Of the 5 original forward outputs, the 4th (c) is an input, |
| # which won't show up in the compiled forward graph |
| [False, True, True, False, False] + [False] * 4 + [True] * 4, |
| [is_sym_node(n) for n in fw_graph_out_nodes] |
| ) |
| |
| real_outs = f(*inp) |
| self.assertEqual(compiled_outs, real_outs) |
| self.assertTrue(isinstance(real_outs[1], torch.Size)) |
| |
| # TODO(whc) we should learn to return torch.Sizes |
| self.assertFalse(isinstance(compiled_outs[1], torch.Size)) |
| |
| @unittest.skipIf(not USE_NETWORKX, "networkx not available") |
| def test_min_cut_partitioner_output_tensor_shape_tensor(self): |
| |
| inp = [ |
| torch.randn(10, requires_grad=True), |
| torch.randn((3, 10), requires_grad=True), |
| torch.randn((2, 10), requires_grad=True), |
| torch.randn((10, 1), requires_grad=True), |
| ] |
| |
| def f(a, b, c, d): |
| # Try to force symints intermixed with outputs in the function's returns |
| sb = b.size() |
| sc = c.size() |
| x = sb[0] + sc[0] |
| a_sz = (x, a.size(0)) |
| cat = torch.cat([a.expand(a_sz), b, c]) |
| mm = torch.mm(cat, d) |
| mm2 = torch.mm(mm, a.view(mm.size(1), a.size(0))) # this saves 4 new ints for backward. why? |
| # and what do i have to do to make it save a tensor for backward? |
| return cat, sb, c, mm2 |
| |
| fw_graph_cell = [None] |
| bw_graph_cell = [None] |
| compiled_outs = aot_function( |
| f, |
| fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), |
| bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), |
| partition_fn=min_cut_rematerialization_partition, |
| decompositions=default_decompositions, |
| dynamic=True)(*inp) |
| fw_graph = fw_graph_cell[0] |
| (compiled_outs[0].sum() + compiled_outs[2].sum()).backward() |
| bw_graph = bw_graph_cell[0] |
| |
| self.assertEqual(get_num_ins_outs(fw_graph), (4, 12)) |
| self.assertEqual(get_num_ins_outs(bw_graph), (9, 4)) |
| _, fw_graph_out_nodes = get_ins_outs(fw_graph) |
| self.assertEqual( |
| # fw outputs include b.size() which expands to 2 symints, |
| # then 4 tensors (transposes of matricies used for mm) are saved |
| # finally 3 symints are saved |
| [False, True, True, False, False] + [False] * 4 + [True] * 3, |
| [is_sym_node(n) for n in fw_graph_out_nodes] |
| ) |
| |
| real_outs = f(*inp) |
| self.assertEqual(compiled_outs, real_outs) |
| self.assertTrue(isinstance(real_outs[1], torch.Size)) |
| |
| # TODO(whc) we should learn to return torch.Sizes |
| self.assertFalse(isinstance(compiled_outs[1], torch.Size)) |
| |
| @unittest.skipIf(not USE_NETWORKX, "networkx not available") |
| def test_min_cut_partitioner(self): |
| def f(x): |
| return x.cos().cos().cos() |
| |
| fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(3, requires_grad=True)]) |
| self.assertEqual(get_num_ins_outs(fw_graph), (1, 2)) |
| self.assertEqual(get_num_ins_outs(bw_graph), (2, 1)) |
| |
| def f(a, b, c, d): |
| x = a + b + c + d |
| return x.cos().cos() |
| |
| fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(3, requires_grad=True) for _ in range(4)]) |
| self.assertEqual(get_num_ins_outs(fw_graph), (4, 2)) |
| self.assertEqual(get_num_ins_outs(bw_graph), (2, 4)) |
| |
| @unittest.skipIf(not USE_NETWORKX, "networkx not available") |
| def test_min_cut_partitioner_recomputable_ops(self): |
| def f(x): |
| return x * x * x |
| |
| recomputable_ops = [] |
| partition_fn = partial(min_cut_rematerialization_partition, recomputable_ops=recomputable_ops) |
| |
| fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(3, requires_grad=True)], partition_fn) |
| # Expected forward graph: |
| # opcode name target args kwargs |
| # ------------- --------- --------------- -------------------------- -------- |
| # placeholder primals_1 primals_1 () {} |
| # call_function mul aten.mul.Tensor (primals_1, primals_1) {} |
| # call_function mul_1 aten.mul.Tensor (mul, primals_1) {} |
| # output output output ([mul_1, primals_1, mul],) {} |
| self.assertEqual(get_num_ins_outs(fw_graph), (1, 3)) |
| # Expected backward graph: |
| # opcode name target args kwargs |
| # ------------- ---------- --------------- ----------------------- -------- |
| # placeholder primals_1 primals_1 () {} |
| # placeholder mul mul () {} |
| # placeholder tangents_1 tangents_1 () {} |
| # call_function mul_2 aten.mul.Tensor (tangents_1, mul) {} |
| # call_function mul_3 aten.mul.Tensor (tangents_1, primals_1) {} |
| # call_function mul_4 aten.mul.Tensor (mul_3, primals_1) {} |
| # call_function add aten.add.Tensor (mul_2, mul_4) {} |
| # call_function add_1 aten.add.Tensor (add, mul_4) {} |
| # output output output ([add_1],) {} |
| self.assertEqual(get_num_ins_outs(bw_graph), (3, 1)) |
| |
| recomputable_ops = [torch.ops.aten.mul] |
| partition_fn = partial(min_cut_rematerialization_partition, recomputable_ops=recomputable_ops) |
| fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(3, requires_grad=True)], partition_fn) |
| # Expected forward graph: |
| # opcode name target args kwargs |
| # ------------- --------- --------------- ---------------------- -------- |
| # placeholder primals_1 primals_1 () {} |
| # call_function mul aten.mul.Tensor (primals_1, primals_1) {} |
| # call_function mul_1 aten.mul.Tensor (mul, primals_1) {} |
| # output output output ([mul_1, primals_1],) {} |
| self.assertEqual(get_num_ins_outs(fw_graph), (1, 2)) |
| # Expected backward graph: |
| # opcode name target args kwargs |
| # ------------- ---------- --------------- ----------------------- -------- |
| # placeholder primals_1 primals_1 () {} |
| # placeholder tangents_1 tangents_1 () {} |
| # call_function mul aten.mul.Tensor (primals_1, primals_1) {} # RECOMPUTED |
| # call_function mul_2 aten.mul.Tensor (tangents_1, mul) {} |
| # call_function mul_3 aten.mul.Tensor (tangents_1, primals_1) {} |
| # call_function mul_4 aten.mul.Tensor (mul_3, primals_1) {} |
| # call_function add aten.add.Tensor (mul_2, mul_4) {} |
| # call_function add_1 aten.add.Tensor (add, mul_4) {} |
| # output output output ([add_1],) {} |
| self.assertEqual(get_num_ins_outs(bw_graph), (2, 1)) |
| |
| def test_contiguous(self): |
| # The test simulates the condition where transpose followed by view |
| # happens in the backward pass. |
| # https://discuss.pytorch.org/t/error-on-transpose-and-view/434 |
| def f(x): |
| return x.view(2, 3).t() |
| |
| inp = torch.randn(6, requires_grad=True) |
| out = aot_function(f, nop)(inp) |
| torch.autograd.grad(out, inp, torch.randn(3, 2)) |
| |
| def test_preserve_random(self): |
| def fn(x): |
| return torch.nn.functional.dropout(x, 0.5) + x |
| |
| x = torch.randn(4) |
| |
| torch.manual_seed(0) |
| ref = fn(x) |
| |
| torch.manual_seed(0) |
| aot_fn = aot_function(fn, nop) |
| res = aot_fn(x) |
| |
| assert torch.allclose(ref, res) |
| |
| # https://github.com/pytorch/pytorch/issues/110666 |
| def test_generate_gives_inference_graph(self): |
| # We expect this to give an inference graph |
| def generate(x): |
| with torch.no_grad(): |
| return torch.mul(x, x) |
| |
| inference_graph_cell = [None] |
| inference_compiler = make_boxed_compiler(partial(extract_graph, graph_cell=inference_graph_cell)) |
| aot_fn = aot_function(generate, nop, inference_compiler=inference_compiler) |
| # Even though x requires grad, we should still get an inference graph |
| x = torch.randn(4, requires_grad=True) |
| res = aot_fn(x) |
| self.assertTrue(inference_graph_cell[0] is not None) |
| |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") |
| @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision") |
| def test_autocast(self): |
| mod = torchvision.models.resnet18().cuda() |
| mod.train() |
| |
| x = torch.randn(16, 3, 32, 32, device="cuda") |
| aot_mod = memory_efficient_fusion(mod) |
| |
| # Ensure that AOT Autograd works with AMP |
| with torch.cuda.amp.autocast(True): |
| res = aot_mod(x) |
| res.sum().backward() |
| |
| |
| class TestAOTDispatch(AOTTestCase): |
| |
| # Tests to add cases for (non-exhaustive list, mostly for my notes): |
| # - subclass / mode introduced in the middle of the compiled fn |
| # - various input mutation / intermediate base tests |
| # - input mutation that changes a tensor into a subclass |
| # - metadata mutation? (TBD) |
| # - guard tests (fw guards *and* bw guards) |
| # - subclass test involving _indices_of_inps_to_detach |
| def test_aot_dispatch_simple(self): |
| # a is a subclass, b is not |
| def f(a, b): |
| aa = torch.mul(a, 6) |
| bb = torch.div(b, 2) |
| return aa + bb |
| |
| a1_ref = torch.ones(3, 3, requires_grad=True) |
| a2_ref = torch.ones(3, 3, requires_grad=True) |
| a_ref = TwoTensor(a1_ref, a2_ref) |
| b_ref = torch.ones(3, 3, requires_grad=True) |
| |
| a1_test = a1_ref.clone().detach().requires_grad_(True) |
| a2_test = a2_ref.clone().detach().requires_grad_(True) |
| a_test = TwoTensor(a1_test, a2_test) |
| b_test = b_ref.clone().detach().requires_grad_(True) |
| |
| fw_graph_cell = [None] |
| bw_graph_cell = [None] |
| |
| compiled_f = aot_function( |
| f, |
| fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), |
| bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), |
| partition_fn=min_cut_rematerialization_partition |
| ) |
| out_ref = f(a_ref, b_ref) |
| out_test = compiled_f(a_test, b_test) |
| |
| # Output is a TwoTensor (check both inner tensors) |
| self.assertEqual(out_ref.a, out_test.a) |
| self.assertEqual(out_ref.b, out_test.b) |
| |
| out_ref.sum().backward() |
| out_test.sum().backward() |
| # Both grad_inputs are TwoTensor |
| self.assertEqual(a_ref.grad.a, a_test.grad.a) |
| self.assertEqual(a_ref.grad.b, a_test.grad.b) |
| self.assertEqual(b_ref.grad.a, b_test.grad.a) |
| self.assertEqual(b_ref.grad.b, b_test.grad.b) |
| |
| # Important pieces of the graph: |
| # - mul() and div() show up twice, because we called them on a TwoTensor |
| # - add() shows up once, because we called it on a plain Tensor |
| # - The user forward() fn returns 1 output (the result of add), |
| # while the graph itself returns two outputs (add, add_1) |
| # - add, add_1 correspond to the two inner dense tensors that will be wrapped |
| # - into a single TwoTensor output. |
| self.assertExpectedInline(fw_graph_cell[0].code.strip(), """\ |
| def forward(self, primals_1, primals_2, primals_3): |
| mul = torch.ops.aten.mul.Tensor(primals_1, 6); primals_1 = None |
| mul_1 = torch.ops.aten.mul.Tensor(primals_2, 6); primals_2 = None |
| div = torch.ops.aten.div.Tensor(primals_3, 2); primals_3 = None |
| add = torch.ops.aten.add.Tensor(mul, div); mul = None |
| add_1 = torch.ops.aten.add.Tensor(mul_1, div); mul_1 = div = None |
| return [add, add_1]""") |
| |
| # Important pieces of the graph: |
| # - 4 total dense outputs. |
| # This corresponds to the fact that each user fwd inpt (a, b) |
| # will get a gradient that is a TwoTensor subclass, |
| # so (mul_2, mul_3) will be wrapped into a.grad |
| # and (div_1, div_2) will be wrapped into b.grad |
| # - 4 total dense outputs, |
| self.assertExpectedInline(bw_graph_cell[0].code.strip(), """\ |
| def forward(self, tangents_1, tangents_2): |
| div_1 = torch.ops.aten.div.Tensor(tangents_1, 2) |
| div_2 = torch.ops.aten.div.Tensor(tangents_2, 2) |
| mul_2 = torch.ops.aten.mul.Tensor(tangents_1, 6); tangents_1 = None |
| mul_3 = torch.ops.aten.mul.Tensor(tangents_2, 6); tangents_2 = None |
| return [mul_2, mul_3, div_1, div_2]""") |
| |
| def test_aot_dispatch_inference(self): |
| # a is a subclass, b is not |
| def f(a, b): |
| aa = torch.mul(a, 6) |
| bb = torch.div(b, 2) |
| return aa + bb |
| |
| a1_ref = torch.ones(3, 3) |
| a2_ref = torch.ones(3, 3) |
| a_ref = TwoTensor(a1_ref, a2_ref) |
| b_ref = torch.ones(3, 3) |
| |
| a1_test = a1_ref.clone() |
| a2_test = a2_ref.clone() |
| a_test = TwoTensor(a1_test, a2_test) |
| b_test = b_ref.clone() |
| |
| compiled_f = aot_function( |
| f, |
| fw_compiler=nop, |
| bw_compiler=nop, |
| partition_fn=min_cut_rematerialization_partition |
| ) |
| out_ref = f(a_ref, b_ref) |
| out_test = compiled_f(a_test, b_test) |
| |
| # Output is a TwoTensor (check both inner tensors) |
| self.assertEqual(out_ref.a, out_test.a) |
| self.assertEqual(out_ref.b, out_test.b) |
| |
| def test_aot_dispatch_incorrect_backward(self): |
| # a is a subclass, b is not |
| def f(a, b): |
| aa = torch.mul(a, 2) |
| bb = torch.add(b, 3) |
| out_subclass = torch.div(aa, bb) |
| out_reg = torch.add(b, b) |
| # When creating the joint, we assume that the second grad_out |
| # is not a subclass. |
| # In the below test case though, we end up being wrong. |
| # This would require re-tracing and recompiling the backward. |
| return out_subclass, out_reg |
| |
| a1_ref = torch.ones(3, 3, requires_grad=True) |
| a2_ref = torch.ones(3, 3, requires_grad=True) |
| a_ref = TwoTensor(a1_ref, a2_ref) |
| b_ref = torch.ones(3, 3, requires_grad=True) |
| |
| a1_test = a1_ref.clone().detach().requires_grad_(True) |
| a2_test = a2_ref.clone().detach().requires_grad_(True) |
| a_test = TwoTensor(a1_test, a2_test) |
| b_test = b_ref.clone().detach().requires_grad_(True) |
| |
| compiled_f = aot_function( |
| f, |
| fw_compiler=nop, |
| bw_compiler=nop, |
| partition_fn=min_cut_rematerialization_partition |
| ) |
| out_ref = f(a_ref, b_ref) |
| out_test = compiled_f(a_test, b_test) |
| # First out is a TwoTensor, second is an ordinary tensor |
| self.assertEqual(out_ref[0].a, out_test[0].a) |
| self.assertEqual(out_ref[0].b, out_test[0].b) |
| self.assertEqual(out_ref[1], out_test[1]) |
| |
| # We compiled our graph assuming type(grad_out[1]) == torch.Tensor, |
| # but we were wrong: in the below tests, it is a subclass. |
| # This will eventually require a repartition + recompile |
| with self.assertRaisesRegex( |
| AssertionError, |
| "incorrectly attempted to compile the backward with incorrect subclass metadata" |
| ): |
| (out_test[0] + out_test[1]).sum().backward() |
| |
| def test_aot_dispatch_output_alias(self): |
| # a is a tensor, b is a TwoTensor |
| def f(a, b): |
| return b.view(b.shape), a * b |
| |
| b1_ref = torch.ones(3, 3, requires_grad=True) |
| b2_ref = torch.ones(3, 3, requires_grad=True) |
| b_ref = TwoTensor(b1_ref, b2_ref) |
| a_ref = torch.ones(3, 3, requires_grad=True) |
| |
| b1_test = b1_ref.clone().detach().requires_grad_(True) |
| b2_test = b2_ref.clone().detach().requires_grad_(True) |
| b_test = TwoTensor(b1_test, b2_test) |
| a_test = a_ref.clone().detach().requires_grad_(True) |
| |
| compiled_f = aot_function( |
| f, |
| fw_compiler=nop, |
| bw_compiler=nop, |
| partition_fn=min_cut_rematerialization_partition |
| ) |
| out_ref1, out_ref2 = f(a_ref, b_ref) |
| out_test1, out_test2 = compiled_f(a_test, b_test) |
| self.assertEqual(out_ref1, out_test1) |
| self.assertEqual(out_ref2.a, out_test2.a) |
| self.assertEqual(out_ref2.b, out_test2.b) |
| |
| (out_ref1 + out_ref2).sum().backward() |
| (out_test1 + out_test2).sum().backward() |
| # Both grad_inputs are TwoTensor |
| self.assertEqual(a_ref.grad.a, a_test.grad.a) |
| self.assertEqual(a_ref.grad.b, a_test.grad.b) |
| self.assertEqual(b_ref.grad.a, b_test.grad.a) |
| self.assertEqual(b_ref.grad.b, b_test.grad.b) |
| |
| def test_aot_dispatch_input_mutation(self): |
| def f(a, b): |
| a.mul_(2) |
| b.mul_(3) |
| return a + b |
| |
| b1_ref = torch.ones(3, 3, requires_grad=True) |
| b2_ref = torch.ones(3, 3, requires_grad=True) |
| b_ref_base = TwoTensor(b1_ref, b2_ref) |
| a_ref_base = torch.ones(3, 3, requires_grad=True) |
| b_ref = b_ref_base + 1 |
| a_ref = a_ref_base + 1 |
| |
| b1_test = b1_ref.clone().detach().requires_grad_(True) |
| b2_test = b2_ref.clone().detach().requires_grad_(True) |
| b_test_base = TwoTensor(b1_test, b2_test) |
| a_test_base = a_ref_base.clone().detach().requires_grad_(True) |
| b_test = b_test_base + 1 |
| a_test = a_test_base + 1 |
| |
| compiled_f = aot_function( |
| f, |
| fw_compiler=nop, |
| bw_compiler=nop, |
| partition_fn=min_cut_rematerialization_partition |
| ) |
| out_ref = f(a_ref, b_ref) |
| out_test = compiled_f(a_test, b_test) |
| self.assertEqual(out_ref.a, out_test.a) |
| self.assertEqual(out_ref.b, out_test.b) |
| |
| # confirm input mutations worked |
| self.assertEqual(a_test, a_ref) |
| self.assertEqual(b_test.a, b_ref.a) |
| self.assertEqual(b_test.b, b_ref.b) |
| |
| # NOTE: we need to use b in our gradient compute. Otherwise we will need to recompile teh backward. |
| (b_ref * out_ref).sum().backward() |
| (b_test * out_test).sum().backward() |
| # Both grad_inputs are TwoTensor |
| self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a) |
| self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b) |
| self.assertEqual(b_ref_base.grad.a, b_test_base.grad.a) |
| self.assertEqual(b_ref_base.grad.b, b_test_base.grad.b) |
| |
| # NB: Metadata mutation for subclasses is currently broken and disabled |
| # See https://github.com/pytorch/pytorch/issues/114975 |
| @unittest.expectedFailure |
| def test_aot_dispatch_input_metadata_mutation(self): |
| def f(a, b): |
| a.t_() |
| b.unsqueeze_(0) |
| return a + b |
| |
| b1_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3) |
| b2_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3) |
| b_ref_base = TwoTensor(b1_ref, b2_ref) |
| a_ref_base = torch.arange(9, dtype=torch.float32).reshape(3, 3).detach().requires_grad_(True) |
| b_ref = b_ref_base + 1 |
| a_ref = a_ref_base + 1 |
| |
| b1_test = b1_ref.clone().detach().requires_grad_(True) |
| b2_test = b2_ref.clone().detach().requires_grad_(True) |
| b_test_base = TwoTensor(b1_test, b2_test) |
| a_test_base = a_ref_base.clone().detach().requires_grad_(True) |
| b_test = b_test_base + 1 |
| a_test = a_test_base + 1 |
| |
| compiled_f = aot_function( |
| f, |
| fw_compiler=nop, |
| bw_compiler=nop, |
| partition_fn=min_cut_rematerialization_partition |
| ) |
| out_ref = f(a_ref, b_ref) |
| out_test = compiled_f(a_test, b_test) |
| self.assertEqual(out_ref.a, out_test.a) |
| self.assertEqual(out_ref.b, out_test.b) |
| |
| # confirm input mutations worked |
| self.assertEqual(a_test, a_ref) |
| self.assertEqual(b_test.a, b_ref.a) |
| self.assertEqual(b_test.b, b_ref.b) |
| |
| # NOTE: we need to use b in our gradient compute. Otherwise we will need to recompile the backward. |
| (b_ref * out_ref).sum().backward() |
| (b_test * out_test).sum().backward() |
| # Both grad_inputs are TwoTensor |
| self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a) |
| self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b) |
| self.assertEqual(b_ref_base.grad.a, b_test_base.grad.a) |
| self.assertEqual(b_ref_base.grad.b, b_test_base.grad.b) |
| |
| # NB: Metadata mutation for subclasses is currently broken and disabled |
| # See https://github.com/pytorch/pytorch/issues/114975 |
| @unittest.expectedFailure |
| def test_aot_dispatch_input_data_and_metadata_mutation(self): |
| def f(a, b): |
| a.t_() |
| b.unsqueeze_(0) |
| a.mul_(2) |
| b.mul_(3) |
| return a + b |
| |
| b1_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3) |
| b2_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3) |
| b_ref_base = TwoTensor(b1_ref, b2_ref) |
| a_ref_base = torch.arange(9, dtype=torch.float32).reshape(3, 3).detach().requires_grad_(True) |
| b_ref = b_ref_base + 1 |
| a_ref = a_ref_base + 1 |
| |
| b1_test = b1_ref.clone().detach().requires_grad_(True) |
| b2_test = b2_ref.clone().detach().requires_grad_(True) |
| b_test_base = TwoTensor(b1_test, b2_test) |
| a_test_base = a_ref_base.clone().detach().requires_grad_(True) |
| b_test = b_test_base + 1 |
| a_test = a_test_base + 1 |
| |
| compiled_f = aot_function( |
| f, |
| fw_compiler=nop, |
| bw_compiler=nop, |
| partition_fn=min_cut_rematerialization_partition |
| ) |
| out_ref = f(a_ref, b_ref) |
| out_test = compiled_f(a_test, b_test) |
| self.assertEqual(out_ref.a, out_test.a) |
| self.assertEqual(out_ref.b, out_test.b) |
| |
| # confirm input mutations worked |
| self.assertEqual(a_test, a_ref) |
| self.assertEqual(b_test.a, b_ref.a) |
| self.assertEqual(b_test.b, b_ref.b) |
| |
| # NOTE: we need to use b in our gradient compute. Otherwise we will need to recompile the backward. |
| (b_ref * out_ref).sum().backward() |
| (b_test * out_test).sum().backward() |
| # Both grad_inputs are TwoTensor |
| self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a) |
| self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b) |
| self.assertEqual(b_ref_base.grad.a, b_test_base.grad.a) |
| self.assertEqual(b_ref_base.grad.b, b_test_base.grad.b) |
| |
| def test_aot_dispatch_input_mutation_and_output_alias(self): |
| def f(a, b): |
| a.mul_(2) |
| b.mul_(3) |
| return b.view(b.shape), a + b |
| |
| b1_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3) |
| b2_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3) |
| b_ref_base = TwoTensor(b1_ref, b2_ref) |
| a_ref_base = torch.arange(9, dtype=torch.float32).reshape(3, 3).detach().requires_grad_(True) |
| b_ref = b_ref_base + 1 |
| a_ref = a_ref_base + 1 |
| |
| b1_test = b1_ref.clone().detach().requires_grad_(True) |
| b2_test = b2_ref.clone().detach().requires_grad_(True) |
| b_test_base = TwoTensor(b1_test, b2_test) |
| a_test_base = a_ref_base.clone().detach().requires_grad_(True) |
| b_test = b_test_base + 1 |
| a_test = a_test_base + 1 |
| |
| compiled_f = aot_function( |
| f, |
| fw_compiler=nop, |
| bw_compiler=nop, |
| partition_fn=min_cut_rematerialization_partition |
| ) |
| out_ref1, out_ref2 = f(a_ref, b_ref) |
| out_test1, out_test2 = compiled_f(a_test, b_test) |
| self.assertEqual(out_ref1.a, out_test1.a) |
| self.assertEqual(out_ref1.b, out_test1.b) |
| self.assertEqual(out_ref2.a, out_test2.a) |
| self.assertEqual(out_ref2.b, out_test2.b) |
| |
| # confirm input mutations worked |
| self.assertEqual(a_test, a_ref) |
| self.assertEqual(b_test.a, b_ref.a) |
| self.assertEqual(b_test.b, b_ref.b) |
| |
| (out_ref1 * out_ref2).sum().backward() |
| (out_test1 * out_test2).sum().backward() |
| # Both grad_inputs are TwoTensors |
| self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a) |
| self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b) |
| |
| |
| class TestAOTModuleSimplified(AOTTestCase): |
| def test_aot_module_simplified(self): |
| class MockModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(20, 30) |
| |
| def forward(self, x, y): |
| return (self.linear(x) + y, ) |
| |
| mod = MockModule() |
| mod.zero_grad() |
| |
| x = torch.randn(128, 20, requires_grad=True) |
| y = torch.randn(128, 30, requires_grad=True) |
| inputs = [x, y] |
| cloned_inputs = [x.detach().clone().requires_grad_(True) for x in inputs] |
| |
| ref = mod(*inputs) |
| ref[0].sum().backward() |
| |
| compiled_f = aot_module_simplified(mod, cloned_inputs, nop) |
| mod.zero_grad() |
| res = compiled_f(*cloned_inputs) |
| res[0].sum().backward() |
| |
| assert torch.allclose(ref[0], res[0]) |
| assert torch.allclose(inputs[0].grad, cloned_inputs[0].grad) |
| assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad) |
| |
| def test_aot_module_simplified_dynamic(self): |
| class MockModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(20, 30) |
| |
| def forward(self, x, y): |
| return (self.linear(x) + y, ) |
| |
| mod = MockModule() |
| |
| shape_env = ShapeEnv() |
| fake_mode = FakeTensorMode(shape_env=shape_env) |
| |
| x = torch.randn(128, 20, requires_grad=True) |
| y = torch.randn(128, 30, requires_grad=True) |
| |
| inputs = [x, y] |
| fake_inputs = [fake_mode.from_tensor(x) for x in inputs] |
| compiled_f = aot_module_simplified(mod, fake_inputs, nop) |
| |
| ref = mod(*inputs) |
| ref[0].sum().backward() |
| |
| cloned_inputs = [x.detach().clone().requires_grad_(True) for x in inputs] |
| res = compiled_f(*cloned_inputs) |
| res[0].sum().backward() |
| |
| self.assertExpectedInline(shape_env.format_guards(), """\ |
| - Eq(s1, 20) |
| - Eq(s2, 30)""") |
| |
| assert torch.allclose(ref[0], res[0]) |
| assert torch.allclose(inputs[0].grad, cloned_inputs[0].grad) |
| assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad) |
| |
| # https://github.com/pytorch/pytorch/issues/105327 |
| def test_lift_fresh_copy_in_graph(self): |
| class MyMod(torch.nn.Module): |
| def forward(self, x): |
| _tensor_constant0 = torch.tensor([1]) |
| lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0) |
| y = x.mul(lift_fresh_copy) |
| return (y,) |
| |
| mod = MyMod() |
| shape_env = ShapeEnv() |
| fake_mode = FakeTensorMode(shape_env=shape_env) |
| x = torch.ones(4, requires_grad=True) |
| inputs = [x] |
| fake_inputs = [fake_mode.from_tensor(x) for x in inputs] |
| compiled_f = aot_module_simplified(mod, fake_inputs, nop) |
| |
| out_ref = mod(x) |
| out_test = compiled_f(x) |
| self.assertEqual(out_ref[0].detach(), out_test[0].detach()) |
| |
| def test_inference_python_dispatcher(self): |
| # Extracted from unet |
| class MockModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.upsample = torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) |
| |
| def forward(self, x): |
| return (self.upsample(x), ) |
| |
| mod = MockModule() |
| shape_env = ShapeEnv() |
| fake_mode = FakeTensorMode(shape_env=shape_env) |
| x = torch.randn(2, 512, 40, 59) # NB: must not require grad |
| inputs = [x] |
| fake_inputs = [fake_mode.from_tensor(x) for x in inputs] |
| compiled_f = aot_module_simplified(mod, fake_inputs, nop) |
| |
| def test_aot_module_simplified_preserves_stack_trace(self): |
| class MockModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(20, 30) |
| |
| def forward(self, x, y): |
| z = self.linear(x) |
| z = z + y |
| z = z.relu() |
| return (z, ) |
| |
| tracer = torch.fx.Tracer() |
| tracer.record_stack_traces = True |
| graph = tracer.trace(MockModule()) |
| mod = torch.fx.GraphModule(tracer.root, graph) |
| |
| for node in mod.graph.nodes: |
| if node.op == 'output': |
| continue |
| self.assertTrue(node.stack_trace is not None) |
| assert 'test_aotdispatch.py' in node.stack_trace |
| |
| def assert_compiler(gm: torch.fx.GraphModule, _): |
| for node in gm.graph.nodes: |
| if node.op == 'output' or node.op == 'placeholder': |
| continue |
| self.assertTrue(node.stack_trace is not None) |
| assert 'test_aotdispatch.py' in node.stack_trace |
| return gm.forward # return a python callable |
| |
| x = torch.randn(128, 20, requires_grad=True) |
| y = torch.randn(128, 30, requires_grad=True) |
| inputs = [x, y] |
| |
| compiled_f = aot_module_simplified(mod, inputs, fw_compiler=assert_compiler, bw_compiler=assert_compiler) |
| res = compiled_f(*inputs) |
| res[0].sum().backward() |
| |
| def test_aot_module_simplified_preserves_stack_trace_from_mutation(self): |
| class MockModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, x): |
| x_view = x[0] |
| x_view.mul_(2) |
| return (x + x, ) |
| |
| tracer = torch.fx.Tracer() |
| tracer.record_stack_traces = True |
| graph = tracer.trace(MockModule()) |
| mod = torch.fx.GraphModule(tracer.root, graph) |
| |
| for node in mod.graph.nodes: |
| if node.op == 'output': |
| continue |
| self.assertTrue(node.stack_trace is not None) |
| assert 'test_aotdispatch.py' in node.stack_trace |
| |
| def assert_compiler(gm: torch.fx.GraphModule, _): |
| assert torch.ops.aten.copy_.default in [x.target for x in gm.graph.nodes] |
| for node in gm.graph.nodes: |
| if node.target == torch.ops.aten.copy_.default: |
| assert 'stack_trace' in node.meta |
| assert 'x_view.mul_(2)' in node.meta['stack_trace'] |
| return gm.forward # return a python callable |
| |
| x = torch.randn(128, 20) |
| inputs = [x] |
| |
| aot_module_simplified( |
| mod, |
| inputs, |
| fw_compiler=assert_compiler, |
| bw_compiler=assert_compiler, |
| keep_inference_input_mutations=True, |
| ) |
| |
| def test_aot_module_simplified_fake_tensor_gm_raises(self): |
| fake_mode = torch._subclasses.fake_tensor.FakeTensorMode() |
| real_x = torch.randn(4, requires_grad=True) |
| fake_x = fake_mode.from_tensor(real_x) |
| real_z = torch.randn(4) |
| fake_z = fake_mode.from_tensor(real_z) |
| |
| class MockModule(torch.nn.Module): |
| def forward(self, x): |
| # Accessing a free variable fake tensor will look like a |
| # constant to make_fx, and result in the tensor being traced |
| # into the graph, which is an error condition. Make sure we |
| # report adequately in this case. |
| return (x + fake_z, ) |
| |
| with self.assertRaisesRegex( |
| AssertionError, "Unexpected fake" |
| ): |
| aot_module_simplified(MockModule(), (fake_x,), nop) |
| |
| |
| # entries in here don't work and need to be fixed. |
| # Each one of these is a bug (or needs to be investigated) |
| aot_autograd_failures = { |
| # data-dependent control flow |
| xfail('cov'), |
| xfail('nn.functional.gaussian_nll_loss'), |
| xfail('tensor_split'), |
| xfail('corrcoef'), |
| xfail('quantile'), |
| xfail('nanquantile'), |
| xfail('narrow'), |
| xfail('istft'), |
| xfail('linalg.eig'), |
| |
| skip('as_strided_scatter'), |
| skip('as_strided', 'partial_views'), # flaky |
| |
| # Given input size: (s0xs1x2). Calculated output size: ... |
| skip('max_pool2d_with_indices_backward'), |
| |
| skip('nn.functional.nll_loss', ''), # UBSAN failure! |
| |
| # Misc |
| xfail('to_sparse'), |
| xfail('corrcoef'), |
| xfail('cov'), |
| xfail('chalf'), # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf' |
| xfail('sparse.sampled_addmm'), |
| xfail('sparse.mm', 'reduce'), |
| skip('nn.functional.binary_cross_entropy_with_logits'), # seems to fail sometimes? |
| skip('nn.functional.margin_ranking_loss'), # seems flaky |
| skip('linalg.lu_solve'), # flaky |
| decorate('matmul', decorator=unittest.skipIf(IS_ARM64, 'flaky')), |
| decorate('__rmatmul__', decorator=unittest.skipIf(IS_ARM64, 'flaky')), |
| # overrides atol=1e-4, rtol=1e-5 would do as well |
| decorate('svd_lowrank', decorator=toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-05)})), |
| decorate('linalg.householder_product', decorator=unittest.skipIf(IS_MACOS and IS_X86, 'flaky')), |
| decorate('linalg.pinv', 'singular', decorator=toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-05)})), |
| decorate('nn.functional.interpolate', 'bicubic', decorator=toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-05)})), |
| # conv2d sometimes nondeterministic in this config? |
| decorate('nn.functional.conv2d', decorator=unittest.skipIf(IS_ARM64, "flaky")), |
| } |
| |
| symbolic_aot_autograd_failures = { |
| xfail('combinations', ''), # aten.masked_select.default |
| xfail('frexp', ''), # aten.frexp.Tensor - couldn't find symbolic meta function/decomposition |
| xfail('index_fill', ''), # Cannot call sizes() on tensor with symbolic sizes/strides |
| xfail('kthvalue', ''), # Cannot call sizes() on tensor with symbolic sizes/strides |
| xfail('linalg.eigvals', ''), # aten.linalg_eig.default - couldn't find symbolic meta function/decomposition |
| xfail('linalg.lstsq', ''), # aten.linalg_lstsq.default - couldn't find symbolic meta function/decomposition |
| xfail('linalg.lstsq', 'grad_oriented'), # aten.linalg_lstsq.default - couldn't find symbolic meta funct... |
| xfail('linalg.lu_solve', ''), # aten.linalg_lu_solve.default - couldn't find symbolic meta function/deco... |
| skip('nn.functional.batch_norm', ''), # '0 is not tracked with proxy for <torch.fx.experimental.proxy_te.. |
| xfail('nn.functional.binary_cross_entropy', ''), # aten.fill_.Scalar - couldn't find symbolic meta funct... |
| xfail('nn.functional.cross_entropy', ''), # Cannot call sizes() on tensor with symbolic sizes/strides |
| xfail('nn.functional.ctc_loss', ''), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/deco... |
| xfail('nn.functional.embedding_bag', ''), # Cannot call sizes() on tensor with symbolic sizes/strides |
| xfail('nn.functional.fractional_max_pool2d', ''), # rand() received an invalid combination of arguments - g... |
| xfail('nn.functional.fractional_max_pool3d', ''), # rand() received an invalid combination of arguments - g... |
| xfail('nn.functional.group_norm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides |
| xfail('nn.functional.interpolate', 'linear'), # Cannot call sizes() on tensor with symbolic sizes/strides |
| xfail('nn.functional.interpolate', 'trilinear'), # Cannot call sizes() on tensor with symbolic sizes/st... |
| xfail('nn.functional.nll_loss', ''), # Cannot call sizes() on tensor with symbolic sizes/strides |
| xfail('_segment_reduce', 'lengths'), # aten.segment_reduce.default - couldn't find symbolic meta functio... |
| xfail('_segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta functio... |
| xfail('trace', ''), # Cannot call sizes() on tensor with symbolic sizes/strides |
| xfail('_upsample_bilinear2d_aa'), # RuntimeError: isIntList() INTERNAL ASSERT FAILED Expected IntList but got GenericList |
| decorate('linalg.householder_product', decorator=unittest.skipIf(IS_MACOS and IS_X86, 'flaky')), |
| |
| # many complex operators incorrect striding, metadata |
| xfail('fft.fft', ''), |
| xfail('fft.hfft2', ''), |
| xfail('fft.hfft', ''), |
| xfail('fft.hfftn', ''), |
| xfail('fft.ifft', ''), |
| xfail('fft.ihfft2', ''), |
| xfail('fft.ihfft', ''), |
| xfail('fft.ihfftn', ''), |
| xfail('fft.irfft2', ''), |
| xfail('fft.irfft', ''), |
| xfail('fft.irfftn', ''), |
| xfail('fft.rfft2', ''), |
| xfail('fft.rfft', ''), |
| xfail('fft.rfftn', ''), |
| |
| xfail('stft', ''), # Cannot call sizes() on tensor with symbolic sizes/strides |
| } |
| |
| def _test_aot_autograd_helper(self, device, dtype, op, dynamic=False): |
| if not op.supports_autograd: |
| self.skipTest("Op does not support autograd") |
| |
| # aot_autograd_check is able to check data specialization by |
| # randomizing the inputs. Here's a list of ops that really do not |
| # like random inputs for which we want to disable that. |
| cant_check_data_specialization = set({ |
| 'nn.functional.max_unpool1d', |
| 'nn.functional.max_unpool2d', |
| 'nn.functional.max_unpool3d', |
| }) |
| try_check_data_specialization = op.name not in cant_check_data_specialization |
| |
| sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True) |
| for sample_input in sample_inputs_itr: |
| t_args = [sample_input.input] + list(sample_input.args) |
| t_kwargs = sample_input.kwargs |
| try: |
| aot_autograd_check( |
| op.op, t_args, t_kwargs, dynamic, |
| self.assertRaisesRegex, self.assertEqual, |
| check_gradients=True, |
| try_check_data_specialization=try_check_data_specialization) |
| except DynamicOutputShapeException: |
| self.skipTest("Dynamic output shape operation in trace") |
| except GuardOnDataDependentSymNode: |
| # Carveout for getitem; I don't want to xfail the entire test |
| # because that will reject known to be good tests see |
| # https://github.com/pytorch/pytorch/issues/94705 |
| if op.name == "__getitem__": |
| self.skipTest("Dynamic output shape operation in trace") |
| else: |
| raise |
| |
| def _test_aot_autograd_module_helper(self, device, dtype, training, module_info, *, dynamic=False): |
| module_cls = module_info.module_cls |
| module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, |
| requires_grad=True, training=training) |
| for module_input in module_inputs: |
| if module_input.forward_input is None: |
| continue |
| |
| args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs |
| m = module_cls(*args, **kwargs) |
| m.to(device).to(dtype) |
| m.train(training) |
| |
| # Lazy modules need to see an input first to initialize params. |
| args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs |
| flat_args, args_spec = pytree.tree_flatten((args, kwargs)) |
| |
| # PackedSequence is only used for RNNs. It might be possible to fake-ify if they're pytrees but |
| # torchdynamo already doesn't support RNNs |
| if any(tuple(isinstance(flat_arg, PackedSequence) for flat_arg in flat_args)): |
| continue |
| |
| if issubclass(module_info.module_cls, torch.nn.modules.lazy.LazyModuleMixin): |
| with torch.no_grad(): |
| m(*args, **kwargs) |
| |
| sentinel_val = -42 |
| is_tensor_spec = [sentinel_val if isinstance(arg, torch.Tensor) |
| else arg for arg in flat_args] |
| args = [arg for arg in flat_args if isinstance(arg, torch.Tensor)] |
| |
| def f(params_buffers_args): |
| named_params, named_buffers, args = params_buffers_args |
| cur_flat_args = list(is_tensor_spec) |
| args = iter(args) |
| for idx, v in enumerate(cur_flat_args): |
| if v == sentinel_val: |
| cur_flat_args[idx] = next(args) |
| c_args, c_kwargs = pytree.tree_unflatten(cur_flat_args, args_spec) |
| params_and_buffers = {**named_params, **named_buffers} |
| return torch.func.functional_call(m, params_and_buffers, c_args, c_kwargs) |
| |
| named_params = dict(m.named_parameters(remove_duplicate=False)) |
| named_buffers = dict(m.named_buffers(remove_duplicate=False)) |
| num_params_buffers = len(named_params) + len(named_buffers) |
| compiled_f = aot_function(f, nop, num_params_buffers=num_params_buffers, dynamic=dynamic) |
| params_buffers_args = [named_params, named_buffers, args] |
| _test_aot_autograd_forwards_backwards_helper( |
| f, compiled_f, params_buffers_args, |
| self.assertRaisesRegex, self.assertEqual, True) |
| |
| |
| class TestEagerFusionOpInfo(AOTTestCase): |
| @ops(op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,)) |
| @skipOps('TestEagerFusionOpInfo', 'test_aot_autograd_exhaustive', aot_autograd_failures) |
| def test_aot_autograd_exhaustive(self, device, dtype, op): |
| _test_aot_autograd_helper(self, device, dtype, op) |
| |
| @ops(op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,)) |
| @patch("functorch.compile.config.debug_assert", True) |
| @skipOps('TestEagerFusionOpInfo', 'test_aot_autograd_symbolic_exhaustive', |
| aot_autograd_failures | symbolic_aot_autograd_failures) |
| def test_aot_autograd_symbolic_exhaustive(self, device, dtype, op): |
| _test_aot_autograd_helper(self, device, dtype, op, dynamic=True) |
| |
| |
| aot_autograd_module_failures = set({ |
| torch.nn.CTCLoss, # torch._subclasses.fake_tensor.DynamicOutputShapeException: aten._ctc_loss.default |
| torch.nn.GaussianNLLLoss, # RuntimeError: It appears that you're trying to get value out |
| # of a tracing tensor with aten._local_scalar_dense.default - |
| # erroring out! It's likely that this is caused by data-dependent |
| # control flow or similar. |
| torch.nn.MultiLabelMarginLoss, # AssertionError: The values for attribute 'shape' do not match: |
| # torch.Size([1]) != torch.Size([]). Outputs of the operator are different in |
| # eager-mode PyTorch vs AOTAutograd. This means the operator will have incorrect |
| # output underneath torch.compile. This could be because the operator's |
| # implementation not traceable or that there is a bug in AOTAutograd. |
| torch.nn.TransformerEncoder, # DataDependentOutputException: aten.eq compares a mask input |
| # to a causal mask tensor, to see if Boolean is_causal should be set |
| # for TrnasformerEncoder layers, MHA and sdp custom kernels |
| torch.nn.Transformer, # DataDependentOutputException: aten.equal compares a mask input |
| # to a causal mask tensor, to see if Boolean is_causal should be set |
| # for TransformerEncoder layers, MHA and sdp custom kernels |
| # (this bubbles up to Transformer) |
| }) |
| |
| symbolic_aot_autograd_module_failures = { |
| torch.nn.Transformer, # DataDependentOutputException: aten.equal compares a mask input to a mask producing a bool |
| torch.nn.TransformerEncoder, # DataDependentOutputException: aten.equal compares a mask input to a mask producing a bool |
| torch.nn.GaussianNLLLoss, # NotImplementedError: local_scalar_dense/item NYI for torch.bool |
| torch.nn.GroupNorm, # in native_group_norm_backward cpg, _rem = divmod(C, group) |
| # TypeError: unsupported operand type(s) for divmod(): 'SymInt' and 'int' |
| torch.nn.FractionalMaxPool2d, # int() argument must be a string, a bytes-like object or a number, not 'SymFloat' |
| torch.nn.FractionalMaxPool3d, # int() argument must be a string, a bytes-like object or a number, not 'SymFloat' |
| torch.nn.BCELoss, # new_size = _infer_size(target.size(), weight.size()) |
| # RuntimeError: expected int at position 0, but got: SymInt |
| torch.nn.CrossEntropyLoss, # RuntimeError: Cannot call numel() on tensor with symbolic sizes/strides |
| torch.nn.NLLLoss, # RuntimeError: Cannot call numel() on tensor with symbolic sizes/strides |
| } |
| |
| |
| class TestEagerFusionModuleInfo(AOTTestCase): |
| @modules(module_db, allowed_dtypes=(torch.float,)) |
| @decorateForModules(unittest.expectedFailure, aot_autograd_module_failures) |
| def test_aot_autograd_module_exhaustive(self, device, dtype, training, module_info): |
| _test_aot_autograd_module_helper(self, device, dtype, training, module_info) |
| |
| @modules(module_db, allowed_dtypes=(torch.float,)) |
| @decorateForModules(unittest.expectedFailure, |
| aot_autograd_module_failures | symbolic_aot_autograd_module_failures) |
| def test_aot_autograd_symbolic_module_exhaustive(self, device, dtype, training, module_info): |
| _test_aot_autograd_module_helper(self, device, dtype, training, module_info, dynamic=True) |
| |
| |
| instantiate_parametrized_tests(TestAOTAutograd) |
| only_for = ("cpu") |
| instantiate_device_type_tests( |
| TestPythonKey, |
| globals(), |
| only_for=only_for, |
| ) |
| instantiate_device_type_tests(TestEagerFusionOpInfo, globals(), only_for=only_for) |
| instantiate_device_type_tests(TestEagerFusionModuleInfo, globals(), only_for=only_for) |
| |
| |
| if __name__ == '__main__': |
| run_tests() |