| # Owner(s): ["oncall: pt2"] |
| |
| # Copyright (c) Facebook, Inc. and its affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| from typing import Union, Callable, List, Any, Optional, Dict |
| from unittest.mock import patch |
| from torch.testing._internal.common_utils import ( |
| TestCase, |
| run_tests, |
| IS_ARM64, |
| IS_MACOS, |
| IS_X86, |
| compare_equal_outs_and_grads, |
| outs_and_grads, |
| skipIfRocm, |
| ) |
| import torch |
| import torch.nn as nn |
| import torch.utils._pytree as pytree |
| import unittest |
| import warnings |
| import itertools |
| from functools import partial |
| from torch.nn.utils.rnn import PackedSequence |
| from torch.testing._internal.common_device_type import instantiate_device_type_tests, toleranceOverride, tol |
| from torch.testing._internal.common_methods_invocations import op_db |
| from torch.testing._internal.common_modules import module_db, modules |
| from torch.testing._internal.control_flow_opinfo_db import control_flow_opinfo_db |
| from torch.testing._internal.optests import _test_aot_autograd_forwards_backwards_helper, aot_autograd_check |
| from functorch import ( |
| grad, vjp, vmap, jacrev, |
| make_fx |
| ) |
| from torch._functorch.aot_autograd import aot_module_simplified, aot_export_module, aot_export_joint_simple |
| from functorch.compile import ( |
| nnc_jit, compiled_function, compiled_module, |
| min_cut_rematerialization_partition, aot_function, aot_module, |
| nop, default_partition, default_decompositions, |
| memory_efficient_fusion, get_aot_compilation_context |
| ) |
| from torch._decomp import decomposition_table |
| |
| from torch.testing._internal.common_device_type import ops |
| from common_utils import ( |
| decorate, |
| xfail, |
| skip, |
| skipOps, |
| decorateForModules, |
| ) |
| from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode |
| from torch.fx.experimental.proxy_tensor import is_sym_node |
| from torch.fx.experimental.symbolic_shapes import ShapeEnv, GuardOnDataDependentSymNode |
| |
| USE_TORCHVISION = False |
| try: |
| import torchvision |
| USE_TORCHVISION = True |
| except ImportError: |
| warnings.warn("Couldn't import torchvision. Some of our tests use it, try " |
| "to install it with commands from pytorch.org, post-fixed with " |
| "`--no-deps` to avoid overwriting the pytorch installation", |
| UserWarning) |
| |
| USE_NETWORKX = False |
| try: |
| import networkx # noqa: F401 |
| USE_NETWORKX = True |
| except ImportError: |
| warnings.warn("Some tests use networkx but it was not installed", |
| UserWarning) |
| |
| # NB: numpy is a testing dependency! |
| |
| class AOTTestCase(TestCase): |
| def setUp(self): |
| super().setUp() |
| |
| class TestPythonKey(AOTTestCase): |
| def test_make_fx(self, device): |
| def f(x): |
| return torch.sin(x) |
| inp = torch.randn(3) |
| fx_f = make_fx(f)(inp) |
| |
| new_inp = torch.randn(3) |
| self.assertEqual(fx_f(new_inp), f(new_inp)) |
| |
| def test_make_fx_grad(self, device): |
| def f(x): |
| return torch.sin(x).sum() |
| inp = torch.randn(3) |
| f = grad(f) |
| fx_f = make_fx(f)(inp) |
| |
| new_inp = torch.randn(3) |
| self.assertEqual(fx_f(new_inp), f(new_inp)) |
| |
| def test_scalar_device(self, device): |
| def f(a, b): |
| return a + b |
| inps = [torch.randn(3, device=device), torch.tensor(5)] |
| fx_f = make_fx(f)(*inps) |
| self.assertEqual(fx_f(*inps), f(*inps)) |
| |
| def test_make_fx_vmap(self, device): |
| def f(x): |
| return torch.sin(x) |
| inp = torch.randn(5, 3) |
| f = vmap(f) |
| fx_f = make_fx(f)(inp) |
| new_inp = torch.randn(5, 3) |
| self.assertEqual(fx_f(new_inp), f(new_inp)) |
| |
| def test_make_fx_jacrev(self, device): |
| def f(x): |
| return x.sin().sum() |
| inp = torch.randn(3) |
| f = jacrev(jacrev(f)) |
| fx_f = make_fx(f)(inp) |
| new_inp = torch.randn(3) |
| self.assertEqual(fx_f(new_inp), f(new_inp)) |
| |
| def test_make_fx_vjp(self, device): |
| def f(x): |
| return torch.sin(x).sum() |
| |
| primals = torch.randn(3) |
| _, vjp_fn = vjp(f, primals) |
| cotangent = torch.randn(()) |
| fx_f = make_fx(vjp_fn)(cotangent, True, True) |
| new_cotangent = torch.randn(()) |
| self.assertEqual(fx_f(new_cotangent, True, True), vjp_fn(new_cotangent)) |
| |
| def test_make_fx_functionalize(self, device): |
| from functorch.experimental import functionalize |
| |
| def fn(a): |
| a = a * 2 |
| a.relu_() |
| return a |
| |
| a = torch.randn(3, device=device) |
| symbolic_gm = torch.fx.symbolic_trace(fn) |
| includes_method_relu_ = any( |
| str(n.target) == "relu_" for n in symbolic_gm.graph.nodes |
| ) |
| self.assertTrue(includes_method_relu_) |
| # Also verifies fix for https://github.com/pytorch/pytorch/issues/84570 |
| gm = make_fx(functionalize(symbolic_gm))(a) |
| includes_aten_relu = any( |
| n.target == torch.ops.aten.relu.default for n in gm.graph.nodes |
| ) |
| self.assertTrue(includes_aten_relu) |
| |
| def test_make_fx_no_decompose(self, device): |
| # FIXME |
| return self.skipTest("error: maximum recursion reached") |
| |
| def f(x): |
| return torch.tanh(x).sum() |
| |
| fx_f = make_fx(grad(f))(torch.randn(5)) |
| ops = {i.target for i in fx_f.graph.nodes} |
| |
| self.assertEqual(torch.ops.aten.tanh_backward in ops, True) |
| |
| fx_f = make_fx(grad(f), decomposition_table)(torch.randn(5)) |
| ops = {i.target for i in fx_f.graph.nodes} |
| self.assertEqual(torch.ops.aten.tanh_backward in ops, False) |
| |
| def test_nnc_jit(self, device): |
| def f(x): |
| return torch.sin(x) |
| |
| jit_f = nnc_jit(f) |
| |
| inp = torch.randn(3) |
| self.assertEqual(jit_f(inp), f(inp)) |
| |
| def test_nnc_scalar(self, device): |
| def f(x): |
| return torch.sin(x) |
| |
| jit_f = nnc_jit(f) |
| |
| inp = torch.randn(()) |
| self.assertEqual(jit_f(inp), f(inp)) |
| |
| def test_nnc_pytrees(self, device): |
| def f(x): |
| return [torch.sin(x[0])] |
| |
| jit_f = nnc_jit(f) |
| |
| inp = [torch.randn(3)] |
| self.assertEqual(jit_f(inp), f(inp)) |
| |
| def test_external_calls(self, device): |
| def f(a, b): |
| return torch.mv(a, b) |
| jit_f = nnc_jit(f) |
| inp = [torch.randn(3, 3), torch.randn(3)] |
| self.assertEqual(jit_f(*inp), f(*inp)) |
| |
| def test_nnc_passthrough(self, device): |
| def f(x, y): |
| return x + y, y |
| inp = (torch.randn(3), torch.randn(3)) |
| jit_f = nnc_jit(f) |
| self.assertEqual(jit_f(*inp), f(*inp)) |
| |
| def f(x): |
| x['a'] = x['a'] * 2 |
| return x |
| inp = ({'a': torch.randn(3), 'b': torch.randn(3)},) |
| jit_f = nnc_jit(f) |
| self.assertEqual(jit_f(*inp), f(*inp)) |
| |
| @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision") |
| def test_resnet18_backward_trace(self, device): |
| mod = torchvision.models.resnet18() |
| |
| def f(x): |
| out = mod(x) |
| out.sum().backward() |
| return [a.grad for a in mod.parameters()] |
| |
| inp = torch.randn(3, 3, 250, 250, requires_grad=True) |
| grads = f(inp) |
| |
| mod.zero_grad() |
| mod(inp).sum().backward() |
| grads2 = [a.grad for a in mod.parameters()] |
| self.assertEqual(grads, grads2) |
| |
| def get_base(t): |
| return t._base if t._is_view() else t |
| |
| def is_in_base(t, maybe_tensors): |
| t_base = get_base(t) |
| for maybe_tensor in maybe_tensors: |
| if isinstance(maybe_tensor, torch.Tensor): |
| if t_base is get_base(maybe_tensor): |
| return True |
| return False |
| |
| class TestAOTAutograd(AOTTestCase): |
| # test_mutation will: |
| # - Ensure that inputs are non-leaves, so our graphs can mutate them |
| # - try to mutate outputs of the graph (to ensure that autograd meta is set properly on outputs) |
| @patch("functorch.compile.config.debug_assert", True) |
| def verify_aot_autograd( |
| self, |
| f, |
| inp_: Union[Callable, List[Any]], |
| *, |
| test_mutation: bool = False, |
| decompositions: Optional[Dict] = None, |
| dynamic: bool = False, |
| ): |
| for keep_input_mutations in [True, False]: |
| # Some tests pass in a callable for inp, to generate the inputs |
| # (useful if we want to generate complicated aliasing inputs) |
| if isinstance(inp_, Callable): |
| inp_callable = inp_ |
| # The callable should return a tuple of f_inputs, f_graph_inputs |
| # (The idea is that we might want to compile a function with the graph inputs, |
| # but test autograd backprop all the way through the actual inputs) |
| inp_copy, graph_inps_copy = inp_callable() |
| inp, graph_inps = inp_callable() |
| else: |
| inp_copy = [] |
| inp = [] |
| # Our input clones need to mimic when inputs are duplicates of one another |
| dupes_map = {} |
| for i, x in enumerate(inp_): |
| if x in dupes_map: |
| x_dupe_idx = dupes_map[x] |
| inp_copy.append(inp_copy[x_dupe_idx]) |
| inp.append(inp[x_dupe_idx]) |
| else: |
| dupes_map[x] = i |
| if not isinstance(x, torch.Tensor): |
| x_copy = x |
| x_copy2 = x |
| else: |
| x_copy = x.clone().detach().requires_grad_(x.requires_grad) |
| x_copy2 = x.clone().detach().requires_grad_(x.requires_grad) |
| if x.requires_grad and not x.is_leaf: |
| x_copy = x_copy.clone() |
| x_copy2 = x_copy2.clone() |
| inp_copy.append(x_copy) |
| inp.append(x_copy2) |
| |
| if test_mutation: |
| # For graphs where we mutate inputs, need our test to make sure inputs aren't leaves |
| graph_inps = [x.add(1) for x in inp] |
| graph_inps_copy = [x.add(1) for x in inp_copy] |
| else: |
| graph_inps = inp |
| graph_inps_copy = inp_copy |
| fw_graph_cell = [None] |
| if isinstance(f, nn.Module): |
| compiled_f = aot_module( |
| f, |
| fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), |
| bw_compiler=nop, |
| decompositions=decompositions, |
| keep_inference_input_mutations=keep_input_mutations, |
| dynamic=dynamic |
| ) |
| else: |
| compiled_f = aot_function( |
| f, |
| fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), |
| bw_compiler=nop, |
| decompositions=decompositions, |
| keep_inference_input_mutations=keep_input_mutations, |
| dynamic=dynamic |
| ) |
| ref_out, ref_grad = outs_and_grads(f, graph_inps, inp) |
| test_out, test_grad = outs_and_grads(compiled_f, graph_inps_copy, inp_copy) |
| self.assertEqual(ref_grad, test_grad) |
| |
| if isinstance(ref_out, torch.Tensor): |
| self.assertTrue(isinstance(test_out, torch.Tensor)) |
| ref_out, test_out = [ref_out], [test_out] |
| for ref_o, test_o in zip(ref_out, test_out): |
| if isinstance(ref_o, torch.Tensor): |
| self.assertEqual(ref_o.requires_grad, test_o.requires_grad) |
| self.assertEqual(ref_o.is_leaf, test_o.is_leaf) |
| ref_is_view_of_non_interm = is_in_base(ref_o, graph_inps) or is_in_base(ref_o, ref_out) |
| test_is_view_of_non_interm = is_in_base(test_o, graph_inps_copy) or is_in_base(test_o, test_out) |
| self.assertEqual(ref_is_view_of_non_interm, test_is_view_of_non_interm) |
| self.assertEqual(ref_o, test_o) |
| if test_mutation: |
| # This tests that autograd meta is set properly on the output we can |
| # mutate it. |
| ref_o.mul_(2) |
| test_o.mul_(2) |
| self.assertEqual(ref_o, test_o) |
| for ref_i, test_i in zip(inp, inp_copy): |
| if isinstance(ref_i, torch.Tensor): |
| self.assertEqual(ref_i.requires_grad, test_i.requires_grad) |
| self.assertEqual(ref_i, test_i) |
| return fw_graph_cell[0] |
| |
| def test_non_tensor_and_none_inputs(self): |
| # int, None, Tensor |
| def f(a, b, c): |
| return a * c |
| inp = [2, None, torch.ones(3, 3, dtype=torch.float32, requires_grad=True)] |
| self.verify_aot_autograd(f, inp) |
| inp = [2, None, torch.ones(3, 3, dtype=torch.float32, requires_grad=False)] |
| self.verify_aot_autograd(f, inp) |
| |
| def test_single_output(self): |
| def f(a, b): |
| return a + b |
| inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)] |
| self.verify_aot_autograd(f, inp) |
| inp = [torch.randn(3, 3, requires_grad=False), torch.randn(3, 3)] |
| self.verify_aot_autograd(f, inp) |
| |
| def test_multi_output(self): |
| def f(a, b): |
| return a + b, a - b |
| inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)] |
| self.verify_aot_autograd(f, inp) |
| inp = [torch.randn(3, 3, requires_grad=False), torch.randn(3, 3)] |
| self.verify_aot_autograd(f, inp) |
| |
| def test_multi_output_list(self): |
| def f(a, b): |
| return [a + b, a - b] |
| inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)] |
| self.verify_aot_autograd(f, inp) |
| inp = [torch.randn(3, 3, requires_grad=False), torch.randn(3, 3)] |
| self.verify_aot_autograd(f, inp) |
| |
| # Test for bug occurring at the intersection of fake tensors & functionalization. |
| def test_squeeze_mutation(self): |
| def f(a): |
| b = a.clone().squeeze(-1) |
| b.add_(1.) |
| return a + b |
| |
| inp = [torch.randn(3, 1, requires_grad=True)] |
| self.verify_aot_autograd(f, inp, dynamic=True) |
| inp = [torch.randn(3, 1, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, dynamic=True) |
| |
| def test_complex_linear(self): |
| # https://github.com/pytorch/pytorch/issues/93424 |
| inp = [torch.randn(1, 10, 10, dtype=torch.complex64)] |
| |
| class F(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = nn.Linear(10, 10, dtype=torch.complex64) |
| |
| def forward(self, x): |
| return self.linear(x).sum().abs() |
| |
| self.verify_aot_autograd(F(), inp) |
| |
| def test_embedding_bag_view_dynamic(self): |
| # Backwards pass tries to wrap a sparse tensor in a FunctionalTensorWrapper; |
| # test that this works even though the sparse tensor has no storage. |
| |
| class F(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.emb = torch.nn.EmbeddingBag(100, 8, sparse=True) |
| |
| def forward(self, x, y): |
| return self.emb(x, y).view(-1) |
| |
| x = torch.arange(3) |
| y = torch.arange(3) |
| self.verify_aot_autograd(F(), [x, y], dynamic=False) |
| self.verify_aot_autograd(F(), [x, y], dynamic=True) |
| |
| |
| |
| def test_input_mutation_simple(self): |
| def f(a): |
| a.mul_(2) |
| return a * 3 |
| inp = [torch.ones(3, 3, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(3, 3, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| # Things to note: |
| # - the extra clone is because we need to pass the pre-mutated input to grad(), |
| # but autograd operates above functionalization so we need to manually clone. |
| # Hopefully backends can optimize this easily. |
| # - The extra return arg is because the compiled forward returns (mutated inputs + outputs) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| clone = torch.ops.aten.clone.default(primals_1); primals_1 = None |
| mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None |
| mul_1 = torch.ops.aten.mul.Tensor(mul, 3) |
| return [mul, mul_1]""") |
| |
| def test_input_mutation_simple_with_none_and_nontensor(self): |
| # Tensor, None, int |
| def f(a, b, c): |
| return a * c |
| f_compiled = aot_function(f, nop) |
| for req_grad in [True, False]: |
| inp = [torch.ones(3, 3, requires_grad=req_grad), None, 3] |
| out_ref = f(*inp) |
| out_test = f_compiled(*inp) |
| self.assertEqual(out_ref, out_test) |
| |
| # https://github.com/pytorch/pytorch/issues/93363 |
| def test_mutates_input_noncontiguous(self): |
| def f(a): |
| a.add_(1) |
| return () |
| |
| f_compiled = aot_function(f, nop) |
| ref = torch.ones(4, requires_grad=True) + 0 |
| ref_view = ref[0::2] |
| |
| test = torch.ones(4, requires_grad=True) + 0 |
| test_view = test[0::2] |
| |
| out_ref = f(ref_view) |
| out_test = f_compiled(test_view) |
| print(ref) |
| print(test) |
| self.assertEqual(ref, test) |
| |
| def test_outputs_are_aliased(self): |
| # Tensor, None, int |
| def f(a): |
| b = a.mul(2) |
| c = b.view(-1) |
| return b, c |
| f_compiled = aot_function(f, nop) |
| for req_grad in [True, False]: |
| inp = torch.ones(3, requires_grad=req_grad) |
| out_ref = f(inp) |
| out_test = f_compiled(inp) |
| self.assertEqual(out_ref[0], out_test[0]) |
| self.assertEqual(out_ref[1], out_test[1]) |
| # Try mutating one of the outputs, which is aliased. |
| out_ref[0].mul_(3) |
| out_test[0].mul_(3) |
| # Assert that the aliasing relationship was preserved |
| self.assertEqual(out_ref[0], out_test[0]) |
| self.assertEqual(out_ref[1], out_test[1]) |
| |
| def test_input_mutation_is_output(self): |
| def f(a): |
| a.mul_(2) |
| return a |
| inp = [torch.ones(3, 3, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(3, 3, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| clone = torch.ops.aten.clone.default(primals_1); primals_1 = None |
| mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None |
| return [mul, mul]""") |
| |
| def test_input_mutation_multiple(self): |
| def f(a, b, c): |
| a.mul_(2) |
| c.mul_(2) |
| return a + b + c |
| |
| def create_inp(req_grad): |
| return [ |
| torch.ones(3, 3, requires_grad=req_grad), |
| torch.ones(3, 3, requires_grad=req_grad), |
| torch.ones(3, 3, requires_grad=req_grad), |
| ] |
| |
| self.verify_aot_autograd(f, create_inp(False), test_mutation=True) |
| |
| fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1, primals_2, primals_3): |
| clone = torch.ops.aten.clone.default(primals_1); primals_1 = None |
| clone_1 = torch.ops.aten.clone.default(primals_3); primals_3 = None |
| mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None |
| mul_1 = torch.ops.aten.mul.Tensor(clone_1, 2); clone_1 = None |
| add = torch.ops.aten.add.Tensor(mul, primals_2); primals_2 = None |
| add_1 = torch.ops.aten.add.Tensor(add, mul_1); add = None |
| return [mul, mul_1, add_1]""") |
| |
| def test_input_mutation_metadata(self): |
| def f(a, b): |
| a.transpose_(1, 0) |
| return a + b |
| |
| def create_inp(req_grad): |
| return [ |
| torch.ones(3, 3, requires_grad=req_grad), |
| torch.ones(3, 3, requires_grad=req_grad), |
| ] |
| |
| self.verify_aot_autograd(f, create_inp(True), test_mutation=True) |
| self.verify_aot_autograd(f, create_inp(False), test_mutation=True) |
| |
| def test_input_output_aliase_custom_autograd_function(self): |
| |
| class Foo(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): |
| return x |
| |
| @staticmethod |
| def backward(ctx, gx): |
| return gx * 0.5 |
| |
| def f(x): |
| return Foo.apply(x) |
| |
| inp = [torch.ones(2, 2, requires_grad=True)] |
| self.verify_aot_autograd(f, inp, test_mutation=False) |
| |
| def test_input_mutation_requires_grad_detach(self): |
| # Here, "a" requires grad, and gets mutated, so we append a copy_() to the end of the graph. |
| # Its mutation doesn't take part in autograd though, because we mutated a detach'd view. |
| # Need to make sure that this copy_() doesn't error, and doesn't participate in autograd either. |
| def f(a): |
| a.detach().mul_(2) |
| return a + 3 |
| inp = [torch.ones(4, requires_grad=True)] |
| self.verify_aot_autograd(f, inp, test_mutation=False) |
| inp = [torch.ones(4, requires_grad=True)] |
| # test_mutation=True will first do some compute on inp, so it is no longer an autograd leaf |
| # by the time it becomes a graph input. Good to test both cases. |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| |
| def test_input_mutation_requires_grad_no_grad(self): |
| def f(a): |
| with torch.no_grad(): |
| a.mul_(2) |
| return a + 3 |
| inp = [torch.ones(4, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=False) |
| |
| def test_input_mutation_requires_grad_no_grad_detach_mixed(self): |
| # Perform a mix of mutations on a: |
| # 1 normal, 1 in no_grad, 1 on a detach'd tensor. |
| # Only the first should participate in gradient computation. |
| def f(a): |
| a.detach().mul_(2) |
| a.mul_(3) |
| with torch.no_grad(): |
| a.mul_(4) |
| return a + 5 |
| inp = [torch.ones(4, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| |
| def test_input_mutation_metadata2(self): |
| def f(a): |
| a.transpose_(1, 0) |
| a.mul_(2) |
| return a + 1 |
| inp = [torch.ones(3, 3, requires_grad=True)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(3, 3, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| |
| def test_input_mutation_resize_smaller(self): |
| def f(a, b): |
| a.resize_(2, 2) |
| return a + b |
| # tenors that require gradients cannot be resized, so only test requires_grad=False case |
| inp = [ |
| torch.ones(3, 3), |
| torch.ones(2, 2, requires_grad=True), |
| ] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| |
| inp = [ |
| torch.ones(3, 3), |
| torch.ones(2, 2), |
| ] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| |
| def test_input_mutation_batchnorm(self): |
| def f(inpt, weight, bias, running_mean, running_var): |
| # This is additionally a good test, because the input tensors that we mutate |
| # are *also* saved for backwards. |
| # This tests that what we save for the backward is actually cloned inputs, |
| # and not the original inputs that got mutated. |
| return torch._native_batch_norm_legit(inpt, weight, bias, running_mean, running_var, True, 0.5, 1e-5) |
| |
| def create_inp(req_grad): |
| return [ |
| torch.ones(2, 5, 5, 5, requires_grad=req_grad), |
| torch.ones(5, requires_grad=req_grad), |
| torch.ones(5, requires_grad=req_grad), |
| torch.ones(5), |
| torch.ones(5), |
| ] |
| |
| from torch._decomp import get_decompositions |
| # This simulates what inductor does (running the fw + bw decompositions) |
| decompositions = get_decompositions([ |
| torch.ops.aten._native_batch_norm_legit_functional, |
| torch.ops.aten.native_batch_norm_backward, |
| ]) |
| self.verify_aot_autograd(f, create_inp(True), test_mutation=True, decompositions=decompositions) |
| self.verify_aot_autograd(f, create_inp(False), test_mutation=True, decompositions=decompositions) |
| |
| def test_batchnorm_inference(self): |
| inp = [ |
| torch.ones(2, 5, 5, 5, requires_grad=True), |
| torch.ones(5, requires_grad=True), |
| torch.ones(5, requires_grad=True), |
| torch.ones(5), |
| torch.ones(5), |
| ] |
| |
| m = torch.nn.BatchNorm2d(4, 4) |
| m.eval() |
| fw_graph_cell = [None] |
| inp = torch.ones(4, 4, 4, 4) |
| fw_graph_cell = [None] |
| compiled_m = aot_module( |
| m, |
| fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), |
| bw_compiler=nop, |
| keep_inference_input_mutations=True, |
| ) |
| inp = torch.ones(4, 4, 4, 4) |
| with torch.no_grad(): |
| out = compiled_m(inp) |
| # expectation: there are no copy_() calls in the decomposed batch norm when running under training=False (eval mode) |
| code = fw_graph_cell[0].code.strip() |
| self.assertTrue("copy_" not in str(code)) |
| |
| def test_input_output_view_simple(self): |
| def f(a): |
| return a.view(-1) |
| inp = [torch.ones(2, 2, requires_grad=False).add(1)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(2, 2, requires_grad=True).add(1)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| # Outputs that alias inputs are pulled out of the graph entirely, so we don't compile anything here |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| view = torch.ops.aten.view.default(primals_1, [-1]); primals_1 = None |
| return [view]""") |
| |
| def test_input_output_view_mutate_multiple(self): |
| def f(a, b, c): |
| a.mul_(2) |
| c.mul_(3) |
| return b.view(2, 2), c.view(2, 2) |
| |
| def create_inp(req_grad): |
| return [ |
| torch.ones(2, 2, requires_grad=req_grad).add(1), |
| torch.ones(2, 2, requires_grad=req_grad).add(1), |
| torch.ones(2, 2, requires_grad=req_grad).add(1), |
| ] |
| |
| self.verify_aot_autograd(f, create_inp(False), test_mutation=True) |
| fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True) |
| # The original function returned two outputs, both of which aliased inputs. |
| # We expect two outputs in the functional graph, a_updated and c_updated. |
| # The actual aliased outputs themselves aren't in the compiled forward graph; |
| # Instead, they're generated outside of the graph. |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1, primals_2, primals_3): |
| clone = torch.ops.aten.clone.default(primals_1); primals_1 = None |
| clone_1 = torch.ops.aten.clone.default(primals_3); primals_3 = None |
| mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None |
| mul_1 = torch.ops.aten.mul.Tensor(clone_1, 3); clone_1 = None |
| view = torch.ops.aten.view.default(primals_2, [2, 2]); primals_2 = None |
| view_2 = torch.ops.aten.view.default(mul_1, [2, 2]) |
| return [mul, mul_1, view, view_2]""") |
| |
| def test_input_output_view_metadata_mutate_multiple(self): |
| def f(a, b, c): |
| b.mul_(3) |
| c.t_() |
| return a.view(2, 2), b.view(2, 2), c.view(2, 2) |
| |
| def create_inp(req_grad): |
| return [ |
| torch.ones(2, 2, requires_grad=req_grad).add(1), |
| torch.ones(2, 2, requires_grad=req_grad).add(1), |
| torch.ones(2, 2, requires_grad=req_grad).add(1), |
| ] |
| |
| self.verify_aot_autograd(f, create_inp(False), test_mutation=True) |
| fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True) |
| # Important thing to check here: of the three inputs: |
| # Only the b.mul_(3) should show up in the graph (we functionalize it and return it). |
| # Everything else that does not show up in the graph includes: |
| # - The metadata mutation on c (we do it outside the graph) |
| # - All 3 original fw outputs, which are aliases of inputs (we regenerate them outside of the graph) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1, primals_2, primals_3): |
| clone = torch.ops.aten.clone.default(primals_2); primals_2 = None |
| view = torch.ops.aten.view.default(primals_3, [2, 2]); primals_3 = None |
| mul = torch.ops.aten.mul.Tensor(clone, 3); clone = None |
| t = torch.ops.aten.t.default(view); view = None |
| view_1 = torch.ops.aten.view.default(primals_1, [2, 2]); primals_1 = None |
| view_3 = torch.ops.aten.view.default(t, [2, 2]) |
| view_4 = torch.ops.aten.view.default(mul, [2, 2]) |
| return [mul, t, view_1, view_4, view_3]""") |
| |
| def test_input_mutation_and_output_view(self): |
| def f(a): |
| a.add_(1) |
| return a.view(-1) |
| inp = [torch.ones(2, 2, requires_grad=False).add(1)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(2, 2, requires_grad=True).add(1)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| # Here, total # of outputs is 1 because: |
| # - num_mutated_inps = 1 (a_updated) |
| # - num_fw_outputs = 0 (the output is an alias of the input, so we move it outside the compiled fw) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| clone = torch.ops.aten.clone.default(primals_1); primals_1 = None |
| add = torch.ops.aten.add.Tensor(clone, 1); clone = None |
| view_1 = torch.ops.aten.view.default(add, [-1]) |
| return [add, view_1]""") |
| |
| |
| def test_input_mutation_output_view_multiple(self): |
| def f(a, b, c, d): |
| b.transpose_(1, 0) |
| c.add_(1) |
| return d + 1, b.diagonal(), a + c |
| |
| def create_inp(req_grad): |
| return [ |
| torch.arange(4, requires_grad=req_grad, dtype=torch.float32).view(2, 2).add(1), |
| torch.arange(4, requires_grad=req_grad, dtype=torch.float32).view(2, 2).add(1), |
| torch.ones(2, 2, requires_grad=req_grad).add(1), |
| torch.ones(2, 2, requires_grad=req_grad).add(1), |
| ] |
| |
| self.verify_aot_autograd(f, create_inp(False), test_mutation=True) |
| fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1, primals_2, primals_3, primals_4): |
| view = torch.ops.aten.view.default(primals_2, [2, 2]); primals_2 = None |
| clone = torch.ops.aten.clone.default(primals_3); primals_3 = None |
| transpose = torch.ops.aten.transpose.int(view, 1, 0); view = None |
| add = torch.ops.aten.add.Tensor(clone, 1); clone = None |
| add_1 = torch.ops.aten.add.Tensor(primals_4, 1); primals_4 = None |
| diagonal = torch.ops.aten.diagonal.default(transpose) |
| add_2 = torch.ops.aten.add.Tensor(primals_1, add); primals_1 = None |
| return [transpose, add, add_1, diagonal, add_2]""") |
| |
| def test_output_aliases_intermediate_single(self): |
| def f(a): |
| out = torch.mul(a, 3) |
| return out.view(-1) |
| inp = [torch.ones(3, 3, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(3, 3, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| # In AOTAutograd, we are obligated to make the compiled forward directly return `out`, |
| # and reconstruct `out.view(-1)` as a fresh output. |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None |
| view = torch.ops.aten.view.default(mul, [-1]); mul = None |
| return [view]""") |
| |
| def test_output_aliases_intermediate_mutation_linear(self): |
| def f(x): |
| return (x + 1).view(-1) |
| |
| inp = [torch.ones(3, 3, requires_grad=True)] |
| # use inductor's decomps (which will e.g. turn _unsafe_view() into view()) |
| from torch._inductor.decomposition import decompositions |
| f_compiled = aot_function(f, nop, decompositions=decompositions) |
| |
| out_ref = f(*inp) |
| out_test = f_compiled(*inp) |
| |
| out_ref.mul_(2) |
| out_test.mul_(2) |
| self.assertEqual(out_ref, out_test) |
| |
| def test_output_aliases_intermediate_no_grad(self): |
| def f(a, b): |
| out = torch.mul(a, 3) |
| # First output is an alias of an intermediate that doesn't require grad |
| return out.view(-1), b.add(1) |
| inp = [torch.ones(3, 3), torch.ones(3, 3, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(3, 3), torch.ones(3, 3, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| # important bit: we don't bother generating an intermediate base as an output in the graph, |
| # because the intermediate base itself didn't require gradients. |
| # (the only problematic case is when both the base and the aliasesed output require gradients). |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1, primals_2): |
| mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None |
| view = torch.ops.aten.view.default(mul, [-1]); mul = None |
| add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None |
| return [view, add]""") |
| |
| def test_output_aliases_intermediate_returned_multiple_times(self): |
| def f(a): |
| out = torch.mul(a, 3) |
| out_view = out.view(-1) |
| return out, out_view, out |
| inp = [torch.ones(3, 3, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(3, 3, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| |
| def test_output_aliases_intermediate_multiple(self): |
| def f(a): |
| out = torch.mul(a, 3) |
| # AOTAutograd should manually generate these two output views in the epilogue. |
| return out.view(-1), out.view(-1) |
| inp = [torch.ones(3, 3, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(3, 3, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None |
| view = torch.ops.aten.view.default(mul, [-1]) |
| view_1 = torch.ops.aten.view.default(mul, [-1]) |
| return [view, view_1, mul]""") |
| |
| def test_output_aliases_intermediate_and_returned(self): |
| def f(a): |
| out = torch.mul(a, 3) |
| # AOTAutograd should manually generate the first output (a view of an intermediate) |
| # but not the second (which is itself the intermediate for the first) |
| return out.view(-1), out |
| inp = [torch.ones(3, 3, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(3, 3, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None |
| view = torch.ops.aten.view.default(mul, [-1]) |
| return [view, mul]""") |
| |
| def test_output_aliases_intermediate_and_returned_flipped(self): |
| def f(a): |
| out = torch.mul(a, 3) |
| # AOTAutograd should manually generate the first output (a view of an intermediate) |
| # but not the second (which is itself the intermediate for the first) |
| return out, out.view(-1) |
| inp = [torch.ones(3, 3, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(3, 3, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None |
| view = torch.ops.aten.view.default(mul, [-1]) |
| return [mul, view]""") |
| |
| def test_output_aliases_intermediate_and_returned_different_grad(self): |
| def f(a): |
| out = torch.mul(a, 3) |
| # AOTAutograd should manually generate the first output (a view of an intermediate) |
| # but not the second (which is itself the intermediate for the first) |
| return out.view(-1), out, out[0].detach() |
| inp = [torch.ones(3, 3, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(3, 3, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None |
| view = torch.ops.aten.view.default(mul, [-1]) |
| select = torch.ops.aten.select.int(mul, 0, 0) |
| detach = torch.ops.aten.detach.default(select); select = None |
| return [view, mul, detach]""") |
| |
| def test_output_aliases_intermediate_inplace_view(self): |
| def f(a): |
| out = torch.mul(a, 3) |
| out.t_() |
| return out |
| inp = [torch.ones(2, 4, requires_grad=True)] |
| |
| # TODO: fix this test. |
| # See https://github.com/pytorch/pytorch/issues/90507 |
| # self.verify_aot_autograd(f, inp, test_mutation=True) |
| |
| def test_output_aliases_intermediate_inplace_view_with_detach(self): |
| def f(a): |
| out = torch.mul(a, 3) |
| out.t_() |
| out.detach_() |
| # Thanks to the detach_() AOT Autograd doesn't need to do anything. |
| # `out` will show up as having OutputType.non_alias, |
| # and ._is_view() == False |
| return out |
| inp = [torch.ones(2, 4, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(2, 4, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None |
| t = torch.ops.aten.t.default(mul); mul = None |
| return [t]""") |
| |
| |
| def test_output_aliases_intermediate_inplace_view_and_view(self): |
| def f(a): |
| out = torch.mul(a, 3) |
| out_view = out.unsqueeze(0) |
| out.t_() |
| out_view2 = out.unsqueeze(0) |
| return out_view, out, out_view2 |
| inp = [torch.ones(2, 4, requires_grad=True)] |
| |
| # TODO: fix this test. |
| # See <github issue link> |
| # self.verify_aot_autograd(f, inp, test_mutation=True) |
| |
| def test_output_aliases_intermediate_multiple_mixed(self): |
| def f(a): |
| out1 = torch.mul(a, 3) |
| out2 = torch.mul(a, 4) |
| # AOTAutograd should manually generate these two output views in the epilogue. |
| return out1.view(-1), out2.transpose(1, 0), out1.transpose(1, 0) |
| inp = [torch.ones(3, 3, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(3, 3, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| mul = torch.ops.aten.mul.Tensor(primals_1, 3) |
| mul_1 = torch.ops.aten.mul.Tensor(primals_1, 4); primals_1 = None |
| view = torch.ops.aten.view.default(mul, [-1]) |
| transpose = torch.ops.aten.transpose.int(mul_1, 1, 0); mul_1 = None |
| transpose_1 = torch.ops.aten.transpose.int(mul, 1, 0) |
| return [view, transpose, transpose_1, mul]""") |
| |
| def test_output_all_alias_types(self): |
| # There are 3 types of aliasing that require us to return metadata in the compiled fw: |
| # (1) outputs that are views of inputs |
| # (2) outputs that are views of intermediates |
| # (3) inputs that get metadata mutations |
| # test all 3 of them here |
| def f(a): |
| a.transpose_(1, 0) |
| tmp = a.mul(2) |
| return tmp.squeeze(), tmp.transpose(1, 0), a.unsqueeze(0) |
| |
| def inp_callable(req_grad): |
| x = torch.ones(1, 2, 4, requires_grad=req_grad).clone() |
| return [(x,), (x,)] |
| |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True) |
| fw_graph = self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True) |
| # TODO: make this test run with dynamic shapes so it is more meaningful |
| # metadata output order: (a_updated_meta, out1_meta, out2_meta, out3_meta) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| view = torch.ops.aten.view.default(primals_1, [1, 2, 4]); primals_1 = None |
| transpose = torch.ops.aten.transpose.int(view, 1, 0); view = None |
| mul = torch.ops.aten.mul.Tensor(transpose, 2) |
| squeeze = torch.ops.aten.squeeze.default(mul) |
| transpose_1 = torch.ops.aten.transpose.int(mul, 1, 0) |
| unsqueeze = torch.ops.aten.unsqueeze.default(transpose, 0) |
| return [transpose, squeeze, transpose_1, unsqueeze, mul]""") |
| |
| def test_input_data_and_metadata_mutation(self): |
| def f(a): |
| a.t_() |
| a[0].mul_(2) |
| return a.view(a.shape) |
| inp = [torch.ones(3, 3, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(3, 3, requires_grad=True)] |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| clone = torch.ops.aten.clone.default(primals_1); primals_1 = None |
| t = torch.ops.aten.t.default(clone) |
| select = torch.ops.aten.select.int(t, 0, 0); t = None |
| mul = torch.ops.aten.mul.Tensor(select, 2); select = None |
| t_1 = torch.ops.aten.t.default(clone); clone = None |
| select_scatter = torch.ops.aten.select_scatter.default(t_1, mul, 0, 0); t_1 = mul = None |
| t_2 = torch.ops.aten.t.default(select_scatter); select_scatter = None |
| t_4 = torch.ops.aten.t.default(t_2) |
| t_6 = torch.ops.aten.t.default(t_2); t_2 = None |
| view_1 = torch.ops.aten.view.default(t_6, [3, 3]); t_6 = None |
| return [t_4, view_1]""") |
| |
| def test_view_and_inplace_view(self): |
| def f(a, b): |
| a.t_() |
| return b.view(b.shape), a.view(a.shape) |
| |
| def create_inp(req_grad): |
| return [ |
| torch.ones(3, 3, requires_grad=req_grad), |
| torch.ones(3, 3, requires_grad=req_grad) |
| ] |
| |
| self.verify_aot_autograd(f, create_inp(False), test_mutation=True) |
| fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1, primals_2): |
| view = torch.ops.aten.view.default(primals_1, [3, 3]); primals_1 = None |
| t = torch.ops.aten.t.default(view); view = None |
| view_1 = torch.ops.aten.view.default(primals_2, [3, 3]); primals_2 = None |
| view_2 = torch.ops.aten.view.default(t, [3, 3]) |
| return [t, view_1, view_2]""") |
| |
| def test_view_detach(self): |
| def f(a): |
| tmp = a.detach() |
| a.mul_(2) |
| return a, tmp |
| inp = [torch.ones(3, 3, requires_grad=True)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| inp = [torch.ones(3, 3, requires_grad=False)] |
| self.verify_aot_autograd(f, inp, test_mutation=True) |
| |
| def test_input_inplace_requires_grad_true(self): |
| def f(a, b): |
| a.requires_grad_(True) |
| return a.mul(3), b.mul(4) |
| inp = [ |
| # First inp doesnt require grad, but we switch it on |
| torch.ones(3, 3, requires_grad=False), |
| torch.ones(3, 3, requires_grad=True), |
| ] |
| |
| fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1, primals_2): |
| mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None |
| mul_1 = torch.ops.aten.mul.Tensor(primals_2, 4); primals_2 = None |
| return [mul, mul_1]""") |
| |
| # This is a torture test: |
| # a and b get turned into a synthetic base in the compiled graph |
| # One gets a data mutation, the other gets a metadata mutation. |
| # We need to make sure that the metadata mutation gets propagated |
| # back to the original input. |
| def test_input_data_and_metadata_mutation_aliases_other_input(self): |
| # a and b are aliased |
| def f(a, b): |
| a.mul_(2) |
| b.t_() |
| return a.mul(b) |
| |
| def inp_callable(req_grad): |
| base = torch.ones(2, 2, requires_grad=req_grad) |
| # Note: in our test, the add() is important because we need the graph inputs to be non-leaves so we can mutate them. |
| x = base.add(1) |
| inp1 = x[0] |
| inp2 = x[1] |
| return [base], [inp1, inp2] |
| |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True) |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True) |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") |
| def test_mem_leak_from_save_for_bw(self): |
| # See a full diagnosis at this issue: https://github.com/pytorch/pytorch/issues/94990 |
| # Note [Detaching saved tensors in AOTAutograd] |
| # This program creates a ref-cycle. Long term, we should fix this ref cycle |
| # (since it can arise, naturally albeit rarely, from uses of autograd.Function). |
| # But AOTAutograd makes it more likely to show up from tracing user programs, |
| # so we deal with it by manually detaching the tensors that we save for backward. |
| # This is completely wrong and would give wrong results if we were to do double backward. |
| # Fortunately today, double backward is explicitly banned in AOTAutograd. |
| def f(a, b): |
| add = a + a |
| split = torch.functional.split(add, [4, 4], dim=1) |
| getitem_2 = split[1] |
| unsqueeze = getitem_2.unsqueeze(-1) |
| mul = unsqueeze * b |
| return (getitem_2, mul) |
| |
| f_compiled = aot_function(f, nop) |
| inps = [ |
| torch.ones(8, 8, device='cuda', requires_grad=True), |
| torch.ones(1, 4, 1, device='cuda', requires_grad=True), |
| ] |
| mem_before = torch.cuda.memory_allocated() |
| f_compiled(*inps) |
| mem_after = torch.cuda.memory_allocated() |
| self.assertTrue(mem_after == mem_before) |
| |
| def test_output_aliases_multiple_inputs_get_correct_one(self): |
| # a and b are aliased, but have different shapes |
| # The first output should view off the the first input, the 2nd output should view off the 2nd input |
| def f(a, b): |
| return a.view(a.shape), b.view(b.shape) |
| |
| def inp_callable(req_grad): |
| base = torch.ones(2, 2, requires_grad=req_grad) |
| # Note: in our test, the add() is important because we need the graph inputs to be non-leaves so we can mutate them. |
| x = base.mul(2) |
| inp1 = x.view(-1) |
| inp2 = x[0] |
| return [base], [inp1, inp2] |
| |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True) |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True) |
| |
| def test_input_mutation_aliases_other_input(self): |
| def f(a, b): |
| a.add_(1) |
| return a + b |
| |
| def inp_callable(req_grad): |
| base = torch.ones(2, 2, requires_grad=req_grad) |
| # Note: in our test, the add() is important because we need the graph inputs to be non-leaves so we can mutate them. |
| x = base.add(1) |
| inp1 = x[0] |
| inp2 = x[1] |
| return [base], [inp1, inp2] |
| |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True) |
| fw_graph = self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True) |
| # Important parts of the graph: |
| # - the compiled graph takes in a base, and we generate a and b (the views) off of the base |
| # - clone() is still in the graph, because we need to call grad() on the original (non-mutated) inputs |
| # - We re-generate the views *after* the clone, to preserve view relationships. |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| clone = torch.ops.aten.clone.default(primals_1); primals_1 = None |
| as_strided = torch.ops.aten.as_strided.default(clone, [2], [1], 0) |
| add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None |
| as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = add = None |
| as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0) |
| as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 2) |
| add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_2 = as_strided_5 = None |
| return [as_strided_scatter, add_1]""") # noqa: B950 |
| |
| def test_input_mutation_aliases_other_input2(self): |
| def f(a, b): |
| a.add_(1) |
| return a + b |
| |
| def inp_callable(req_grad): |
| base = torch.ones(2, 2, requires_grad=req_grad) |
| x = base.add(1) |
| inp1 = x[0] |
| # Here, one of the aliased inputs is the base itself |
| inp2 = x |
| return [base], [inp1, inp2] |
| |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True) |
| fw_graph = self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| clone = torch.ops.aten.clone.default(primals_1); primals_1 = None |
| as_strided = torch.ops.aten.as_strided.default(clone, [2], [1], 0) |
| add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None |
| as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = add = None |
| as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0) |
| as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2, 2], [2, 1], 0) |
| add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_2 = as_strided_5 = None |
| return [as_strided_scatter, add_1]""") # noqa: B950 |
| |
| def test_input_mutation_aliases_and_output_alias(self): |
| def f(a, b): |
| # Here, we need to take care:that because and b are aliased |
| # since a and b are aliased, we generate a view off of "updated b" |
| a.add_(1) |
| return b.view(b.shape) |
| |
| def inp_callable(req_grad): |
| base = torch.ones(2, 2, requires_grad=req_grad) |
| x = base.add(1) |
| return [base], [x.view(-1), x.view(-1)] |
| |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True) |
| fw_graph = self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1): |
| clone = torch.ops.aten.clone.default(primals_1); primals_1 = None |
| as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0) |
| add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None |
| as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None |
| as_strided_8 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) |
| view_1 = torch.ops.aten.view.default(as_strided_8, [4]); as_strided_8 = None |
| return [as_strided_scatter, view_1]""") # noqa: B950 |
| |
| def test_input_aliased_with_mutation_output_alias(self): |
| def f(a, b, c): |
| # a and c alias |
| c.mul_(2) |
| # The main thing we're testing here is that |
| # (1) We need to reconstruct c.view(-1) from the 3rd input to the forward |
| # (2) But we need to be careful to do this *before* converting aliased inputs into synthetic bases. |
| # The original fw takes in 3 args, but the compiled fw takes in only 2 args. |
| return b.add(1), c.view(-1) |
| |
| def inp_callable(req_grad): |
| base1 = torch.ones(2, 2, requires_grad=req_grad) |
| base2 = torch.ones(2, 2, requires_grad=req_grad) |
| x = base1.add(1) |
| y = base2.add(1) |
| return [base1, base2], [x.view(-1), y, x.view(-1)] |
| |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True) |
| fw_graph = self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1, primals_2): |
| clone = torch.ops.aten.clone.default(primals_1); primals_1 = None |
| as_strided_1 = torch.ops.aten.as_strided.default(clone, [4], [1], 0) |
| mul = torch.ops.aten.mul.Tensor(as_strided_1, 2); as_strided_1 = None |
| as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None |
| add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None |
| as_strided_7 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) |
| view_1 = torch.ops.aten.view.default(as_strided_7, [-1]); as_strided_7 = None |
| return [as_strided_scatter, add, view_1]""") # noqa: B950 |
| |
| def test_input_metadata_mutation_aliases(self): |
| def f(a, b): |
| # a and b alias, and we do a metadata mutation on a |
| # Since we're not mutating data, then b isn't affected at all. |
| # We expect aot autograd to not bother with constructing a synthetic base. |
| a.t_() |
| return a + b |
| |
| def inp_callable(req_grad): |
| base = torch.ones(2, 2, requires_grad=req_grad) |
| x = base.add(1) |
| return [base], [x.view(-1), x.view(-1)] |
| |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True) |
| fw_graph = self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True) |
| # Expectation: fwd() takes in 2 args, and we don't construct a synthetic base. |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1, primals_2): |
| view = torch.ops.aten.view.default(primals_1, [4]); primals_1 = None |
| t = torch.ops.aten.t.default(view); view = None |
| add = torch.ops.aten.add.Tensor(t, primals_2); primals_2 = None |
| return [t, add]""") |
| |
| def test_input_mutation_aliases_and_none_require_gradients(self): |
| def f(a, b, c): |
| # a and b alias, but neither require gradients (so they don't have a _base) |
| # aot autograd should construct the synthetic base from `torch.Tensor(a.storage())` |
| a.mul_(2) |
| return b + 1, c + 1 |
| |
| def inp_callable(req_grad): |
| base = torch.ones(2, 2) |
| c_arg = torch.ones(2, 2, requires_grad=req_grad) |
| x = base.add(1) |
| return [base, c_arg], [x.view(-1), x.view(-1), c_arg] |
| |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True) |
| fw_graph = self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1, primals_2): |
| as_strided = torch.ops.aten.as_strided.default(primals_1, [4], [1], 0) |
| mul = torch.ops.aten.mul.Tensor(as_strided, 2); as_strided = None |
| as_strided_scatter = torch.ops.aten.as_strided_scatter.default(primals_1, mul, [4], [1], 0); primals_1 = mul = None |
| as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) |
| add = torch.ops.aten.add.Tensor(as_strided_3, 1); as_strided_3 = None |
| add_1 = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None |
| return [as_strided_scatter, add, add_1]""") # noqa: B950 |
| |
| def test_input_mutation_aliases_bases_out_of_order(self): |
| # This tests our calling convention: if b and d are aliased, then the outer calling convention |
| # that we send to the compiled forward becomes: |
| # (b_d_base, a, c) |
| # Importantly, even though a and c alias in our test, neither inputs are mutated, |
| # So we don't need to do the base construction / deconstruction |
| def f(a, b, c, d): |
| b.add_(1) |
| d.t_() |
| return a + c + d, b.view(-1) |
| |
| def inp_callable(req_grad): |
| base1 = torch.ones(2, 2, requires_grad=req_grad) |
| base2 = torch.ones(2, 2, requires_grad=req_grad) |
| x1 = base1.add(1) |
| x2 = base2.add(1) |
| # a and c alias, b and d alias |
| return [base1, base2], [x1.view(-1), x2.view(-1), x1.view(-1), x2.view(-1)] |
| |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True) |
| fw_graph = self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True) |
| # 3 graph inputs: (b_d_base, a, c) |
| # 2 returns: (b_updated, a+c+d) |
| # (there are 2 original fw outs, but one is a view of b so it's not part of the graph) |
| # (there are also 2 input mutations, but one is a metadata-only mutation so the compiled forward doesn't return it) |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1, primals_2, primals_3): |
| clone = torch.ops.aten.clone.default(primals_1); primals_1 = None |
| as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0) |
| add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None |
| as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None |
| add_1 = torch.ops.aten.add.Tensor(primals_2, primals_3); primals_2 = primals_3 = None |
| as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) |
| t_1 = torch.ops.aten.t.default(as_strided_5); as_strided_5 = None |
| add_2 = torch.ops.aten.add.Tensor(add_1, t_1); add_1 = None |
| as_strided_14 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) |
| view_1 = torch.ops.aten.view.default(as_strided_14, [-1]); as_strided_14 = None |
| return [as_strided_scatter, add_2, view_1, t_1]""") # noqa: B950 |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") |
| def test_synthetic_base_base_attribute_is_none(self): |
| def f(a, b): |
| a.add_(1) |
| return a + b |
| |
| def inp_callable(): |
| base = torch.ones(4, 4, device='cuda') |
| # detach() so that none of the inputs have a ._base attribute. |
| a = base[0].detach() |
| b = base[1].detach() |
| base2 = torch.ones(2, 2, requires_grad=True) |
| return [base], [a, b] |
| |
| self.verify_aot_autograd(f, inp_callable, test_mutation=True) |
| |
| |
| def test_input_mutation_alias_everything(self): |
| # Mondo test that tests a combination of: |
| # input is mutated, that aliases another input (so we make a synthetic base) |
| # an output is an alias of another output |
| # an output is an alias of an intermediate |
| # a and c are aliased |
| def f(a, b, c): |
| c.mul_(2) # mutates c |
| b.t_() # metadata mutate b |
| tmp = a + c |
| out1 = tmp.view(-1) |
| out2 = b.t() |
| out3 = out1.unsqueeze(0) |
| # out1 and out3 are aliases of an intermediate, and alias each other! |
| # out2 aliases an input, so we don't return it |
| return out1, out2, out3 |
| |
| def inp_callable(req_grad): |
| base1 = torch.ones(2, 2, requires_grad=req_grad) |
| base2 = torch.ones(2, 2, requires_grad=req_grad) |
| # Note: in our test, the add() is important because we need the graph inputs to be non-leaves so we can mutate them. |
| base1_ = base1.add(1) |
| base2_ = base2.add(1) |
| a = base1_.view(-1) |
| b = base2_ |
| c = base1_.view(-1) |
| return [base1, base2], [a, b, c] |
| |
| self.verify_aot_autograd(f, partial(inp_callable, req_grad=False), test_mutation=True) |
| fw_graph = self.verify_aot_autograd(f, partial(inp_callable, req_grad=True), test_mutation=True) |
| # Expected: |
| # - 2 inputs in the forward: synthetic_base_a_c, b |
| # - 1 output in the forward: "tmp" |
| # out2 is an alias of an input, and will be generated off of b outside of the compiled fn |
| # out1 and out3 are aliases of tmp, that we generate outside of the compiled function |
| self.assertExpectedInline(fw_graph.code.strip(), """\ |
| def forward(self, primals_1, primals_2): |
| clone = torch.ops.aten.clone.default(primals_1); primals_1 = None |
| view = torch.ops.aten.view.default(primals_2, [2, 2]); primals_2 = None |
| as_strided_1 = torch.ops.aten.as_strided.default(clone, [4], [1], 0) |
| mul = torch.ops.aten.mul.Tensor(as_strided_1, 2); as_strided_1 = None |
| as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None |
| as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) |
| t = torch.ops.aten.t.default(view); view = None |
| as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) |
| add = torch.ops.aten.add.Tensor(as_strided_5, as_strided_2); as_strided_5 = as_strided_2 = None |
| view_1 = torch.ops.aten.view.default(add, [-1]) |
| t_1 = torch.ops.aten.t.default(t) |
| unsqueeze = torch.ops.aten.unsqueeze.default(view_1, 0) |
| return [as_strided_scatter, t, view_1, t_1, unsqueeze, add]""") # noqa: B950 |
| |
| def test_dynamic_shape_output_not_in_bw_graph(self): |
| def f(x): |
| return [x + 1, x.shape[0]] |
| inp = torch.ones(5, requires_grad=True) |
| bw_graph_cell = [None] |
| compiled_f = aot_function( |
| f, |
| fw_compiler=nop, |
| bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), |
| decompositions={}, |
| keep_inference_input_mutations=False, |
| dynamic=True, |
| ) |
| out = compiled_f(inp) |
| out[0].sum().backward() |
| # The important bit: the forward fn returns 2 outputs, |
| # but one of them is a symint so we should only see |
| # 1 grad_output as an input to the backward graph. |
| # (Otherwise, autograd will plumb a None as the value of the grad_output, |
| # which causes inductor to complain). |
| self.assertExpectedInline(bw_graph_cell[0].code.strip(), """\ |
| def forward(self, tangents_1): |
| return [tangents_1]""") |
| |
| def test_no_grad_input_output(self): |
| def f(a, b): |
| return a.cos(), b.cos(), a * b |
| |
| inp_thunks = [lambda: torch.randn(5, requires_grad=True), lambda: torch.randn(5, requires_grad=False)] |
| for inps in itertools.product(inp_thunks, repeat=2): |
| inps = [i() for i in inps] |
| self.verify_aot_autograd(f, inps) |
| |
| def test_some_output_requires_grad_input_doesnt(self): |
| def f(a, b): |
| a_view = a.view(-1) |
| a_view.requires_grad_(True) |
| return a_view |
| inp = [torch.randn(3, 3), torch.randn(3, 3, requires_grad=True)] |
| self.verify_aot_autograd(f, inp) |
| |
| def test_some_outputs_dont_require_grad_view(self): |
| def f(a, b): |
| return a.detach(), b |
| inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3, requires_grad=True)] |
| self.verify_aot_autograd(f, inp) |
| |
| def test_some_outputs_dont_require_grad_non_view(self): |
| def f(a, b): |
| return a.add(1).detach(), b |
| inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3, requires_grad=True)] |
| self.verify_aot_autograd(f, inp) |
| |
| def test_inner_grad(self): |
| def foo(x): |
| y = torch.exp(x) |
| z = torch.autograd.grad(y, x) |
| return z |
| inps = [torch.randn((), requires_grad=True)] |
| self.verify_aot_autograd(foo, inps) |
| |
| def test_grad_context(self): |
| def foo(x): |
| return x * 2 |
| inps = [torch.randn((), requires_grad=True)] |
| graph_size = None |
| |
| def get_graph_size(fx_g, _): |
| nonlocal graph_size |
| graph_size = len(fx_g.graph.nodes) |
| return fx_g |
| |
| f = aot_function(foo, nop, get_graph_size) |
| with torch.set_grad_enabled(False): |
| f(*inps) |
| self.assertIsNone(graph_size) |
| |
| f = aot_function(foo, nop, get_graph_size) |
| with torch.set_grad_enabled(True): |
| out = f(*inps) |
| self.assertIsNone(graph_size) |
| out.sum().backward() |
| self.assertTrue(graph_size > 2) |
| |
| def test_output_dict(self): |
| def f(x): |
| return {'a': x, 'b': x} |
| inp = [torch.randn(3, 3, requires_grad=True)] |
| self.verify_aot_autograd(f, inp) |
| |
| def f(x, y): |
| return {'a': x, 'b': y + x} |
| inp = [torch.randn(3, requires_grad=True), torch.randn(3)] |
| self.verify_aot_autograd(f, inp) |
| |
| def f(x): |
| new_d = {} |
| for k in x: |
| new_d[k] = x[k] * 2 |
| return new_d |
| |
| a = torch.randn(3, requires_grad=True) |
| b = torch.randn(3, requires_grad=True) |
| |
| def inp_callable(): |
| inps = [{'a': a, 'b': b}] |
| return inps, inps |
| |
| self.verify_aot_autograd(f, inp_callable) |
| |
| def test_module(self): |
| mod = nn.Sequential(nn.Linear(32, 32), nn.ReLU()) |
| compiled_mod = compiled_module(mod, nop, nop) |
| inp = torch.randn(32, 32) |
| ref_out = mod(inp) |
| ref_out.sum().backward() |
| ref_grads = sorted([(name, p.grad) for name, p in mod.named_parameters()]) |
| out = compiled_mod(inp) |
| out.sum().backward() |
| grads = sorted([(name, p.grad) for name, p in mod.named_parameters()]) |
| self.assertEqual((out, grads), (ref_out, ref_grads)) |
| |
| def test_batchnorm(self): |
| mod = compiled_module(nn.BatchNorm2d(4), nop, nop) |
| x = torch.ones(1, 4, 2, 2) |
| mod(x).sum().backward() |
| |
| def test_list_codegen(self): |
| def list_nop(f, _): |
| def g(inps): |
| return f(*inps) |
| g._boxed_call = True |
| return g |
| |
| def f(a, b, c): |
| return a.sin() * b.cos() * c.sin() |
| f = aot_function(f, list_nop) |
| inp = [torch.randn(5, requires_grad=True) for _ in range(3)] |
| f(*inp).sum().backward() |
| |
| @patch('torch._functorch.aot_autograd.AOT_COUNTER', new_callable=itertools.count) |
| def test_compilation_context(self, counter): |
| def f(x): |
| return x.sin().sin() |
| count = [] |
| |
| def compiler(fx_g, _): |
| context = get_aot_compilation_context() |
| count.append((context[0], len(fx_g.graph.nodes))) |
| return fx_g |
| |
| f = aot_function(f, compiler) |
| out = f(torch.randn(5, requires_grad=True)) |
| f = aot_function(f, compiler) |
| f(torch.randn(5)) |
| out.sum().backward() |
| self.assertExpectedInline(str(count), """[(['0_forward'], 4), (['1_inference'], 4), (['0_backward'], 8)]""") |
| |
| def test_dupe_arg(self): |
| def f(x, y): |
| return x + y |
| |
| x = torch.randn(3, 3, requires_grad=True) |
| self.verify_aot_autograd(f, [x, x]) |
| |
| def test_dupe_arg_torture(self): |
| def f(x, y): |
| x.t_() |
| y.t_() |
| return x + y |
| |
| x = torch.randn(3, 3, requires_grad=True).clone() |
| self.verify_aot_autograd(f, [x, x]) |
| |
| # See https://github.com/pytorch/pytorch/issues/100224 |
| def test_dupe_arg_returned_as_output(self): |
| def f(a, b, a_): |
| a[0].add_(1) |
| return a_ |
| f_compiled = aot_function(f, nop) |
| a = torch.ones(2) |
| b = torch.ones(2) |
| out_ref = f(a, b, a) |
| |
| a2 = torch.ones(2) |
| b2 = torch.ones(2) |
| out_test = f_compiled(a2, b2, a2) |
| |
| self.assertEqual(out_ref, out_test) |
| self.assertEqual(a, a2) |
| |
| @patch('torch._functorch.aot_autograd.AOT_COUNTER', new_callable=itertools.count) |
| @patch("torch._functorch.config.debug_assert", True) |
| def test_invalid_dupe_left_bias(self, counter): |
| # This test checks that, just because only the first |
| # argument did a metadata mutation, we still correctly |
| # switch to strategy 2 (deduplicate) |
| # See: https://github.com/pytorch/pytorch/pull/89896#discussion_r1036224447 |
| class F(torch.nn.Module): |
| def forward(self, x, y): |
| x.t_() |
| return (x + y,) |
| |
| x = torch.randn(3, 3, requires_grad=True).clone() |
| y = torch.randn(3, 3, requires_grad=True) |
| self.verify_aot_autograd(F(), [x, x]) |
| |
| fxx = aot_module_simplified(F(), (x, x), nop) |
| self.assertExpectedRaisesInline( |
| AssertionError, lambda: fxx(x, y), |
| """At compilation time, graph 2 was compiled under the assumption that input 1 would be a duplicate of input 0, but at runtime this was not the case. This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""" # noqa: B950 |
| ) |
| |
| |
| @patch('torch._functorch.aot_autograd.AOT_COUNTER', new_callable=itertools.count) |
| @patch("torch._functorch.config.debug_assert", True) |
| def test_invalid_dupe(self, counter): |
| self._test_invalid_dupe(counter, fake=False) |
| |
| # See Note: Dynamo recompilation guarding invalid grad for why this test exists |
| @patch('torch._functorch.aot_autograd.AOT_COUNTER', new_callable=itertools.count) |
| @patch("torch._functorch.config.debug_assert", True) |
| def test_invalid_dupe_fake(self, counter): |
| self._test_invalid_dupe(counter, fake=True) |
| |
| |
| def _test_invalid_dupe(self, counter, fake): |
| class F(torch.nn.Module): |
| def forward(self, x, y): |
| x.t_() |
| y.t_() |
| return (x + y,) |
| |
| x = torch.randn(3, 3, requires_grad=True).clone() |
| y = torch.randn(3, 3, requires_grad=True).clone() |
| |
| if fake: |
| shape_env = ShapeEnv() |
| fake_mode = FakeTensorMode(shape_env=shape_env) |
| |
| fake_x = fake_mode.from_tensor(x) |
| fake_y = fake_mode.from_tensor(y) |
| |
| if fake: |
| fxy = aot_module_simplified(F(), (fake_x, fake_y), nop) |
| else: |
| fxy = aot_module_simplified(F(), (x, y), nop) |
| |
| fxy(x, y) |
| fxy(x, x) # is ok! |
| |
| if fake: |
| fxx = aot_module_simplified(F(), (fake_x, fake_x), nop) |
| else: |
| fxx = aot_module_simplified(F(), (x, x), nop) |
| |
| fxx(x, x) |
| # Note This should not raise! Once we have guards in place here, |
| # we will have this working correctly, as it should recompile. |
| self.assertExpectedRaisesInline( |
| AssertionError, lambda: fxx(x, y), |
| """At compilation time, graph 1 was compiled under the assumption that input 1 would be a duplicate of input 0, but at runtime this was not the case. This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""" # noqa: B950 |
| ) |
| |
| |
| @patch('torch._functorch.aot_autograd.AOT_COUNTER', new_callable=itertools.count) |
| @patch("torch._functorch.config.debug_assert", True) |
| def test_invalid_requires_grad(self, counter): |
| self._test_invalid_requires_grad(counter, fake=False) |
| |
| # See Note: Dynamo recompilation guarding invalid grad for why this test exists |
| @patch('torch._functorch.aot_autograd.AOT_COUNTER', new_callable=itertools.count) |
| @patch("torch._functorch.config.debug_assert", True) |
| def test_invalid_requires_grad_fake(self, counter): |
| self._test_invalid_requires_grad(counter, fake=True) |
| |
| def _test_invalid_requires_grad(self, counter, fake): |
| class F(torch.nn.Module): |
| def forward(self, x, y): |
| return (x + y,) |
| |
| x = torch.randn(3, 3, requires_grad=True) |
| y = torch.randn(3, 3, requires_grad=True) |
| z = torch.randn(3, 3, requires_grad=False) |
| |
| if fake: |
| shape_env = ShapeEnv() |
| fake_mode = FakeTensorMode(shape_env=shape_env) |
| |
| fake_x = fake_mode.from_tensor(x) |
| fake_y = fake_mode.from_tensor(y) |
| fake_z = fake_mode.from_tensor(z) |
| |
| if fake: |
| fxy = aot_module_simplified(F(), (fake_x, fake_y), nop) |
| else: |
| fxy = aot_module_simplified(F(), (x, y), nop) |
| |
| compare_equal_outs_and_grads(self, F(), fxy, (x, y)) |
| compare_equal_outs_and_grads(self, F(), fxy, (x, z)) |
| |
| if fake: |
| fxz = aot_module_simplified(F(), (fake_x, fake_z), nop) |
| else: |
| fxz = aot_module_simplified(F(), (x, z), nop) |
| |
| compare_equal_outs_and_grads(self, F(), fxz, (x, z)) |
| |
| self.assertExpectedRaisesInline( |
| AssertionError, lambda: fxz(x, y), |
| """At compilation time, graph 1 was compiled under the assumption that input 1 would not require grad, but at runtime this was not the case. This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""" # noqa: B950 |
| ) |
| |
| def test_resize_input(self): |
| def f(x, y): |
| y.resize_(4) |
| y.zero_() |
| self.assertEqual(x.shape, (4,)) |
| return y |
| |
| # NB: don't use verify_aot_autograd as the inputs get |
| # mutated and I don't trust verify to do it right |
| |
| compiled_f = aot_function(f, nop) |
| ref_x = torch.randn(0) |
| ref_out = f(ref_x, ref_x) |
| |
| test_x = torch.randn(0) |
| test_out = compiled_f(test_x, test_x) |
| |
| self.assertEqual(ref_out, test_out) |
| |
| def test_resize_input_smaller(self): |
| def f(x, y): |
| y.resize_(4) |
| y.zero_() |
| self.assertEqual(x.shape, (4,)) |
| return y |
| |
| # NB: don't use verify_aot_autograd as the inputs get |
| # mutated and I don't trust verify to do it right |
| |
| compiled_f = aot_function(f, nop) |
| ref_x = torch.randn(5) |
| ref_out = f(ref_x, ref_x) |
| |
| test_x = torch.randn(5) |
| test_out = compiled_f(test_x, test_x) |
| |
| self.assertEqual(ref_out, test_out) |
| |
| |
| def test_custom_autograd(self): |
| class CustomFn(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x): |
| return x.clone() |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| return grad_output + 1 |
| |
| def f(x): |
| return CustomFn.apply(x) |
| |
| self.verify_aot_autograd(f, [torch.randn(3)]) |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") |
| def test_autocast_disable_guard(self): |
| with torch._C._DisableAutocast(): |
| x = torch.rand([4, 4]).cuda() |
| y = x @ x |
| self.assertEqual(y.dtype, torch.float32) |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") |
| def test_nonidempotent_amp(self): |
| def f(self_s_emb, add_3): |
| einsum_2 = torch.functional.einsum('ah,th->t', self_s_emb, add_3) |
| log_softmax_2 = einsum_2.log_softmax(-1) |
| return (log_softmax_2,) |
| |
| args = [torch.rand((1, 256), dtype=torch.float32, device='cuda'), torch.rand((30, 256), dtype=torch.float16, device='cuda')] |
| with torch.cuda.amp.autocast(enabled=True): |
| self.verify_aot_autograd(f, args) |
| |
| args = [e.requires_grad_(True) for e in args] |
| with torch.cuda.amp.autocast(enabled=True): |
| self.verify_aot_autograd(f, args) |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") |
| @unittest.skipIf(not torch.backends.cudnn.is_available(), "CUDNN is unavailable") |
| @skipIfRocm # https://github.com/pytorch/pytorch/issues/96560 |
| def test_batch_norm_amp(self): |
| device = "cuda" |
| input_dtype = torch.float16 |
| param_dtype = torch.float32 |
| weight, bias = (torch.ones(64, device=device, dtype=param_dtype, requires_grad=True) for _ in range(2)) |
| running_mean, running_var = (torch.ones(64, device=device, dtype=param_dtype) for _ in range(2)) |
| |
| def bn(x): |
| return torch.ops.aten.cudnn_batch_norm( |
| x, |
| weight, |
| bias, |
| running_mean, |
| running_var, |
| False, |
| 0.1, |
| 1e-05, |
| ) |
| inp = torch.ones(torch.Size([16, 64, 112, 112]), dtype=input_dtype, device=device) |
| |
| ref = bn(inp) |
| cudnn_batch_norm_decomp = torch._decomp.get_decompositions({torch.ops.aten.cudnn_batch_norm}) |
| aot_fn = make_fx(bn, decomposition_table=cudnn_batch_norm_decomp)(inp) |
| res = aot_fn(inp) |
| for a, b in zip(ref, res): |
| assert torch.allclose(a, b) |
| |
| def test_output_op_depending_on_symint(self): |
| """ |
| It won't be obvious from reading this test what it's testing for. We should probably make it into a more |
| focused unit test. |
| |
| An issue with the following program was the expand op would end up depending on a symint whose proxy was |
| incorrectly associated with one of the grad tensors rather than input tensors. It broke partitioner logic |
| and the net result was aot_function failed to produce a function and threw an exception instead. |
| """ |
| inp = torch.randn(5, requires_grad=True) |
| |
| def f(x): |
| return x.expand(x.shape) |
| |
| # TODO(whc) make this work (test setup is wrong somehow) |
| # joint_forward_backward = create_joint_forward_backward(f) |
| # out = f(inp) |
| # joint_inputs = ([inp], [out.detach().contiguous()]) |
| # fx_g = make_fx(joint_forward_backward)(*joint_inputs) |
| # TODO: assert outputs of fwd graph trace to correct symint |
| |
| # e2e test that fails without symint clone fix |
| af = aot_function(f, nop, partition_fn=partial(min_cut_rematerialization_partition, compiler="inductor"), dynamic=True) |
| out = af(inp) |
| self.assertEqual(out, f(inp)) |
| |
| def test_inference_mode(self): |
| m = torch.nn.Linear(4, 4) |
| inp = torch.randn(4, 4) |
| |
| aot_mod = aot_module(m, fw_compiler=nop) |
| |
| with torch.inference_mode(): |
| out_ref = m(inp) |
| out_test = aot_mod(inp) |
| self.assertEqual(out_ref, out_test) |
| |
| def test_default_partitioner_saves_symints_not_tensors_for_bw(self): |
| """ |
| In this test, the important thing is that primals_1 is **only** needed in the backward |
| in order to grab its sizes. |
| We need to assert that what we save for the backward are the tensor's sizes, and not the tensor itself. |
| |
| The way this test is set up, it will actually fail if we try to save the input tensor for backward. |
| Why? |
| b.masked_fill_(c, 0) has a backward that requires knowing a's sizes |
| b.masked_fill_(c, 0) **also** mutates a (because b and a are aliased) |
| The autograd engine yells at us if we save "a" for backward, and then try to mutate it. |
| """ |
| inp = torch.randn(2, 2, requires_grad=True) |
| |
| def f(a): |
| b = a[0] |
| c = torch.ones_like(b, dtype=torch.bool) |
| d = b.masked_fill_(c, 0) |
| return d |
| |
| compiled_f = aot_function(f, nop, dynamic=True) |
| inp_ref = torch.ones(2, 2, requires_grad=True) |
| inp_test = torch.ones(2, 2, requires_grad=True) |
| |
| out_ref = f(inp_ref.clone()) |
| out_test = compiled_f(inp_test.clone()) |
| |
| self.assertEqual(out_ref, out_test) |
| |
| out_ref.sum().backward() |
| out_test.sum().backward() |
| |
| self.assertEqual(inp_ref.grad, inp_test.grad) |
| |
| def test_real_weights_in_symbolic_mode(self): |
| from functorch.experimental import functionalize |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(5, 5) |
| |
| def forward(self, x): |
| x = self.linear(x) |
| return x |
| |
| m = M().eval() |
| |
| inp = torch.randn(2, 5) |
| |
| gm = make_fx(m, tracing_mode="symbolic", _allow_non_fake_inputs=True)(inp) |
| self.assertEqual(gm(torch.ones(2, 5)), m(torch.ones(2, 5))) |
| |
| gm_functionalized = make_fx(functionalize(gm,), tracing_mode="symbolic", _allow_non_fake_inputs=True)(inp) |
| self.assertEqual(gm_functionalized(torch.ones(2, 5)), m(torch.ones(2, 5))) |
| |
| inp_count = 0 |
| for node in gm.graph.nodes: |
| if node.op == "placeholder": |
| inp_count += 1 |
| |
| # No more param lifting |
| self.assertEqual(inp_count, 1) |
| |
| inp_count = 0 |
| for node in gm_functionalized.graph.nodes: |
| if node.op == "placeholder": |
| inp_count += 1 |
| |
| # No more param lifting |
| self.assertEqual(inp_count, 1) |
| |
| with self.assertRaisesRegex(Exception, "Please convert all Tensors to FakeTensors"): |
| make_fx(m, tracing_mode="symbolic", _allow_non_fake_inputs=False)(torch.randn(2, 5)) |
| |
| def test_real_weights_in_symbolic_mode_with_inplace_ops(self): |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.buffer = torch.nn.Buffer(torch.ones(4, 5)) |
| |
| def forward(self, x): |
| y = self.buffer.add_(3) |
| y.resize_([20]) |
| assert y.shape == self.buffer.shape |
| return x.sum() + self.buffer.sum() |
| |
| m = M().eval() |
| inp = torch.randn(2, 5) |
| # inplace mutation on attr is not allowed |
| with self.assertRaisesRegex(Exception, "Can't call metadata"): |
| make_fx(m, tracing_mode="symbolic", _allow_non_fake_inputs=True)(inp) |
| |
| |
| def extract_graph(fx_g, _, graph_cell): |
| graph_cell[0] = fx_g |
| return fx_g |
| |
| |
| def get_ins_outs(fx_g): |
| ins = [] |
| outs = [] |
| for n in fx_g.graph.nodes: |
| if n.op == 'placeholder': |
| ins.append(n) |
| elif n.op == 'output': |
| outs = tuple(n.args[0]) |
| return ins, outs |
| |
| |
| def get_num_ins_outs(fx_g): |
| return tuple(len(i) for i in get_ins_outs(fx_g)) |
| |
| |
| def get_fw_bw_graph(f, inps, partitioner=min_cut_rematerialization_partition, dynamic=False): |
| fw_graph_cell = [None] |
| bw_graph_cell = [None] |
| aot_function(f, |
| fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), |
| bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), |
| partition_fn=partitioner, |
| decompositions=default_decompositions, |
| dynamic=dynamic)(*inps).sum().backward() |
| return (fw_graph_cell[0], bw_graph_cell[0]) |
| |
| class TestMod(torch.nn.Module): |
| def __init__(self, fn): |
| super().__init__() |
| self.p = torch.nn.Parameter(torch.ones(2, requires_grad=True)) |
| self.fn = fn |
| |
| def forward(self, *args): |
| return self.fn(self.p, *args) |
| |
| class TestAOTExport(AOTTestCase): |
| |
| def test_aot_export_module_joint(self): |
| class ConvBatchnormRelu(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(1, 3, 1, 1) |
| self.bn = torch.nn.BatchNorm2d(3) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = self.bn(x) |
| user_out = torch.nn.functional.relu(x) |
| loss = user_out.sum() |
| return loss, user_out.detach() |
| |
| mod = ConvBatchnormRelu() |
| mod.train() |
| inp = torch.randn(1, 1, 3, 3) |
| o_ref = mod(inp) |
| fx_g, signature = aot_export_module(mod, [inp], trace_joint=True, output_loss_index=0) |
| # Some important characteristics of the exported graph below: |
| # 8 arguments: 2 params from conv, 2 params from batchnorm, 2 buffers from 1 batchnorm, 1 user input |
| # 9 outputs: 3 mutated buffers (from batchnorm), 2 user outputs and 4 gradients (since there were 4 parameters) |
| self.assertExpectedInline(fx_g.print_readable(print_output=False), """\ |
| class <lambda>(torch.nn.Module): |
| def forward(self, arg0_1: f32[3, 1, 1, 1], arg1_1: f32[3], arg2_1: f32[3], arg3_1: f32[3], arg4_1: f32[3], arg5_1: f32[3], arg6_1: i64[], arg7_1: f32[1, 1, 3, 3]): |
| # No stacktrace found for following nodes |
| convolution: f32[1, 3, 3, 3] = torch.ops.aten.convolution.default(arg7_1, arg0_1, arg1_1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); arg1_1 = None |
| add: i64[] = torch.ops.aten.add.Tensor(arg6_1, 1); arg6_1 = None |
| _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(convolution, arg2_1, arg3_1, arg4_1, arg5_1, True, 0.1, 1e-05); arg3_1 = arg4_1 = arg5_1 = None |
| getitem: f32[1, 3, 3, 3] = _native_batch_norm_legit_functional[0] |
| getitem_1: f32[3] = _native_batch_norm_legit_functional[1] |
| getitem_2: f32[3] = _native_batch_norm_legit_functional[2] |
| getitem_3: f32[3] = _native_batch_norm_legit_functional[3] |
| getitem_4: f32[3] = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None |
| relu: f32[1, 3, 3, 3] = torch.ops.aten.relu.default(getitem); getitem = None |
| detach: f32[1, 3, 3, 3] = torch.ops.aten.detach.default(relu) |
| sum_1: f32[] = torch.ops.aten.sum.default(relu) |
| detach_1: f32[1, 3, 3, 3] = torch.ops.aten.detach.default(relu) |
| detach_2: f32[1, 3, 3, 3] = torch.ops.aten.detach.default(detach_1); detach_1 = None |
| ones_like: f32[] = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format) |
| expand: f32[1, 3, 3, 3] = torch.ops.aten.expand.default(ones_like, [1, 3, 3, 3]); ones_like = None |
| threshold_backward: f32[1, 3, 3, 3] = torch.ops.aten.threshold_backward.default(expand, relu, 0); expand = relu = None |
| native_batch_norm_backward = torch.ops.aten.native_batch_norm_backward.default(threshold_backward, convolution, arg2_1, getitem_3, getitem_4, getitem_1, getitem_2, True, 1e-05, [True, True, True]); threshold_backward = convolution = arg2_1 = getitem_1 = getitem_2 = None |
| getitem_5: f32[1, 3, 3, 3] = native_batch_norm_backward[0] |
| getitem_6: f32[3] = native_batch_norm_backward[1] |
| getitem_7: f32[3] = native_batch_norm_backward[2]; native_batch_norm_backward = None |
| convolution_backward = torch.ops.aten.convolution_backward.default(getitem_5, arg7_1, arg0_1, [3], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]); getitem_5 = arg7_1 = arg0_1 = None |
| getitem_8 = convolution_backward[0] |
| getitem_9: f32[3, 1, 1, 1] = convolution_backward[1] |
| getitem_10: f32[3] = convolution_backward[2]; convolution_backward = None |
| return (getitem_3, getitem_4, add, sum_1, detach_2, getitem_9, getitem_10, getitem_6, getitem_7) |
| """) # noqa: B950 |
| |
| |
| self.assertExpectedInline(str(signature.parameters), """['conv.weight', 'conv.bias', 'bn.weight', 'bn.bias']""") |
| self.assertExpectedInline(str(signature.buffers), """['bn.running_mean', 'bn.running_var', 'bn.num_batches_tracked']""") |
| self.assertExpectedInline(str(signature.user_inputs), """['arg7_1']""") |
| self.assertExpectedInline(str(signature.inputs_to_parameters), """{'arg0_1': 'conv.weight', 'arg1_1': 'conv.bias', 'arg2_1': 'bn.weight', 'arg3_1': 'bn.bias'}""") # noqa: B950 |
| self.assertExpectedInline(str(signature.inputs_to_buffers), """{'arg4_1': 'bn.running_mean', 'arg5_1': 'bn.running_var', 'arg6_1': 'bn.num_batches_tracked'}""") # noqa: B950 |
| self.assertExpectedInline(str(signature.buffers_to_mutate), """{'getitem_3': 'bn.running_mean', 'getitem_4': 'bn.running_var', 'add': 'bn.num_batches_tracked'}""") # noqa: B950 |
| self.assertExpectedInline(str(signature.backward_signature.gradients_to_parameters), """{'getitem_9': 'conv.weight', 'getitem_10': 'conv.bias', 'getitem_6': 'bn.weight', 'getitem_7': 'bn.bias'}""") # noqa: B950 |
| self.assertExpectedInline(str(signature.backward_signature.gradients_to_user_inputs), """{}""") |
| self.assertExpectedInline(str(signature.backward_signature.loss_output), """getitem_3""") |
| |
| # Also check the inference graph |
| # Main important thing here is that there are 5 total outputs: 3 total mutated buffers (from batchnorm), 2 user outputs. |
| fx_g_inference, signature_inference = aot_export_module(mod, [inp], trace_joint=False) |
| self.assertExpectedInline(fx_g_inference.print_readable(print_output=False), """\ |
| class <lambda>(torch.nn.Module): |
| def forward(self, arg0_1: f32[3, 1, 1, 1], arg1_1: f32[3], arg2_1: f32[3], arg3_1: f32[3], arg4_1: f32[3], arg5_1: f32[3], arg6_1: i64[], arg7_1: f32[1, 1, 3, 3]): |
| # No stacktrace found for following nodes |
| convolution: f32[1, 3, 3, 3] = torch.ops.aten.convolution.default(arg7_1, arg0_1, arg1_1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); arg7_1 = arg0_1 = arg1_1 = None |
| add: i64[] = torch.ops.aten.add.Tensor(arg6_1, 1); arg6_1 = None |
| _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(convolution, arg2_1, arg3_1, arg4_1, arg5_1, True, 0.1, 1e-05); convolution = arg2_1 = arg3_1 = arg4_1 = arg5_1 = None |
| getitem: f32[1, 3, 3, 3] = _native_batch_norm_legit_functional[0] |
| getitem_3: f32[3] = _native_batch_norm_legit_functional[3] |
| getitem_4: f32[3] = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None |
| relu: f32[1, 3, 3, 3] = torch.ops.aten.relu.default(getitem); getitem = None |
| sum_1: f32[] = torch.ops.aten.sum.default(relu) |
| detach: f32[1, 3, 3, 3] = torch.ops.aten.detach.default(relu); relu = None |
| return (getitem_3, getitem_4, add, sum_1, detach) |
| """) # noqa: B950 |
| # Some important characteristics of the exported graph below: |
| # 8 arguments: 2 params from conv, 2 params from batchnorm, 2 buffers from 1 batchnorm, 1 user input |
| # 9 outputs: 2 mutated buffers (from batchnorm), 2 user outputs and 4 gradients (since there were 4 parameters) |
| |
| def test_aot_export_simplified_basic(self): |
| def f(x, y): |
| return x * y, y * y.detach() |
| |
| x = torch.randn(2, requires_grad=True) |
| y = torch.randn(2, requires_grad=True) |
| |
| f_graph_fw = aot_export_joint_simple(f, [x, y], trace_joint=False) |
| out_ref = f(x, y) |
| # No calling convention changes necessary to invoke the traced graph |
| out_test = f_graph_fw(x, y) |
| self.assertEqual(out_ref, out_test) |
| |
| # Now test the backward |
| x = torch.randn(2, requires_grad=True) |
| y = torch.randn(2, requires_grad=True) |
| x2 = x.clone().detach().requires_grad_(True) |
| y2 = y.clone().detach().requires_grad_(True) |
| x3 = x.clone().detach().requires_grad_(True) |
| y3 = y.clone().detach().requires_grad_(True) |
| f_graph_joint = aot_export_joint_simple(f, [x, y], trace_joint=True) |
| num_fw_outputs = 2 |
| fw_g, bw_g = default_partition(f_graph_joint, [x, y], num_fwd_outputs=num_fw_outputs) |
| out_ref2 = f(x2, y2) |
| fw_outs = fw_g(x3, y3) |
| out_test2, activations = fw_outs[:num_fw_outputs], fw_outs[num_fw_outputs:] |
| self.assertEqual(out_ref2, out_test2) |
| |
| # Test running the traced backward graph with a mocked-up grad_output |
| grad_outs = [torch.ones_like(x) for x in out_ref2] |
| grads_ref = torch.autograd.grad(out_ref2, [x2, y2], grad_outputs=grad_outs) |
| grads_test = bw_g(*activations, *grad_outs) |
| for g_ref, g_test in zip(grads_ref, grads_test): |
| self.assertEqual(g_ref, g_test) |
| |
| def test_aot_export_metadata_mutation_banned(self): |
| def fn(p, x): |
| x.t_() |
| return (x * 2,) |
| mod = TestMod(fn) |
| inp = torch.randn(2) |
| with self.assertRaisesRegex( |
| RuntimeError, "Found an input that received a metadata mutation" |
| ): |
| aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False) |
| aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True) |
| aot_export_module(mod, [inp], trace_joint=False) |
| |
| def test_aot_export_forward_mutation_no_buffer_mut_banned(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.buffer1 = torch.nn.Buffer(torch.ones(6, 4)) |
| |
| def forward(self, x): |
| x.add_(4) |
| return (x.cos().sum() + self.buffer1.sum(),) |
| |
| with self.assertRaisesRegex(RuntimeError, "Found following user inputs located at \\[0\\] are mutated"): |
| aot_export_module(M(), [torch.ones(6, 4)], trace_joint=False) |
| |
| def test_aot_export_forward_mutation_multiple_mut_banned(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.buffer1 = torch.nn.Buffer(torch.ones(6, 4)) |
| |
| def forward(self, x, y): |
| y.add_(4) |
| self.buffer1.add_(5) |
| return (x.cos().sum() + y.sin().sum(), self.buffer1.sum(),) |
| |
| with self.assertRaisesRegex(RuntimeError, "Found following user inputs located at \\[1\\] are mutated"): |
| aot_export_module(M(), [torch.ones(6, 4), torch.zeros(6, 4)], trace_joint=False) |
| |
| def test_aot_export_input_mutation_on_parameter_banned(self): |
| def fn(p, x): |
| p.mul_(2) |
| return (p + x,) |
| mod = TestMod(fn) |
| inp = torch.randn(2) |
| with self.assertRaisesRegex( |
| RuntimeError, "Found a graph input that requires gradients, and received a mutation" |
| ): |
| aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False) |
| aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True) |
| aot_export_module(mod, [inp], trace_joint=False) |
| |
| def test_aot_export_synthetic_bases_banned(self): |
| def fn(p, x, y): |
| x.mul_(2) |
| return (x + y,) |
| mod = TestMod(fn) |
| inp = torch.randn(2) |
| inp2 = inp.view(-1) |
| with self.assertRaisesRegex( |
| RuntimeError, "Encountered aliased inputs that are mutated" |
| ): |
| aot_export_joint_simple(fn, [mod.p, inp, inp2], trace_joint=False) |
| aot_export_joint_simple(fn, [mod.p, inp, inp2], trace_joint=True) |
| aot_export_module(mod, [inp, inp2], trace_joint=False) |
| |
| def test_aot_export_input_dupes_banned(self): |
| def fn(p, x, y): |
| x.mul_(2) |
| return (x + y,) |
| mod = TestMod(fn) |
| inp = torch.randn(2) |
| with self.assertRaisesRegex( |
| RuntimeError, "Encountered duplicated inputs that are mutated in the graph" |
| ): |
| aot_export_joint_simple(fn, [mod.p, inp, inp], trace_joint=False) |
| aot_export_joint_simple(fn, [mod.p, inp, inp], trace_joint=True) |
| aot_export_module(mod, [inp, inp], trace_joint=False) |
| |
| def test_aot_export_multiple_outputs_require_grad_banned(self): |
| def fn(p, x): |
| out = p * x |
| return out, out.sum() |
| mod = TestMod(fn) |
| inp = torch.randn(2) |
| with self.assertRaisesRegex( |
| RuntimeError, "Found an output of the forward that requires gradients, that was not" |
| ): |
| aot_export_module(mod, [inp], trace_joint=True, output_loss_index=1) |
| |
| def test_aot_export_simplified_input_mutations_banned(self): |
| def fn(x): |
| x.mul_(2) |
| return (x + x,) |
| inp = torch.randn(2) |
| with self.assertRaisesRegex( |
| RuntimeError, "Found following user inputs located at \\[0\\] are mutated" |
| ): |
| aot_export_joint_simple(fn, [inp], trace_joint=False) |
| aot_export_joint_simple(fn, [inp], trace_joint=True) |
| |
| def test_aot_export_simplified_pytrees_banned(self): |
| def fn(inps): |
| return (inps[0] + inps[1],) |
| inp1 = torch.randn(2) |
| inp2 = torch.randn(2) |
| inps = [inp1, inp2] |
| with self.assertRaisesRegex( |
| RuntimeError, "aot_export_joint_simple requires individual inputs not to be pytrees" |
| ): |
| aot_export_joint_simple(fn, [inps], trace_joint=False) |
| aot_export_joint_simple(fn, [inps], trace_joint=True) |
| |
| def test_aot_export_functionalized_rng_banned(self): |
| def fn(p, x): |
| return (p + x,) |
| mod = TestMod(fn) |
| inp = torch.randn(2) |
| with patch("functorch.compile.config.functionalize_rng_ops", True), self.assertRaisesRegex( |
| RuntimeError, "Functionalized RNG is not currently supported in the aot_export" |
| ): |
| aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False) |
| aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True) |
| aot_export_module(mod, [inp], trace_joint=False) |
| |
| |
| class TestPartitioning(AOTTestCase): |
| @unittest.skipIf(not USE_NETWORKX, "networkx not available") |
| def test_recompute_partitioning(self): |
| def fn(a, b): |
| return torch.sin(torch.sin(a)) + b |
| |
| # Reference calculation |
| ref_a = torch.rand(10, 10, requires_grad=True) |
| ref_b = torch.rand(10, 10, requires_grad=True) |
| ref = fn(ref_a, ref_b) |
| ref.sum().backward() |
| |
| # Compiled function calculation |
| res_a = ref_a.clone().detach().requires_grad_(True) |
| res_b = ref_b.clone().detach().requires_grad_(True) |
| |
| def compile_fn(x, _): |
| return x |
| |
| compiled_fn = compiled_function(fn, compile_fn, compile_fn, min_cut_rematerialization_partition) |
| res = compiled_fn(res_a, res_b) |
| res.sum().backward() |
| assert torch.allclose(ref, res, atol=1e-3, rtol=1e-3) |
| assert torch.allclose(ref_a.grad, res_a.grad, atol=1e-3, rtol=1e-3) |
| assert torch.allclose(ref_b.grad, res_b.grad, atol=1e-3, rtol=1e-3) |
| |
| def test_meta_tensor_inplace_op(self): |
| # Following module results in inplace ops while tracing. The test checks |
| # that the meta tensor information is stored for inplace ops. |
| class MockModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.weight = torch.nn.Parameter(torch.randn(3072, 768, requires_grad=True)) |
| self.bias = torch.nn.Parameter(torch.randn(3072, requires_grad=True)) |
| |
| def forward(self, add_4): |
| linear_4 = torch.nn.functional.linear(add_4, self.weight, bias=self.bias) |
| gelu = torch.nn.functional.gelu(linear_4) |
| return gelu |
| |
| def check_meta_tensor(fx_g, _): |
| for node in fx_g.graph.nodes: |
| if node.op != 'output': |
| assert 'tensor_meta' in node.meta |
| return fx_g |
| |
| inp0 = torch.randn(16, 128, 768, requires_grad=True) |
| inputs = [inp0, ] |
| mod = MockModule().to(device="cpu") |
| aot_mod = aot_module(mod, fw_compiler=check_meta_tensor) |
| aot_mod(*inputs) |
| |
| def test_default_partitioner_getitem(self): |
| mod = nn.LayerNorm([10]) |
| |
| def f(x, mod_weight, mod_bias): |
| return torch.nn.functional.layer_norm(x, [10], mod_weight, mod_bias, eps=1e-6) |
| |
| fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(3, 10, requires_grad=True), mod.weight, mod.bias], |
| partitioner=default_partition) |
| self.assertEqual(get_num_ins_outs(fw_graph), (3, 6)) |
| self.assertEqual(get_num_ins_outs(bw_graph), (6, 3)) |
| |
| @unittest.skipIf(not USE_NETWORKX, "networkx not available") |
| def test_min_cut_partitioner_save_shape(self): |
| |
| def f(x): |
| s = x.sum(dim=1) |
| return s |
| |
| inp = [torch.ones([10, 10], requires_grad=True)] |
| fw_graph, bw_graph = get_fw_bw_graph(f, inp, dynamic=True) |
| _, fw_output = get_ins_outs(fw_graph) |
| self.assertEqual(get_num_ins_outs(fw_graph), (1, 3)) |
| self.assertEqual(get_num_ins_outs(bw_graph), (3, 1)) |
| self.assertEqual(str(fw_output[0]), "sum_1") |
| # make sure we don't do the suboptimal thing of saving the bigger primals input to sum, |
| # rather than saving the sizes of the primals input for use in backward expand |
| self.assertEqual(str(fw_output[1]), "sym_size") |
| self.assertEqual(str(fw_output[2]), "sym_size_1") |
| |
| inp = [ |
| torch.randn(10, requires_grad=True), |
| torch.randn((3, 10), requires_grad=True), |
| torch.randn((2, 10), requires_grad=True), |
| ] |
| |
| def f(a, b, c): |
| # tried to test what happens if we save a size tuple in the graph; |
| # turns out we never will due to how we trace, but this is probably |
| # still a good test case for various size manipulations |
| sb = torch.ops.aten.sym_size(b) |
| sc = c.size() |
| x = sb[0] + sc[0] |
| a_sz = (x, a.size(0)) |
| return torch.cat([a.expand(a_sz), b, c]) |
| fw_graph, bw_graph = get_fw_bw_graph(f, inp, dynamic=True) |
| self.assertEqual(get_num_ins_outs(fw_graph), (3, 4)) |
| self.assertEqual(get_num_ins_outs(bw_graph), (4, 3)) |
| _, outs = get_ins_outs(fw_graph) |
| self.assertTrue(all(is_sym_node(n) for n in outs[1:])) |
| |
| def test_default_partitioner_output_tensor_shape_tensor(self): |
| |
| inp = [ |
| torch.randn(10, requires_grad=True), |
| torch.randn((3, 10), requires_grad=True), |
| torch.randn((2, 10), requires_grad=True), |
| torch.randn((10, 1), requires_grad=True), |
| ] |
| |
| def f(a, b, c, d): |
| # Try to force symints intermixed with outputs in the function's returns |
| sb = b.size() |
| sc = c.size() |
| x = sb[0] + sc[0] |
| a_sz = (x, a.size(0)) |
| cat = torch.cat([a.expand(a_sz), b, c]) |
| mm = torch.mm(cat, d) |
| mm2 = torch.mm(mm, a.view(mm.size(1), a.size(0))) # this saves 4 new ints for backward. why? |
| # and what do i have to do to make it save a tensor for backward? |
| return cat, sb, c, mm2 |
| |
| fw_graph_cell = [None] |
| bw_graph_cell = [None] |
| compiled_outs = aot_function( |
| f, |
| fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), |
| bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), |
| partition_fn=default_partition, |
| decompositions=default_decompositions, |
| dynamic=True)(*inp) |
| fw_graph = fw_graph_cell[0] |
| (compiled_outs[0].sum() + compiled_outs[2].sum()).backward() |
| bw_graph = bw_graph_cell[0] |
| |
| # in the fwd graph, 13 outs because: |
| # - 5 original outputs (sb is a tuple, gets expanded to 2 symints) |
| # - 8 saved outputs for backward: 5 tensors, 3 symints |
| self.assertEqual(get_num_ins_outs(fw_graph), (4, 13)) |
| # in the bwd graph, 10 inputs (grad outs) because: |
| # - The fwd graph had 13 outputs |
| # - 1 was a view of an input, which gets regenerated outside of the graph |
| # and doesn't participate in the backward |
| # - 2 user outs were symints (b.size()), which don't get tangents in the backward |
| self.assertEqual(get_num_ins_outs(bw_graph), (10, 4)) |
| _, fw_graph_out_nodes = get_ins_outs(fw_graph) |
| self.assertEqual( |
| # fw outputs include b.size() which expands to 2 symints, |
| # |
| # TODO(whc)- are the saved-tensors/saved-symints correct here? |
| # i just made the test pass based on what default partition did |
| # Of the 5 original forward outputs, the 4th (c) is an input, |
| # which won't show up in the compiled forward graph |
| [False, True, True, False, False] + [False] * 4 + [True] * 4, |
| [is_sym_node(n) for n in fw_graph_out_nodes] |
| ) |
| |
| real_outs = f(*inp) |
| self.assertEqual(compiled_outs, real_outs) |
| self.assertTrue(isinstance(real_outs[1], torch.Size)) |
| |
| # TODO(whc) we should learn to return torch.Sizes |
| self.assertFalse(isinstance(compiled_outs[1], torch.Size)) |
| |
| @unittest.skipIf(not USE_NETWORKX, "networkx not available") |
| def test_min_cut_partitioner_output_tensor_shape_tensor(self): |
| |
| inp = [ |
| torch.randn(10, requires_grad=True), |
| torch.randn((3, 10), requires_grad=True), |
| torch.randn((2, 10), requires_grad=True), |
| torch.randn((10, 1), requires_grad=True), |
| ] |
| |
| def f(a, b, c, d): |
| # Try to force symints intermixed with outputs in the function's returns |
| sb = b.size() |
| sc = c.size() |
| x = sb[0] + sc[0] |
| a_sz = (x, a.size(0)) |
| cat = torch.cat([a.expand(a_sz), b, c]) |
| mm = torch.mm(cat, d) |
| mm2 = torch.mm(mm, a.view(mm.size(1), a.size(0))) # this saves 4 new ints for backward. why? |
| # and what do i have to do to make it save a tensor for backward? |
| return cat, sb, c, mm2 |
| |
| fw_graph_cell = [None] |
| bw_graph_cell = [None] |
| compiled_outs = aot_function( |
| f, |
| fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), |
| bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), |
| partition_fn=min_cut_rematerialization_partition, |
| decompositions=default_decompositions, |
| dynamic=True)(*inp) |
| fw_graph = fw_graph_cell[0] |
| (compiled_outs[0].sum() + compiled_outs[2].sum()).backward() |
| bw_graph = bw_graph_cell[0] |
| |
| self.assertEqual(get_num_ins_outs(fw_graph), (4, 12)) |
| self.assertEqual(get_num_ins_outs(bw_graph), (9, 4)) |
| _, fw_graph_out_nodes = get_ins_outs(fw_graph) |
| self.assertEqual( |
| # fw outputs include b.size() which expands to 2 symints, |
| # then 4 tensors (transposes of matricies used for mm) are saved |
| # finally 3 symints are saved |
| [False, True, True, False, False] + [False] * 4 + [True] * 3, |
| [is_sym_node(n) for n in fw_graph_out_nodes] |
| ) |
| |
| real_outs = f(*inp) |
| self.assertEqual(compiled_outs, real_outs) |
| self.assertTrue(isinstance(real_outs[1], torch.Size)) |
| |
| # TODO(whc) we should learn to return torch.Sizes |
| self.assertFalse(isinstance(compiled_outs[1], torch.Size)) |
| |
| @unittest.skipIf(not USE_NETWORKX, "networkx not available") |
| def test_min_cut_partitioner(self): |
| def f(x): |
| return x.cos().cos().cos() |
| |
| fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(3, requires_grad=True)]) |
| self.assertEqual(get_num_ins_outs(fw_graph), (1, 2)) |
| self.assertEqual(get_num_ins_outs(bw_graph), (2, 1)) |
| |
| def f(a, b, c, d): |
| x = a + b + c + d |
| return x.cos().cos() |
| |
| fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(3, requires_grad=True) for _ in range(4)]) |
| self.assertEqual(get_num_ins_outs(fw_graph), (4, 2)) |
| self.assertEqual(get_num_ins_outs(bw_graph), (2, 4)) |
| |
| @unittest.skipIf(not USE_NETWORKX, "networkx not available") |
| def test_min_cut_partitioner_recomputable_ops(self): |
| def f(x): |
| return x * x * x |
| |
| recomputable_ops = [] |
| partition_fn = partial(min_cut_rematerialization_partition, recomputable_ops=recomputable_ops) |
| |
| fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(3, requires_grad=True)], partition_fn) |
| # Expected forward graph: |
| # opcode name target args kwargs |
| # ------------- --------- --------------- -------------------------- -------- |
| # placeholder primals_1 primals_1 () {} |
| # call_function mul aten.mul.Tensor (primals_1, primals_1) {} |
| # call_function mul_1 aten.mul.Tensor (mul, primals_1) {} |
| # output output output ([mul_1, primals_1, mul],) {} |
| self.assertEqual(get_num_ins_outs(fw_graph), (1, 3)) |
| # Expected backward graph: |
| # opcode name target args kwargs |
| # ------------- ---------- --------------- ----------------------- -------- |
| # placeholder primals_1 primals_1 () {} |
| # placeholder mul mul () {} |
| # placeholder tangents_1 tangents_1 () {} |
| # call_function mul_2 aten.mul.Tensor (tangents_1, mul) {} |
| # call_function mul_3 aten.mul.Tensor (tangents_1, primals_1) {} |
| # call_function mul_4 aten.mul.Tensor (mul_3, primals_1) {} |
| # call_function add aten.add.Tensor (mul_2, mul_4) {} |
| # call_function add_1 aten.add.Tensor (add, mul_4) {} |
| # output output output ([add_1],) {} |
| self.assertEqual(get_num_ins_outs(bw_graph), (3, 1)) |
| |
| recomputable_ops = [torch.ops.aten.mul] |
| partition_fn = partial(min_cut_rematerialization_partition, recomputable_ops=recomputable_ops) |
| fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(3, requires_grad=True)], partition_fn) |
| # Expected forward graph: |
| # opcode name target args kwargs |
| # ------------- --------- --------------- ---------------------- -------- |
| # placeholder primals_1 primals_1 () {} |
| # call_function mul aten.mul.Tensor (primals_1, primals_1) {} |
| # call_function mul_1 aten.mul.Tensor (mul, primals_1) {} |
| # output output output ([mul_1, primals_1],) {} |
| self.assertEqual(get_num_ins_outs(fw_graph), (1, 2)) |
| # Expected backward graph: |
| # opcode name target args kwargs |
| # ------------- ---------- --------------- ----------------------- -------- |
| # placeholder primals_1 primals_1 () {} |
| # placeholder tangents_1 tangents_1 () {} |
| # call_function mul aten.mul.Tensor (primals_1, primals_1) {} # RECOMPUTED |
| # call_function mul_2 aten.mul.Tensor (tangents_1, mul) {} |
| # call_function mul_3 aten.mul.Tensor (tangents_1, primals_1) {} |
| # call_function mul_4 aten.mul.Tensor (mul_3, primals_1) {} |
| # call_function add aten.add.Tensor (mul_2, mul_4) {} |
| # call_function add_1 aten.add.Tensor (add, mul_4) {} |
| # output output output ([add_1],) {} |
| self.assertEqual(get_num_ins_outs(bw_graph), (2, 1)) |
| |
| def test_contiguous(self): |
| # The test simulates the condition where transpose followed by view |
| # happens in the backward pass. |
| # https://discuss.pytorch.org/t/error-on-transpose-and-view/434 |
| def f(x): |
| return x.view(2, 3).t() |
| |
| inp = torch.randn(6, requires_grad=True) |
| out = aot_function(f, nop)(inp) |
| torch.autograd.grad(out, inp, torch.randn(3, 2)) |
| |
| def test_preserve_random(self): |
| def fn(x): |
| return torch.nn.functional.dropout(x, 0.5) + x |
| |
| x = torch.randn(4) |
| |
| torch.manual_seed(0) |
| ref = fn(x) |
| |
| torch.manual_seed(0) |
| aot_fn = aot_function(fn, nop) |
| res = aot_fn(x) |
| |
| assert torch.allclose(ref, res) |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") |
| @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision") |
| def test_autocast(self): |
| mod = torchvision.models.resnet18().cuda() |
| mod.train() |
| |
| x = torch.randn(16, 3, 32, 32, device="cuda") |
| aot_mod = memory_efficient_fusion(mod) |
| |
| # Ensure that AOT Autograd works with AMP |
| with torch.cuda.amp.autocast(True): |
| res = aot_mod(x) |
| res.sum().backward() |
| |
| |
| class TestAOTModuleSimplified(AOTTestCase): |
| def test_aot_module_simplified(self): |
| class MockModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(20, 30) |
| |
| def forward(self, x, y): |
| return (self.linear(x) + y, ) |
| |
| mod = MockModule() |
| mod.zero_grad() |
| |
| x = torch.randn(128, 20, requires_grad=True) |
| y = torch.randn(128, 30, requires_grad=True) |
| inputs = [x, y] |
| cloned_inputs = [x.detach().clone().requires_grad_(True) for x in inputs] |
| |
| ref = mod(*inputs) |
| ref[0].sum().backward() |
| |
| compiled_f = aot_module_simplified(mod, cloned_inputs, nop) |
| mod.zero_grad() |
| res = compiled_f(*cloned_inputs) |
| res[0].sum().backward() |
| |
| assert torch.allclose(ref[0], res[0]) |
| assert torch.allclose(inputs[0].grad, cloned_inputs[0].grad) |
| assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad) |
| |
| def test_aot_module_simplified_dynamic(self): |
| class MockModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(20, 30) |
| |
| def forward(self, x, y): |
| return (self.linear(x) + y, ) |
| |
| mod = MockModule() |
| |
| shape_env = ShapeEnv() |
| fake_mode = FakeTensorMode(shape_env=shape_env) |
| |
| x = torch.randn(128, 20, requires_grad=True) |
| y = torch.randn(128, 30, requires_grad=True) |
| |
| inputs = [x, y] |
| fake_inputs = [fake_mode.from_tensor(x) for x in inputs] |
| compiled_f = aot_module_simplified(mod, fake_inputs, nop) |
| |
| ref = mod(*inputs) |
| ref[0].sum().backward() |
| |
| cloned_inputs = [x.detach().clone().requires_grad_(True) for x in inputs] |
| res = compiled_f(*cloned_inputs) |
| res[0].sum().backward() |
| |
| self.assertExpectedInline(shape_env.format_guards(), """\ |
| - Eq(s1, 20) |
| - Eq(s2, 30)""") |
| |
| assert torch.allclose(ref[0], res[0]) |
| assert torch.allclose(inputs[0].grad, cloned_inputs[0].grad) |
| assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad) |
| |
| def test_inference_python_dispatcher(self): |
| # Extracted from unet |
| class MockModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.upsample = torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) |
| |
| def forward(self, x): |
| return (self.upsample(x), ) |
| |
| mod = MockModule() |
| shape_env = ShapeEnv() |
| fake_mode = FakeTensorMode(shape_env=shape_env) |
| x = torch.randn(2, 512, 40, 59) # NB: must not require grad |
| inputs = [x] |
| fake_inputs = [fake_mode.from_tensor(x) for x in inputs] |
| compiled_f = aot_module_simplified(mod, fake_inputs, nop) |
| |
| def test_aot_module_simplified_preserves_stack_trace(self): |
| class MockModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(20, 30) |
| |
| def forward(self, x, y): |
| z = self.linear(x) |
| z = z + y |
| z = z.relu() |
| return (z, ) |
| |
| tracer = torch.fx.Tracer() |
| tracer.record_stack_traces = True |
| graph = tracer.trace(MockModule()) |
| mod = torch.fx.GraphModule(tracer.root, graph) |
| |
| for node in mod.graph.nodes: |
| if node.op == 'output': |
| continue |
| self.assertTrue(node.stack_trace is not None) |
| assert 'test_aotdispatch.py' in node.stack_trace |
| |
| def assert_compiler(gm: torch.fx.GraphModule, _): |
| for node in gm.graph.nodes: |
| if node.op == 'output' or node.op == 'placeholder': |
| continue |
| self.assertTrue(node.stack_trace is not None) |
| assert 'test_aotdispatch.py' in node.stack_trace |
| return gm.forward # return a python callable |
| |
| x = torch.randn(128, 20, requires_grad=True) |
| y = torch.randn(128, 30, requires_grad=True) |
| inputs = [x, y] |
| |
| compiled_f = aot_module_simplified(mod, inputs, fw_compiler=assert_compiler, bw_compiler=assert_compiler) |
| res = compiled_f(*inputs) |
| res[0].sum().backward() |
| |
| def test_aot_module_simplified_fake_tensor_gm_raises(self): |
| fake_mode = torch._subclasses.fake_tensor.FakeTensorMode() |
| real_x = torch.randn(4, requires_grad=True) |
| fake_x = fake_mode.from_tensor(real_x) |
| real_z = torch.randn(4) |
| fake_z = fake_mode.from_tensor(real_z) |
| |
| class MockModule(torch.nn.Module): |
| def forward(self, x): |
| # Accessing a free variable fake tensor will look like a |
| # constant to make_fx, and result in the tensor being traced |
| # into the graph, which is an error condition. Make sure we |
| # report adequately in this case. |
| return (x + fake_z, ) |
| |
| with self.assertRaisesRegex( |
| AssertionError, "Unexpected fake" |
| ): |
| aot_module_simplified(MockModule(), (fake_x,), nop) |
| |
| |
| # entries in here don't work and need to be fixed. |
| # Each one of these is a bug (or needs to be investigated) |
| aot_autograd_failures = { |
| # data-dependent control flow |
| xfail('cov'), |
| xfail('nn.functional.gaussian_nll_loss'), |
| xfail('tensor_split'), |
| xfail('corrcoef'), |
| xfail('quantile'), |
| xfail('nanquantile'), |
| xfail('narrow'), |
| xfail('index_reduce'), |
| xfail('istft'), |
| xfail('linalg.eig'), |
| xfail('scatter_reduce', 'prod'), |
| |
| skip('as_strided_scatter'), |
| skip('as_strided', 'partial_views'), # flaky |
| |
| # Given input size: (s0xs1x2). Calculated output size: ... |
| skip('max_pool2d_with_indices_backward'), |
| |
| # Worked with real but not with fake |
| xfail('_segment_reduce', 'lengths'), |
| skip('nn.functional.nll_loss', ''), # UBSAN failure! |
| |
| # Misc |
| xfail('to_sparse'), |
| xfail('corrcoef'), |
| xfail('cov'), |
| xfail('chalf'), # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf' |
| xfail('sparse.sampled_addmm'), |
| xfail('normal', 'number_mean'), # TypeError: randn_like(): argument 'input' (position 1) must be Tensor, not float |
| xfail('sparse.mm', 'reduce'), |
| skip('nn.functional.binary_cross_entropy_with_logits'), # seems to fail sometimes? |
| skip('nn.functional.margin_ranking_loss'), # seems flaky |
| skip('linalg.lu_solve'), # flaky |
| decorate('matmul', decorator=unittest.skipIf(IS_ARM64, 'flaky')), |
| decorate('__rmatmul__', decorator=unittest.skipIf(IS_ARM64, 'flaky')), |
| # overrides atol=1e-4, rtol=1e-5 would do as well |
| decorate('svd_lowrank', decorator=toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-05)})), |
| decorate('linalg.householder_product', decorator=unittest.skipIf(IS_MACOS and IS_X86, 'flaky')), |
| decorate('linalg.pinv', 'singular', decorator=toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-05)})), |
| # conv2d sometimes nondeterministic in this config? |
| decorate('nn.functional.conv2d', decorator=unittest.skipIf(IS_ARM64, "flaky")), |
| } |
| |
| symbolic_aot_autograd_failures = { |
| xfail('block_diag', ''), # Cannot call sizes() on tensor with symbolic sizes/strides |
| xfail('combinations', ''), # aten.masked_select.default |
| xfail('diff', ''), # aten.zeros_like.default - couldn't find symbolic meta function/decomposition |
| xfail('frexp', ''), # aten.frexp.Tensor - couldn't find symbolic meta function/decomposition |
| xfail('gradient', ''), # Cannot call sizes() on tensor with symbolic sizes/strides |
| xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition |
| xfail('index_fill', ''), # Cannot call sizes() on tensor with symbolic sizes/strides |
| xfail('kron', ''), # Cannot call sizes() on tensor with symbolic sizes/strides |
| xfail('kthvalue', ''), # Cannot call sizes() on tensor with symbolic sizes/strides |
| xfail('linalg.eigvals', ''), # aten.linalg_eig.default - couldn't find symbolic meta function/decomposition |
| xfail('linalg.lstsq', ''), # aten.linalg_lstsq.default - couldn't find symbolic meta function/decomposition |
| xfail('linalg.lstsq', 'grad_oriented'), # aten.linalg_lstsq.default - couldn't find symbolic meta funct... |
| xfail('linalg.lu_solve', ''), # aten.linalg_lu_solve.default - couldn't find symbolic meta function/deco... |
| xfail('linalg.multi_dot', ''), # Cannot call sizes() on tensor with symbolic sizes/strides |
| xfail('masked.prod', ''), # Cannot call sizes() on tensor with symbolic sizes/strides |
| xfail('masked_scatter', ''), # Cannot call sizes() on tensor with symbolic sizes/strides |
| xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decompos... |
| xfail('nn.functional.adaptive_max_pool2d', ''), # aten.adaptive_max_pool2d.default - couldn't find symbo... |
| xfail('nn.functional.adaptive_max_pool3d', ''), # argument 'output_size' (position 2... |
| skip('nn.functional.batch_norm', ''), # '0 is not tracked with proxy for <torch.fx.experimental.proxy_te.. |
| xfail('nn.functional.binary_cross_entropy', ''), # aten.fill_.Scalar - couldn't find symbolic meta funct... |
| xfail('nn.functional.cross_entropy', ''), # Cannot call sizes() on tensor with symbolic sizes/strides |
| xfail('nn.functional.ctc_loss', ''), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/deco... |
| xfail('nn.functional.embedding_bag', ''), # Cannot call sizes() on tensor with symbolic sizes/strides |
| xfail('nn.functional.fractional_max_pool2d', ''), # rand() received an invalid combination of arguments - g... |
| xfail('nn.functional.fractional_max_pool3d', ''), # rand() received an invalid combination of arguments - g... |
| xfail('nn.functional.group_norm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides |
| xfail('nn.functional.interpolate', 'linear'), # Cannot call sizes() on tensor with symbolic sizes/strides |
| xfail('nn.functional.interpolate', 'trilinear'), # Cannot call sizes() on tensor with symbolic sizes/st... |
| xfail('nn.functional.nll_loss', ''), # Cannot call sizes() on tensor with symbolic sizes/strides |
| xfail('nn.functional.pixel_shuffle', ''), # aten.pixel_shuffle.default - couldn't find symbolic meta fun... |
| xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta... |
| xfail('nn.functional.rrelu', ''), # aten.rrelu_with_noise.default - couldn't find symbolic meta function... |
| xfail('normal', 'number_mean'), # Cannot call sizes() on tensor with symbolic sizes/strides |
| xfail('prod', ''), # Cannot call numel() on tensor with symbolic sizes/strides |
| xfail('repeat_interleave', ''), # aten.repeat_interleave.Te... |
| xfail('_segment_reduce', 'lengths'), # aten.segment_reduce.default - couldn't find symbolic meta functio... |
| xfail('_segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta functio... |
| xfail('sgn', ''), # Cannot call sizes() on tensor with symbolic sizes/strides |
| xfail('special.i1', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition |
| xfail('take_along_dim', ''), # Cannot call sizes() on tensor with symbolic sizes/strides |
| xfail('trace', ''), # Cannot call sizes() on tensor with symbolic sizes/strides |
| xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/de... |
| xfail('_upsample_bilinear2d_aa'), # RuntimeError: isIntList() INTERNAL ASSERT FAILED Expected IntList but got GenericList |
| decorate('linalg.householder_product', decorator=unittest.skipIf(IS_MACOS and IS_X86, 'flaky')), |
| |
| # many complex operators incorrect striding, metadata |
| xfail('fft.fft', ''), |
| xfail('fft.hfft2', ''), |
| xfail('fft.hfft', ''), |
| xfail('fft.hfftn', ''), |
| xfail('fft.ifft', ''), |
| xfail('fft.ihfft2', ''), |
| xfail('fft.ihfft', ''), |
| xfail('fft.ihfftn', ''), |
| xfail('fft.irfft2', ''), |
| xfail('fft.irfft', ''), |
| xfail('fft.irfftn', ''), |
| xfail('fft.rfft2', ''), |
| xfail('fft.rfft', ''), |
| xfail('fft.rfftn', ''), |
| |
| xfail('stft', ''), # Cannot call sizes() on tensor with symbolic sizes/strides |
| } |
| |
| def _test_aot_autograd_helper(self, device, dtype, op, dynamic=False): |
| if not op.supports_autograd: |
| self.skipTest("Op does not support autograd") |
| |
| # aot_autograd_check is able to check data specialization by |
| # randomizing the inputs. Here's a list of ops that really do not |
| # like random inputs for which we want to disable that. |
| cant_check_data_specialization = set({ |
| 'nn.functional.max_unpool1d', |
| 'nn.functional.max_unpool2d', |
| 'nn.functional.max_unpool3d', |
| }) |
| try_check_data_specialization = op.name not in cant_check_data_specialization |
| |
| sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=True) |
| for sample_input in sample_inputs_itr: |
| t_args = [sample_input.input] + list(sample_input.args) |
| t_kwargs = sample_input.kwargs |
| try: |
| aot_autograd_check( |
| op.op, t_args, t_kwargs, dynamic, |
| self.assertRaisesRegex, self.assertEqual, |
| check_gradients=True, |
| try_check_data_specialization=try_check_data_specialization) |
| except DynamicOutputShapeException: |
| self.skipTest("Dynamic output shape operation in trace") |
| except GuardOnDataDependentSymNode: |
| # Carveout for getitem; I don't want to xfail the entire test |
| # because that will reject known to be good tests see |
| # https://github.com/pytorch/pytorch/issues/94705 |
| if op.name == "__getitem__": |
| self.skipTest("Dynamic output shape operation in trace") |
| else: |
| raise |
| |
| def _test_aot_autograd_module_helper(self, device, dtype, training, module_info, *, dynamic=False): |
| module_cls = module_info.module_cls |
| module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, |
| requires_grad=True, training=training) |
| for module_input in module_inputs: |
| if module_input.forward_input is None: |
| continue |
| |
| args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs |
| m = module_cls(*args, **kwargs) |
| m.to(device).to(dtype) |
| m.train(training) |
| |
| # Lazy modules need to see an input first to initialize params. |
| args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs |
| flat_args, args_spec = pytree.tree_flatten((args, kwargs)) |
| |
| # PackedSequence is only used for RNNs. It might be possible to fake-ify if they're pytrees but |
| # torchdynamo already doesn't support RNNs |
| if any(tuple(isinstance(flat_arg, PackedSequence) for flat_arg in flat_args)): |
| continue |
| |
| if issubclass(module_info.module_cls, torch.nn.modules.lazy.LazyModuleMixin): |
| with torch.no_grad(): |
| m(*args, **kwargs) |
| |
| sentinel_val = -42 |
| is_tensor_spec = [sentinel_val if isinstance(arg, torch.Tensor) |
| else arg for arg in flat_args] |
| args = [arg for arg in flat_args if isinstance(arg, torch.Tensor)] |
| |
| def f(params_buffers_args): |
| named_params, named_buffers, args = params_buffers_args |
| cur_flat_args = list(is_tensor_spec) |
| args = iter(args) |
| for idx, v in enumerate(cur_flat_args): |
| if v == sentinel_val: |
| cur_flat_args[idx] = next(args) |
| c_args, c_kwargs = pytree.tree_unflatten(cur_flat_args, args_spec) |
| params_and_buffers = {**named_params, **named_buffers} |
| return torch.func.functional_call(m, params_and_buffers, c_args, c_kwargs) |
| |
| named_params = dict(m.named_parameters(remove_duplicate=False)) |
| named_buffers = dict(m.named_buffers(remove_duplicate=False)) |
| num_params_buffers = len(named_params) + len(named_buffers) |
| compiled_f = aot_function(f, nop, num_params_buffers=num_params_buffers, dynamic=dynamic) |
| params_buffers_args = [named_params, named_buffers, args] |
| _test_aot_autograd_forwards_backwards_helper( |
| f, compiled_f, params_buffers_args, |
| self.assertRaisesRegex, self.assertEqual, True) |
| |
| |
| class TestEagerFusionOpInfo(AOTTestCase): |
| @ops(op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,)) |
| @skipOps('TestEagerFusionOpInfo', 'test_aot_autograd_exhaustive', aot_autograd_failures) |
| def test_aot_autograd_exhaustive(self, device, dtype, op): |
| _test_aot_autograd_helper(self, device, dtype, op) |
| |
| @ops(op_db + control_flow_opinfo_db, allowed_dtypes=(torch.float,)) |
| @patch("functorch.compile.config.debug_assert", True) |
| @skipOps('TestEagerFusionOpInfo', 'test_aot_autograd_symbolic_exhaustive', |
| aot_autograd_failures | symbolic_aot_autograd_failures) |
| def test_aot_autograd_symbolic_exhaustive(self, device, dtype, op): |
| _test_aot_autograd_helper(self, device, dtype, op, dynamic=True) |
| |
| |
| aot_autograd_module_failures = set({ |
| torch.nn.GaussianNLLLoss, # RuntimeError: It appears that you're trying to get value out |
| # of a tracing tensor with aten._local_scalar_dense.default - |
| # erroring out! It's likely that this is caused by data-dependent |
| # control flow or similar. |
| torch.nn.TransformerEncoder, # DataDependentOutputException: aten.eq compares a mask input |
| # to a causal mask tensor, to see if Boolean is_causal should be set |
| # for TrnasformerEncoder layers, MHA and sdp custom kernels |
| torch.nn.Transformer, # DataDependentOutputException: aten.equal compares a mask input |
| # to a causal mask tensor, to see if Boolean is_causal should be set |
| # for TransformerEncoder layers, MHA and sdp custom kernels |
| # (this bubbles up to Transformer) |
| }) |
| |
| symbolic_aot_autograd_module_failures = { |
| torch.nn.Transformer, # DataDependentOutputException: aten.equal compares a mask input to a mask producing a bool |
| torch.nn.TransformerEncoder, # DataDependentOutputException: aten.equal compares a mask input to a mask producing a bool |
| torch.nn.GaussianNLLLoss, # NotImplementedError: local_scalar_dense/item NYI for torch.bool |
| torch.nn.AdaptiveMaxPool2d, # Cannot call sizes() on tensor with symbolic sizes/strides |
| torch.nn.AdaptiveMaxPool3d, # Cannot call sizes() on tensor with symbolic sizes/strides |
| torch.nn.GroupNorm, # in native_group_norm_backward cpg, _rem = divmod(C, group) |
| # TypeError: unsupported operand type(s) for divmod(): 'SymInt' and 'int' |
| torch.nn.FractionalMaxPool2d, # int() argument must be a string, a bytes-like object or a number, not 'SymFloat' |
| torch.nn.FractionalMaxPool3d, # int() argument must be a string, a bytes-like object or a number, not 'SymFloat' |
| } |
| |
| |
| class TestEagerFusionModuleInfo(AOTTestCase): |
| @modules(module_db, allowed_dtypes=(torch.float,)) |
| @decorateForModules(unittest.expectedFailure, aot_autograd_module_failures) |
| def test_aot_autograd_module_exhaustive(self, device, dtype, training, module_info): |
| _test_aot_autograd_module_helper(self, device, dtype, training, module_info) |
| |
| @modules(module_db, allowed_dtypes=(torch.float,)) |
| @decorateForModules(unittest.expectedFailure, |
| aot_autograd_module_failures | symbolic_aot_autograd_module_failures) |
| def test_aot_autograd_symbolic_module_exhaustive(self, device, dtype, training, module_info): |
| _test_aot_autograd_module_helper(self, device, dtype, training, module_info, dynamic=True) |
| |
| |
| only_for = ("cpu") |
| instantiate_device_type_tests( |
| TestPythonKey, |
| globals(), |
| only_for=only_for, |
| ) |
| instantiate_device_type_tests(TestEagerFusionOpInfo, globals(), only_for=only_for) |
| instantiate_device_type_tests(TestEagerFusionModuleInfo, globals(), only_for=only_for) |
| |
| |
| if __name__ == '__main__': |
| run_tests() |