| # Owner(s): ["module: unknown"] |
| |
| import torch |
| from torch.testing._internal.common_utils import run_tests, TemporaryFileName, TestCase |
| from torch.utils import ThroughputBenchmark |
| |
| |
| class TwoLayerNet(torch.jit.ScriptModule): |
| def __init__(self, D_in, H, D_out): |
| super().__init__() |
| self.linear1 = torch.nn.Linear(D_in, H) |
| self.linear2 = torch.nn.Linear(2 * H, D_out) |
| |
| @torch.jit.script_method |
| def forward(self, x1, x2): |
| h1_relu = self.linear1(x1).clamp(min=0) |
| h2_relu = self.linear1(x2).clamp(min=0) |
| cat = torch.cat((h1_relu, h2_relu), 1) |
| y_pred = self.linear2(cat) |
| return y_pred |
| |
| |
| class TwoLayerNetModule(torch.nn.Module): |
| def __init__(self, D_in, H, D_out): |
| super().__init__() |
| self.linear1 = torch.nn.Linear(D_in, H) |
| self.linear2 = torch.nn.Linear(2 * H, D_out) |
| |
| def forward(self, x1, x2): |
| h1_relu = self.linear1(x1).clamp(min=0) |
| h2_relu = self.linear1(x2).clamp(min=0) |
| cat = torch.cat((h1_relu, h2_relu), 1) |
| y_pred = self.linear2(cat) |
| return y_pred |
| |
| |
| class TestThroughputBenchmark(TestCase): |
| def linear_test(self, Module, profiler_output_path=""): |
| D_in = 10 |
| H = 5 |
| D_out = 15 |
| B = 8 |
| NUM_INPUTS = 2 |
| |
| module = Module(D_in, H, D_out) |
| |
| inputs = [] |
| |
| for i in range(NUM_INPUTS): |
| inputs.append([torch.randn(B, D_in), torch.randn(B, D_in)]) |
| bench = ThroughputBenchmark(module) |
| |
| for input in inputs: |
| # can do both args and kwargs here |
| bench.add_input(input[0], x2=input[1]) |
| |
| for i in range(NUM_INPUTS): |
| # or just unpack the list of inputs |
| module_result = module(*inputs[i]) |
| bench_result = bench.run_once(*inputs[i]) |
| torch.testing.assert_close(bench_result, module_result) |
| |
| stats = bench.benchmark( |
| num_calling_threads=4, |
| num_warmup_iters=100, |
| num_iters=1000, |
| profiler_output_path=profiler_output_path, |
| ) |
| |
| print(stats) |
| |
| def test_script_module(self): |
| self.linear_test(TwoLayerNet) |
| |
| def test_module(self): |
| self.linear_test(TwoLayerNetModule) |
| |
| def test_profiling(self): |
| with TemporaryFileName() as fname: |
| self.linear_test(TwoLayerNetModule, profiler_output_path=fname) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |