blob: f28855c1ae2548c3185109392000102bbdd5b7a6 [file] [log] [blame] [edit]
# Owner(s): ["module: dynamo"]
import inspect
import io
import os
import tempfile
from unittest.mock import patch
import torch
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.testing import CompileCounter
class ToyModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 10)
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(self.linear(x))
class InPlaceCompilationTests(TestCase):
def test_compilation(self):
torch._dynamo.reset()
model = ToyModel()
cnt = CompileCounter()
model.compile(backend=cnt)
x = torch.randn(10, 10)
model(x)
self.assertEqual(cnt.frame_count, 1)
def test_overwrite_call_impl(self):
torch._dynamo.reset()
model = ToyModel()
self.assertTrue(model._compiled_call_impl is None)
model.compile()
self.assertTrue(model._compiled_call_impl is not None)
def test_save(self):
torch._dynamo.reset()
model = ToyModel()
model.compile()
model(torch.randn(1, 10))
with tempfile.TemporaryDirectory() as tmpdirname:
torch.save(model, os.path.join(tmpdirname, "model.pt"))
loaded_model = torch.load(os.path.join(tmpdirname, "model.pt"))
loaded_model(torch.randn(1, 10))
def test_state_dict_save(self):
torch._dynamo.reset()
model = ToyModel()
model.compile()
model(torch.randn(1, 10))
with tempfile.TemporaryDirectory() as tmpdirname:
torch.save(model.state_dict(), os.path.join(tmpdirname, "model.pt"))
loaded_model = ToyModel()
loaded_model.load_state_dict(
torch.load(os.path.join(tmpdirname, "model.pt"))
)
loaded_model(torch.randn(1, 10))
def test_jit_save(self):
torch._dynamo.reset()
model = ToyModel()
model.compile()
model(torch.randn(1, 10))
scripted_model = torch.jit.script(model)
with tempfile.TemporaryDirectory() as tmpdirname:
torch.jit.save(scripted_model, os.path.join(tmpdirname, "model.pt"))
loaded_model = torch.jit.load(os.path.join(tmpdirname, "model.pt"))
loaded_model(torch.randn(1, 10))
def test_compilation_callback(self):
torch._dynamo.reset()
@torch._dynamo.on_compile_start
def start_callback():
print("Compilation started.")
@torch._dynamo.on_compile_end
def end_callback():
print("Compilation ended.")
mod = ToyModel()
x = torch.randn(10, 10)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
opt_mod = torch.compile(backend="eager", fullgraph=True)(mod)
opt_mod(x)
printed_output = mock_stdout.getvalue().strip()
self.assertEqual(printed_output, "Compilation started.\nCompilation ended.")
def test_compile_eager_options(self):
@torch.compile(backend="eager", options={"foo": 2})
def f(x):
return x + x
f(torch.randn(3))
@torch.compile(backend="aot_eager", options={"foo": 2})
def g(x):
return x + x
g(torch.randn(3))
def test_compilation_callback_with_graph_break(self):
torch._dynamo.reset()
counter = 0
@torch._dynamo.on_compile_start
def start_callback():
nonlocal counter
counter += 1
print(f"Counter = {counter}")
@torch._dynamo.on_compile_end
def end_callback():
nonlocal counter
counter += 1
print(f"Counter = {counter}")
@torch.compile(backend="eager")
def fn(x):
x = x + 1
torch._dynamo.graph_break()
return torch.sin(x)
x = torch.randn(10, 10)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
fn(x)
printed_output = mock_stdout.getvalue().strip()
self.assertEqual(
printed_output, "Counter = 1\nCounter = 2\nCounter = 3\nCounter = 4"
)
# The private variants of the below functions are extensively tested
# So as long as the signatures match we're good
class PublicTorchCompilerTests(TestCase):
def check_signature(self, public_fn_name, private_fn_name, private_namespace):
public_fn = getattr(torch.compiler, public_fn_name)
private_fn = getattr(private_namespace, private_fn_name)
public_sig = inspect.signature(public_fn)
private_sig = inspect.signature(private_fn)
self.assertEqual(
public_sig,
private_sig,
f"Signatures do not match for function {public_fn_name}() \n Public: {public_sig} \n Private: {private_sig}",
)
def test_dynamo_signatures(self):
function_names = [
"reset",
"allow_in_graph",
"list_backends",
"assume_constant_result",
"disable",
]
for fn_name in function_names:
self.check_signature(fn_name, fn_name, torch._dynamo)
if __name__ == "__main__":
run_tests()