blob: 28e8f15c737eb2237c8577196ce7ae3368d1d535 [file] [log] [blame] [edit]
# Owner(s): ["module: dynamo"]
import collections
import re
import sys
import time
from io import StringIO
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.comptime import comptime
# Because we don't support free variables in comptime at the moment,
# we have to communicate via globals. This also means these tests cannot
# be run in parallel in a single process (not that you'd... ever want
# to do that?)
FILE = None
SELF = None
class ComptimeTests(torch._dynamo.test_case.TestCase):
def test_print_single(self):
global FILE
FILE = StringIO()
cnt = torch._dynamo.testing.CompileCounter()
def comptime_print(e):
@comptime
def _(ctx):
ctx.print(ctx.get_local("e"), file=FILE)
Employee = collections.namedtuple("Employee", ["name", "id"])
class mylist(list):
pass
@torch._dynamo.optimize(cnt, dynamic=True)
def f(x):
y = x * 2
comptime_print(y)
comptime_print(2)
comptime_print([y, 2])
comptime_print((y, 2))
comptime_print({"foo": y})
comptime_print(range(1, 3))
comptime_print(Employee("foo", 2))
comptime_print(mylist([1, 2]))
comptime_print(collections.defaultdict(lambda: None))
comptime_print(set())
comptime_print({"a", "b"})
comptime_print(x.size(0))
return y + 3
f(torch.randn(2))
self.assertEqual(cnt.frame_count, 1)
self.assertExpectedInline(
FILE.getvalue().strip(),
"""\
FakeTensor(..., size=(s0,))
2
[FakeTensor(..., size=(s0,)), 2]
(FakeTensor(..., size=(s0,)), 2)
{'foo': FakeTensor(..., size=(s0,))}
range(1, 3, 1)
Employee(name='foo', id=2)
[1, 2]
defaultdict(NestedUserFunctionVariable(), {})
set()
{'a','b'}
s0""",
)
def test_print_graph(self):
global FILE
FILE = StringIO()
cnt = torch._dynamo.testing.CompileCounter()
@torch._dynamo.optimize(cnt)
def f(x):
y = x * 2
@comptime
def _(ctx):
ctx.print_graph(verbose=False, file=FILE)
# Test the compact notation doesn't error or graph break;
# you'll have to visually inspect to see that it printed
comptime.print_graph()
return y + 3
f(torch.randn(2))
self.assertEqual(cnt.frame_count, 1)
self.assertExpectedInline(
FILE.getvalue().strip(),
"""\
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
y = l_x_ * 2; l_x_ = y = None""",
)
def test_print_disas(self):
global FILE
FILE = StringIO()
cnt = torch._dynamo.testing.CompileCounter()
@torch._dynamo.optimize(cnt)
def f(x):
y = x * 2
@comptime
def _(ctx):
ctx.print_disas(file=FILE)
comptime.print_disas()
return y + 3
def munge_disas(s):
re.sub(
r"^(?: +\d+)?(?: +(-->)) \+\d+ ([A-Za-z0-9_]+)",
"\1 \3",
s,
flags=re.MULTILINE,
)
f(torch.randn(2))
self.assertEqual(cnt.frame_count, 1)
out = FILE.getvalue()
# Check that the instruction offset is working
self.assertIn("-->", out)
# Check that the bytecode resembles what we expect
self.assertIn("STORE_FAST", out)
if sys.version_info < (3, 11):
self.assertIn("BINARY_MULTIPLY", out)
else:
self.assertIn("BINARY_OP", out)
def test_print_value_stack(self):
global FILE
FILE = StringIO()
cnt = torch._dynamo.testing.CompileCounter()
def g(x):
@comptime
def _(ctx):
ctx.print_value_stack(file=FILE, stacklevel=1)
return x
@torch._dynamo.optimize(cnt)
def f(x):
y = x + g(x)
return y + comptime.print_value_stack_and_return(y * 2)
f(torch.randn(2))
self.assertEqual(cnt.frame_count, 1)
self.assertExpectedInline(
FILE.getvalue(),
"""\
- FakeTensor(..., size=(2,))
""",
)
def test_print_locals(self):
global FILE
FILE = StringIO()
cnt = torch._dynamo.testing.CompileCounter()
@torch._dynamo.optimize(cnt)
def f(x):
y = x * 2
@comptime
def _(ctx):
ctx.print_locals(file=FILE)
comptime.print_locals()
return y + 3
f(torch.randn(2))
self.assertEqual(cnt.frame_count, 1)
self.assertExpectedInline(
FILE.getvalue(),
"""\
x = FakeTensor(..., size=(2,))
y = FakeTensor(..., size=(2,))
""",
)
# Just make sure it doesn't crash
def test_print_direct(self):
cnt = torch._dynamo.testing.CompileCounter()
@torch._dynamo.optimize(cnt)
def f(x, z):
y = x * 2
lambda: z
comptime.print(z)
return y + 3
f(torch.randn(2), torch.randn(2))
def test_sleep(self):
sleep_time = 5
cnt = torch._dynamo.testing.CompileCounter()
@torch._dynamo.optimize(cnt)
def f(x, z, should_sleep):
if should_sleep:
comptime.sleep(sleep_time)
y = x * 2
return y + 3
start = time.time()
f(torch.randn(2), torch.randn(2), False)
total_no_sleep = time.time() - start
start = time.time()
f(torch.randn(2), torch.randn(2), True)
total_with_sleep = time.time() - start
self.assertTrue(total_with_sleep > sleep_time)
# Hopefully this won't be flaky
self.assertTrue(abs(total_with_sleep - sleep_time - total_no_sleep) < 3)
# Just make sure it doesn't crash
def test_get_local_closure_variable(self):
global SELF
SELF = self
cnt = torch._dynamo.testing.CompileCounter()
@torch._dynamo.optimize(cnt)
def f(x):
z = 3
def g():
@comptime
def _(ctx):
r = ctx.get_local("z")
SELF.assertEqual(repr(r), "3")
comptime.print(z)
return 2
y = x * g()
return y + 3
f(torch.randn(2))
def test_print_bt(self):
global FILE
FILE = StringIO()
cnt = torch._dynamo.testing.CompileCounter()
def g(x):
@comptime
def _(ctx):
ctx.print_bt(file=FILE)
comptime.print_bt()
return x + 3
@torch._dynamo.optimize(cnt)
def f(x):
y = x * 2
y = g(y)
return y + 3
def munge_filenames(s):
return re.sub(r'File "[^"]+", line \d+', 'File "X", line X', s)
f(torch.randn(2))
self.assertEqual(cnt.frame_count, 1)
bt = FILE.getvalue()
self.assertIn("y = g(y)", bt)
def test_print_guards(self):
global FILE
FILE = StringIO()
cnt = torch._dynamo.testing.CompileCounter()
@torch._dynamo.optimize(cnt)
def f(x):
y = x * 2
@comptime
def _(ctx):
ctx.print_guards(file=FILE)
comptime.print_guards()
return y + 3
f(torch.randn(2))
self.assertEqual(cnt.frame_count, 1)
self.assertExpectedInline(
re.sub(r"\s+$", "", FILE.getvalue().rstrip(), flags=re.MULTILINE),
"""\
local "L['x']" TENSOR_MATCH
{
'guard_types': None,
'code': None,
'obj_weakref': None
'guarded_class': None
}
global '' GRAD_MODE
{
'guard_types': None,
'code': None,
'obj_weakref': None
'guarded_class': None
}
global '' DETERMINISTIC_ALGORITHMS
{
'guard_types': None,
'code': None,
'obj_weakref': None
'guarded_class': None
}
global '' TORCH_FUNCTION_STATE
{
'guard_types': None,
'code': None,
'obj_weakref': None
'guarded_class': None
}
global '' DEFAULT_DEVICE
{
'guard_types': None,
'code': None,
'obj_weakref': None
'guarded_class': None
}
shape_env '' SHAPE_ENV
{
'guard_types': None,
'code': None,
'obj_weakref': None
'guarded_class': None
}""",
)
def test_graph_break(self):
cnt = torch._dynamo.testing.CompileCounter()
@torch._dynamo.optimize(cnt)
def f(x):
y = x * 2
@comptime
def _(ctx):
pass
return y + 3
f(torch.randn(2))
self.assertEqual(cnt.frame_count, 1)
cnt.frame_count = 0
@torch._dynamo.optimize(cnt)
def g(x):
y = x * 2
@comptime
def _(ctx):
ctx.graph_break()
y = y + 2
comptime.graph_break()
return y * 3
g(torch.randn(2))
self.assertEqual(cnt.frame_count, 3)
def test_get_local(self):
global SELF, FILE
SELF = self
FILE = StringIO()
cnt = torch._dynamo.testing.CompileCounter()
@torch._dynamo.optimize(cnt)
def f(x):
y = x * 2
lit = 2
@comptime
def _(ctx):
y = ctx.get_local("y")
SELF.assertEqual(y.as_fake().size(0), 2)
SELF.assertEqual(y.size(0), 2)
# Trigger a graph write (TODO: this is not so
# useful right now as there's no way to make use
# of the output proxy; maybe it's useful for inserting
# side-effectful operations into the graph)
y.as_proxy() + 4
ctx.print_graph(verbose=False, file=FILE)
SELF.assertIs(y.python_type(), torch.Tensor)
lit = ctx.get_local("lit")
SELF.assertEqual(lit.as_python_constant(), 2)
return y + 3
f(torch.randn(2))
self.assertEqual(cnt.frame_count, 1)
self.assertExpectedInline(
FILE.getvalue().strip(),
"""\
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
y = l_x_ * 2; l_x_ = None
add = y + 4; y = add = None""",
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()