blob: 8c456842c20d0f922124045212d1caa50e21527f [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import unittest
import weakref
import torch
import torch._dynamo
import torch._dynamo.config
import torch._dynamo.test_case
import torch._dynamo.testing
import torch._logging
from torch.testing._internal.logging_utils import kwargs_to_settings, log_settings
class RecompileUxTests(torch._dynamo.test_case.TestCase):
# TODO(whc) dynamo actually recompiles one more time than the cache limit
cache_limit = 1
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._exit_stack.enter_context(
torch._dynamo.config.patch("cache_size_limit", cls.cache_limit)
)
def test_drop_cache_on_skip(self):
def model(x, i):
return x + i
attached = False
triggered = False
def trigger():
nonlocal triggered
triggered = True
def compiler(gm, input):
nonlocal attached
f = gm.forward
assert not attached
# NB: making this a weakref.ref causes the cycle to no
# longer be promptly GC'ed
weakref.finalize(f, trigger)
attached = True
return f
x = torch.randn(2)
for i in range(2):
opt_model = torch._dynamo.optimize(compiler)(model)
opt_model(x, i)
self.assertTrue(triggered)
def test_loop_torture(self):
def loop_torture(input, iters):
out = input
# randint itself causes one graph break
for _ in range(iters):
out += input
return out
compile_counter = torch._dynamo.testing.CompileCounter()
for _ in range(10):
x = torch.randn(3)
iters = torch.randint(low=0, high=1000, size=())
opt_loop_torture = torch._dynamo.optimize(compile_counter)(loop_torture)
opt_loop_torture(x, iters)
# Currently, we recompile each time,
# We'd probably like to bail out quickly and warn
# TODO(whc) these checks fail on py37. Why?
# self.assertEqual(counters["frames"]["total"], 2 + self.cache_limit)
# self.assertEqual(counters["frames"]["ok"], 1 + self.cache_limit)
# compile_counter only sees frames that were fed to the backend compiler,
# which is a subset of counters["frames"]["ok"] -- probably because
# counters["frames"]["ok"] includes frames not containing torch ops?
self.assertEqual(compile_counter.frame_count, self.cache_limit)
@torch._dynamo.config.patch("automatic_dynamic_shapes", False)
def test_dynamic_input(self):
def model(input):
return input + input
expected_recompiles = 2
compile_counter = torch._dynamo.testing.CompileCounter()
with torch._dynamo.config.patch("cache_size_limit", expected_recompiles):
with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs:
for _ in range(10):
bsz = torch.randint(low=0, high=1000, size=())
x = torch.randn((bsz, 3, 4))
opt_model = torch._dynamo.optimize(compile_counter)(model)
opt_model(x)
self.assertEqual(compile_counter.frame_count, expected_recompiles)
self.assertEqual(len(logs.records), 1)
print(logs.records[0])
self.assertTrue(
logs.records[0]
.getMessage()
.startswith("torch._dynamo hit config.cache_size_limit")
)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_nvfuser_guards(self):
# we may want to model dynamo's guards sufficiently after nvfuser's ProfilingExecutor guards
# such that we ensure dynamo is in charge of all the recompilations at the top level,
# and we could thus simplify the underlying torchscript executor
def func(a, b, c):
return a + b * c
a = torch.rand(3, 4, 5, device="cuda")
b = torch.rand(3, 4, 5, device="cuda")
b_v = torch.rand(3, 5, 4, device="cuda").view(3, 4, 5)
b_p = torch.rand(3, 5, 4, device="cuda").permute(0, 2, 1)
c = torch.rand(3, 4, 5, device="cuda")
compile_counter = torch._dynamo.testing.CompileCounter()
with torch._dynamo.config.patch("cache_size_limit", 2):
opt_func = torch._dynamo.optimize(compile_counter)(func)
opt_func(a, b, c) # warmup
self.assertEqual(compile_counter.frame_count, 1)
opt_func(a, b, c) # no guard fail or recompile
self.assertEqual(compile_counter.frame_count, 1)
opt_func(a, b_v, c) # a view should not cause nvfuser recompile
self.assertEqual(compile_counter.frame_count, 1)
opt_func(a, b_p, c) # a permutation should cause recompile
self.assertEqual(compile_counter.frame_count, 2)
def assert_single_log_contains(self, logs, contains_str):
self.assertEqual(len(logs.records), 1)
self.assertTrue(
logs.records[0].getMessage().find(contains_str) > 0,
msg=f'Expected to find "{contains_str}" in log "{logs.records[0].getMessage()}"',
)
def test_verbose_tensor_check(self):
def func(a):
# Warning: choose a function here whose meta implementation lives
# entirely in C++. If you do a Python one, Dynamo will dive into
# torch._refs which is OK but it will muddy up the warnings
return torch.add(a, 4)
def cache_fail_test(cached_input, missed_input, expected_failure):
# TODO(whc) maybe its hacky to have a 'test within a test' but this seemed convenient
torch._dynamo.reset()
torch._dynamo.utils.counters.clear()
opt_func = torch._dynamo.optimize("eager")(func)
# warmup
opt_func(cached_input)
with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs:
opt_func = torch._dynamo.optimize("eager")(func)
opt_func(missed_input)
self.assert_single_log_contains(logs, expected_failure)
a = torch.rand(3, 4, 5)
cache_fail_test(
a,
a[0:2, :, :],
"tensor 'L['a']' size mismatch at index 0. expected 3, actual 2",
)
cache_fail_test(
a,
a.clone().as_strided((3, 4, 5), stride=(1, 3, 12)),
"tensor 'L['a']' stride mismatch at index 0. expected 20, actual 1",
)
cache_fail_test(
a, a[0, :, :], "tensor 'L['a']' rank mismatch. expected 3, actual 2"
)
cache_fail_test(a, a.to("meta"), "tensor 'L['a']' dispatch key set mismatch.")
cache_fail_test(
a,
a.to(torch.float16),
"tensor 'L['a']' dtype mismatch. expected Float, actual Half",
)
a_grad = a.clone()
a_grad.requires_grad = True
cache_fail_test(
a,
a_grad,
"tensor 'L['a']' requires_grad mismatch. expected requires_grad=0",
)
def test_mismatched_type(self):
a = torch.rand(3, 4, 5)
b = torch.rand(3, 4, 5)
def func(a, b):
return a + b
opt_func = torch._dynamo.optimize("eager")(func)
# warmup
opt_func(a, b)
with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs:
opt_func = torch._dynamo.optimize("eager")(func)
opt_func(a, 1)
self.assert_single_log_contains(
logs,
"expected type of 'L['b']' to be a tensor type, ' but found <class 'int'>",
)
@torch._dynamo.config.patch("cache_size_limit", 32)
def test_multiple_guard_fails(self):
failure_reasons = []
def guard_fail_fn(failure):
failure_reasons.append(failure[0])
def f(x):
return torch.relu(x)
opt_f = torch._dynamo.optimize(
backend="eager", guard_fail_fn=guard_fail_fn, dynamic=False
)(f)
for i in range(5):
failure_reasons.clear()
opt_f(torch.randn(8 + i))
failure_str = "\n".join(failure_reasons)
for line in """\
tensor 'L['x']' size mismatch at index 0. expected 11, actual 12
tensor 'L['x']' size mismatch at index 0. expected 10, actual 12
tensor 'L['x']' size mismatch at index 0. expected 9, actual 12
tensor 'L['x']' size mismatch at index 0. expected 8, actual 12""".split(
"\n"
):
self.assertIn(
line,
failure_str,
)
@torch._dynamo.config.patch("cache_size_limit", 32)
def test_multiple_guard_fails_report_all(self):
with log_settings(kwargs_to_settings(recompiles_verbose=True)):
failure_reasons = []
def guard_fail_fn(failure):
failure_reasons.append(failure[0])
def f(x):
return torch.ones(len(x), x[-1])
opt_f = torch._dynamo.optimize(
backend="eager", guard_fail_fn=guard_fail_fn, dynamic=False
)(f)
opt_f([4, 5, 6])
def filter_reasons():
return "\n".join(
[
line
for line in "\n".join(failure_reasons).splitlines()
if not line.startswith("___check_type_id")
]
)
failure_reasons.clear()
opt_f([7, 8])
for line in """\
len(L['x']) == 3""".split(
"\n"
):
self.assertIn(line, filter_reasons())
failure_reasons.clear()
opt_f([9])
for line in """\
len(L['x']) == 2
len(L['x']) == 3""".split(
"\n"
):
self.assertIn(line, filter_reasons())
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()