| # Owner(s): ["module: dynamo"] |
| |
| import torch |
| import torch._dynamo as torchdynamo |
| from torch.testing._internal.common_utils import TestCase, run_tests, TEST_CUDA |
| import unittest |
| |
| try: |
| import tabulate # noqa: F401 # type: ignore[import] |
| from torch.utils.benchmark.utils.compile import bench_all |
| HAS_TABULATE = True |
| except ImportError: |
| HAS_TABULATE = False |
| |
| @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") |
| @unittest.skipIf(not HAS_TABULATE, "tabulate not available") |
| class TestCompileBenchmarkUtil(TestCase): |
| def test_training_and_inference(self): |
| class ToyModel(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.weight = torch.nn.Parameter(torch.Tensor(2, 2)) |
| |
| def forward(self, x): |
| return x * self.weight |
| |
| torchdynamo.reset() |
| model = ToyModel().cuda() |
| |
| inference_table = bench_all(model, torch.ones(1024, 2, 2).cuda(), 5) |
| self.assertTrue("Inference" in inference_table and "Eager" in inference_table and "-" in inference_table) |
| |
| training_table = bench_all(model, torch.ones(1024, 2, 2).cuda(), 5, optimizer=torch.optim.SGD(model.parameters(), lr=0.01)) |
| self.assertTrue("Train" in training_table and "Eager" in training_table and "-" in training_table) |
| |
| if __name__ == '__main__': |
| run_tests() |