| # Owner(s): ["module: dynamo"] |
| import abc |
| import collections |
| import copy |
| import dataclasses |
| import dis |
| import enum |
| import logging |
| import math |
| import operator |
| import os |
| import sys |
| import typing |
| import unittest |
| import unittest.mock as mock |
| import weakref |
| from unittest.mock import patch |
| |
| import numpy as np |
| import torch |
| |
| import torch._dynamo.test_case |
| import torch._dynamo.testing |
| import torch.onnx.operators |
| from torch._C import FileCheck |
| from torch._dynamo import bytecode_analysis, bytecode_transformation |
| from torch._dynamo.output_graph import OutputGraph |
| from torch._dynamo.source import GetItemSource, LocalSource |
| from torch._dynamo.testing import ( |
| CompileCounter, |
| requires_static_shapes, |
| same, |
| skipIfNotPy311, |
| unsupported, |
| ) |
| |
| from torch._dynamo.utils import CompileProfiler, ifdyn, ifunspec |
| from torch.ao.quantization import MinMaxObserver |
| from torch.ao.quantization.fake_quantize import FakeQuantize |
| from torch.ao.quantization.qconfig import QConfig |
| from torch.ao.quantization.quantize_fx import prepare_qat_fx |
| from torch.autograd.profiler import _enable_dynamo_cache_lookup_profiler |
| from torch.fx.experimental.symbolic_shapes import ConstraintViolationError |
| from torch.nn import functional as F |
| from torch.testing._internal.common_cuda import ( |
| PLATFORM_SUPPORTS_FUSED_SDPA, |
| SM80OrLater, |
| ) |
| from torch.testing._internal.common_utils import freeze_rng_state |
| from torch.testing._internal.jit_utils import JitTestCase |
| |
| mytuple = collections.namedtuple("mytuple", ["a", "b", "ab"]) |
| |
| |
| class MyPickledModule(torch.nn.Module): |
| def __init__(self, z): |
| super().__init__() |
| self.z = z |
| |
| def forward(self, x, y): |
| return x * x * x + y + self.z |
| |
| |
| # These are used for test_{cond/map}_with_quantization |
| default_symmetric_fake_quant = FakeQuantize.with_args( |
| observer=MinMaxObserver, qscheme=torch.per_tensor_symmetric, dtype=torch.quint8 |
| ) |
| default_weight_symmetric_fake_quant = FakeQuantize.with_args( |
| observer=MinMaxObserver, qscheme=torch.per_tensor_symmetric, dtype=torch.qint8 |
| ) |
| uniform_qconfig_8bit = QConfig( |
| activation=default_symmetric_fake_quant, |
| weight=default_weight_symmetric_fake_quant.with_args, |
| ) |
| qconfig_dict = {"object_type": [(torch.nn.Linear, uniform_qconfig_8bit)]} |
| |
| |
| class MiscTests(torch._dynamo.test_case.TestCase): |
| def test_boolarg(self): |
| def boolarg(aa, bb, flag): |
| if flag: |
| return aa - bb |
| else: |
| return bb - aa |
| |
| a = torch.randn(10, 10) |
| b = torch.randn(10, 10) |
| correct1 = boolarg(a, b, True) |
| correct2 = boolarg(a, b, False) |
| correct3 = boolarg(a, b, None) |
| counter = CompileCounter() |
| opt_boolarg = torch._dynamo.optimize_assert(counter)(boolarg) |
| val1 = opt_boolarg(a, b, True) |
| val2 = opt_boolarg(a, b, False) |
| val3 = opt_boolarg(a, b, None) |
| val4 = opt_boolarg(a, b, True) |
| self.assertTrue(same(val1, correct1)) |
| self.assertTrue(same(val2, correct2)) |
| self.assertTrue(same(val3, correct3)) |
| self.assertTrue(same(val4, correct1)) |
| self.assertEqual(counter.frame_count, 3) |
| |
| def test_callpacked(self): |
| def call_packed(args): |
| a, b, c = args |
| return a - b * c |
| |
| counter = CompileCounter() |
| a = torch.randn(10, 10) |
| b = torch.randn(10, 10) |
| c = torch.randn(10, 10) |
| correct = call_packed([a, b, c]) |
| opt_call_packed = torch._dynamo.optimize_assert(counter)(call_packed) |
| val1 = opt_call_packed([a, b, c]) |
| val2 = opt_call_packed((a, b, c)) |
| val3 = opt_call_packed([a, b, c]) |
| val4 = opt_call_packed((a, b, c)) |
| self.assertTrue(same(val1, correct)) |
| self.assertTrue(same(val2, correct)) |
| self.assertTrue(same(val3, correct)) |
| self.assertTrue(same(val4, correct)) |
| self.assertEqual(counter.frame_count, 2) |
| |
| def test_raises(self): |
| def fn(a, b, c, cls): |
| x = a + b - c * 10 |
| raise cls(str(x)) |
| |
| counter = CompileCounter() |
| a = torch.randn(10, 10) |
| b = torch.randn(10, 10) |
| c = torch.randn(10, 10) |
| opt_fn = torch._dynamo.optimize(counter)(fn) |
| self.assertRaises(AssertionError, lambda: opt_fn(a, b, c, AssertionError)) |
| self.assertEqual(counter.frame_count, 1) |
| self.assertEqual(counter.op_count, 3) |
| |
| def test_inplace(self): |
| def inplace1(a, b): |
| o = torch.empty((10, 10)) |
| o.copy_(a) |
| o -= b |
| return o |
| |
| torch._dynamo.testing.standard_test(self, inplace1, 2, expected_ops=3) |
| |
| def test_unpack4(self): |
| def unpack4(a, b): |
| a = a[:5, :] |
| b = b[:5, :] |
| x, y = a.size() |
| o = torch.empty((x, y)) |
| o.copy_(a / b) |
| return o |
| |
| torch._dynamo.testing.standard_test( |
| self, unpack4, 2, expected_ops=5, expected_ops_dynamic=8 |
| ) |
| |
| def test_unpack5(self): |
| def unpack5(a, b): |
| a = a[:5, :] |
| b = b[:5, :] |
| x, y = a.shape |
| o = torch.empty((x, y)) |
| o.copy_(a / b) |
| return o |
| |
| torch._dynamo.testing.standard_test( |
| self, unpack5, 2, expected_ops=5, expected_ops_dynamic=8 |
| ) |
| |
| def test_matmul1(self): |
| def matmul_op1(a, b): |
| return a @ b |
| |
| # TODO(jansel): FX doesn't support this, should add upstream support |
| torch._dynamo.testing.standard_test(self, matmul_op1, 2, expected_ops=1) |
| |
| def test_int_shape_binops(self): |
| def fn(x): |
| # Test reversal by putting int arg first. |
| y = 15 - x.shape[0] |
| y = 4 + y |
| y = 5 * y |
| y = 2 % y |
| y = 3**y |
| y = 10 // y |
| y = pow(2, y) |
| y = 10 / y |
| return x + y |
| |
| torch._dynamo.testing.standard_test( |
| self, fn, 1, expected_ops=1, expected_ops_dynamic=11 |
| ) |
| |
| def test_shape_int_inplace_binops(self): |
| def fn(x): |
| p = x.shape[0] |
| p += 2 |
| p -= 2 |
| p **= 2 |
| p /= 2 |
| p *= 2 |
| p //= 2 |
| p %= 2 |
| return x + p |
| |
| torch._dynamo.testing.standard_test( |
| self, fn, 1, expected_ops=1, expected_ops_dynamic=10 |
| ) |
| |
| def test_int_shape_inplace_binops(self): |
| def fn(x): |
| p = x.shape[0] |
| # Test reversal by putting constant first |
| y = 2 |
| y += p |
| y = 2 |
| y -= p |
| y = 2 |
| y **= p |
| y = 2 |
| y /= p |
| y = 2 |
| y *= p |
| y = 2 |
| y //= p |
| y = 2 |
| y %= p |
| return x + y |
| |
| torch._dynamo.testing.standard_test( |
| self, fn, 1, expected_ops=1, expected_ops_dynamic=10 |
| ) |
| |
| def test_int_int_comparisons(self): |
| def fn(x): |
| if 2 != 2: |
| out = 1 |
| elif 2 < 1: |
| out = 1 |
| elif 1 > 2: |
| out = 1 |
| elif 1 >= 2: |
| out = 1 |
| elif 2 <= 1: |
| out = 1 |
| elif 2 == 2: |
| out = 2 |
| else: |
| out = 1 |
| return x + out |
| |
| torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1) |
| |
| def test_shape_int_comparisons(self): |
| def fn(x): |
| a = x.shape[0] |
| # Ensure support for constant on right side |
| if a != 10: |
| out = 1 |
| elif a < 2: |
| out = 1 |
| elif a > 12: |
| out = 1 |
| elif a >= 12: |
| out = 1 |
| elif a <= 2: |
| out = 1 |
| elif a == 10: |
| out = 2 |
| else: |
| out = 1 |
| return x + out |
| |
| # expect for dynamic: size, index, 6 comparison ops, add |
| torch._dynamo.testing.standard_test( |
| self, fn, 1, expected_ops=1, expected_ops_dynamic=9 |
| ) |
| |
| def test_int_shape_comparisons(self): |
| def fn(x): |
| a = x.shape[0] |
| # Ensure support for constant on left side |
| if 10 != a: |
| out = 1 |
| elif 12 < a: |
| out = 1 |
| elif 2 > a: |
| out = 1 |
| elif 2 >= a: |
| out = 1 |
| elif 12 <= a: |
| out = 1 |
| elif 10 == a: |
| out = 2 |
| else: |
| out = 1 |
| return x + out |
| |
| # expect for dynamic: size, index, 6 comparison ops, add |
| torch._dynamo.testing.standard_test( |
| self, fn, 1, expected_ops=1, expected_ops_dynamic=9 |
| ) |
| |
| def test_param_shape_binops(self): |
| class MyModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.param = torch.nn.Parameter(torch.randn(15)) |
| |
| def forward(self, x): |
| # Test reversal by putting param shape arg first. |
| p = self.param.shape[0] |
| y = p - x.shape[0] |
| y = p + y |
| y = p * y |
| y = p % y |
| y = p**y |
| y = p // y |
| y = pow(p, y) |
| y = p / y |
| return x + y |
| |
| counts = torch._dynamo.testing.CompileCounter() |
| mod = MyModule() |
| optimized_mod = torch._dynamo.optimize(counts, nopython=True)(mod) |
| |
| x = torch.randn(3) |
| ref = mod(x) |
| res = optimized_mod(x) |
| |
| self.assertTrue(same(ref, res)) |
| self.assertEqual(counts.frame_count, 1) |
| expected_op_count = 13 if torch._dynamo.testing.config.dynamic_shapes else 1 |
| self.assertEqual(counts.op_count, expected_op_count) |
| |
| def test_user_defined_binop(self): |
| class MyClass: |
| def __init__(self, value): |
| self.value = value |
| |
| def __radd__(self, other): |
| return self.value + other |
| |
| def fn(x, c): |
| y = x.shape[0] + c |
| return x + y |
| |
| counts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(counts)(fn) |
| |
| x = torch.randn(3) |
| c = MyClass(4) |
| ref = fn(x, c) |
| res = opt_fn(x, c) |
| |
| self.assertTrue(same(ref, res)) |
| self.assertEqual(counts.frame_count, 1) |
| expected_op_count = 4 if torch._dynamo.testing.config.dynamic_shapes else 1 |
| self.assertEqual(counts.op_count, expected_op_count) |
| |
| def test_compare_shapes_eq(self): |
| def compare_shapes(a, b, to_list): |
| x = list(a.unsqueeze(-1).shape) if to_list else a.shape |
| y = list(b.unsqueeze(-1).shape) if to_list else b.shape |
| if x == y: |
| return a + 1 |
| else: |
| return a + 2 |
| |
| # Test both ListVariable and ShapeVariable |
| torch._dynamo.testing.standard_test( |
| self, lambda a, b: compare_shapes(a, b, to_list=True), 2 |
| ) |
| torch._dynamo.testing.standard_test( |
| self, lambda a, b: compare_shapes(a, b, to_list=False), 2 |
| ) |
| |
| def test_compare_shapes_tuple_eq(self): |
| def compare_shapes(a, b): |
| x = tuple(a.unsqueeze(-1).shape) |
| y = tuple(b.unsqueeze(-1).shape) |
| if x == y: |
| return a + 1 |
| else: |
| return a + 2 |
| |
| torch._dynamo.testing.standard_test(self, lambda a, b: compare_shapes(a, b), 2) |
| |
| def test_compare_shapes_tuple_neq(self): |
| def compare_shapes(a, b): |
| x = tuple(a.unsqueeze(-1).shape) |
| y = tuple(b.unsqueeze(-1).shape) |
| if x != y: |
| return a + 1 |
| else: |
| return a + 2 |
| |
| torch._dynamo.testing.standard_test(self, lambda a, b: compare_shapes(a, b), 2) |
| |
| def test_compare_shapes_neq(self): |
| def compare_shapes(a, b, to_list): |
| x = list(a.unsqueeze(-1).shape) if to_list else a.shape |
| y = list(b.unsqueeze(-1).shape) if to_list else b.shape |
| if x != y: |
| return a + 1 |
| else: |
| return a + 2 |
| |
| # Test both ListVariable and ShapeVariable |
| torch._dynamo.testing.standard_test( |
| self, lambda a, b: compare_shapes(a, b, to_list=True), 2 |
| ) |
| torch._dynamo.testing.standard_test( |
| self, lambda a, b: compare_shapes(a, b, to_list=False), 2 |
| ) |
| |
| @patch.object(torch._dynamo.config, "dynamic_shapes", True) |
| def test_compare_shapes_with_constant(self): |
| def compare_shapes(a): |
| x = a.shape |
| if x[0] != 3: |
| return a * 4 |
| return a * 3 |
| |
| guard_failure = None |
| |
| def guard_failures(failure): |
| nonlocal guard_failure |
| guard_failure = failure |
| |
| opt_fn = torch._dynamo.optimize( |
| "eager", nopython=True, guard_fail_fn=guard_failures |
| )(compare_shapes) |
| opt_fn(torch.randn([3, 4])) |
| opt_fn(torch.randn([4, 3])) |
| self.assertExpectedInline( |
| guard_failure.reason, |
| """tensor 'L['a']' size mismatch at index 0. expected 3, actual 4""", |
| ) |
| |
| def test_builtin_isinstance(self): |
| def fn(x): |
| t = torch.arange(1, 3) |
| a = isinstance(x, torch.Tensor) |
| b = isinstance(t, torch.Tensor) |
| c = isinstance(x, int) |
| d = isinstance(3, int) |
| e = isinstance([1, 2, 3], list) |
| f = isinstance({"foo": 1, "bar": 2}, dict) |
| res = [a, b, c, d, e, f] |
| # Can't run yet due to other unimplemented instructions |
| # res += [isinstance(torch.nn.LazyLinear(2, 3), torch.nn.Linear)] |
| return res |
| |
| torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1) |
| |
| def test_fold(self): |
| def fn(a): |
| return a + math.sqrt(63) |
| |
| torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1) |
| |
| def test_shape_unpack(self): |
| def fn(x): |
| a, b = x.size() |
| return x * b |
| |
| i = torch.randn(5, 10) |
| r1 = fn(i) |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| r2 = opt_fn(i) |
| self.assertTrue(same(r1, r2)) |
| |
| def test_tensor_iter(self): |
| def fn(x): |
| for y in x: |
| y.add_(1.0) |
| return y |
| |
| # expect extra size node for dynamic |
| torch._dynamo.testing.standard_test( |
| self, fn, 1, expected_ops=20, expected_ops_dynamic=21 |
| ) |
| |
| def test_empty_list(self): |
| def fn(x, ll): |
| if len(ll) == 0 and not ll and ll is not None: |
| return x + 1 |
| |
| i = torch.randn(5, 10) |
| r1 = fn(i, []) |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| r2 = opt_fn(i, []) |
| r3 = opt_fn(i, tuple()) |
| self.assertTrue(same(r1, r2)) |
| self.assertTrue(same(r1, r3)) |
| |
| def test_min_max_over_iterable(self): |
| def get_test_fn(func): |
| def _fn(a, b, func=func): |
| # try all of list, iterator, tuple, vararg. |
| lst = [a.shape[0] + 1, 8, a.shape[0]] |
| x = func(lst) |
| y = func(iter(lst)) |
| z = func(tuple(lst)) |
| w = func(*lst) |
| return a + (x + y + z + w) |
| |
| return _fn |
| |
| # expect for dynamic: |
| # 2 * (size, getitem) ops + |
| # 1 add op + |
| # 4 * 2 min / max ops + |
| # 4 final add ops = 17 |
| torch._dynamo.testing.standard_test( |
| self, get_test_fn(func=min), 2, expected_ops=1, expected_ops_dynamic=17 |
| ) |
| torch._dynamo.testing.standard_test( |
| self, get_test_fn(func=max), 2, expected_ops=1, expected_ops_dynamic=17 |
| ) |
| |
| def test_config_obj(self): |
| class Cfg: |
| def __init__(self): |
| self.val = 0.5 |
| self.count = 3 |
| |
| def fn(x, cfg): |
| for i in range(cfg.count): |
| x = x + cfg.val |
| return x |
| |
| cfg1 = Cfg() |
| cfg1.val = 1.0 |
| cfg2 = Cfg() |
| v = torch.zeros(1) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| v = opt_fn(v, cfg1) # 3 |
| v = opt_fn(v, cfg2) # 4.5 |
| cfg2.count = 1 |
| v = opt_fn(v, cfg2) # 5 |
| cfg2.val = 2.0 |
| v = opt_fn(v, cfg2) # 7 |
| self.assertEqual(v[0], 7) |
| self.assertEqual(cnts.op_count, 8) |
| |
| def test_config_getattr_default(self): |
| class Cfg: |
| def __init__(self): |
| self.val = 0.5 |
| self.count = 10 |
| |
| def fn(x, cfg): |
| if getattr(cfg, "just_add_7", False): |
| return x + 7 |
| for i in range(cfg.count): |
| x = x + cfg.val |
| return x |
| |
| cfg1 = Cfg() |
| v = torch.zeros(1) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertEqual(opt_fn(v, cfg1)[0], 5) |
| self.assertEqual(opt_fn(v, cfg1)[0], 5) |
| cfg1.just_add_7 = True |
| self.assertEqual(opt_fn(v, cfg1)[0], 7) |
| self.assertEqual(opt_fn(v, cfg1)[0], 7) |
| cfg1.just_add_7 = False |
| self.assertEqual(opt_fn(v, cfg1)[0], 5) |
| self.assertEqual(opt_fn(v, cfg1)[0], 5) |
| self.assertEqual(cnts.frame_count, 3) |
| |
| def test_size_input(self): |
| def fn(x, s): |
| a, b = s |
| return x + (a - b) |
| |
| v = torch.zeros(10, 20) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertEqual(opt_fn(v, v.size())[0, 0], -10) |
| self.assertEqual(opt_fn(v, (10, 20))[0, 0], -10) |
| self.assertEqual(opt_fn(v, [10, 20])[0, 0], -10) |
| # One recompile per differing input type |
| self.assertEqual(cnts.frame_count, 3) |
| |
| def test_cell_output1(self): |
| out = None |
| |
| def fn(a, b): |
| nonlocal out |
| out = a + b * 10 |
| |
| v = torch.Tensor([100]) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertIsNone(opt_fn(v, v)) |
| self.assertEqual(out[0], 1100) |
| self.assertEqual(cnts.op_count, 2) |
| |
| def test_cell_output2(self): |
| out = None |
| |
| def fn(a, b): |
| nonlocal out |
| c = unsupported(a, b) |
| out = a + b * 10 + c |
| |
| v = torch.Tensor([100]) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertIsNone(opt_fn(v, v)) |
| self.assertEqual(out[0], 1200) |
| self.assertEqual(cnts.op_count, 3) |
| |
| def test_return_nested_function(self): |
| out = None |
| |
| def fn(a, b): |
| nonlocal out |
| c = a + b |
| d = a + 1.0 |
| |
| def fn2(f: int = 7, g: float = 9.0): |
| nonlocal out |
| out = a + b * 10 |
| return c * f - d * g |
| |
| return fn2 |
| |
| v1 = torch.Tensor([100]) |
| v2 = torch.Tensor([200]) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| opt_fn_ret = torch._dynamo.optimize(cnts)(opt_fn(v1, v2)) |
| self.assertEqual(opt_fn_ret(1.5)[0], -459) |
| self.assertEqual(out[0], 2100) |
| self.assertEqual(cnts.frame_count, 2) |
| self.assertEqual(cnts.op_count, 7) |
| |
| def test_tensor_dict1(self): |
| def fn(inputs): |
| return inputs["a"] - inputs["b"] * 1.5 |
| |
| v1 = torch.Tensor([100]) |
| v2 = torch.Tensor([200]) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertEqual(opt_fn({"a": v1, "b": v2})[0], -200) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 2) |
| |
| def test_tensor_dict2(self): |
| def fn1(inputs): |
| total = torch.zeros(1) |
| for k, v in inputs.items(): |
| total += v |
| return total |
| |
| def fn2(inputs): |
| total = torch.zeros(1) |
| for v in inputs.values(): |
| total += v |
| return total |
| |
| def fn3(inputs): |
| total = torch.zeros(1) |
| for k in inputs.keys(): |
| total += inputs[k] |
| return total |
| |
| v1 = torch.Tensor([100]) |
| v2 = torch.Tensor([200]) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn1 = torch._dynamo.optimize(cnts)(fn1) |
| opt_fn2 = torch._dynamo.optimize(cnts)(fn2) |
| opt_fn3 = torch._dynamo.optimize(cnts)(fn3) |
| self.assertEqual(opt_fn1({"a": v1, "b": v2})[0], 300) |
| self.assertEqual(opt_fn2({"a": v1, "b": v2})[0], 300) |
| self.assertEqual(opt_fn3({"a": v1, "b": v2})[0], 300) |
| self.assertEqual(cnts.frame_count, 3) |
| self.assertEqual(cnts.op_count, 9) |
| |
| def test_dictcomp(self): |
| def fn1(inputs): |
| return {k: v + 1 for k, v in inputs.items()} |
| |
| v1 = torch.Tensor([100]) |
| v2 = torch.Tensor([200]) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn1 = torch._dynamo.optimize(cnts)(fn1) |
| self.assertEqual(opt_fn1({"a": v1, "b": v2})["a"], 101) |
| self.assertEqual(opt_fn1({"a": v1, "b": v2})["b"], 201) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 2) |
| |
| def test_listcomp(self): |
| def fn2(inputs): |
| return torch.sum(torch.cat([v + 1 for k, v in inputs.items()], 0)) |
| |
| v1 = torch.Tensor([100]) |
| v2 = torch.Tensor([200]) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn2 = torch._dynamo.optimize(cnts)(fn2) |
| self.assertEqual(opt_fn2({"a": v1, "b": v2}), 302) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 4) |
| |
| def test_is_floating_point(self): |
| def fn(a, b): |
| x = a + 1.0 |
| if torch.is_floating_point(b): |
| x = x + b |
| return x + 2.0 |
| |
| return torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3) |
| |
| def test_is_floating_point2(self): |
| def fn(a, b): |
| x = a + 1.0 |
| if b.is_floating_point(): |
| x = x + b |
| return x + 2.0 |
| |
| return torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3) |
| |
| def test_is_tensor(self): |
| def fn(a, b): |
| x = a + 1.0 |
| if torch.is_tensor(b): |
| x = x + b |
| return x + 2.0 |
| |
| return torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3) |
| |
| def test_is_tensor2(self): |
| def fn(x): |
| if torch.is_tensor(x): |
| return x + 1 |
| else: |
| return torch.ones([2, 3]) |
| |
| x1 = {"input": torch.rand(2, 3)} |
| x2 = torch.rand(2, 3) |
| ref1 = fn(x1) |
| ref2 = fn(x2) |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| res1 = opt_fn(x1) |
| res2 = opt_fn(x2) |
| self.assertEqual(ref1, res1) |
| self.assertEqual(ref2, res2) |
| |
| def test_numel(self): |
| def fn(a): |
| return (a + a.numel() + torch.numel(a), a + a.nelement()) |
| |
| return torch._dynamo.testing.standard_test( |
| self, fn=fn, nargs=1, expected_ops=3, expected_ops_dynamic=6 |
| ) |
| |
| def test_pair(self): |
| def fn(a): |
| return ( |
| torch.zeros(torch.nn.modules.utils._pair(a.size())) |
| + a |
| + torch.ones(torch.nn.modules.utils._ntuple(3)(3)).sum() |
| ) |
| |
| return torch._dynamo.testing.standard_test( |
| self, fn=fn, nargs=1, expected_ops=5, expected_ops_dynamic=8 |
| ) |
| |
| @patch.object(torch._dynamo.config, "dynamic_shapes", True) |
| @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) |
| def test_tensor_item_capture(self): |
| def fn(a, b): |
| return (a + b).sum().item() |
| |
| v1 = torch.randn((10, 10)) |
| v2 = torch.randn((10, 10)) |
| correct = fn(v1, v2) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize((cnts))(fn) |
| self.assertEqual(opt_fn(v1, v2), correct) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 3) |
| |
| @patch.object(torch._dynamo.config, "dynamic_shapes", True) |
| @patch.object(torch._dynamo.config, "capture_scalar_outputs", False) |
| def test_tensor_item_no_capture(self): |
| def fn(a, b): |
| return (a + b).sum().item() |
| |
| v1 = torch.randn((10, 10)) |
| v2 = torch.randn((10, 10)) |
| correct = fn(v1, v2) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize((cnts))(fn) |
| self.assertEqual(opt_fn(v1, v2), correct) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 2) |
| |
| def test_namedtuple1(self): |
| def fn(a, b): |
| tmp = mytuple(a, b, a + b) |
| return mytuple(tmp.a, tmp[1], tmp.ab + b) |
| |
| v1 = torch.Tensor([10]) |
| v2 = torch.Tensor([20]) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertEqual(opt_fn(v1, v2).ab, 50) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 2) |
| |
| def test_namedtuple2(self): |
| def fn(packed): |
| a, b, c = packed |
| if hasattr(packed, "b"): |
| b = packed.b + 1 |
| c = packed[2] |
| return a + b + c |
| |
| v1 = torch.Tensor([1]) |
| v2 = torch.Tensor([2]) |
| v3 = torch.Tensor([3]) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertEqual(opt_fn(mytuple(v1, v2, v3))[0], 7) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 3) |
| |
| def test_namedtuple3(self): |
| def fn(x, packed): |
| if isinstance(packed, mytuple): |
| return x + 1 |
| else: |
| return x - 1 |
| |
| x = torch.rand([2, 3]) |
| packed = mytuple(1, 2, 3) |
| ref = fn(x, packed) |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| res = opt_fn(x, packed) |
| self.assertTrue(same(ref, res)) |
| |
| def test_range_input(self): |
| def fn(a, rng): |
| x = a |
| for i in rng: |
| x = x + i |
| return x |
| |
| def fn1(a): |
| return fn(a, rng=range(3)) |
| |
| return torch._dynamo.testing.standard_test( |
| self, fn=fn1, nargs=1, expected_ops=3 |
| ) |
| |
| def test_range_with_shape(self): |
| def fn(a): |
| for i in range(1, a.shape[0]): |
| a += 1 |
| return a |
| |
| # expect 1 more op (size call) for dynamic |
| return torch._dynamo.testing.standard_test( |
| self, fn=fn, nargs=1, expected_ops=9, expected_ops_dynamic=10 |
| ) |
| |
| def test_build_tuple_unpack(self): |
| def fn1(a, b, c): |
| return a - b / c |
| |
| def fn2(a, b, c): |
| tmp1 = (a,) |
| tmp2 = (b, c) |
| args = (*tmp1, *tmp2) |
| return fn1(*args) |
| |
| def fn3(a, *args): |
| return fn1(a, *args) |
| |
| torch._dynamo.testing.standard_test(self, fn=fn2, nargs=3, expected_ops=2) |
| torch._dynamo.testing.standard_test(self, fn=fn3, nargs=3, expected_ops=2) |
| |
| def test_list_mul(self): |
| def fn(count): |
| head_mask = count * [None] * count |
| return head_mask |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertEqual(opt_fn(2), [None] * 4) |
| # TODO: the captured frame here is a bit goofy, because we don't |
| # output anything and none of the traced operations have side |
| # effects. Probably need better heuristic for bailing on |
| # dynamo if there are no outputs |
| self.assertEqual(cnts.frame_count, ifunspec(1, 0)) |
| self.assertEqual(cnts.op_count, ifunspec(2, 0)) |
| |
| def test_list_slice_mul(self): |
| def fn(count): |
| a = [1, 2, 3] |
| head_mask = count * a[1:] * count |
| return head_mask |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertEqual(opt_fn(2), [2, 3] * 4) |
| self.assertEqual(cnts.frame_count, ifunspec(1, 0)) |
| self.assertEqual(cnts.op_count, ifunspec(14, 0)) |
| |
| def test_tuple_mul(self): |
| def fn(count): |
| head_mask = count * (2, 3) * count |
| return head_mask |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertEqual(opt_fn(2), (2, 3) * 4) |
| self.assertEqual(cnts.frame_count, ifunspec(1, 0)) |
| self.assertEqual(cnts.op_count, ifunspec(14, 0)) |
| |
| def test_tuple_mul_with_shape(self): |
| def fn(a): |
| x = a.shape[0] |
| y = 2 * (x, 3) * 2 |
| return a + y[4] |
| |
| # expect 3 ops post folding for dynamic case: size, index, add |
| torch._dynamo.testing.standard_test( |
| self, fn, 1, expected_ops=1, expected_ops_dynamic=3 |
| ) |
| |
| def test_tuple_iadd_with_shape(self): |
| def fn(a): |
| output = (a + a.shape[0], a - a.shape[0]) |
| # tuple += tuple |
| output += (a - a.shape[0], a + a.shape[0]) |
| # tuple += constant tuple |
| output += (2, 3) |
| return output |
| |
| # expect 4 add / subs for static, 4 * 3 (size, index, math op) for dynamic |
| torch._dynamo.testing.standard_test( |
| self, fn, 1, expected_ops=4, expected_ops_dynamic=12 |
| ) |
| |
| def test_list_iadd_with_shape(self): |
| def fn(a): |
| output = [a + a.shape[0], a - a.shape[0]] |
| # list += list |
| output += [a - a.shape[0], a + a.shape[0]] |
| # list += tuple |
| output += (a + a.shape[0], a - a.shape[0]) |
| return output |
| |
| # expect 6 add / subs for static, 6 * 3 (size, index, math op) for dynamic |
| torch._dynamo.testing.standard_test( |
| self, fn, 1, expected_ops=6, expected_ops_dynamic=18 |
| ) |
| |
| def test_user_getattr1(self): |
| class MyConfig(dict): |
| def __getattr__(self, name): |
| return self[name] |
| |
| def fn(cfg, x, y): |
| return x + y + cfg.offset |
| |
| x = torch.randn(10) |
| cfg = MyConfig(offset=5) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertTrue(same(opt_fn(cfg, x, x), 2 * x + 5)) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 2) |
| |
| def test_user_getattr2(self): |
| class MyConfig: |
| defined_on_class = 1 |
| |
| def __init__(self): |
| self.defined_on_object = 2 |
| |
| def __getattr__(self, name): |
| return 3 |
| |
| def fn(cfg, x): |
| return x + cfg.defined_on_class - cfg.defined_on_object + cfg.not_defined |
| |
| x = torch.randn(10) |
| cfg = MyConfig() |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertTrue(same(opt_fn(cfg, x), x + 1 - 2 + 3)) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 3) |
| |
| def test_user_getattribute(self): |
| class MyObject: |
| def __init__(self): |
| self.custom_dict = {"a": torch.rand((2, 2))} |
| self.my_number = 42 |
| |
| def __getattribute__(self, name): |
| custom_dict = super().__getattribute__("custom_dict") |
| if name in custom_dict: |
| return custom_dict[name] |
| return super().__getattribute__(name) |
| |
| def run(self, x): |
| return self.my_number * x + self.a * x |
| |
| def fn(obj, x): |
| return obj.run(x) |
| |
| obj = MyObject() |
| x = torch.rand((2, 2)) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertTrue(same(opt_fn(obj, x), fn(obj, x))) |
| |
| def test_nn_module_getattr(self): |
| class MyMod(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.custom_dict = {"queue": [torch.rand((2, 2)) for _ in range(3)]} |
| self.other_attr = torch.rand((2, 2)) |
| |
| def __getattr__(self, name): |
| custom_dict = self.custom_dict |
| if name in custom_dict: |
| return custom_dict[name] |
| return super().__getattr__(name) |
| |
| def forward(self, x): |
| return x @ self.other_attr + self.queue[-1] |
| |
| x = torch.rand((2, 2)) |
| mod = MyMod() |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_mod = torch._dynamo.optimize(cnts)(mod) |
| self.assertTrue(same(opt_mod(x), mod(x))) |
| self.assertTrue(cnts.frame_count, 1) |
| self.assertTrue(cnts.op_count, 2) |
| |
| def test_nn_module_getattribute(self): |
| class MyMod(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.my_number = 42 |
| |
| def __getattribute__(self, name): |
| if name == "special_attr": |
| return torch.tensor([[1, 2], [3, 4]]) |
| return super().__getattribute__(name) |
| |
| def forward(self, x): |
| return self.my_number * x + self.special_attr * x |
| |
| def fn(mod, x): |
| return mod(x) |
| |
| mod = MyMod() |
| x = torch.rand((2, 2)) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertTrue(same(opt_fn(mod, x), fn(mod, x))) |
| |
| def test_constant_getattr(self): |
| # https://github.com/pytorch/pytorch/issues/97480 |
| def fn(): |
| return getattr(None, "arg", 3) |
| |
| cnt = torch._dynamo.testing.CompileCounter() |
| optimized_fn = torch._dynamo.optimize(cnt)(fn) |
| res = optimized_fn() |
| self.assertTrue(same(res, 3)) |
| |
| def test_user_property(self): |
| class MyConfig: |
| @property |
| def prop5(self): |
| return 5 |
| |
| def fn(cfg, x, y): |
| return x + y + cfg.prop5 |
| |
| x = torch.randn(10) |
| cfg = MyConfig() |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertTrue(same(opt_fn(cfg, x, x), 2 * x + 5)) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 2) |
| |
| def test_dataclass_fields(self): |
| @dataclasses.dataclass |
| class MyDataClass: |
| a: torch.Tensor |
| b: torch.Tensor = None |
| c: torch.Tensor = None |
| d: torch.Tensor = None |
| e: torch.Tensor = None |
| |
| def fn(obj): |
| class_fields = dataclasses.fields(obj) |
| assert len(class_fields) |
| assert all(field.default is None for field in class_fields[1:]) |
| other_fields_are_none = all( |
| getattr(obj, field.name) is None for field in class_fields[1:] |
| ) |
| assert not other_fields_are_none |
| |
| total = getattr(obj, class_fields[0].name) |
| for field in class_fields[1:]: |
| v = getattr(obj, field.name) |
| if v is not None: |
| total += v |
| |
| return total |
| |
| obj1 = MyDataClass(torch.randn(10), torch.randn(10), torch.randn(10)) |
| obj2 = MyDataClass(torch.randn(10), e=torch.randn(10)) |
| correct1 = fn(obj1) |
| correct2 = fn(obj2) |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertTrue(same(opt_fn(obj1), correct1)) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 2) |
| |
| torch._dynamo.reset() |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertTrue(same(opt_fn(obj2), correct2)) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 1) |
| |
| @requires_static_shapes |
| def test_tensor_build_list_unpack(self): |
| def fn(x): |
| # seen in fastNLP_Bert |
| return torch.cat([*x], dim=-1) |
| |
| val = torch.randn([1, 1, 473, 768]) |
| correct = fn(val) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertTrue(same(opt_fn(val), correct)) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 2) |
| |
| def test_numpy_int_constant(self): |
| def fn(x, a, b): |
| return x + (a % b) |
| |
| args = [torch.randn(10), 4096, np.int64(8)] |
| correct = fn(*args) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertTrue(same(opt_fn(*args), correct)) |
| self.assertTrue(same(opt_fn(*args), correct)) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 2) |
| |
| def test_inplace_resize_on_graph_input(self): |
| cnts = torch._dynamo.testing.CompileCounter() |
| |
| # graph break when calling resize_() on graph input |
| def f1(x): |
| x.resize_(6) |
| x.mul_(2) |
| return x |
| |
| @torch.compile(backend=cnts) |
| def f2(x): |
| x.resize_(6) |
| x.mul_(2) |
| return x |
| |
| x = torch.ones(4) |
| y = torch.ones(4) |
| self.assertTrue(same(f1(x).shape, f2(y).shape)) |
| self.assertEqual(cnts.frame_count, 0) |
| |
| def test_dict_mutation_side_effect(self): |
| def fn(d): |
| d["c"] = d["a"] + d.pop("b") |
| return d |
| |
| args1 = {"a": torch.randn(10), "b": torch.randn(10)} |
| args2 = dict(args1) |
| assert fn(args1) is args1 |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertIs(opt_fn(args2), args2) |
| self.assertTrue(same(args1, args2)) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 1) |
| |
| def test_module_deepcopy(self): |
| m1 = torch.nn.Sequential( |
| torch.nn.Linear(10, 10), |
| torch.nn.ReLU(), |
| torch.nn.Linear(10, 10), |
| torch.nn.ReLU(), |
| ) |
| m2 = torch.nn.Sequential( |
| torch.nn.Linear(10, 10), |
| torch.nn.ReLU(), |
| torch.nn.Linear(10, 10), |
| torch.nn.ReLU(), |
| ) |
| |
| def fn(m, x): |
| m_copy = copy.deepcopy(m) |
| return m_copy(x) |
| |
| v = torch.randn(10) |
| correct1 = fn(m1, v) |
| correct2 = fn(m2, v) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| for _ in range(10): |
| self.assertTrue(same(opt_fn(m1, v), correct1)) |
| for _ in range(10): |
| self.assertTrue(same(opt_fn(m2, v), correct2)) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 4) |
| |
| def test_type_copy(self): |
| def fn(seq): |
| a, b = seq |
| return type(seq)([a + 1, b + 2, a + b]) |
| |
| args1 = [torch.randn(10), torch.randn(10)] |
| args2 = (torch.randn(10), torch.randn(10)) |
| correct1 = fn(args1) |
| correct2 = fn(args2) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertTrue(same(opt_fn(args1), correct1)) |
| self.assertTrue(same(opt_fn(args2), correct2)) |
| self.assertIsInstance(opt_fn(args1), list) |
| self.assertIsInstance(opt_fn(args2), tuple) |
| self.assertEqual(cnts.frame_count, 2) |
| self.assertEqual(cnts.op_count, 6) |
| |
| def test_setattr_mutation1(self): |
| class MyObj: # noqa: B903 |
| def __init__(self, a, b): |
| self.a = a |
| self.b = b |
| |
| def fn(obj): |
| obj.c = obj.a * obj.b + 1 |
| obj.b = obj.a * obj.c + 2 |
| obj.a = obj.b * obj.c + 3 |
| obj.c = obj.a * obj.b + 4 |
| obj.b = obj.a * obj.c + 5 |
| obj.a = obj.b * obj.c + 6 |
| return obj |
| |
| x1 = torch.randn(10) |
| x2 = torch.randn(10) |
| obj1 = MyObj(x1, x2) |
| obj2 = MyObj(x1, x2) |
| fn(obj2) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| self.assertIs(opt_fn(obj1), obj1) |
| self.assertTrue(same(obj1.a, obj2.a)) |
| self.assertTrue(same(obj1.b, obj2.b)) |
| self.assertTrue(same(obj1.c, obj2.c)) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 12) |
| |
| def test_setattr_mutation2(self): |
| class MyObj: |
| def __init__(self, x): |
| self.a = x + 1 |
| self.b = x + 2 |
| |
| def fn(x): |
| x = x / 3.0 |
| obj = MyObj(x) |
| obj.c = obj.a * obj.b + 1 |
| obj.b = obj.a * obj.c + 2 |
| obj.a = obj.b * obj.c + 3 |
| return obj |
| |
| x1 = torch.randn(10) |
| obj2 = fn(x1) |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| obj1 = opt_fn(x1) |
| self.assertTrue(same(obj1.a, obj2.a)) |
| self.assertTrue(same(obj1.b, obj2.b)) |
| self.assertTrue(same(obj1.c, obj2.c)) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 9) |
| |
| def test_setattr_mutation3(self): |
| # TODO(jansel): dead code eliminate the object creation |
| class MyObj: |
| def __init__(self, x): |
| super().__init__() |
| self.a = x + 1 |
| self.b = x + 2 |
| |
| def fn(x): |
| x = x / 3.0 |
| obj = MyObj(x) |
| obj.c = obj.a * obj.b + 1 |
| obj.b = obj.a * obj.c + 2 |
| obj.a = obj.b * obj.c + 3 |
| return obj.a, obj.b, obj.c |
| |
| x1 = torch.randn(10) |
| obj2 = fn(x1) |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| obj1 = opt_fn(x1) |
| self.assertTrue(same(obj1, obj2)) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 9) |
| |
| def test_user_defined_class_name(self): |
| class MyClassFoo: |
| pass |
| |
| def fn1(a, b, c): |
| tmp = MyClassFoo() |
| if tmp.__class__.__name__ == "MyClassFoo": |
| return a - b / c |
| |
| torch._dynamo.testing.standard_test(self, fn=fn1, nargs=3) |
| |
| def test_user_defined_class_python_type(self): |
| class MyClass1: |
| pass |
| |
| class ExampleMeta(type): |
| pass |
| |
| class MyClass2(metaclass=ExampleMeta): |
| pass |
| |
| def fn(x, c): |
| if isinstance(c, MyClass1): |
| return x + 1 |
| elif isinstance(c, MyClass2): |
| return x + 2 |
| else: |
| return x + 3 |
| |
| x = torch.rand(3) |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| for c in [MyClass1, MyClass2]: |
| ref = fn(x, c) |
| res = opt_fn(x, c) |
| self.assertTrue(same(ref, res)) |
| |
| def test_super_calling_with_metaclass(self): |
| class ExampleMeta(type): |
| pass |
| |
| class MyClass1(metaclass=ExampleMeta): |
| @classmethod |
| def add(cls, x): |
| return x + 1 |
| |
| class MyClass2(MyClass1): |
| @classmethod |
| def add(cls, x): |
| torch._dynamo.graph_break() |
| return x + super().add(x) |
| |
| def fn(x, obj): |
| return x + obj.add(x) |
| |
| x = torch.rand(3) |
| obj = MyClass2() |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| ref = fn(x, obj) |
| res = opt_fn(x, obj) |
| self.assertTrue(same(ref, res)) |
| |
| def test_manual_seed(self): |
| def fn(a, b): |
| x = a + b |
| torch.manual_seed(9000) |
| return x + 1 |
| |
| torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3) |
| |
| def test_usr_cls_staticmethod(self): |
| class Foo: |
| @staticmethod |
| def bar(a, b): |
| return a + b |
| |
| def fn(a, b): |
| return Foo.bar(a, b) - 1 |
| |
| torch._dynamo.testing.standard_test(self, fn=fn, nargs=2) |
| |
| def test_usr_cls_classmethod(self): |
| class Foo: |
| @classmethod |
| def bar(cls, a, b): |
| return a + b |
| |
| def fn(a, b): |
| return Foo.bar(a, b) - 1 |
| |
| torch._dynamo.testing.standard_test(self, fn=fn, nargs=2) |
| |
| def test_dunder_methods(self): |
| class Foo: |
| def __init__(self, val): |
| super().__init__() |
| self.val = val |
| |
| def __add__(self, other): |
| return Foo(self.val + other.val) |
| |
| def __mul__(self, other): |
| return Foo(self.val * other.val) |
| |
| def __truediv__(self, other): |
| return Foo(self.val / other.val) |
| |
| def __sub__(self, other): |
| return Foo(self.val - other.val) |
| |
| def fn(a, b, c): |
| return Foo(a) + Foo(b) * Foo(c) / Foo(a) - Foo(b) |
| |
| torch._dynamo.testing.standard_test(self, fn=fn, nargs=3, expected_ops=4) |
| |
| def test_function_annotation(self): |
| class Variable: |
| pass |
| |
| def fn(x): |
| x = x / 3.0 |
| |
| def inner(y: typing.List[Variable]): |
| return x + 1 |
| |
| return inner |
| |
| x1 = torch.randn(10) |
| obj2 = fn(x1)([]) |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize_assert(cnts)(fn) |
| opt_fn_inner = torch._dynamo.optimize_assert(cnts)(opt_fn(x1)) |
| obj1 = opt_fn_inner([]) |
| self.assertTrue(same(obj1, obj2)) |
| self.assertEqual(cnts.frame_count, 2) |
| self.assertEqual(cnts.op_count, 2) |
| |
| def test_nested_closure(self): |
| v0 = torch.randn(10) |
| |
| def fn1(): |
| v1 = torch.randn(10) |
| |
| def fn2(*args, **kwargs): |
| assert len(args) == 1 |
| assert len(kwargs) == 1 |
| v2 = torch.randn(10) + args[0] + kwargs["b"] |
| |
| def fn3(v3=torch.randn(10)): |
| def fn4(): |
| return v0 + v1 + v2 + v3 + 1 |
| |
| return fn4 |
| |
| return fn3 |
| |
| return fn2(1, b=2)() |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn1 = torch._dynamo.optimize_assert(cnts)(fn1) |
| tmp1 = torch._dynamo.optimize_assert(cnts)(opt_fn1()) |
| tmp2 = torch._dynamo.optimize_assert(cnts)(opt_fn1()) |
| self.assertTrue(tmp1().shape, (10,)) |
| self.assertTrue(same(tmp1(), tmp1())) |
| self.assertFalse(same(tmp1(), tmp2())) |
| self.assertEqual(cnts.frame_count, 2) |
| self.assertEqual(cnts.op_count, 9) |
| |
| def test_nested_closure_mutation(self): |
| def fn1(): |
| v1 = torch.randn(10) |
| |
| def fn2(): |
| v2 = torch.randn(10) |
| |
| def fn3(): |
| nonlocal v1, v2 |
| v1 += 1 |
| v2 += 2 |
| return v1 + v2 |
| |
| return fn3 |
| |
| rv = fn2() |
| rv() |
| rv() |
| return rv |
| |
| torch.manual_seed(9000) |
| counter1 = fn1() |
| result1 = [counter1(), counter1(), counter1()] |
| |
| torch.manual_seed(9000) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn1 = torch._dynamo.optimize_assert(cnts)(fn1) |
| counter2 = torch._dynamo.optimize_assert(cnts)(opt_fn1()) |
| result2 = [counter2(), counter2(), counter2()] |
| result1.append(counter1()) |
| result2.append(counter2()) |
| |
| self.assertTrue(same(result1, result2)) |
| self.assertEqual(cnts.frame_count, 2) |
| self.assertEqual(cnts.op_count, 11) |
| |
| def test_write_to_closures_in_inlining(self): |
| out = [] |
| for use_dynamo in [False, True]: |
| |
| def make_counter(): |
| x = torch.randn(10) |
| |
| def counter(): |
| nonlocal x |
| x = x + 1 |
| return x |
| |
| return counter |
| |
| torch.manual_seed(0) |
| counter = make_counter() |
| if not use_dynamo: |
| out.append(counter() + counter()) |
| else: |
| cnts = torch._dynamo.testing.CompileCounter() |
| |
| @torch._dynamo.optimize(cnts, nopython=True) |
| def fn(counter): |
| return counter() + counter() |
| |
| out.append(fn(counter)) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 3) |
| self.assertFalse(same(counter() + counter(), out[-1])) |
| |
| self.assertTrue(same(out[0], out[1])) |
| |
| def test_top_package_import(self): |
| def fn(x): |
| import torch.fx |
| |
| assert not isinstance(x, torch.fx.Proxy) |
| return torch.sin(x) |
| |
| x = torch.randn(4, 5) |
| ref = fn(x) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize_assert(cnts)(fn) |
| res = opt_fn(x) |
| self.assertTrue(same(ref, res)) |
| |
| def test_typing_union_and_optional(self): |
| def fn(x): |
| a = torch.jit.annotate(typing.Dict[str, typing.Optional[torch.Tensor]], {}) |
| b = torch.jit.annotate( |
| typing.Dict[str, typing.Union[torch.Tensor, None]], {} |
| ) |
| return a, b, x + 1 |
| |
| x = torch.randn(3) |
| ref = fn(x) |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| res = opt_fn(x) |
| self.assertTrue(same(ref, res)) |
| |
| def test_optimize_on_module(self): |
| class MockModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.relu = torch.nn.ReLU() |
| |
| def custom_member(self): |
| # Just for checking that Dynamo returned mod object can redirect |
| # to this method |
| pass |
| |
| def forward(self, x): |
| return self.relu(x) |
| |
| cnts1 = torch._dynamo.testing.CompileCounter() |
| mod = MockModule() |
| optimized_mod = torch._dynamo.optimize(cnts1, nopython=True)(mod) |
| |
| a = torch.randn(10) |
| ref = mod(a) |
| res = optimized_mod(a) |
| |
| optimized_mod.custom_member() |
| |
| self.assertTrue(same(ref, res)) |
| |
| def test_nested_optimize_decorator(self): |
| cnts2 = torch._dynamo.testing.CompileCounter() |
| cnts3 = torch._dynamo.testing.CompileCounter() |
| |
| @torch._dynamo.run() |
| def fn1(x): |
| return torch.sin(x) * 10 |
| |
| @torch._dynamo.optimize(cnts2, nopython=True) |
| def fn2(x): |
| return fn1(x) + 1 |
| |
| @torch._dynamo.optimize(cnts3, nopython=True) |
| def fn3(x): |
| return torch.relu(fn2(x)) |
| |
| fn3(torch.randn(4, 5)) |
| self.assertEqual(cnts2.frame_count, 0) |
| self.assertEqual(cnts3.frame_count, 1) |
| self.assertEqual(cnts3.op_count, 4) |
| |
| def test_nested_optimize_run(self): |
| cnts = torch._dynamo.testing.CompileCounter() |
| |
| @torch._dynamo.optimize(cnts, nopython=True) |
| def fn(x): |
| return torch.relu(torch.cos(x) + torch.sin(x)) |
| |
| fn(torch.randn(4)) |
| self.assertEqual(cnts.frame_count, 1) |
| |
| fn(torch.randn(4, 4)) |
| self.assertEqual(cnts.frame_count, 2) |
| |
| # Test that run works on a decorated fn |
| fn = torch._dynamo.run(fn) |
| fn(torch.randn(4, 4, 4)) |
| self.assertEqual(cnts.frame_count, 2) |
| |
| def test_nested_optimize(self): |
| cnts1 = torch._dynamo.testing.CompileCounter() |
| cnts2 = torch._dynamo.testing.CompileCounter() |
| |
| def fn(x): |
| return torch.relu(torch.cos(x) + torch.sin(x)) |
| |
| fn1 = torch._dynamo.optimize(cnts1, nopython=True)(fn) |
| fn2 = torch._dynamo.optimize(cnts2, nopython=True)(fn1) |
| |
| # The first optimize in the nesting should be ignored |
| fn2(torch.randn(4)) |
| self.assertEqual(cnts2.frame_count, 1) |
| self.assertEqual(cnts1.frame_count, 0) |
| |
| # Since the fn code object is already compiled, calling fn1 should |
| # directly call the compiled_fn callable. |
| torch._dynamo.run()(fn1)(torch.randn(4)) |
| self.assertEqual(cnts1.frame_count, 0) |
| |
| # Test same behavior by reversing the calls |
| torch._dynamo.reset() |
| cnts1 = torch._dynamo.testing.CompileCounter() |
| cnts2 = torch._dynamo.testing.CompileCounter() |
| fn1 = torch._dynamo.optimize(cnts1, nopython=True)(fn) |
| fn2 = torch._dynamo.optimize(cnts2, nopython=True)(fn1) |
| fn1(torch.randn(4)) |
| self.assertEqual(cnts1.frame_count, 1) |
| torch._dynamo.run()(fn2)(torch.randn(4)) |
| self.assertEqual(cnts2.frame_count, 0) |
| |
| def test_torch_size(self): |
| cnts = torch._dynamo.testing.CompileCounter() |
| |
| def fn(x): |
| output_size = torch.Size([10, 10]) |
| x = x.view(*output_size) |
| return (x,) |
| |
| x = torch.randn(100, requires_grad=True) |
| x_clone = x.clone() |
| ref = fn(x) |
| |
| opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) |
| res = opt_fn(x_clone) |
| |
| self.assertTrue(same(ref, res)) |
| |
| def test_size_dim(self): |
| cnts = torch._dynamo.testing.CompileCounter() |
| |
| def fn(x, dim): |
| return x.size(dim=dim) |
| |
| opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) |
| x = torch.empty([4, 9, 8]) |
| self.assertTrue(opt_fn(x, 1) == 9) |
| self.assertTrue(opt_fn(x, -2) == 9) |
| |
| def test_stride_dim(self): |
| cnts = torch._dynamo.testing.CompileCounter() |
| |
| def fn(x, dim): |
| return x.stride(dim=dim) |
| |
| opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) |
| x = torch.empty([4, 9, 8]) |
| self.assertTrue(opt_fn(x, 0) == 72) |
| self.assertTrue(opt_fn(x, -2) == 8) |
| |
| def test_torch_seed(self): |
| cnts = torch._dynamo.testing.CompileCounter() |
| |
| def fn(x): |
| attention_seed = int(torch.seed() % sys.maxsize) |
| torch.manual_seed(attention_seed) |
| return (x,) |
| |
| x = torch.randn(100, requires_grad=True) |
| ref = fn(x) |
| |
| opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) |
| res = opt_fn(x) |
| |
| self.assertTrue(same(ref, res)) |
| |
| def test_is_tensor_like(self): |
| cnts = torch._dynamo.testing.CompileCounter() |
| |
| def f(x): |
| if torch.overrides.is_tensor_like(x): |
| return (x * 2,) |
| return (torch.ones(10) + x,) |
| |
| x = torch.randn(10) |
| ref0 = f(x) |
| ref1 = f(4) |
| opt_f = torch._dynamo.optimize(cnts, nopython=True)(f) |
| res0 = opt_f(x) |
| res1 = opt_f(4) |
| self.assertTrue(same(ref0, res0)) |
| self.assertTrue(same(ref1, res1)) |
| |
| def test_is_tensor_like2(self): |
| class MyTensor: |
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| if kwargs is None: |
| kwargs = {} |
| |
| if func is torch.max: |
| return torch.tensor(123) |
| return func(*args, **kwargs) |
| |
| def fn(x): |
| if torch.overrides.is_tensor_like(x): |
| return torch.max(x) |
| else: |
| return torch.zeros(1) |
| |
| x = MyTensor() |
| ref0 = fn(x) |
| ref1 = fn(4) |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| res0 = opt_fn(x) |
| res1 = opt_fn(4) |
| self.assertTrue(same(ref0, res0)) |
| self.assertTrue(same(ref1, res1)) |
| |
| def test_tensor_data(self): |
| def fn(x, y): |
| return x[y.data] |
| |
| x = torch.rand(8) |
| y = torch.ones(8).to(torch.int) |
| ref = fn(x, y) |
| opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) |
| res = opt_fn(x, y) |
| self.assertTrue(same(ref, res)) |
| |
| def test_tensor_layout(self): |
| def fn(x): |
| return torch.zeros( |
| [x.size()[0], x.size()[1]], |
| dtype=x.dtype, |
| layout=x.layout, |
| device=x.device, |
| ) |
| |
| x = torch.rand(2, 3) |
| ref = fn(x) |
| opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) |
| res = opt_fn(x) |
| self.assertTrue(same(ref, res)) |
| |
| def test_version_ci(self): |
| # temporary test to check that the ci torch version is set correctly |
| self.assertTrue(hasattr(torch, "_subclasses")) |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") |
| def test_rand(self): |
| cnts = torch._dynamo.testing.CompileCounter() |
| device = "cuda" |
| |
| def fn(): |
| return torch.randn(10, device=device) |
| |
| torch.manual_seed(10) |
| ref_run1 = fn() |
| |
| torch.manual_seed(10) |
| ref_run2 = fn() |
| self.assertTrue(same(ref_run1, ref_run2)) |
| |
| torch.manual_seed(10) |
| opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) |
| res = opt_fn() |
| |
| self.assertTrue(same(res, ref_run1)) |
| |
| def test_slice_input(self): |
| cnts = torch._dynamo.testing.CompileCounter() |
| |
| def getitem(a, idx): |
| if isinstance(idx, slice): |
| return ( |
| torch.zeros(1), |
| a[idx] |
| + [ |
| 100, |
| ], |
| ) |
| else: |
| return (torch.zeros(1), a[idx]) |
| |
| layers = list(range(10)) |
| ref0 = getitem(layers, slice(0, 2, 1)) |
| ref1 = getitem(layers, 2) |
| ref2 = getitem(layers, slice(3, 8, 2)) |
| opt_getitem = torch._dynamo.optimize(cnts, nopython=True)(getitem) |
| res0 = opt_getitem(layers, slice(0, 2, 1)) |
| res1 = opt_getitem(layers, 2) |
| res2 = opt_getitem(layers, slice(3, 8, 2)) |
| |
| self.assertTrue(ref0 == res0) |
| self.assertTrue(ref1 == res1) |
| self.assertTrue(ref2 == res2) |
| |
| def test_grad(self): |
| cnts = torch._dynamo.testing.CompileCounter() |
| |
| def fn(a, b): |
| out = a * b |
| out.sum().backward() |
| real_out = torch.sigmoid(a.grad + b) |
| return real_out |
| |
| inps = [torch.randn(4, requires_grad=True) for _ in range(2)] |
| for inp in inps: |
| inp.grad = None |
| ref = fn(*inps) |
| |
| for inp in inps: |
| inp.grad = None |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| res = opt_fn(*inps) |
| |
| self.assertTrue(same(ref, res)) |
| |
| @skipIfNotPy311 |
| def test_linetable_311_writer1(self): |
| def fn(): |
| a = 10 |
| b = 20 |
| c = a + b |
| f = "linetable_writer" |
| return f"Test if {f} generates correct co_linetable: {c}" |
| |
| # Dynamo doesn't deal with column locations or end line numbers, |
| # so we only check that start line numbers in the linetables match. |
| keys = bytecode_transformation.get_code_keys() |
| code_options = {k: getattr(fn.__code__, k) for k in keys} |
| result = bytecode_transformation.clean_and_assemble_instructions( |
| bytecode_transformation.cleaned_instructions(fn.__code__), |
| keys, |
| code_options, |
| ) |
| l1, l2 = list(fn.__code__.co_positions()), list(result[1].co_positions()) |
| self.assertEqual(len(l1), len(l2)) |
| for p1, p2 in zip(l1, l2): |
| # check that start line numbers match |
| self.assertEqual(p1[0], p2[0]) |
| self.assertEqual(fn.__code__.co_lnotab, result[1].co_lnotab) |
| |
| @skipIfNotPy311 |
| def test_linetable_311_writer2(self): |
| """ |
| test large ops (LOAD_METHOD) and EXTENDED_ARGS |
| fn_str is in the form: |
| def fn(): |
| ... |
| x0 = 1 |
| x1 = 1 |
| ... |
| l = [x0, x1, ...] |
| """ |
| fn_str = f"""\ |
| def fn(): |
| foo.bar(1, 2, 3) |
| {str(chr(10)).join(' ' * 4 + 'x' + str(i) + ' = 1' for i in range(1 << 9))} |
| l = [{str(' ').join('x' + str(i) + ',' for i in range(1 << 9))}] |
| """ |
| locals = {} |
| exec(fn_str, {}, locals) |
| fn = locals["fn"] |
| orig_inst_str = "\n".join(list(map(str, dis.get_instructions(fn)))) |
| self.assertIn("EXTENDED_ARG", orig_inst_str) |
| self.assertIn("LOAD_METHOD", orig_inst_str) |
| keys = bytecode_transformation.get_code_keys() |
| code_options = {k: getattr(fn.__code__, k) for k in keys} |
| result = bytecode_transformation.clean_and_assemble_instructions( |
| bytecode_transformation.cleaned_instructions(fn.__code__), |
| keys, |
| code_options, |
| ) |
| new_inst_str = "\n".join(list(map(str, result[0]))) |
| self.assertIn("EXTENDED_ARG", new_inst_str) |
| self.assertIn("LOAD_METHOD", new_inst_str) |
| l1, l2 = list(fn.__code__.co_positions()), list(result[1].co_positions()) |
| self.assertEqual(len(l1), len(l2)) |
| for p1, p2 in zip(l1, l2): |
| # check that start line numbers match |
| self.assertEqual(p1[0], p2[0]) |
| self.assertEqual(fn.__code__.co_lnotab, result[1].co_lnotab) |
| |
| @unittest.skipIf( |
| sys.version_info < (3, 10) or sys.version_info >= (3, 11), |
| "linetable test for Python 3.10", |
| ) |
| def test_linetable_310_writer(self): |
| def fn(): |
| a = 10 |
| b = 20 |
| c = a + b |
| f = "linetable_writer" |
| return f"Test if {f} generates correct co_linetable: {c}" |
| |
| inst = dis.get_instructions(fn) |
| result = bytecode_transformation.assemble(inst, fn.__code__.co_firstlineno) |
| self.assertTrue(result[1] == fn.__code__.co_linetable) |
| |
| @unittest.skipIf(sys.version_info >= (3, 10), "use lnotab when python < 3.10") |
| def test_lnotab_writer(self): |
| def fn(): |
| a = 10 |
| b = 20 |
| c = a + b |
| f = "lnotab_writer" |
| return f"Test if {f} generates correct co_lnotab: {c}" |
| |
| inst = dis.get_instructions(fn) |
| result = bytecode_transformation.assemble(inst, fn.__code__.co_firstlineno) |
| self.assertTrue(result[1] == fn.__code__.co_lnotab) |
| |
| def test_profiler_cache_lookup(self): |
| def fn(x): |
| y = x**2 |
| y = y + 2 |
| z = y**3 |
| return z |
| |
| for profiler, get_events in ( |
| (torch.autograd.profiler.profile, lambda prof: prof.function_events), |
| (torch.profiler.profiler.profile, lambda prof: prof.events()), |
| ): |
| x = torch.randn((2, 2), requires_grad=True) |
| ref = fn(x) |
| opt_fn = torch.compile(fn, backend="aot_eager") |
| |
| # warmup |
| opt_fn(x) |
| |
| # whenver we enter the profiler context, hooks are automatically registered |
| with profiler() as prof: |
| res = opt_fn(x) |
| events = list( |
| filter( |
| lambda event: event.name == "TorchDynamo Cache Lookup", |
| get_events(prof), |
| ) |
| ) |
| |
| self.assertTrue(same(ref, res)) |
| self.assertTrue( |
| len(events) == 1, |
| "Expected one lookup profiler event for one opt_fn run", |
| ) |
| |
| with profiler() as prof: |
| # just make sure the disable functionality works |
| _enable_dynamo_cache_lookup_profiler(False) |
| res = opt_fn(x) |
| events = list( |
| filter( |
| lambda event: event.name == "TorchDynamo Cache Lookup", |
| get_events(prof), |
| ) |
| ) |
| |
| self.assertTrue(same(ref, res)) |
| self.assertTrue(len(events) == 0, "Expected disabled profiling") |
| |
| def test_tensor_is_contiguous(self): |
| def fn(x): |
| input = torch.randn((1, 16, 1, 1)) |
| weight = torch.randn((8, 16, 3, 3)) |
| weight = weight.to(memory_format=x) |
| output = torch.conv2d(input, weight, None, (2, 1), (1, 1), (1, 1), 1) |
| return output.is_contiguous(memory_format=x) |
| |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| for x in [torch.contiguous_format, torch.channels_last]: |
| self.assertEqual(fn(x), opt_fn(x)) |
| |
| def test_python_slice(self): |
| def f1(input): |
| y = 0 |
| for i, x in enumerate(input[2:], 1): |
| y = y + x |
| return y |
| |
| def f2(input): |
| y = 0 |
| for i, x in enumerate(input.shape[2:], 1): |
| y = y + x |
| return y |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_f1 = torch._dynamo.optimize(cnts)(f1) |
| opt_f2 = torch._dynamo.optimize(cnts)(f2) |
| res1 = opt_f1([1, 2, 3, 5]) |
| res2 = opt_f2(torch.rand([2, 3, 4, 5])) |
| |
| self.assertEqual(res1, 8) |
| self.assertEqual(res2, 9) |
| |
| def test_enum_as_dict_key(self): |
| class MyEnum(enum.Enum): |
| FOO = 10 |
| BAR = 20 |
| |
| def fn(x): |
| y = x + 2 |
| z = { |
| MyEnum.FOO: torch.tensor(1), |
| MyEnum.BAR: 10, |
| "MyEnum.BAR": torch.tensor(8), |
| 5: torch.rand(3), |
| } |
| torch._dynamo.graph_break() |
| a = z[MyEnum.FOO] + z["MyEnum.BAR"] |
| b = y * 2 |
| return a, b |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| for _ in range(10): |
| x = torch.rand(3) |
| ref = fn(x) |
| res = opt_fn(x) |
| self.assertTrue(same(ref, res)) |
| self.assertEqual(cnts.frame_count, 2) |
| |
| def test_const_dict_variable_python_type(self): |
| from torch._dynamo.variables import ConstantVariable, ConstDictVariable |
| |
| d1 = {"a": ConstantVariable(10), "b": ConstantVariable(20)} |
| d2 = collections.OrderedDict( |
| [("x", ConstantVariable(12)), ("y", ConstantVariable(22))] |
| ) |
| self.assertEqual(ConstDictVariable(d1, dict).python_type(), dict) |
| self.assertEqual( |
| ConstDictVariable(d2, collections.OrderedDict).python_type(), |
| collections.OrderedDict, |
| ) |
| |
| def test_builtin_subclasses_as_method_on_class_type(self): |
| class Foo: |
| def __init__(self, name): |
| self.ame_ = name |
| |
| def get_name(self): |
| return "Foo " + self.name_ |
| |
| class Bar(Foo): |
| def __init__(self, name): |
| self.name_ = name |
| |
| def get_name(self): |
| return "Bar " + self.name_ |
| |
| class Baz(Foo): |
| def __init__(self, name): # noqa: B903 |
| self.name_ = name |
| |
| def get_name(self): |
| return "Baz " + self.name_ |
| |
| subs_of_foo_reg = Foo.__subclasses__() |
| |
| counter = CompileCounter() |
| |
| @torch._dynamo.optimize_assert(counter) |
| def fn(): |
| return Foo.__subclasses__() |
| |
| subs_of_foo_optim = fn() |
| |
| self.assertEqual(len(subs_of_foo_reg), 2) |
| self.assertEqual(subs_of_foo_reg, subs_of_foo_optim) |
| |
| def test_builtin_subclasses_as_method_on_var(self): |
| class Foo: |
| def __init__(self, name): |
| self.name_ = name |
| |
| def get_name(self): |
| return "Foo " + self.name_ |
| |
| class Bar(Foo): |
| def __init__(self, name): |
| self.name_ = name |
| |
| def get_name(self): |
| return "Bar " + self.name_ |
| |
| class Baz(Bar): |
| def __init__(self, name): |
| self.name_ = name |
| |
| def get_name(self): |
| return "Baz " + self.name_ |
| |
| subs_of_foo_reg = Foo.__subclasses__() |
| sub_of_foo_subclass_var_reg = subs_of_foo_reg[0].__subclasses__() |
| |
| sub_of_foo_subclass_var_optim = list() |
| counter = CompileCounter() |
| |
| @torch._dynamo.optimize_assert(counter) |
| def fn(): |
| return Foo.__subclasses__() |
| |
| @torch._dynamo.optimize_assert(counter) |
| def fn_single(subs_of_foo_optim): |
| return subs_of_foo_optim[0].__subclasses__() |
| |
| subs_of_foo_optim = fn() |
| sub_of_foo_subclass_var_optim = fn_single(subs_of_foo_optim) |
| |
| self.assertEqual(len(sub_of_foo_subclass_var_optim), 1) |
| self.assertEqual(sub_of_foo_subclass_var_optim, sub_of_foo_subclass_var_reg) |
| |
| def test_enum_no_graphbreaks(self): |
| class Foo(enum.Enum): |
| FOO = 0 |
| BAR = 1 |
| |
| def fn(x, foo): |
| if foo is Foo.FOO: |
| x = torch.add(x, 1.0) |
| x = torch.mul(x, 1.0) |
| return x |
| |
| x = torch.randn(1) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) |
| opt_fn(x, Foo.FOO) |
| self.assertEqual(cnts.op_count, 2) |
| |
| torch._dynamo.reset() |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) |
| opt_fn(x, Foo.BAR) |
| self.assertEqual(cnts.op_count, 1) |
| |
| def test_id_of_nn_module(self): |
| class M(torch.nn.Module): |
| def forward(self, x, ref_id): |
| self_id = id(self) |
| if self_id == ref_id: |
| x = torch.mul(x, 1.0) |
| x = torch.add(x, 1.0) |
| return x |
| |
| m = M().eval() |
| data = torch.randn(1) |
| cnts = torch._dynamo.testing.CompileCounter() |
| correct_ref_id = id(m) |
| opt_m = torch._dynamo.optimize(cnts, nopython=True)(m) |
| opt_m(data, correct_ref_id) |
| # Extra op is the recorded equality test (although once |
| # the trace is flattened this is dead!) |
| self.assertEqual(cnts.op_count, ifunspec(3, 2)) |
| |
| torch._dynamo.reset() |
| cnts = torch._dynamo.testing.CompileCounter() |
| incorrect_ref_id = id(m) + 1 |
| opt_m = torch._dynamo.optimize(cnts, nopython=True)(m) |
| opt_m(data, incorrect_ref_id) |
| self.assertEqual(cnts.op_count, ifunspec(2, 1)) |
| |
| def test_inline_func_jump_on_tensor_condition(self): |
| def f1(input): |
| if input == 0: |
| return input + 1 |
| else: |
| return input + 2 |
| |
| def f2(input): |
| return f1(input) |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_f2 = torch._dynamo.optimize(cnts)(f2) |
| res1 = opt_f2(torch.tensor([1.0])) |
| res2 = opt_f2(torch.tensor([0.0])) |
| |
| self.assertEqual(res1, 3) |
| self.assertEqual(res2, 1) |
| |
| def test_frozenset_torch_func_contains(self): |
| funcs = frozenset([torch.add]) |
| |
| def fn(x, func): |
| if func in funcs: |
| x = torch.add(x, 1.0) |
| x = torch.mul(x, 1.0) |
| return x |
| |
| x = torch.randn(1) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) |
| opt_fn(x, torch.add) |
| self.assertEqual(cnts.op_count, 2) |
| |
| torch._dynamo.reset() |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) |
| opt_fn(x, torch.mul) |
| self.assertEqual(cnts.op_count, 1) |
| |
| def test_inline_list_mutation(self): |
| def f1(x): |
| x.append(torch.ones(8)) |
| return x |
| |
| def f2(): |
| x = [torch.ones(6)] |
| f1(x) |
| return x |
| |
| res1 = f2() |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_f2 = torch._dynamo.optimize(cnts)(f2) |
| res2 = opt_f2() |
| self.assertTrue(same(res1, res2)) |
| |
| def test_inline_dict_mutation(self): |
| def f1(d): |
| d["c"] = d["a"] + d.pop("b") |
| return d |
| |
| def f2(): |
| d = {"a": torch.ones(5), "b": torch.ones(5)} |
| f1(d) |
| return d |
| |
| res1 = f2() |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_f2 = torch._dynamo.optimize(cnts)(f2) |
| res2 = opt_f2() |
| self.assertTrue(same(res1, res2)) |
| |
| def test_recursive_inline_list_mutation(self): |
| def f1(x, y): |
| x.append(torch.tensor([1.1])) |
| y.append(torch.tensor([1.2])) |
| return x, y |
| |
| def f2(x, y): |
| x.append(torch.tensor([2.1])) |
| y.append(torch.tensor([2.2])) |
| f1(x, y) |
| return x, y |
| |
| def f3(x): |
| x.append(torch.tensor([3.1])) |
| y = [torch.tensor([3.2])] |
| f2(x, y) |
| return x, y |
| |
| def f4(): |
| x = [torch.tensor([4.1])] |
| return f3(x) |
| |
| res1 = f4() |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_f4 = torch._dynamo.optimize(cnts)(f4) |
| res2 = opt_f4() |
| self.assertTrue(same(res1, res2)) |
| |
| def test_sample_input(self): |
| from torch.testing._internal.common_methods_invocations import SampleInput |
| |
| def fn(sample): |
| if isinstance(sample.input, torch.Tensor): |
| return sample.input * 2 |
| return torch.zeros(()) |
| |
| sample = SampleInput(torch.ones(2)) |
| ref = fn(sample) |
| |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| res = opt_fn(sample) |
| |
| self.assertTrue(same(ref, res)) |
| |
| def test_release_input_memory(self): |
| x = torch.rand([4]) |
| x_ref = weakref.ref(x) |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| |
| @torch._dynamo.optimize(cnts) |
| def foo(x): |
| return x + x |
| |
| out = foo(x) |
| self.assertTrue(same(out, x + x)) |
| del x |
| self.assertIs(x_ref(), None) |
| |
| def test_release_module_memory(self): |
| mod = torch.nn.Linear(10, 10) |
| x = torch.rand([10, 10]) |
| mod_weight_ref = weakref.ref(mod.weight) |
| mod_ref = weakref.ref(mod) |
| |
| # Modules that are passed into torch._dynamo optimized functions |
| # will normally be held onto through the generated GraphModule, |
| # which contains the modules. remove the reference in this backend |
| # and test that no additional references are being held. |
| class NoLeakBackend: |
| def __call__(self, gm: torch.fx.GraphModule, example_inputs): |
| gm.mod = None |
| |
| def foo(*args, **kwargs): |
| return (1,) |
| |
| return foo |
| |
| no_leak_backend = NoLeakBackend() |
| |
| @torch._dynamo.optimize(no_leak_backend) |
| def foo(mod, x): |
| return mod(x) |
| |
| foo(mod, x) |
| del mod |
| del x |
| self.assertIsNone(mod_ref(), None) |
| self.assertIsNone(mod_weight_ref(), None) |
| |
| def test_update_locals_and_stack_uses_shared_cache(self): |
| def fn(x): |
| perm = [0, 3, 5] |
| perm = list(range(min(perm))) + perm |
| perm.extend(i for i in range(x.dim()) if i not in perm) |
| return perm |
| |
| x = torch.rand([2, 2, 2, 2, 2, 2]) |
| res1 = fn(x) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| res2 = opt_fn(x) |
| self.assertTrue(same(res1, res2)) |
| |
| def test_dict_reconstruct_keeps_original_order(self): |
| def fn(): |
| modules = collections.OrderedDict([("act", torch.nn.ReLU())]) |
| module_dict = torch.nn.ModuleDict(modules) |
| |
| next_modules = {"fc4": torch.nn.Linear(5, 6), "act3": torch.nn.Sigmoid()} |
| modules.update(next_modules.items()) |
| module_dict.update(next_modules) |
| return modules, module_dict |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| modules, module_dict = opt_fn() |
| |
| self.assertEqual(len(module_dict), len(modules)) |
| for k1, m2 in zip(modules, module_dict.children()): |
| self.assertTrue(modules[k1] is m2) |
| |
| def test_side_effects_codegen_update_mutated(self): |
| # codegen to update mutated variables with side effect |
| # should after stack value's codegen |
| def f1(x): |
| alist = [x] |
| alist.append(x + 1) |
| alist[0].sum().item() # graph break |
| res = alist.pop() |
| res.sum().item() # graph break |
| return res |
| |
| def f2(a, b): |
| d = {"a": a + 1, "b": b + 2} |
| x = d.pop("b") |
| x.sum().item() # graph break |
| y = d["a"] + x |
| y.sum().item() # graph break |
| d["c"] = y |
| return d |
| |
| x = torch.rand([2, 3]) |
| a = torch.rand([5, 6]) |
| b = torch.rand([5, 6]) |
| res11 = f1(x) |
| res21 = f2(a, b) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_f1 = torch._dynamo.optimize(cnts)(f1) |
| opt_f2 = torch._dynamo.optimize(cnts)(f2) |
| res12 = opt_f1(x) |
| res22 = opt_f2(a, b) |
| self.assertTrue(same(res11, res12)) |
| self.assertTrue(same(res21, res22)) |
| |
| def test_list_append_return_none(self): |
| def fn(x): |
| alist = [] |
| blist = alist.append(x + 1) |
| return alist, blist |
| |
| x = torch.tensor([2.3]) |
| res = fn(x) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| res2 = opt_fn(x) |
| self.assertEqual(res, res2) |
| |
| def test_tensor_types(self): |
| def fn(dtype, tensor_type): |
| x = torch.empty(4, dtype=dtype) |
| assert isinstance(x, tensor_type) |
| |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| opt_fn(torch.float32, torch.FloatTensor) |
| opt_fn(torch.float64, torch.DoubleTensor) |
| opt_fn(torch.float16, torch.HalfTensor) |
| opt_fn(torch.bfloat16, torch.BFloat16Tensor) |
| opt_fn(torch.uint8, torch.ByteTensor) |
| opt_fn(torch.int8, torch.CharTensor) |
| opt_fn(torch.int64, torch.LongTensor) |
| opt_fn(torch.int, torch.IntTensor) |
| opt_fn(torch.int16, torch.ShortTensor) |
| opt_fn(torch.bool, torch.BoolTensor) |
| |
| def test_nan(self): |
| def f(x, n): |
| return x * 2 + n |
| |
| x = torch.randn(4) |
| n = float("nan") |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_f = torch._dynamo.optimize(cnts)(f) |
| opt_f(x, n) |
| opt_f(x, n) |
| self.assertEqual(cnts.frame_count, 1) |
| |
| @patch.object(torch._dynamo.config, "dynamic_shapes", True) |
| @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) |
| def test_item(self): |
| class MyMod(torch.nn.Module): |
| def forward(self, x): |
| z = torch.max(x) |
| return z.int().item() |
| |
| x = torch.tensor([[10.6763, 11.7445, -2.2369]]) |
| model = MyMod() |
| y = torch._dynamo.optimize("eager", nopython=True)(model)(x) |
| |
| self.assertEqual(y, 11) |
| |
| @patch.object(torch._dynamo.config, "dynamic_shapes", True) |
| @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) |
| def test_item_changes(self): |
| class MyMod(torch.nn.Module): |
| def forward(self, x): |
| z = torch.max(x) |
| return z.int().item() |
| |
| x = torch.tensor([[10.6763, 11.7445, -2.2369]]) |
| model = MyMod() |
| opt_model = torch._dynamo.optimize("eager", nopython=True)(model) |
| y = opt_model(x) |
| z = opt_model(torch.tensor([[y - 5, y + 10, y + 50]])) |
| |
| self.assertEqual(y, 11) |
| self.assertEqual(z, 61) |
| |
| @patch.object(torch._dynamo.config, "dynamic_shapes", True) |
| @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) |
| def test_item_changes_new_shape(self): |
| class MyMod(torch.nn.Module): |
| def forward(self, x): |
| z = torch.max(x) |
| return z.int().item() |
| |
| x = torch.tensor([[10.6763, 11.7445, -2.2369]]) |
| model = MyMod() |
| opt_model = torch._dynamo.optimize("eager", nopython=True)(model) |
| y = opt_model(x) |
| z = opt_model(torch.tensor([[y - 5, y + 50], [y + 5, y - 50]])) |
| |
| self.assertEqual(y, 11) |
| self.assertEqual(z, 61) |
| |
| @unittest.skip("https://github.com/pytorch/pytorch/issues/99726") |
| def test_cross_entropy_loss_fancy_ctor1(self): |
| rand_5 = torch.randn(5) |
| rand_3_5 = torch.randn(3, 5) |
| target = torch.empty(3, dtype=torch.long).random_(5) |
| |
| loss = torch.nn.CrossEntropyLoss( |
| weight=rand_5, reduce=False, label_smoothing=0.5 |
| ) |
| opt_loss = torch._dynamo.optimize("eager", nopython=True)(loss) |
| input = rand_3_5 |
| dynamo_output = opt_loss(input, target) |
| |
| loss = torch.nn.CrossEntropyLoss( |
| weight=rand_5, reduce=False, label_smoothing=0.5 |
| ) |
| input = rand_3_5 |
| output = loss(input, target) |
| |
| self.assertTrue(torch.allclose(dynamo_output, output)) |
| |
| @requires_static_shapes |
| def test_cross_entropy_loss_fancy_ctor2(self): |
| rand_3_5 = torch.randn(3, 5) |
| target = torch.empty(3, dtype=torch.long).random_(5) |
| |
| loss = torch.nn.CrossEntropyLoss(reduce=False, label_smoothing=0.5) |
| opt_loss = torch._dynamo.optimize("eager", nopython=True)(loss) |
| input = rand_3_5 |
| dynamo_output = opt_loss(input, target) |
| |
| loss = torch.nn.CrossEntropyLoss(reduce=False, label_smoothing=0.5) |
| input = rand_3_5 |
| output = loss(input, target) |
| |
| self.assertTrue(torch.allclose(dynamo_output, output)) |
| |
| def test_cross_entropy_loss_simple_ctor(self): |
| output = None |
| rand_3_5 = torch.randn(3, 5) |
| target = torch.empty(3, dtype=torch.long).random_(5) |
| |
| loss = torch.nn.CrossEntropyLoss() |
| opt_loss = torch._dynamo.optimize("eager", nopython=True)(loss) |
| input = rand_3_5 |
| dynamo_output = opt_loss(input, target) |
| |
| loss = torch.nn.CrossEntropyLoss() |
| input = rand_3_5 |
| output = loss(input, target) |
| |
| self.assertTrue(torch.allclose(dynamo_output, output)) |
| |
| def test_nn_functional_reduction(self): |
| def fn(loss, reduction): |
| reduction_enum = F._Reduction.get_enum(reduction) |
| if reduction_enum == 0: |
| return loss |
| elif reduction_enum == 1: |
| return loss.mean() |
| elif reduction_enum == 2: |
| return loss.sum() |
| |
| x = torch.rand([3, 5]) |
| y = "mean" |
| ref = fn(x, y) |
| opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) |
| res = opt_fn(x, y) |
| self.assertTrue(torch.allclose(ref, res)) |
| |
| def test_large_reduction_list(self): |
| dtype = torch.float32 |
| device = "cpu" |
| |
| def check_sum_all(tensor: torch.Tensor) -> None: |
| pylist = tensor.reshape(-1).tolist() |
| self.assertTrue(same(tensor.sum(), torch.tensor(sum(pylist)))) |
| |
| check_sum_all(torch.randn(200000, dtype=dtype, device=device)) |
| |
| def test_raise_on_backend_error(self): |
| def my_compiler(gm, _): |
| raise RuntimeError("duck!") |
| |
| @torch._dynamo.optimize(my_compiler) |
| def fn(a, b): |
| return a + b / (a - b) |
| |
| self.assertRaises( |
| torch._dynamo.exc.BackendCompilerFailed, |
| lambda: fn(torch.randn(10), torch.randn(10)), |
| ) |
| |
| def test_named_parameters(self): |
| n_embd = 768 |
| block_size = 128 |
| vocab_size = 65 |
| embd_pdrop = 0.1 |
| |
| class MyModel2(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.tok_emb = torch.nn.Embedding(vocab_size, n_embd) |
| self.pos_emb = torch.nn.Parameter(torch.zeros(1, block_size, n_embd)) |
| self.drop = torch.nn.Dropout(embd_pdrop) |
| |
| def forward(self, x): |
| return x |
| |
| class MyModel(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.tok_emb = torch.nn.Embedding(vocab_size, n_embd) |
| self.pos_emb = torch.nn.Parameter(torch.zeros(1, block_size, n_embd)) |
| self.drop = torch.nn.Dropout(embd_pdrop) |
| self.submod2 = MyModel2() |
| |
| def forward(self, x): |
| return x |
| |
| # Regular |
| params = [] |
| mod = MyModel() |
| actual_params = list(mod.named_parameters()) |
| |
| @torch._dynamo.optimize("eager", nopython=True) |
| def fn(): |
| return list(mod.named_parameters()) |
| |
| params = fn() |
| |
| self.assertEqual(len(actual_params), len(params)) |
| for idx in range(len(params)): |
| k_a, v_a = actual_params[idx] |
| k, v = params[idx] |
| self.assertEqual(k_a, k) |
| self.assertTrue(torch.allclose(v_a, v)) |
| |
| # Prefix |
| params = [] |
| mod = MyModel() |
| actual_params = list(mod.named_parameters(prefix="foo")) |
| |
| @torch._dynamo.optimize("eager", nopython=True) |
| def fn1(): |
| return list(mod.named_parameters(prefix="foo")) |
| |
| params = fn1() |
| |
| self.assertEqual(len(actual_params), len(params)) |
| for idx in range(len(params)): |
| k_a, v_a = actual_params[idx] |
| k, v = params[idx] |
| self.assertEqual(k_a, k) |
| self.assertTrue(torch.allclose(v_a, v)) |
| |
| def test_module_complex_iter(self): |
| n_embd = 768 |
| block_size = 128 |
| vocab_size = 65 |
| embd_pdrop = 0.1 |
| |
| class FakeGPT(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.tok_emb = torch.nn.Embedding(vocab_size, n_embd) |
| self.pos_emb = torch.nn.Parameter(torch.zeros(1, block_size, n_embd)) |
| self.drop = torch.nn.Dropout(embd_pdrop) |
| self.ln_f = torch.nn.LayerNorm(n_embd) |
| self.head = torch.nn.Linear(n_embd, vocab_size, bias=False) |
| |
| self.block_size = block_size |
| self.names = [] |
| |
| def forward(self, idx, targets=None): |
| b, t = idx.size() |
| assert ( |
| t <= self.block_size |
| ), "Cannot forward, model block size is exhausted." |
| |
| # forward the GPT model |
| token_embeddings = self.tok_emb( |
| idx |
| ) # each index maps to a (learnable) vector |
| position_embeddings = self.pos_emb[ |
| :, :t, : |
| ] # each position maps to a (learnable) vector |
| x = self.drop(token_embeddings + position_embeddings) |
| x = self.blocks(x) |
| x = self.ln_f(x) |
| logits = self.head(x) |
| |
| # if we are given some desired targets also calculate the loss |
| loss = None |
| if targets is not None: |
| loss = F.cross_entropy( |
| logits.view(-1, logits.size(-1)), targets.view(-1) |
| ) |
| |
| return logits, loss |
| |
| def foo(self, memo=None, prefix="", remove_duplicate=False): |
| for mn, m in self.named_modules( |
| memo=memo, prefix=prefix, remove_duplicate=remove_duplicate |
| ): |
| for pn, p in self.named_parameters(): |
| fpn = "%s.%s" % (mn, pn) if mn else pn |
| self.names.append(fpn) |
| |
| # Test plain recurse |
| model_a = FakeGPT() |
| model_a.foo() |
| a_names = model_a.names |
| |
| model_b = FakeGPT() |
| opt_model_b = torch._dynamo.optimize("eager", nopython=True)(model_b) |
| opt_model_b.foo() |
| |
| self.assertEqual(a_names, model_b.names) |
| |
| # Test with prefix |
| model_a = FakeGPT() |
| model_a.foo(prefix="abc") |
| a_names = model_a.names |
| |
| model_b = FakeGPT() |
| opt_model_b = torch._dynamo.optimize("eager", nopython=True)(model_b) |
| opt_model_b.foo(prefix="abc") |
| |
| self.assertEqual(a_names, model_b.names) |
| |
| def test_numpy_variable_isinstance(self): |
| def fn(x, m): |
| if isinstance(m, np.ndarray): |
| return x + 1 |
| else: |
| return x - 1 |
| |
| x = torch.tensor([2.3]) |
| m = np.array([1, 2, 3]) |
| ref = fn(x, m) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| res = opt_fn(x, m) |
| self.assertEqual(ref, res) |
| |
| def test_tensor_dot_grad_no_graph_break(self): |
| def fn(a, b): |
| y = 3 * a**3 - b**2 |
| y.backward(gradient=torch.tensor([1.0, 1.0])) |
| b.grad.zero_() |
| return a.grad, b.grad |
| |
| a = torch.tensor([2.0, 3.0], requires_grad=True) |
| b = torch.tensor([6.0, 4.0], requires_grad=True) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| _, b_grad = opt_fn(a, b) |
| self.assertTrue(same(b_grad, torch.tensor([0.0, 0.0]))) |
| self.assertEqual(cnts.frame_count, 2) |
| |
| def test_torch_nn_parameter_isinstance(self): |
| def fn(x): |
| a = torch.nn.Parameter(torch.rand(2, 3)) |
| if isinstance(a, torch.Tensor): |
| return x + 1 |
| else: |
| return x - 1 |
| |
| x = torch.tensor([2.5]) |
| ref = fn(x) |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| res = opt_fn(x) |
| self.assertEqual(ref, res) |
| |
| @torch._dynamo.config.patch(raise_on_backend_change=True) |
| def test_change_backends(self): |
| @torch._dynamo.optimize("eager", nopython=True) |
| def fn1(): |
| return x + 1 |
| |
| @torch._dynamo.optimize("ts") |
| def fn2(): |
| return x + 2 |
| |
| @torch._dynamo.optimize("eager", nopython=False) |
| def fn3(): |
| return x + 1 |
| |
| x = torch.tensor([3, 5]) |
| |
| fn1() |
| fn1() |
| fn3() |
| self.assertRaises(torch._dynamo.exc.ResetRequired, fn2) |
| fn1() |
| torch._dynamo.reset() |
| fn2() |
| fn2() |
| self.assertRaises(torch._dynamo.exc.ResetRequired, fn1) |
| self.assertRaises(torch._dynamo.exc.ResetRequired, fn3) |
| fn2() |
| |
| def test_dynamo_min_operator_with_shape(self): |
| @torch._dynamo.optimize("eager", nopython=True) |
| def f(x, a): |
| return min(x.shape[0], a) |
| |
| result = f(torch.ones(6), 3) |
| self.assertEqual(result, 3) |
| |
| @patch.object(torch._dynamo.config, "dynamic_shapes", True) |
| def test_onnx_shape_as_tensor(self): |
| @torch._dynamo.optimize("eager", nopython=True) |
| def f(x): |
| return 1 + torch._shape_as_tensor(x)[0] |
| |
| gm, _ = torch._dynamo.export(f, torch.ones(6)) |
| |
| input_one_dim = torch.ones(6) |
| input_two_dims = torch.ones(7, 4) |
| self.assertEqual(f(input_one_dim), 7) |
| self.assertEqual(f(input_two_dims), 8) |
| self.assertEqual(f(input_two_dims), 8) |
| |
| @torch._dynamo.optimize("eager", nopython=True) |
| def f_onnx(x): |
| return 1 + torch.onnx.operators.shape_as_tensor(x)[0] |
| |
| self.assertEqual(f_onnx(input_one_dim), 7) |
| self.assertEqual(f_onnx(input_two_dims), 8) |
| self.assertEqual(f_onnx(input_two_dims), 8) |
| |
| def test_cond(self): |
| from functorch.experimental.control_flow import cond |
| |
| def true_fn(x): |
| return x.sin() |
| |
| def false_fn(x): |
| return x.cos() |
| |
| def f(pred, x): |
| return cond(pred, true_fn, false_fn, [x]) |
| |
| opt_fn = torch._dynamo.optimize("eager")(f) |
| a = opt_fn(torch.tensor(False), torch.tensor([0.25, 0.25])) |
| self.assertTrue(same(torch.cos(torch.tensor([0.25, 0.25])), a)) |
| b = opt_fn(torch.tensor(True), torch.tensor([0.25, 0.25])) |
| self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), b)) |
| |
| def test_nonzero_static(self): |
| # invalid size |
| with self.assertRaisesRegex( |
| RuntimeError, "nonzero_static: 'size' must be an non-negative integer" |
| ): |
| torch.nonzero_static(torch.tensor([8]), size=-2) |
| |
| with self.assertRaisesRegex( |
| RuntimeError, "nonzero_static: 'size' must be an non-negative integer" |
| ): |
| torch.nonzero_static(torch.tensor([8]), size=-2, out=torch.tensor(0)) |
| |
| # nonzero_static.out: out dtype mismatch |
| input_tensor = torch.tensor([8]) |
| static_size = 1 |
| out_tensor = torch.empty((static_size, input_tensor.dim()), dtype=torch.float) |
| with self.assertRaisesRegex( |
| RuntimeError, "nonzero_static: Expected out tensor to have scalar type Long" |
| ): |
| torch.nonzero_static(input_tensor, size=static_size, out=out_tensor) |
| |
| # nonzero_static.out: out resize (shrink) |
| input_tensor = torch.tensor([8]) |
| static_size = 1 |
| out_tensor = torch.empty((10, 10, 10, 10), dtype=torch.long) |
| self.assertTrue( |
| same( |
| torch.nonzero_static(input_tensor, size=static_size, out=out_tensor), |
| torch.tensor([0]), |
| ) |
| ) |
| self.assertTrue( |
| same( |
| out_tensor, |
| torch.tensor([0]), |
| ) |
| ) |
| |
| # nonzero_static.out: out resize (enlarge) |
| input_tensor = torch.tensor([8]) |
| static_size = 1 |
| out_tensor = torch.empty((0), dtype=torch.long) |
| self.assertTrue( |
| same( |
| torch.nonzero_static(input_tensor, size=static_size, out=out_tensor), |
| torch.tensor([0]), |
| ) |
| ) |
| self.assertTrue( |
| same( |
| out_tensor, |
| torch.tensor([0]), |
| ) |
| ) |
| |
| # 0 rank |
| input_tensor = torch.tensor(6) |
| static_size = 2 |
| self.assertTrue( |
| same( |
| torch.nonzero_static(input_tensor, size=static_size), |
| torch.empty((static_size, input_tensor.dim()), dtype=torch.long), |
| ) |
| ) |
| |
| # 0 size |
| input_tensor = torch.tensor([[[1]]]) |
| static_size = 0 |
| self.assertTrue( |
| same( |
| torch.nonzero_static(input_tensor, size=static_size), |
| torch.empty((static_size, input_tensor.dim()), dtype=torch.long), |
| ) |
| ) |
| |
| # 1D input |
| input_tensor = torch.tensor([0, 8]) |
| static_size = 1 |
| self.assertTrue( |
| same( |
| torch.nonzero_static(input_tensor, size=static_size), |
| torch.tensor([1]), |
| ) |
| ) |
| |
| input_tensor = torch.tensor([8, 0]) |
| static_size = 2 |
| self.assertTrue( |
| same( |
| torch.nonzero_static(input_tensor, size=static_size), |
| torch.tensor([[0], [-1]]), # padded with default fill_value "-1" |
| ) |
| ) |
| |
| # 2D input |
| input_tensor = torch.tensor([[1.2, 0], [3.4, 5.6]]) |
| static_size = 5 |
| fill_value = -100 |
| self.assertTrue( |
| torch._dynamo.utils.same( |
| torch.nonzero_static( |
| input_tensor, size=static_size, fill_value=fill_value |
| ), |
| torch.tensor( |
| [ |
| [0, 0], |
| [1, 0], |
| [1, 1], |
| [fill_value, fill_value], |
| [fill_value, fill_value], |
| ] |
| ), |
| ) |
| ) |
| input_tensor = torch.tensor([[1.2, 0], [3.4, 5.6]]) |
| static_size = 2 |
| fill_value = -100 |
| self.assertTrue( |
| torch._dynamo.utils.same( |
| torch.nonzero_static( |
| input_tensor, size=static_size, fill_value=fill_value |
| ), |
| torch.tensor([[0, 0], [1, 0]]), |
| ) |
| ) |
| |
| # 3D input |
| input_tensor = torch.tensor([[[0, 0], [0, -3]], [[0, 0], [5, 0]]]) |
| static_size = 4 |
| fill_value = -999 |
| self.assertTrue( |
| torch._dynamo.utils.same( |
| torch.nonzero_static( |
| input_tensor, |
| size=static_size, |
| fill_value=fill_value, |
| ), |
| torch.tensor( |
| [ |
| [0, 1, 1], |
| [1, 1, 0], |
| [fill_value, fill_value, fill_value], |
| [fill_value, fill_value, fill_value], |
| ] |
| ), |
| ) |
| ) |
| |
| def test_cond_with_quantization(self): |
| from functorch.experimental.control_flow import cond |
| |
| class MyModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| example_inputs = (torch.randn(5, 5),) |
| self.model = torch.nn.Linear(5, 5) |
| self.quantized_model = prepare_qat_fx( |
| self.model, qconfig_dict, example_inputs=example_inputs |
| ) |
| |
| def forward(self, pred, x): |
| def true_fn(x): |
| return x.sin() + self.quantized_model(x) |
| |
| def false_fn(x): |
| return x.cos() + self.model(x) |
| |
| return cond(pred, true_fn, false_fn, [x]) |
| |
| module = MyModule() |
| opt_m = torch._dynamo.optimize("eager", nopython=True)(module) |
| x = torch.rand((5, 5)) |
| pred = torch.tensor(True) |
| self.assertTrue(same(module(pred, x), opt_m(pred, x))) |
| pred = torch.tensor(False) |
| self.assertTrue(same(module(pred, x), opt_m(pred, x))) |
| |
| def test_map_with_quantization(self): |
| from functorch.experimental.control_flow import map |
| |
| class MyModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| example_inputs = (torch.randn(5, 5),) |
| self.model = torch.nn.Linear(5, 5) |
| self.quantized_model = prepare_qat_fx( |
| self.model, qconfig_dict, example_inputs=example_inputs |
| ) |
| |
| def forward(self, x): |
| def body(x): |
| return x.sin() + self.quantized_model(x) |
| |
| return map(body, x) |
| |
| module = MyModule() |
| opt_m = torch._dynamo.optimize("eager", nopython=True)(module) |
| x = torch.rand((5, 5)) |
| self.assertTrue(same(module(x), opt_m(x))) |
| |
| def test_cond_side_effects(self): |
| from functorch.experimental.control_flow import cond |
| |
| c = 0 |
| |
| def true_fn(x): |
| return x - c |
| |
| def false_fn(x): |
| return x + c |
| |
| def f(pred, x): |
| nonlocal c |
| c = 1 |
| return cond(pred, true_fn, false_fn, [x]) |
| |
| opt_fn = torch._dynamo.optimize("eager")(f) |
| c = 0 |
| a = opt_fn(torch.tensor(False), torch.tensor([0.25, 0.25])) |
| self.assertTrue(same(torch.tensor([1.25, 1.25]), a)) |
| |
| def test_map_side_effects(self): |
| from functorch.experimental.control_flow import map |
| |
| class Module(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.w = torch.tensor(1) |
| |
| def forward(self, xs): |
| def body(x): |
| self.w += 1 |
| return x |
| |
| return map(body, xs) |
| |
| mod = Module() |
| with self.assertRaisesRegex( |
| TypeError, "missing 1 required positional argument" |
| ): |
| opt_fn = torch._dynamo.optimize("eager", nopython=True)(mod) |
| opt_fn(torch.randn(3, 2)) |
| |
| def test_cond_nested(self): |
| from functorch.experimental.control_flow import cond |
| |
| def true_fn_nested(x): |
| return x * 10 |
| |
| def false_fn_nested(x): |
| return x * -1 |
| |
| def true_fn(pred2, x): |
| return x.sin() |
| |
| def false_fn(pred2, x): |
| return x + cond(pred2, true_fn_nested, false_fn_nested, [x]) |
| |
| def f(pred, pred2, x): |
| return cond(pred, true_fn, false_fn, [pred2, x]) |
| |
| cc = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cc)(f) |
| true_true_sin = opt_fn( |
| torch.tensor(True), torch.tensor(True), torch.tensor([0.25, 0.25]) |
| ) |
| self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_true_sin)) |
| |
| true_false_sin = opt_fn( |
| torch.tensor(True), torch.tensor(False), torch.tensor([0.25, 0.25]) |
| ) |
| self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_false_sin)) |
| |
| false_true_sum_mult = opt_fn( |
| torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25]) |
| ) |
| self.assertTrue( |
| same(torch.tensor([2.75, 2.75]), false_true_sum_mult) |
| ) # * 10 then add x |
| |
| false_false_sum_neg = opt_fn( |
| torch.tensor(False), torch.tensor(False), torch.tensor([0.25, 0.25]) |
| ) |
| self.assertTrue( |
| same(torch.tensor([0.0, 0.0]), false_false_sum_neg) |
| ) # * -1 then add x |
| self.assertTrue(cc.frame_count, 2) |
| |
| def test_cond_export(self): |
| from functorch.experimental.control_flow import cond |
| |
| def true_fn_nested(x): |
| return x * 10 |
| |
| def false_fn_nested(x): |
| return x * -1 |
| |
| def true_fn(pred2, x): |
| return x.sin() |
| |
| def false_fn(pred2, x): |
| return x + cond(pred2, true_fn_nested, false_fn_nested, [x]) |
| |
| def f(pred, pred2, x): |
| return cond(pred, true_fn, false_fn, [pred2, x]) |
| |
| graph, guard = torch._dynamo.export( |
| f, torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25]) |
| ) |
| true_true_sin = graph( |
| torch.tensor(True), torch.tensor(True), torch.tensor([0.25, 0.25]) |
| ) |
| self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_true_sin)) |
| |
| true_false_sin = graph( |
| torch.tensor(True), torch.tensor(False), torch.tensor([0.25, 0.25]) |
| ) |
| self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_false_sin)) |
| |
| false_true_sum_mult = graph( |
| torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25]) |
| ) |
| self.assertTrue( |
| same(torch.tensor([2.75, 2.75]), false_true_sum_mult) |
| ) # * 10 then add x |
| |
| false_false_sum_neg = graph( |
| torch.tensor(False), torch.tensor(False), torch.tensor([0.25, 0.25]) |
| ) |
| self.assertTrue( |
| same(torch.tensor([0.0, 0.0]), false_false_sum_neg) |
| ) # * -1 then add x |
| |
| def test_cond_export_single_arg(self): |
| from functorch.experimental.control_flow import cond |
| |
| def true_fn(x): |
| return x |
| |
| def false_fn(x): |
| return x.sin() |
| |
| def f(pred, x): |
| return cond(pred, true_fn, false_fn, [x]) |
| |
| graph, guard = torch._dynamo.export( |
| f, torch.tensor(False), torch.tensor([0.25, 0.25]) |
| ) |
| true_mirror = graph(torch.tensor(True), torch.tensor([0.25, 0.25])) |
| self.assertTrue(same(torch.tensor([0.25, 0.25]), true_mirror)) |
| true_mirror_2 = graph(torch.tensor(True), torch.tensor([0.33, 0.33, 0.33])) |
| self.assertTrue(same(torch.tensor([0.33, 0.33, 0.33]), true_mirror_2)) |
| |
| false_sin = graph(torch.tensor(False), torch.tensor([0.5, 0.5])) |
| self.assertTrue(same(torch.sin(torch.tensor([0.5, 0.5])), false_sin)) |
| |
| def test_enum_guards(self): |
| class MyEnum(enum.Enum): |
| FOO = 10 |
| BAR = 20 |
| |
| def fn(x, y): |
| if y == MyEnum.FOO: |
| return x + 1 |
| else: |
| return x - 1 |
| |
| x = torch.rand(3) |
| y = MyEnum.BAR |
| ref = fn(x, y) |
| opt_fn = torch.compile(backend="eager")(fn) |
| res = opt_fn(x, y) |
| self.assertTrue(same(ref, res)) |
| |
| @patch.object(torch._dynamo.config, "print_graph_breaks", True) |
| def test_duplicate_graph_break_warning(self): |
| @torch._dynamo.optimize("eager") |
| def f1(a, b): |
| f2(a, b) |
| |
| def f2(a, b): |
| c = a + b |
| print("break") |
| return a + b + c |
| |
| @torch._dynamo.optimize("eager") |
| def g1(a, b): |
| g2(a, b) |
| |
| def g2(a, b): |
| c = a + b |
| print("break") |
| return a + b + c |
| |
| def count_graph_break_msgs(msgs): |
| return sum(msg.find("Graph break") != -1 for msg in msgs) |
| |
| with self.assertLogs(logger="torch._dynamo", level=logging.WARNING) as log: |
| torch._dynamo.config.verbose = True |
| f1(torch.randn(10), torch.randn(10)) |
| self.assertGreater(count_graph_break_msgs(log.output), 1) |
| |
| with self.assertLogs(logger="torch._dynamo", level=logging.WARNING) as log: |
| torch._dynamo.config.verbose = False |
| g1(torch.randn(10), torch.randn(10)) |
| self.assertEqual(count_graph_break_msgs(log.output), 1) |
| |
| def test_inplace_param_update(self): |
| def fn(param, y): |
| prev_grad = torch.is_grad_enabled() |
| try: |
| torch.set_grad_enabled(False) |
| torch.set_grad_enabled(True) |
| torch.set_grad_enabled(False) |
| param.add_(y) |
| finally: |
| torch.set_grad_enabled(prev_grad) |
| |
| y = torch.randn(4) |
| x = torch.nn.Parameter(torch.randn(4)) |
| fn(x, y) |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) |
| opt_fn(x, y) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 3) |
| |
| @unittest.skipIf( |
| not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, |
| "Can't run fused SDPA on this platform", |
| ) |
| def test_parsing_sdpa(self): |
| class MyModule(torch.nn.Module): |
| def forward(self, query, key, value): |
| out = F.scaled_dot_product_attention(query, key, value, None, 0, True) |
| out = F.scaled_dot_product_attention( |
| query, key, value, None, 0, True, scale=8 |
| ) |
| out = F.scaled_dot_product_attention( |
| query=query, |
| key=key, |
| value=value, |
| attn_mask=None, |
| dropout_p=0, |
| is_causal=True, |
| ) |
| out = F.scaled_dot_product_attention( |
| query, |
| key=key, |
| value=value, |
| attn_mask=None, |
| dropout_p=0, |
| is_causal=True, |
| ) |
| out = F.scaled_dot_product_attention( |
| query, key, value, None, dropout_p=0, is_causal=True |
| ) |
| out = F.scaled_dot_product_attention(query, key, value, None, scale=8) |
| return out |
| |
| device = "cuda" |
| dtype = torch.float16 |
| seq_len_q = 1 |
| seq_len_k = 1 |
| head_dim = 8 |
| query = torch.ones( |
| 1, 8, seq_len_q, head_dim, device=device, dtype=dtype, requires_grad=True |
| ) |
| key = torch.ones( |
| 1, 8, seq_len_k, head_dim, device=device, dtype=dtype, requires_grad=True |
| ) |
| value = torch.ones( |
| 1, 8, seq_len_k, head_dim, device=device, dtype=dtype, requires_grad=True |
| ) |
| module = MyModule() |
| opt_mod = torch._dynamo.optimize("inductor")(module) |
| opt_mod(query, key, value) |
| |
| def test_generate_tensor_from_list_of_numpy_primitive_type(self): |
| # Test sth like torch.LongTensor(list(np.int64, np.int64, ...)) |
| def fn(): |
| x = np.array([1, 2, 3, 4, 5, 6], dtype=np.int64) |
| y = [x[0], x[2], x[4]] |
| z = torch.LongTensor(y) |
| return z |
| |
| ref = fn() |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| res = opt_fn() |
| self.assertTrue(same(ref, res)) |
| |
| def test_autograd_function_equivalence(self): |
| for i in range(1, 5): |
| model = globals()[f"Module{i}"]() |
| opt_model = torch._dynamo.optimize("eager", nopython=True)(model) |
| self.assertTrue( |
| torch.allclose(opt_model(torch.ones(2, 3)), torch.tensor([2.0])) |
| ) |
| |
| def test_autograd_function_has_graph_break(self): |
| x = torch.randn(10) |
| for model in [Module5(), Module6()]: |
| torch._dynamo.reset() |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_model = torch._dynamo.optimize(cnts)(model) |
| for _ in range(3): |
| ref = model(x) |
| res = opt_model(x) |
| self.assertTrue(torch.allclose(ref, res)) |
| self.assertEqual(cnts.frame_count, 2) |
| |
| def test_object_classmethod(self): |
| class C: |
| @classmethod |
| def fn(cls, x): |
| return x + x |
| |
| @torch._dynamo.optimize("eager", nopython=True) |
| def f(): |
| return C().fn(torch.ones(2, 3)) |
| |
| self.assertTrue(torch.allclose(f(), torch.tensor([2.0]))) |
| |
| def test_object_staticmethod(self): |
| class C: |
| @staticmethod |
| def fn(x): |
| return x + x |
| |
| @torch._dynamo.optimize("eager", nopython=True) |
| def f(): |
| return C().fn(torch.ones(2, 3)) |
| |
| self.assertTrue(torch.allclose(f(), torch.tensor([2.0]))) |
| |
| def test_user_function_variable_supports_enum_argument(self): |
| class Foo(enum.Enum): |
| FOO = 0 |
| BAR = 1 |
| |
| def gn(x, y=Foo.FOO): |
| if y is Foo.FOO: |
| return x |
| else: |
| return x + 1 |
| |
| def fn(x): |
| return gn(x) |
| |
| x = torch.randn(2, 3) |
| ref = fn(x) |
| opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) |
| res = opt_fn(x) |
| self.assertTrue(torch.allclose(ref, res)) |
| |
| def test_user_function_variable_supports_type_abcmeta_argument(self): |
| class Foo(metaclass=abc.ABCMeta): |
| @abc.abstractclassmethod |
| def read(self): |
| pass |
| |
| class Bar(Foo): |
| def read(self): |
| return "Hello World!" |
| |
| class Baz: |
| pass |
| |
| def gn(x, tys=(Bar, Baz)): |
| if Bar in tys: |
| return x - 1 |
| else: |
| return x + 1 |
| |
| def fn(x): |
| return gn(x) |
| |
| x = torch.randn(2, 3) |
| ref = fn(x) |
| opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) |
| res = opt_fn(x) |
| self.assertTrue(torch.allclose(ref, res)) |
| |
| def test_user_function_variable_supports_function_argument(self): |
| # Test user defined function default arguments can be: |
| # 1, user defined functions (e.g, add1) |
| # 2, torch functions (e.g, torch.sin) |
| # 3, python builtin functions (e.g, operator.neg) |
| def add1(x): |
| return x + 1 |
| |
| def gn(x, f1=add1, f2=torch.sin, f3=operator.neg): |
| return f3(f2(f1(x))) |
| |
| def fn(x): |
| return gn(x) |
| |
| x = torch.randn(2, 3) |
| ref = fn(x) |
| opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) |
| res = opt_fn(x) |
| self.assertTrue(torch.allclose(ref, res)) |
| |
| def test_typing_variable_isinstance(self): |
| def fn(x, m): |
| if isinstance(m, typing.Mapping): |
| return x + 1 |
| else: |
| return x - 1 |
| |
| x = torch.randn(2, 3) |
| m = {"x": torch.randn(3)} |
| ref = fn(x, m) |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| res = opt_fn(x, m) |
| self.assertTrue(torch.allclose(ref, res)) |
| |
| def test_repro_graph_breaks_in__get_item_by_idx(self): |
| class Mod(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.mod = torch.nn.Sequential( |
| torch.nn.Linear(3, 3), torch.nn.Linear(3, 3) |
| ) |
| |
| def forward(self, x): |
| return self.mod[0](x) |
| |
| m = Mod() |
| graph, _ = torch._dynamo.export(m, torch.randn(3, 3)) |
| |
| def test_nn_sequential_invocation(self): |
| with freeze_rng_state(): |
| |
| class TestModel(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.linears = torch.nn.Sequential( |
| torch.nn.Linear(2, 2), |
| torch.nn.Linear(2, 2), |
| torch.nn.Linear(2, 2), |
| torch.nn.Linear(2, 2), |
| ) |
| |
| def forward(self, x): |
| all_but_last = self.linears[:-1] |
| return all_but_last(x) |
| |
| m = TestModel() |
| x = torch.rand((2, 2)) |
| real = m(x) |
| graph, _ = torch._dynamo.export(m, x) |
| dynamo_result = graph(x) |
| self.assertTrue(same(real, dynamo_result)) |
| |
| def test_nn_sequential_invocation_reposition_indices(self): |
| with freeze_rng_state(): |
| |
| class TestModel(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.linears = torch.nn.Sequential( |
| torch.nn.Linear(2, 2), |
| torch.nn.Linear(2, 2), |
| torch.nn.Linear(2, 2), |
| torch.nn.Linear(2, 2), |
| ) |
| |
| def forward(self, x): |
| all_but_last = self.linears[1:3] |
| return all_but_last(x) |
| |
| m = TestModel() |
| x = torch.rand((2, 2)) |
| real = m(x) |
| graph, _ = torch._dynamo.export(m, x) |
| dynamo_result = graph(x) |
| self.assertTrue(same(real, dynamo_result)) |
| |
| def test_error_on_nested_fx_trace(self): |
| input = torch.rand(2, 3) |
| |
| def f(x): |
| x + x |
| |
| real = f(input) |
| |
| optimized = torch._dynamo.optimize("eager")(f) |
| self.assertTrue(same(optimized(input), real)) |
| |
| with self.assertRaisesRegex(RuntimeError, "Detected that you are using FX"): |
| gm = torch.fx.symbolic_trace(optimized) |
| |
| @patch.object(torch._dynamo.config, "error_on_nested_fx_trace", False) |
| def test_no_error_on_nested_fx_trace(self): |
| input = torch.rand(2, 3) |
| |
| def f(x): |
| x + x |
| |
| real = f(input) |
| |
| optimized = torch._dynamo.optimize("eager")(f) |
| self.assertTrue(same(optimized(input), real)) |
| |
| # should not error |
| gm = torch.fx.symbolic_trace(optimized) |
| self.assertTrue(same(gm(input), real)) |
| |
| def test_not_dynamic_scope(self): |
| def f(y): |
| x = 1 |
| |
| def g(): |
| x = 2 |
| return lambda: x |
| |
| return y + g()() |
| |
| input = torch.zeros(1) |
| real = f(input) |
| optimized = torch._dynamo.optimize("eager")(f) |
| opt = optimized(input) |
| self.assertTrue(same(opt, real)) |
| |
| def test_inference_mode(self): |
| @torch.inference_mode() |
| def func(x, y): |
| return x.add(1.0) + y |
| |
| x = torch.ones(4, requires_grad=True) |
| y = torch.ones(4, requires_grad=True) |
| ref = func(x, y) |
| opt_func = torch._dynamo.optimize("eager")(func) |
| |
| x1 = torch.ones(4, requires_grad=True) |
| res = opt_func(x1, y) |
| self.assertTrue(same(ref, res)) |
| self.assertTrue(same(x, x1)) |
| |
| def test_if_cond_nn_mod(self): |
| class MockModule(torch.nn.Module): |
| def __init__(self, output_relu=True): |
| super().__init__() |
| self.relu = torch.nn.ReLU() if output_relu else None |
| |
| def forward(self, x): |
| x = torch.sin(x) |
| if self.relu: |
| x = self.relu(x) |
| return x |
| |
| model = MockModule() |
| opt_model = torch._dynamo.optimize("eager", nopython=True)(model) |
| |
| x = torch.rand(4) |
| ref = model(x) |
| res = opt_model(x) |
| self.assertTrue(same(ref, res)) |
| |
| model = MockModule(output_relu=False) |
| opt_model = torch._dynamo.optimize("eager", nopython=True)(model) |
| |
| x = torch.rand(4) |
| ref = model(x) |
| res = opt_model(x) |
| self.assertTrue(same(ref, res)) |
| |
| def test_if_cond_user_defined_object(self): |
| # obj.__bool__ is not existed |
| class A: # noqa: B903 |
| def __init__(self, x): |
| self.x = x |
| |
| # obj.__bool__ is function and returns bool type |
| class B: |
| def __init__(self, x): |
| self.x = x |
| |
| def __bool__(self): |
| return self.x > 0 |
| |
| # obj.__bool__ is non-function |
| class C: |
| def __init__(self, x): |
| self.x = x |
| self.__bool__ = False |
| |
| def fn(x, obj): |
| if not obj: |
| return x + 1 |
| else: |
| return x - 1 |
| |
| x = torch.rand(4) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) |
| obj1 = A(0.5) |
| obj2 = B(0.5) |
| obj3 = B(-0.5) |
| obj4 = C(0.5) |
| for obj in [obj1, obj2, obj3, obj4, obj3, obj2]: |
| ref = fn(x, obj) |
| res = opt_fn(x, obj) |
| self.assertTrue(same(ref, res)) |
| self.assertEqual(cnts.frame_count, 4) |
| |
| def test_if_cond_user_defined_object2(self): |
| # obj.__bool__ is function and returns non-bool type |
| class MyObj: |
| def __init__(self, x): |
| self.x = x |
| |
| def __bool__(self): |
| self.x = 1 |
| return self.x |
| |
| def fn(a, obj): |
| if not obj: |
| return a + obj.x |
| else: |
| return a - obj.x |
| |
| x = torch.rand(4) |
| obj = MyObj(0.5) |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| try: |
| opt_fn(x, obj) |
| self.assertFalse(True) |
| except TypeError as e: |
| self.assertIn("__bool__ should return bool, returned int", str(e)) |
| |
| def test_class_has_instancecheck_method(self): |
| class A: |
| pass |
| |
| class ExampleMeta(type): |
| def __instancecheck__(cls, instance): |
| return True |
| |
| class B(metaclass=ExampleMeta): |
| pass |
| |
| def fn(x, obj): |
| if isinstance(obj, B): |
| return x + 1 |
| else: |
| return x - 1 |
| |
| x = torch.rand(4) |
| obj = A() |
| ref = fn(x, obj) |
| opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) |
| res = opt_fn(x, obj) |
| self.assertTrue(same(ref, res)) |
| |
| def test_torch_cuda_is_available(self): |
| def fn(x): |
| if torch.cuda.is_available(): |
| return x + 1 |
| else: |
| return x - 1 |
| |
| x = torch.rand(4) |
| ref = fn(x) |
| opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) |
| res = opt_fn(x) |
| self.assertTrue(same(ref, res)) |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") |
| @unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn") |
| def test_torch_cudnn_is_acceptable(self): |
| def fn(x): |
| if torch.backends.cudnn.is_acceptable(tensor=x): |
| return x + 1 |
| return x |
| |
| x = torch.rand(4).cuda() |
| ref = fn(x) |
| opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) |
| res = opt_fn(x) |
| self.assertTrue(same(ref, res)) |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") |
| @unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn") |
| def test_torch_cudnn_is_acceptable_bad_inputs(self): |
| def fn1(x): |
| if torch.backends.cudnn.is_acceptable("invalid"): |
| return x + 1 |
| return x |
| |
| def fn2(x): |
| if torch.backends.cudnn.is_acceptable(x, 3.14): |
| return x + 1 |
| return x |
| |
| with self.assertRaisesRegex( |
| AssertionError, "Expect input to cudnn.is_acceptable to be a tensor" |
| ): |
| x1 = torch.rand(4).cuda() |
| opt_fn1 = torch._dynamo.optimize("eager", nopython=True)(fn1) |
| res1 = opt_fn1(x1) |
| |
| with self.assertRaisesRegex( |
| AssertionError, "Expect 1 input to cudnn.is_acceptable" |
| ): |
| x2 = torch.rand(4).cuda() |
| opt_fn2 = torch._dynamo.optimize("eager", nopython=True)(fn2) |
| res = opt_fn2(x2) |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") |
| def test_get_device(self): |
| def fn(x, y): |
| x = x + 1 |
| y = y + 1 |
| return x.get_device(), y.get_device() |
| |
| x = torch.rand(4, device="cuda") |
| y = torch.rand(4, device="cpu") |
| ref = fn(x, y) |
| opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) |
| res = opt_fn(x, y) |
| self.assertTrue(same(ref, res)) |
| |
| def test_disable_flag(self): |
| cnt = torch._dynamo.testing.CompileCounter() |
| |
| with patch.dict(os.environ, {"TORCH_COMPILE_DISABLE": "1"}): |
| |
| def fn(x, y): |
| x = x + 1 |
| y = y + 1 |
| |
| opt_fn = torch._dynamo.optimize(cnt) |
| |
| self.assertEqual(cnt.frame_count, 0) |
| |
| def test_is_compiling(self): |
| def f(): |
| if torch._dynamo.is_compiling(): |
| return torch.ones(2, 2) |
| else: |
| return torch.zeros(2, 2) |
| |
| opt_f = torch._dynamo.optimize("eager")(f) |
| |
| self.assertEqual(f(), torch.zeros(2, 2)) |
| self.assertEqual(opt_f(), torch.ones(2, 2)) |
| |
| def test_torch_generator_set_state(self): |
| def fn(): |
| default_state = torch.default_generator.get_state() |
| x = torch.rand([2, 3]) |
| torch._dynamo.graph_break() |
| torch.default_generator.set_state(default_state) |
| y = torch.rand([2, 3]) |
| return x, y |
| |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| x, y = opt_fn() |
| self.assertEqual(x, y) |
| |
| def test_guard_failure_fn(self): |
| def fn(x, y, k): |
| x = x + 1 |
| y = y + 1 |
| return x * y * k |
| |
| x = torch.tensor([0.5, 0.5]) |
| y = torch.tensor([1.0, 1.0]) |
| |
| guard_failure = None |
| |
| def guard_failures(failure): |
| nonlocal guard_failure |
| guard_failure = failure |
| |
| opt_fn = torch._dynamo.optimize( |
| "eager", nopython=True, guard_fail_fn=guard_failures |
| )(fn) |
| |
| x2 = torch.tensor([0.5, 0.5, 1.0]) |
| y2 = torch.tensor([0.5, 0.5, 0.5]) |
| |
| opt_fn(x, y, 3) |
| opt_fn(x2, y2, 5) |
| |
| if ( |
| torch._dynamo.config.dynamic_shapes |
| and not torch._dynamo.config.specialize_int |
| and not torch._dynamo.config.assume_static_by_default |
| ): |
| # we didn't actually test guard_failure_fn here but whatever, |
| # nice to see no guard failure on the test |
| self.assertTrue(guard_failure is None) |
| else: |
| self.assertTrue(guard_failure is not None) |
| if not torch._dynamo.config.dynamic_shapes: |
| self.assertExpectedInline(guard_failure[0], """L['k'] == 3""") |
| |
| @patch.object(torch._dynamo.config, "dynamic_shapes", True) |
| def test_guard_failure_fn_shape_control(self): |
| def fn(x, y): |
| if x.shape[0] < 3: |
| if y.shape[0] < 3: |
| return x * y |
| else: |
| return x + y |
| else: |
| return -1 |
| |
| x = torch.randn([2, 2]) |
| y = torch.randn([2, 2]) |
| |
| guard_failure = None |
| |
| def guard_failures(failure): |
| nonlocal guard_failure |
| guard_failure = failure |
| |
| opt_fn = torch._dynamo.optimize( |
| "eager", nopython=True, guard_fail_fn=guard_failures |
| )(fn) |
| |
| x2 = torch.randn([5, 5]) |
| y2 = torch.randn([5, 5]) |
| |
| opt_fn(x, y) |
| opt_fn(x2, y2) |
| |
| self.assertTrue(guard_failure is not None) |
| if torch._dynamo.config.assume_static_by_default: |
| self.assertExpectedInline( |
| guard_failure[0], |
| """tensor 'L['x']' size mismatch at index 0. expected 2, actual 5""", |
| ) |
| else: |
| self.assertExpectedInline(guard_failure[0], """L['x'].size()[0] < 3""") |
| |
| def test_guard_failure_fn2(self): |
| def fn(x, y): |
| x = x + 1 |
| y = y + 1 |
| return x * y |
| |
| x = torch.tensor([0.5, 0.5]) |
| y = torch.tensor([1.0, 1.0]) |
| |
| guard_failure = None |
| |
| def guard_failures(failure): |
| nonlocal guard_failure |
| guard_failure = failure |
| |
| opt_fn = torch._dynamo.optimize( |
| "eager", nopython=True, guard_fail_fn=guard_failures |
| )(fn) |
| |
| x2 = torch.tensor([0.5, 0.5, 1.0]) |
| y2 = torch.tensor([0.5, 0.5, 0.5]) |
| |
| opt_fn(x, y) |
| opt_fn(x2, y2) |
| |
| if torch._dynamo.config.dynamic_shapes: |
| if torch._dynamo.config.assume_static_by_default: |
| self.assertExpectedInline( |
| guard_failure[0], |
| """tensor 'L['x']' size mismatch at index 0. expected 2, actual 3""", |
| ) |
| else: |
| self.assertTrue(guard_failure is None) |
| else: |
| self.assertTrue(guard_failure is not None) |
| self.assertExpectedInline( |
| guard_failure[0], |
| """tensor 'L['x']' size mismatch at index 0. expected 2, actual 3""", |
| ) |
| |
| def test_guard_failure_fn_tensor_iter(self): |
| def fn(x): |
| for y in x: |
| y.add_(1.0) |
| return y |
| |
| guard_failure = None |
| |
| def guard_failures(failure): |
| nonlocal guard_failure |
| guard_failure = failure |
| |
| opt_fn = torch._dynamo.optimize( |
| "eager", nopython=True, guard_fail_fn=guard_failures |
| )(fn) |
| |
| args1 = torch.randn(10, 10) |
| out = fn(args1) |
| opt_out = opt_fn(args1) |
| self.assertTrue(same(out, opt_out)) |
| |
| args2 = torch.randn(9, 10) |
| out = fn(args2) |
| opt_out = opt_fn(args2) |
| self.assertTrue(same(out, opt_out)) |
| |
| # guard is expected for both static and dynamic shapes |
| self.assertTrue(guard_failure is not None) |
| self.assertExpectedInline(guard_failure[0], """len(L['x']) == 10""") |
| |
| def test_restore_graphstate(self): |
| # This function does some guard accumulation, |
| # and then rolls back due to control flow. |
| # The idea is that if one were printing guards as they appear, |
| # they would see this insert a guard that does not show up in the final set of |
| # guards as we rolled back from it. |
| def nested_fn(s): |
| if x[0] < 10: |
| return s * s |
| return s |
| |
| def fn(x, y): |
| x = x + 1 |
| y = nested_fn(y) |
| y = y + 10 |
| return x * y |
| |
| all_guards = [] |
| |
| def guard_export_print(guards): |
| nonlocal all_guards |
| all_guards.extend(guards) |
| |
| opt_fn = torch._dynamo.optimize("eager", guard_export_fn=guard_export_print)(fn) |
| |
| x = torch.tensor([0.5, 0.5]) |
| y = torch.tensor([1.0, 1.0]) |
| opt_fn(x, y) |
| |
| for guard in all_guards: |
| # This guard was created |
| self.assertTrue(guard.name != "nested_fn.__closure__[0].cell_contents") |
| |
| # Note - here be mild dragons. |
| # This test relies a ton on internal implementation. Future refactor efforts |
| # are welcome to delete it if necessary, rewriting this test constantly is a chore, not |
| # a feature. We kept it around with some amount of saddness, as it was extremely useful in debugging. |
| def test_restore_graphstate_internals(self): |
| def fn(x, y): |
| x = x + 1 |
| y = y + 1 |
| return x * y |
| |
| _, guards = torch._dynamo.export( |
| fn, torch.tensor([0.25, 0.25]), torch.tensor([0.25, 0.25]) |
| ) |
| # Dummy ctor |
| graph = OutputGraph( |
| f_globals={}, |
| code_options={}, |
| compiler_fn=None, |
| root_tx=None, |
| export=False, |
| export_constraints=None, |
| frame_state={"_id": 0}, |
| ) |
| # Contrived property so as not to have it be None |
| graph.nn_modules = {} |
| graph.nn_modules_sources = {} |
| # Contrived generation timestamp |
| graph.timestamp = 4 |
| # Contrived guards |
| graph.tracing_context.guards_context.dynamo_guards = guards |
| |
| # Save the state |
| state = graph.copy_graphstate() |
| # Saving increments the generation |
| self.assertEqual(graph.timestamp, 5) |
| |
| # Assure that the saved state is valid |
| self.assertEqual(state.timestamp, 4) |
| |
| # Ensure that the guards reflect the expected state |
| self.assertEqual(graph.tracing_context.guards_context.dynamo_guards, guards) |
| self.assertEqual(graph.guards, guards) |
| |
| # Mess around with the state |
| graph.tracing_context.guards_context.dynamo_guards = set() |
| self.assertEqual(graph.guards, set()) |
| |
| # Restore the state |
| graph.restore_graphstate(state) |
| |
| # Make sure it restored correctly |
| self.assertEqual(graph.timestamp, 4) |
| self.assertEqual(graph.guards, guards) |
| self.assertEqual(graph.tracing_context.guards_context.dynamo_guards, guards) |
| |
| def test_call_parent_non_class_methods_from_child(self): |
| class A: |
| def add(self, x): |
| return x + 10 |
| |
| def mul(self, x): |
| return x * 0.1 |
| |
| class B(A): |
| def add(self, x): |
| return x + 20 |
| |
| def mul(self, x): |
| return x * 0.2 |
| |
| class C(B): |
| def add(self, x): |
| y = A.add(self, x) |
| z = B.mul(self, y) |
| return z + 30 |
| |
| x = torch.rand(4) |
| fn = C().add |
| ref = fn(x) |
| opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) |
| res = opt_fn(x) |
| self.assertTrue(same(ref, res)) |
| |
| def test_builder_for_class_with_metaclass(self): |
| class ExampleMeta(type): |
| pass |
| |
| class MyClass(metaclass=ExampleMeta): |
| pass |
| |
| def fn(x, y): |
| if isinstance(y, MyClass): |
| return x + 1 |
| else: |
| return x - 1 |
| |
| x = torch.rand([4, 4]) |
| y = MyClass() |
| ref = fn(x, y) |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| res = opt_fn(x, y) |
| self.assertTrue(same(ref, res)) |
| |
| def test_tuple_from_tuple_iter(self): |
| def inner_fn(*args): |
| acc = torch.ones(10, 10) |
| for arg in args: |
| acc.add_(arg) |
| |
| return acc |
| |
| @torch._dynamo.optimize("eager") |
| def fn(inputs, params): |
| y = tuple(inputs) + tuple(params) |
| return inner_fn(*y) |
| |
| inputs = [torch.randn(10, 10) for _ in range(3)] |
| |
| fn(inputs, iter(tuple(inputs))) |
| |
| def test_torch_package_working_with_trace(self): |
| # from torch._dynamo.test_case import run_tests |
| |
| inputs = [torch.randn([2, 2]), torch.randn([2, 2])] |
| |
| optimized_model = torch._dynamo.optimize(backend="eager")( |
| MyPickledModule(torch.randn([2, 2])) |
| ) |
| from torch import package |
| |
| path = "/tmp/MyPickledModule.pt" |
| package_name = "MyPickledModule" |
| resource_name = "MyPickledModule.pkl" |
| |
| model = MyPickledModule(torch.randn([2, 2])) |
| |
| with package.PackageExporter(path) as exp: |
| exp.extern("**") |
| exp.save_pickle(package_name, resource_name, model) |
| |
| imp = package.PackageImporter(path) |
| loaded_model = imp.load_pickle(package_name, resource_name) |
| |
| optimized_loaded_model = torch._dynamo.optimize("eager")(loaded_model)(*inputs) |
| |
| def test_shape_and_tuple_equality(self): |
| def fn(x, y, t): |
| z = x * y |
| if x.size() == t: |
| return z.cos() |
| return z.sin() |
| |
| torch._dynamo.optimize("eager", nopython=True)(fn)( |
| torch.randn([4, 4]), torch.randn([4, 4]), (4, 4) |
| ) |
| |
| def test_int_list(self): |
| # if dynamic_shapes == True: unspec int list |
| # if dynamic_shapes == False: spec int list |
| def fn(x, y): |
| return torch.sin(x + y[1] % 2) |
| |
| x = torch.randn(6) |
| cnt = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnt)(fn) |
| for i in range(10, 25, 3): |
| y = [i, i + 1, i + 2] |
| ref = fn(x, y) |
| res = opt_fn(x, y) |
| self.assertTrue(same(ref, res)) |
| self.assertEqual(cnt.frame_count, ifunspec(ifdyn(1, 5), 5)) |
| |
| # specifically test for tensor.attribute -> torch.something() |
| def test_real_imag_tensor_attribute(self): |
| def fn(x, y): |
| a = x.real |
| b = x.imag |
| return torch.mul(torch.add(a, y), b) |
| |
| x_real = torch.rand((4, 4)) |
| x_imag = torch.rand((4, 4)) |
| x = torch.complex(x_real, x_imag) |
| y = torch.rand((4, 4)) |
| |
| ref = fn(x, y) |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| res = opt_fn(x, y) |
| self.assertTrue(same(ref, res)) |
| |
| def test_T_tensor_attribute(self): |
| def fn(x, y): |
| a = x.T |
| return torch.add(a, y) |
| |
| x = torch.rand((4, 4)) |
| y = torch.rand((4, 4)) |
| |
| ref = fn(x, y) |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| res = opt_fn(x, y) |
| self.assertTrue(same(ref, res)) |
| |
| def test_recursive_tensor_attribute(self): |
| def fn(x, y): |
| a = x.real.T |
| b = x.imag |
| return torch.mul(torch.add(a, y), b) |
| |
| x_real = torch.rand((4, 4)) |
| x_imag = torch.rand((4, 4)) |
| x = torch.complex(x_real, x_imag) |
| y = torch.rand((4, 4)) |
| |
| ref = fn(x, y) |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| res = opt_fn(x, y) |
| self.assertTrue(same(ref, res)) |
| |
| def test_tagging_tensors_simple(self): |
| def foo(x, y): |
| return x * y, x, y |
| |
| a = torch.randn([3, 3]) |
| a.tag = "a" |
| a.frog = "ribbity ribbit" |
| b = torch.randn([3, 3]) |
| b.tag = "b" |
| b.frog = "ribbit" |
| |
| exported = torch._dynamo.export(foo, a, b) |
| out_graph = exported[0] |
| |
| nodes = list(out_graph.graph.nodes) |
| placeholders = [node for node in nodes if node.op == "placeholder"] |
| all_tags = [] |
| all_frogs = [] |
| for placeholder in placeholders: |
| if "tensor_dict" in placeholder.meta: |
| all_tags.append(placeholder.meta["tensor_dict"]["tag"]) |
| all_frogs.append(placeholder.meta["tensor_dict"]["frog"]) |
| |
| self.assertEqual(all_tags, ["a", "b"]) |
| self.assertEqual(all_frogs, ["ribbity ribbit", "ribbit"]) |
| |
| def test_tagging_tensors_mix_used_unused_structure(self): |
| def pre_attention_state_ops(input, mems, state): |
| lc_key = state[0] |
| lc_val = state[1] |
| bar = [] |
| for i in range(0, 4): |
| bar2 = [] |
| for j in range(0, 3): |
| bar2.append( |
| lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1]) |
| ) |
| bar.append(bar2) |
| |
| return bar |
| |
| mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]]) |
| state = [ |
| torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]), |
| torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]), |
| ] |
| i = torch.tensor( |
| [ |
| [0.0313, -0.1487, -0.3846, -0.5321], |
| [-1.7073, 1.3331, -0.0890, -1.4935], |
| [-0.8314, -0.1862, -0.5935, 1.5232], |
| ] |
| ) |
| |
| mems.tag = "MEMS" |
| i.tag = "FOO" |
| state[0].tag = "STATE_0" |
| state[1].tag = "HMMM" |
| |
| exported = torch._dynamo.export(pre_attention_state_ops, i, mems, state) |
| out_graph = exported[0] |
| |
| nodes = list(out_graph.graph.nodes) |
| placeholders = [node for node in nodes if node.op == "placeholder"] |
| all_tags = [] |
| for placeholder in placeholders: |
| if "tensor_dict" in placeholder.meta: |
| all_tags.append(placeholder.meta["tensor_dict"]["tag"]) |
| |
| self.assertEqual(all_tags, ["STATE_0", "HMMM"]) |
| |
| def test_get_custom_tensor_attribute(self): |
| def fn(x): |
| return x.custom_attr * x |
| |
| x = torch.rand((2, 2)) |
| x.custom_attr = 3.14 |
| ref = fn(x) |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| res = opt_fn(x) |
| self.assertTrue(same(ref, res)) |
| |
| def test_set_custom_tensor_attribute(self): |
| def fn(x): |
| x.custom_attr = 3.14 |
| return x.custom_attr * x |
| |
| x = torch.rand((2, 2)) |
| ref = fn(x) |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| res = opt_fn(x) |
| self.assertTrue(same(ref, res)) |
| |
| def test_if_tensor_is_none(self): |
| """ |
| Python 3.11 adds new jump instructions that check if |
| TOS is None. We do not support these instructions. |
| """ |
| |
| def f(x, y): |
| z = 1 |
| if x is None: |
| z *= 2 |
| if y is not None: |
| z *= 3 |
| return z |
| |
| opt_f = torch._dynamo.optimize("eager", nopython=True)(f) |
| self.assertEqual(opt_f(None, torch.ones(2)), 6) |
| |
| if sys.version_info >= (3, 11): |
| insts = bytecode_transformation.cleaned_instructions(f.__code__) |
| for inst in insts: |
| self.assertNotIn("_NONE", inst.opname) |
| |
| @skipIfNotPy311 |
| def test_py311_jump_offset(self): |
| new_inst = bytecode_transformation.create_instruction |
| load_global = bytecode_transformation.create_load_global |
| consts = (None, 1, 2, 3, 4) |
| |
| def create_test_code(jump_opname, target_idx): |
| targets = [ |
| new_inst("LOAD_CONST", argval=1), |
| new_inst("LOAD_CONST", argval=3), |
| ] |
| jump_to_target_inst = new_inst(jump_opname, target=targets[target_idx]) |
| """ |
| pseudocode of generated bytecode: |
| def test_py311_fn(): |
| goto target1 |
| target0: |
| return 1 |
| target1: |
| goto [target0/target2] (via fwd or bwd jump) |
| return 2 |
| target2: |
| return 3 |
| return 4 |
| """ |
| # test with LOAD_GLOBAL since it has a different instruction size |
| insts = [ |
| new_inst("RESUME", arg=0), |
| new_inst("JUMP_FORWARD", target=jump_to_target_inst), |
| targets[0], |
| load_global("print", False), |
| new_inst("POP_TOP"), |
| new_inst("RETURN_VALUE"), |
| jump_to_target_inst, |
| new_inst("LOAD_CONST", argval=2), |
| load_global("print", False), |
| new_inst("POP_TOP"), |
| new_inst("RETURN_VALUE"), |
| targets[1], |
| new_inst("RETURN_VALUE"), |
| new_inst("LOAD_CONST", argval=4), |
| new_inst("RETURN_VALUE"), |
| ] |
| code_options = collections.OrderedDict( |
| [ |
| ("co_argcount", 0), |
| ("co_posonlyargcount", 0), |
| ("co_kwonlyargcount", 0), |
| ("co_nlocals", 0), |
| ("co_stacksize", 2), |
| ("co_flags", 3), |
| ("co_code", b""), |
| ("co_consts", consts), |
| ("co_names", ("print",)), |
| ("co_varnames", ()), |
| ("co_filename", __file__), |
| ("co_name", "test_py311_fn"), |
| ("co_qualname", "test_py311_fn"), |
| ("co_firstlineno", 1), |
| ("co_linetable", b""), |
| ("co_exceptiontable", b""), |
| ("co_freevars", ()), |
| ("co_cellvars", ()), |
| ] |
| ) |
| return bytecode_transformation.clean_and_assemble_instructions( |
| insts, |
| list(code_options.keys()), |
| code_options, |
| ) |
| |
| # format: jump_opname, target_idx, expected forward jump, expected return value |
| test_args = ( |
| ("JUMP_FORWARD", 0, False, 1), |
| ("JUMP_FORWARD", 1, True, 3), |
| ("JUMP_BACKWARD", 0, False, 1), |
| ("JUMP_BACKWARD", 1, True, 3), |
| ) |
| |
| for test in test_args: |
| insts, code = create_test_code(test[0], test[1]) |
| # check if offset of latest jump instruction is forward/backward |
| for inst in reversed(insts): |
| if inst.opname.startswith("JUMP"): |
| if test[2]: |
| self.assertIn("FORWARD", inst.opname) |
| else: |
| self.assertIn("BACKWARD", inst.opname) |
| break |
| # run the code and check result |
| |
| def dummy_fn(): |
| pass |
| |
| dummy_fn.__code__ = code |
| self.assertEqual(dummy_fn(), test[3]) |
| |
| dummy_opt = torch._dynamo.optimize("eager")(dummy_fn) |
| self.assertEqual(dummy_opt(), test[3]) |
| |
| def test_exception_table_encode_varint(self): |
| # these numbers have no real meaning to them |
| nums = [ |
| 0b111_101010_000000, |
| 0b1100_111000_010101_101010, |
| ] |
| b = bytecode_transformation.encode_exception_table_varint( |
| nums[0] |
| ) + bytecode_transformation.encode_exception_table_varint(nums[1]) |
| nums_new = [] |
| b_iter = iter(bytes(b)) |
| while True: |
| try: |
| nums_new.append( |
| bytecode_transformation.decode_exception_table_varint(b_iter) |
| ) |
| except StopIteration: |
| break |
| self.assertEqual(nums, nums_new) |
| |
| @skipIfNotPy311 |
| def test_exception_table_parsing(self): |
| def fn(): |
| try: |
| with a(): |
| b() |
| c() |
| except Exception: |
| d() |
| finally: |
| e() |
| f() |
| |
| tab = bytecode_transformation.parse_exception_table( |
| fn.__code__.co_exceptiontable |
| ) |
| b = bytecode_transformation.assemble_exception_table(tab) |
| self.assertEqual(b, fn.__code__.co_exceptiontable) |
| |
| @skipIfNotPy311 |
| def test_exception_table_e2e(self): |
| def fn(): |
| try: |
| with a(): |
| b() |
| c() |
| except Exception: |
| d() |
| finally: |
| e() |
| f() |
| |
| def nothing(*args): |
| pass |
| |
| code = bytecode_transformation.transform_code_object(fn.__code__, nothing) |
| self.assertEqual(code.co_exceptiontable, fn.__code__.co_exceptiontable) |
| |
| @skipIfNotPy311 |
| def test_exception_table_e2e_2(self): |
| # last instructions of an exn_table entry is a large instruction |
| # i.e., LOAD_GLOBAL a |
| def fn(): |
| try: |
| return a |
| except Exception: |
| pass |
| |
| def nothing(*args): |
| pass |
| |
| code = bytecode_transformation.transform_code_object(fn.__code__, nothing) |
| self.assertEqual(code.co_exceptiontable, fn.__code__.co_exceptiontable) |
| |
| @skipIfNotPy311 |
| def test_exception_table_entry_propagation(self): |
| insts = [] |
| for _ in range(10): |
| insts.append(bytecode_transformation.create_instruction("NOP")) |
| insts[8].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( |
| insts[0], insts[9], insts[0], 0, True |
| ) |
| insts[0].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( |
| insts[0], insts[0], insts[1], 0, True |
| ) |
| insts[1].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( |
| insts[0], insts[2], insts[2], 0, True |
| ) |
| insts[5].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( |
| insts[4], insts[6], insts[3], 0, True |
| ) |
| insts[9].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( |
| insts[9], insts[9], insts[4], 0, True |
| ) |
| insts[7].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( |
| insts[7], insts[9], insts[5], 0, True |
| ) |
| bytecode_transformation.propagate_inst_exn_table_entries(insts) |
| expected = [1, 2, 2, 0, 3, 3, 3, 5, 5, 4] |
| for inst, exp in zip(insts, expected): |
| self.assertIsNotNone(inst.exn_tab_entry) |
| self.assertIs(inst.exn_tab_entry.target, insts[exp]) |
| |
| @skipIfNotPy311 |
| def test_compute_exception_table_nested(self): |
| insts = [] |
| for _ in range(20): |
| insts.append(bytecode_transformation.create_instruction("NOP")) |
| insts[10].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( |
| insts[1], insts[10], insts[0], 0, True |
| ) |
| insts[0].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( |
| insts[1], insts[1], insts[1], 0, True |
| ) |
| insts[1].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( |
| insts[1], insts[3], insts[2], 0, True |
| ) |
| insts[5].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( |
| insts[5], insts[7], insts[3], 0, True |
| ) |
| insts[9].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( |
| insts[10], insts[10], insts[4], 0, True |
| ) |
| insts[7].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( |
| insts[8], insts[10], insts[5], 0, True |
| ) |
| insts[14].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( |
| insts[13], insts[17], insts[6], 0, True |
| ) |
| insts[16].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( |
| insts[15], insts[16], insts[7], 0, True |
| ) |
| bytecode_transformation.update_offsets(insts) |
| tab = bytecode_transformation.compute_exception_table(insts) |
| expected = [ |
| (1, 1, 1), |
| (2, 3, 2), |
| (4, 4, 0), |
| (5, 7, 3), |
| (8, 9, 5), |
| (10, 10, 4), |
| (13, 14, 6), |
| (15, 16, 7), |
| (17, 17, 6), |
| ] |
| self.assertEquals(len(tab), len(expected)) |
| for entry, exp in zip(tab, expected): |
| self.assertEquals(entry.start, exp[0] * 2) |
| self.assertEquals(entry.end, exp[1] * 2) |
| self.assertEquals(entry.target, exp[2] * 2) |
| |
| @skipIfNotPy311 |
| def test_remove_dead_code_with_exn_table_entries(self): |
| create_instruction = bytecode_transformation.create_instruction |
| target1 = create_instruction("NOP") |
| target2 = create_instruction("NOP") |
| target3 = create_instruction("NOP") |
| exn_start = create_instruction("NOP") |
| exn_end = create_instruction("NOP") |
| insts = [ |
| create_instruction("JUMP_FORWARD", target=target1), |
| exn_start, # dead |
| target1, |
| create_instruction("JUMP_FORWARD", target=target3), |
| exn_end, # dead |
| target2, |
| target3, |
| ] |
| exn_start.exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( |
| exn_start, exn_end, target2, 0, True |
| ) |
| bytecode_transformation.propagate_inst_exn_table_entries(insts) |
| insts = bytecode_analysis.remove_dead_code(insts) |
| self.assertEquals(len(insts), 5) |
| self.assertNotIn(exn_start, insts) |
| self.assertNotIn(exn_end, insts) |
| self.assertIn(target2, insts) |
| self.assertIn(target3, insts) |
| bytecode_transformation.update_offsets(insts) |
| tab = bytecode_transformation.compute_exception_table(insts) |
| self.assertEquals(len(tab), 1) |
| self.assertEquals(tab[0].start, 2) |
| self.assertEquals(tab[0].end, 4) |
| self.assertEquals(tab[0].target, 6) |
| |
| def test_ordered_dict_alias_reconstruct(self): |
| od = collections.OrderedDict |
| |
| def fn(): |
| d1 = dict() |
| d1["a"] = 1 |
| d2 = od(d1) |
| d2["b"] = 2 |
| torch._dynamo.graph_break() |
| if isinstance(d2, od): |
| return d2["a"] + d2["b"] |
| else: |
| return 0 |
| |
| dis.dis(fn) |
| self.assertEqual(torch._dynamo.optimize("eager")(fn)(), 3) |
| |
| @torch._dynamo.config.patch(dynamic_shapes=True) |
| def test_raise_guard_full_constraint(self): |
| y = torch.randn([3, 3, 3]) |
| |
| def my_dyn_fn(x): |
| if x.shape[0] == 3: |
| return x.sin() |
| return x.cos() |
| |
| torch._dynamo.mark_dynamic(y, 0) |
| with self.assertRaises(ConstraintViolationError): |
| torch._dynamo.optimize("eager")(my_dyn_fn)(y) |
| |
| def test_mark_static(self): |
| counter = CompileCounter() |
| |
| def my_dyn_fn(x): |
| return x.cos() |
| |
| y = torch.randn([3]) |
| torch._dynamo.mark_static(y, 0) |
| torch._dynamo.optimize(counter)(my_dyn_fn)(y) |
| |
| z = torch.randn([4]) |
| torch._dynamo.optimize(counter)(my_dyn_fn)(z) |
| |
| self.assertEqual(counter.frame_count, 2) |
| |
| @torch._dynamo.config.patch(dynamic_shapes=True) |
| def test_no_raise_guard_partial_constraint(self): |
| y = torch.randn([3, 3, 3]) |
| |
| def my_dyn_fn(x): |
| if x.shape[0] > 3: |
| return x.sin() |
| return x.cos() |
| |
| torch._dynamo.optimize("eager")(my_dyn_fn)(y) |
| torch._dynamo.mark_dynamic(y, 0) |
| torch._dynamo.reset() |
| torch._dynamo.optimize("eager")(my_dyn_fn)(y) |
| |
| @torch._dynamo.config.patch(dynamic_shapes=True) |
| def test_no_raise_guard_partial_constraint_across_break(self): |
| y = torch.randn([3, 3, 3]) |
| |
| def my_dyn_fn(x, y): |
| z = x * y |
| |
| torch._dynamo.graph_break() |
| if z.shape[0] > 2: |
| return z.cos() |
| |
| return x.cos() |
| |
| torch._dynamo.optimize("eager")(my_dyn_fn)(y, y) |
| torch._dynamo.mark_dynamic(y, 0) |
| torch._dynamo.reset() |
| torch._dynamo.optimize("eager")(my_dyn_fn)(y, y) |
| |
| # Sadly, this does not throw - we do not prop correctly across the graph break |
| @unittest.expectedFailure |
| @torch._dynamo.config.patch(dynamic_shapes=True) |
| def test_raise_guard_partial_constraint_across_break(self): |
| y = torch.randn([3, 3, 3]) |
| |
| def my_dyn_fn(x, y): |
| z = x * y |
| |
| torch._dynamo.graph_break() |
| if z.shape[0] == 3: |
| return z.cos() |
| |
| return x.cos() |
| |
| torch._dynamo.optimize("eager")(my_dyn_fn)(y, y) |
| torch._dynamo.mark_dynamic(y, 0) |
| torch._dynamo.reset() |
| with self.assertRaisesRegex( |
| Exception, |
| ): |
| torch._dynamo.optimize("eager")(my_dyn_fn)(y, y) |
| |
| @torch._dynamo.config.patch(dynamic_shapes=True) |
| def test_raise_guard_partial_constraint_no_graph_break(self): |
| y = torch.randn([3, 3, 3]) |
| |
| def my_dyn_fn(x, y): |
| z = x * y |
| |
| if z.shape[0] == 3: |
| return z.cos() |
| |
| return x.cos() |
| |
| torch._dynamo.mark_dynamic(y, 0) |
| with self.assertRaises(ConstraintViolationError): |
| torch._dynamo.optimize("eager")(my_dyn_fn)(y, y) |
| |
| def test_cannot_trace_mark_dynamic(self): |
| y = torch.randn([3, 3, 3]) |
| |
| def my_dyn_fn(x): |
| torch._dynamo.mark_dynamic(x, 0) |
| return x * x |
| |
| with self.assertRaisesRegex( |
| AssertionError, "Attempt to trace forbidden callable" |
| ): |
| torch._dynamo.optimize("eager")(my_dyn_fn)(y) |
| |
| def test_cannot_trace_mark_dynamic_safe_unreached(self): |
| y = torch.randn([3, 3, 3]) |
| |
| def my_dyn_fn(x): |
| if x.shape[0] == 3: |
| return x |
| print("Running", torch._dynamo.mark_dynamic(x, 0)) |
| return x * x |
| |
| torch._dynamo.optimize("eager")(my_dyn_fn)(y) |
| |
| @torch._dynamo.config.patch(dynamic_shapes=True) |
| def test_py_guards_mark_dynamic(self): |
| def my_dyn_fn(a): |
| if a.shape[0] > 2: |
| return a.cos() |
| return a.sin() |
| |
| counter = CompileCounter() |
| |
| # Run with dynamic |
| x0 = torch.randn([3, 3, 3]) |
| torch._dynamo.mark_dynamic(x0, 0) |
| torch._dynamo.optimize(counter)(my_dyn_fn)(x0) |
| self.assertEqual(counter.frame_count, 1) |
| |
| # Run without dynamic, no recompile |
| x = torch.randn([3, 3, 3]) |
| torch._dynamo.optimize(counter)(my_dyn_fn)(x) |
| self.assertEqual(counter.frame_count, 1) |
| |
| # Mark a new dim, 1, as dynamic |
| x1 = torch.randn([3, 3, 3]) |
| torch._dynamo.mark_dynamic(x1, 1) |
| torch._dynamo.optimize(counter)(my_dyn_fn)(x1) |
| # Recompile triggered because we marked a new dym as dynamic |
| self.assertEqual(counter.frame_count, 2) |
| |
| # Reset |
| torch._dynamo.reset() |
| # Reset counter |
| counter = CompileCounter() |
| |
| # Run with dynamic 1 |
| torch._dynamo.optimize(counter)(my_dyn_fn)(x1) |
| self.assertEqual(counter.frame_count, 1) |
| |
| # Run with dynamic 0, not subset |
| torch._dynamo.optimize(counter)(my_dyn_fn)(x0) |
| self.assertEqual(counter.frame_count, 2) |
| |
| # Run with dynamic 0, 1, 2, not subset |
| x012 = torch.randn([3, 3, 3]) |
| torch._dynamo.mark_dynamic(x012, 0) |
| torch._dynamo.mark_dynamic(x012, 1) |
| torch._dynamo.mark_dynamic(x012, 2) |
| torch._dynamo.optimize(counter)(my_dyn_fn)(x012) |
| self.assertEqual(counter.frame_count, 3) |
| |
| def test_torch_compile_ctx_on_forward_and_training_step(self): |
| class MyModel(torch.nn.Module): |
| def forward(self): |
| ... |
| |
| def training_step(self): |
| self() |
| |
| model = MyModel() |
| compiled_model = torch.compile(model) |
| |
| model.forward = compiled_model.dynamo_ctx(model.forward) |
| model.training_step = compiled_model.dynamo_ctx(model.training_step) |
| |
| model.training_step() |
| |
| def test_torch_guards_stack_frame_register_inlining(self): |
| x = torch.tensor([0.5, 0.5]) |
| y = torch.tensor([0.75, 0.75, 0.75, 0.75]) |
| z = torch.tensor([0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25]) |
| |
| def uwu_inline_me(x, y, z): |
| r = torch.cat((x, x)) + y |
| r2 = torch.cat((y, y)) + z |
| return r, r2 |
| |
| def fn(x, y, z): |
| r, r2 = uwu_inline_me(x, y, z) |
| return torch.mul(r, r), torch.mul(r2, r2) |
| |
| seen_frames = [] |
| import contextlib |
| |
| @contextlib.contextmanager |
| def global_context_capture_fn(frame_summary): |
| seen_frames.append(frame_summary) |
| yield |
| |
| with mock.patch( |
| "torch._guards.TracingContext.current_frame", |
| side_effect=global_context_capture_fn, |
| ): |
| torch._dynamo.optimize("eager")(fn)(x, y, z) |
| |
| self.assertEqual(len(seen_frames), 1) |
| self.assertEqual(seen_frames[0].name, "fn") |
| self.assertEqual(seen_frames[0].line, "r, r2 = uwu_inline_me(x, y, z)") |
| |
| def test_torch_guards_stack_frame_register_inlining_deep(self): |
| x = torch.tensor([0.5, 0.5]) |
| y = torch.tensor([0.75, 0.75, 0.75, 0.75]) |
| z = torch.tensor([0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25]) |
| |
| def uwu_inline_me_deep(x, y): |
| return torch.cat((x, x)) + y |
| |
| def uwu_inline_me(x, y, z): |
| r = uwu_inline_me_deep(x, y) |
| r2 = uwu_inline_me_deep(y, z) |
| return r, r2 |
| |
| def fn(x, y, z): |
| r, r2 = uwu_inline_me(x, y, z) |
| return torch.mul(r, r), torch.mul(r2, r2) |
| |
| seen_frames = [] |
| import contextlib |
| |
| @contextlib.contextmanager |
| def global_context_capture_fn(frame_summary): |
| seen_frames.append(frame_summary) |
| yield |
| |
| with mock.patch( |
| "torch._guards.TracingContext.current_frame", |
| side_effect=global_context_capture_fn, |
| ): |
| torch._dynamo.optimize("eager")(fn)(x, y, z) |
| |
| self.assertEqual(len(seen_frames), 3) |
| self.assertEqual(seen_frames[0].name, "fn") |
| self.assertEqual(seen_frames[1].name, "uwu_inline_me") |
| self.assertEqual(seen_frames[2].line, "r2 = uwu_inline_me_deep(y, z)") |
| |
| def test_error_on_recompile(self): |
| @torch._dynamo.optimize("eager") |
| def fn(a, b): |
| return a + b |
| |
| with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): |
| with self.assertRaises(torch._dynamo.exc.RecompileError): |
| fn(torch.rand(2, 3), torch.rand(2, 3)) |
| fn(torch.rand(2, 3), (1, 2, 3)) |
| |
| def test_compile_profiler(self): |
| class Model(torch.nn.Module): |
| def forward(self, input): |
| return input + input |
| |
| model = Model() |
| prof = CompileProfiler() |
| compiled = torch.compile(model, backend=prof) |
| base_checker = ( |
| lambda: FileCheck() |
| .check("Torchdynamo Profiler Report") |
| .check("Graph Breaks") |
| .check("No graph breaks detected.") |
| .check("Recompilation") |
| ) |
| input = torch.rand((2, 3, 4)) |
| _ = compiled(input) |
| base_checker().check("No recompilation detected.").run(prof.report()) |
| |
| new_shape_input = torch.rand((3, 3, 4)) |
| _ = compiled(new_shape_input) |
| |
| # Not an exhaustive test of dynamic shapes behavior, but some sanity |
| if ( |
| not torch._dynamo.config.dynamic_shapes |
| or torch._dynamo.config.assume_static_by_default |
| ): |
| base_checker().check("Recompile Reasons").check("'forward'").check( |
| "cache_size_limit to 1" |
| ).run(prof.report()) |
| else: |
| base_checker().check("No recompilation detected.").run(prof.report()) |
| |
| # Ensure correct guard fail message is selected to show to user |
| if not torch._dynamo.config.dynamic_shapes: |
| new_shape_input = torch.rand((4, 3, 4)) |
| _ = compiled(new_shape_input) |
| |
| base_checker().check("Recompile Reasons").check("'forward'").check( |
| "tensor 'L['input']' size mismatch at index 0. expected 2, actual 3" |
| ).check( |
| "tensor 'L['input']' size mismatch at index 0. expected 3, actual 4" |
| ).run( |
| prof.report() |
| ) |
| |
| def test_guards_strip_function_call(self): |
| from torch._dynamo.guards import strip_function_call |
| |
| test_case = [ |
| ("___odict_getitem(a, 1)", "a"), |
| ("a.layers[slice(2)][0]._xyz", "a"), |
| ("getattr(a.layers[slice(2)][0]._abc, '0')", "a"), |
| ("getattr(getattr(a.x[3], '0'), '3')", "a"), |
| ("a.layers[slice(None, -1, None)][0]._xyz", "a"), |
| ("a.layers[func('offset', -1, None)][0]._xyz", "a"), |
| ] |
| # strip_function_call should extract the object from the string. |
| for name, expect_obj in test_case: |
| self.assertEqual(strip_function_call(name), expect_obj) |
| |
| def test_int_neg(self): |
| def int_neg(a, b): |
| x = a.shape[0] |
| y = b.shape[0] |
| return -x * -y * a * b |
| |
| torch._dynamo.testing.standard_test(self, int_neg, 2) |
| |
| def test_hash_getitem_slice(self): |
| s = GetItemSource(LocalSource("foo"), slice(None, -1, None)) |
| s2 = GetItemSource(LocalSource("foo"), slice(None, -1, None)) |
| s3 = GetItemSource(LocalSource("foo"), slice(None, -1, 2)) |
| some_set = set() |
| |
| self.assertTrue(s not in some_set) |
| self.assertTrue(s2 not in some_set) |
| self.assertTrue(s3 not in some_set) |
| |
| some_set.add(s) |
| |
| self.assertTrue(s in some_set) |
| # s and s2 should hash the same |
| self.assertTrue(s2 in some_set) |
| # s3 should be different |
| self.assertTrue(s3 not in some_set) |
| |
| self.assertTrue(s == s2) |
| self.assertTrue(s != s3) |
| |
| |
| class CustomFunc1(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, foo): |
| return foo + foo |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| return grad_output |
| |
| |
| class CustomFunc2(torch.autograd.Function): |
| # the forward function can be staticmethod or classmethod |
| @classmethod |
| def forward(cls, ctx, foo): |
| return foo + foo |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| return grad_output |
| |
| |
| class CustomFunc3(torch.autograd.Function): |
| # Test there is graph break in forward function |
| @staticmethod |
| def forward(ctx, foo): |
| result = foo + foo |
| torch._dynamo.graph_break() |
| result = result + foo |
| ctx.save_for_backward(result) |
| return result |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| (result,) = ctx.saved_tensors |
| return grad_output * math.sqrt(result.numel()) |
| |
| |
| class Module1(torch.nn.Module): |
| def forward(self, foo): |
| return CustomFunc1().apply(foo) |
| |
| |
| class Module2(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.fn = CustomFunc1.apply |
| |
| def forward(self, foo): |
| return self.fn(foo) |
| |
| |
| class Module3(torch.nn.Module): |
| def forward(self, foo): |
| return CustomFunc2().apply(foo) |
| |
| |
| class Module4(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.fn = CustomFunc2.apply |
| |
| def forward(self, foo): |
| return self.fn(foo) |
| |
| |
| class Module5(torch.nn.Module): |
| def forward(self, foo): |
| return CustomFunc3().apply(foo) |
| |
| |
| class Module6(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.fn = CustomFunc3.apply |
| |
| def forward(self, foo): |
| return self.fn(foo) |
| |
| |
| class TestTracer(JitTestCase): |
| def test_jit_save(self): |
| def fn(): |
| class Foo(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = 3 |
| |
| @torch.jit.export |
| def __getstate__(self): |
| return (3, self.training) |
| |
| @torch.jit.export |
| def __setstate__(self, state): |
| self.a = state[0] |
| self.training = state[1] |
| |
| def forward(self, x): |
| return x + self.a |
| |
| f = Foo() |
| |
| return torch.jit.trace(f, (torch.rand(3, 4),)) |
| |
| fn() |
| opt_fn = torch._dynamo.optimize("eager")(fn) |
| opt_fn() |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| run_tests() |