| # Owner(s): ["module: dynamo"] |
| |
| import collections |
| import re |
| import sys |
| import time |
| from io import StringIO |
| |
| import torch._dynamo.test_case |
| import torch._dynamo.testing |
| from torch._dynamo.comptime import comptime |
| |
| |
| # Because we don't support free variables in comptime at the moment, |
| # we have to communicate via globals. This also means these tests cannot |
| # be run in parallel in a single process (not that you'd... ever want |
| # to do that?) |
| FILE = None |
| SELF = None |
| |
| |
| class ComptimeTests(torch._dynamo.test_case.TestCase): |
| def test_print_single(self): |
| global FILE |
| FILE = StringIO() |
| cnt = torch._dynamo.testing.CompileCounter() |
| |
| def comptime_print(e): |
| @comptime |
| def _(ctx): |
| ctx.print(ctx.get_local("e"), file=FILE) |
| |
| Employee = collections.namedtuple("Employee", ["name", "id"]) |
| |
| class mylist(list): |
| pass |
| |
| @torch._dynamo.optimize(cnt, dynamic=True) |
| def f(x): |
| y = x * 2 |
| comptime_print(y) |
| comptime_print(2) |
| comptime_print([y, 2]) |
| comptime_print((y, 2)) |
| comptime_print({"foo": y}) |
| comptime_print(range(1, 3)) |
| comptime_print(Employee("foo", 2)) |
| comptime_print(mylist([1, 2])) |
| comptime_print(collections.defaultdict(lambda: None)) |
| comptime_print(set()) |
| comptime_print({"a", "b"}) |
| comptime_print(x.size(0)) |
| return y + 3 |
| |
| f(torch.randn(2)) |
| self.assertEqual(cnt.frame_count, 1) |
| self.assertExpectedInline( |
| FILE.getvalue().strip(), |
| """\ |
| FakeTensor(..., size=(s0,)) |
| 2 |
| [FakeTensor(..., size=(s0,)), 2] |
| (FakeTensor(..., size=(s0,)), 2) |
| {'foo': FakeTensor(..., size=(s0,))} |
| range(1, 3, 1) |
| Employee(name='foo', id=2) |
| [1, 2] |
| defaultdict(NestedUserFunctionVariable(), {}) |
| set() |
| {'a','b'} |
| s0""", |
| ) |
| |
| def test_print_graph(self): |
| global FILE |
| FILE = StringIO() |
| cnt = torch._dynamo.testing.CompileCounter() |
| |
| @torch._dynamo.optimize(cnt) |
| def f(x): |
| y = x * 2 |
| |
| @comptime |
| def _(ctx): |
| ctx.print_graph(verbose=False, file=FILE) |
| |
| # Test the compact notation doesn't error or graph break; |
| # you'll have to visually inspect to see that it printed |
| comptime.print_graph() |
| |
| return y + 3 |
| |
| f(torch.randn(2)) |
| self.assertEqual(cnt.frame_count, 1) |
| self.assertExpectedInline( |
| FILE.getvalue().strip(), |
| """\ |
| def forward(self, L_x_ : torch.Tensor): |
| l_x_ = L_x_ |
| y = l_x_ * 2; l_x_ = y = None""", |
| ) |
| |
| def test_print_disas(self): |
| global FILE |
| FILE = StringIO() |
| cnt = torch._dynamo.testing.CompileCounter() |
| |
| @torch._dynamo.optimize(cnt) |
| def f(x): |
| y = x * 2 |
| |
| @comptime |
| def _(ctx): |
| ctx.print_disas(file=FILE) |
| |
| comptime.print_disas() |
| |
| return y + 3 |
| |
| def munge_disas(s): |
| re.sub( |
| r"^(?: +\d+)?(?: +(-->)) \+\d+ ([A-Za-z0-9_]+)", |
| "\1 \3", |
| s, |
| flags=re.MULTILINE, |
| ) |
| |
| f(torch.randn(2)) |
| self.assertEqual(cnt.frame_count, 1) |
| out = FILE.getvalue() |
| # Check that the instruction offset is working |
| self.assertIn("-->", out) |
| # Check that the bytecode resembles what we expect |
| self.assertIn("STORE_FAST", out) |
| if sys.version_info < (3, 11): |
| self.assertIn("BINARY_MULTIPLY", out) |
| else: |
| self.assertIn("BINARY_OP", out) |
| |
| def test_print_value_stack(self): |
| global FILE |
| FILE = StringIO() |
| cnt = torch._dynamo.testing.CompileCounter() |
| |
| def g(x): |
| @comptime |
| def _(ctx): |
| ctx.print_value_stack(file=FILE, stacklevel=1) |
| |
| return x |
| |
| @torch._dynamo.optimize(cnt) |
| def f(x): |
| y = x + g(x) |
| |
| return y + comptime.print_value_stack_and_return(y * 2) |
| |
| f(torch.randn(2)) |
| self.assertEqual(cnt.frame_count, 1) |
| self.assertExpectedInline( |
| FILE.getvalue(), |
| """\ |
| - FakeTensor(..., size=(2,)) |
| """, |
| ) |
| |
| def test_print_locals(self): |
| global FILE |
| FILE = StringIO() |
| cnt = torch._dynamo.testing.CompileCounter() |
| |
| @torch._dynamo.optimize(cnt) |
| def f(x): |
| y = x * 2 |
| |
| @comptime |
| def _(ctx): |
| ctx.print_locals(file=FILE) |
| |
| comptime.print_locals() |
| |
| return y + 3 |
| |
| f(torch.randn(2)) |
| self.assertEqual(cnt.frame_count, 1) |
| self.assertExpectedInline( |
| FILE.getvalue(), |
| """\ |
| x = FakeTensor(..., size=(2,)) |
| y = FakeTensor(..., size=(2,)) |
| """, |
| ) |
| |
| # Just make sure it doesn't crash |
| def test_print_direct(self): |
| cnt = torch._dynamo.testing.CompileCounter() |
| |
| @torch._dynamo.optimize(cnt) |
| def f(x, z): |
| y = x * 2 |
| lambda: z |
| comptime.print(z) |
| return y + 3 |
| |
| f(torch.randn(2), torch.randn(2)) |
| |
| def test_sleep(self): |
| sleep_time = 5 |
| cnt = torch._dynamo.testing.CompileCounter() |
| |
| @torch._dynamo.optimize(cnt) |
| def f(x, z, should_sleep): |
| if should_sleep: |
| comptime.sleep(sleep_time) |
| y = x * 2 |
| return y + 3 |
| |
| start = time.time() |
| f(torch.randn(2), torch.randn(2), False) |
| total_no_sleep = time.time() - start |
| |
| start = time.time() |
| f(torch.randn(2), torch.randn(2), True) |
| total_with_sleep = time.time() - start |
| |
| self.assertTrue(total_with_sleep > sleep_time) |
| # Hopefully this won't be flaky |
| self.assertTrue(abs(total_with_sleep - sleep_time - total_no_sleep) < 3) |
| |
| # Just make sure it doesn't crash |
| def test_get_local_closure_variable(self): |
| global SELF |
| SELF = self |
| cnt = torch._dynamo.testing.CompileCounter() |
| |
| @torch._dynamo.optimize(cnt) |
| def f(x): |
| z = 3 |
| |
| def g(): |
| @comptime |
| def _(ctx): |
| r = ctx.get_local("z") |
| SELF.assertEqual(repr(r), "3") |
| |
| comptime.print(z) |
| return 2 |
| |
| y = x * g() |
| return y + 3 |
| |
| f(torch.randn(2)) |
| |
| def test_print_bt(self): |
| global FILE |
| FILE = StringIO() |
| cnt = torch._dynamo.testing.CompileCounter() |
| |
| def g(x): |
| @comptime |
| def _(ctx): |
| ctx.print_bt(file=FILE) |
| |
| comptime.print_bt() |
| |
| return x + 3 |
| |
| @torch._dynamo.optimize(cnt) |
| def f(x): |
| y = x * 2 |
| y = g(y) |
| return y + 3 |
| |
| def munge_filenames(s): |
| return re.sub(r'File "[^"]+", line \d+', 'File "X", line X', s) |
| |
| f(torch.randn(2)) |
| self.assertEqual(cnt.frame_count, 1) |
| bt = FILE.getvalue() |
| self.assertIn("y = g(y)", bt) |
| |
| def test_print_guards(self): |
| global FILE |
| FILE = StringIO() |
| cnt = torch._dynamo.testing.CompileCounter() |
| |
| @torch._dynamo.optimize(cnt) |
| def f(x): |
| y = x * 2 |
| |
| @comptime |
| def _(ctx): |
| ctx.print_guards(file=FILE) |
| |
| comptime.print_guards() |
| |
| return y + 3 |
| |
| f(torch.randn(2)) |
| self.assertEqual(cnt.frame_count, 1) |
| self.assertExpectedInline( |
| re.sub(r"\s+$", "", FILE.getvalue().rstrip(), flags=re.MULTILINE), |
| """\ |
| |
| local "L['x']" TENSOR_MATCH |
| { |
| 'guard_types': None, |
| 'code': None, |
| 'obj_weakref': None |
| 'guarded_class': None |
| } |
| global '' GRAD_MODE |
| { |
| 'guard_types': None, |
| 'code': None, |
| 'obj_weakref': None |
| 'guarded_class': None |
| } |
| global '' DETERMINISTIC_ALGORITHMS |
| { |
| 'guard_types': None, |
| 'code': None, |
| 'obj_weakref': None |
| 'guarded_class': None |
| } |
| global '' TORCH_FUNCTION_STATE |
| { |
| 'guard_types': None, |
| 'code': None, |
| 'obj_weakref': None |
| 'guarded_class': None |
| } |
| global '' DEFAULT_DEVICE |
| { |
| 'guard_types': None, |
| 'code': None, |
| 'obj_weakref': None |
| 'guarded_class': None |
| } |
| shape_env '' SHAPE_ENV |
| { |
| 'guard_types': None, |
| 'code': None, |
| 'obj_weakref': None |
| 'guarded_class': None |
| }""", |
| ) |
| |
| def test_graph_break(self): |
| cnt = torch._dynamo.testing.CompileCounter() |
| |
| @torch._dynamo.optimize(cnt) |
| def f(x): |
| y = x * 2 |
| |
| @comptime |
| def _(ctx): |
| pass |
| |
| return y + 3 |
| |
| f(torch.randn(2)) |
| self.assertEqual(cnt.frame_count, 1) |
| cnt.frame_count = 0 |
| |
| @torch._dynamo.optimize(cnt) |
| def g(x): |
| y = x * 2 |
| |
| @comptime |
| def _(ctx): |
| ctx.graph_break() |
| |
| y = y + 2 |
| |
| comptime.graph_break() |
| |
| return y * 3 |
| |
| g(torch.randn(2)) |
| self.assertEqual(cnt.frame_count, 3) |
| |
| def test_get_local(self): |
| global SELF, FILE |
| SELF = self |
| FILE = StringIO() |
| cnt = torch._dynamo.testing.CompileCounter() |
| |
| @torch._dynamo.optimize(cnt) |
| def f(x): |
| y = x * 2 |
| lit = 2 |
| |
| @comptime |
| def _(ctx): |
| y = ctx.get_local("y") |
| SELF.assertEqual(y.as_fake().size(0), 2) |
| SELF.assertEqual(y.size(0), 2) |
| # Trigger a graph write (TODO: this is not so |
| # useful right now as there's no way to make use |
| # of the output proxy; maybe it's useful for inserting |
| # side-effectful operations into the graph) |
| y.as_proxy() + 4 |
| ctx.print_graph(verbose=False, file=FILE) |
| SELF.assertIs(y.python_type(), torch.Tensor) |
| lit = ctx.get_local("lit") |
| SELF.assertEqual(lit.as_python_constant(), 2) |
| |
| return y + 3 |
| |
| f(torch.randn(2)) |
| self.assertEqual(cnt.frame_count, 1) |
| self.assertExpectedInline( |
| FILE.getvalue().strip(), |
| """\ |
| def forward(self, L_x_ : torch.Tensor): |
| l_x_ = L_x_ |
| y = l_x_ * 2; l_x_ = None |
| add = y + 4; y = add = None""", |
| ) |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| run_tests() |