| # Owner(s): ["oncall: jit"] |
| |
| import os |
| import sys |
| import unittest |
| |
| import torch |
| |
| # as with test_jit tests, requires global dtype set |
| torch.set_default_dtype(torch.double) |
| |
| # Make the helper files in test/ importable |
| pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) |
| sys.path.append(pytorch_test_dir) |
| from torch.testing._internal.jit_utils import JitTestCase, enable_profiling_mode |
| from torch.testing._internal.jit_metaprogramming_utils import try_get_nn_module_compiled_mod_and_inputs, \ |
| get_nn_mod_test_name, get_all_nn_module_tests, nn_functional_tests, get_nn_functional_compiled_fn_and_inputs |
| from torch.testing._internal.common_utils import run_tests, suppress_warnings, IS_FBCODE |
| |
| |
| def num_ifs_loops(graph): |
| graph_str = str(graph) |
| # only look at body of graph |
| graph_body = graph_str[0:graph_str.find("return")] |
| return graph_body.count("prim::Loop") + graph_body.count("prim::If") |
| |
| def num_non_tensor_nodes(block): |
| num_non_tensor = 0 |
| for node in block.nodes(): |
| kind = node.kind() |
| # GetAttr don't provide useful signal here, since they are non-optimizable except with freezing |
| # Constant is not executed, bailouts should be a separate tests, don't provide useful signal here |
| if kind == "prim::Constant" or "prim::Bailout" in kind or "GetAttr" in kind: |
| continue |
| for b in node.blocks(): |
| num_non_tensor += num_non_tensor_nodes(b) |
| tensor_out = False |
| for out in node.outputs(): |
| if "Tensor" in str(out.type()): |
| tensor_out = True |
| break |
| num_non_tensor += int(not tensor_out) |
| return num_non_tensor |
| |
| class TestComplexity(JitTestCase): |
| def setUp(self): |
| super(TestComplexity, self).setUp() |
| self.grad_enabled = torch.is_grad_enabled() |
| torch.set_grad_enabled(False) |
| |
| def tearDown(self): |
| super(TestComplexity, self).tearDown() |
| torch.set_grad_enabled(self.grad_enabled) |
| |
| @suppress_warnings |
| def test_generated_functional_tests(self): |
| with enable_profiling_mode(): |
| stats = [("Name", "Ifs/Loops", "non-tensor ops")] |
| for test in nn_functional_tests: |
| test_name = test[0] |
| |
| fn, inputs = get_nn_functional_compiled_fn_and_inputs(*test) |
| for _ in range(6): |
| fn(*inputs) |
| |
| g = torch.jit.last_executed_optimized_graph() |
| stats.append((test_name, num_ifs_loops(g), num_non_tensor_nodes(g))) |
| for line in stats: |
| print(line) |
| |
| @suppress_warnings |
| @unittest.skipIf(IS_FBCODE, "Causes a RecursionError in fbcode") |
| def test_nn_module_tests(self): |
| with enable_profiling_mode(): |
| stats = [("Name", "Ifs/Loops", "non-tensor ops")] |
| for test in get_all_nn_module_tests(): |
| out = try_get_nn_module_compiled_mod_and_inputs(**test) |
| if not out: |
| continue |
| |
| mod, inputs = out |
| test_name = get_nn_mod_test_name(**test) |
| for _ in range(6): |
| mod(*inputs) |
| |
| g = torch.jit.last_executed_optimized_graph() |
| stats.append((test_name, num_ifs_loops(g), num_non_tensor_nodes(g))) |
| |
| for line in stats: |
| print(line) |
| |
| if __name__ == '__main__': |
| run_tests() |