| # Owner(s): ["oncall: jit"] |
| |
| import os |
| import sys |
| |
| import torch |
| |
| # Make the helper files in test/ importable |
| pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) |
| sys.path.append(pytorch_test_dir) |
| from torch.testing._internal.jit_utils import JitTestCase |
| |
| if __name__ == '__main__': |
| raise RuntimeError("This test file is not meant to be run directly, use:\n\n" |
| "\tpython test/test_jit.py TESTNAME\n\n" |
| "instead.") |
| |
| class TestLogging(JitTestCase): |
| def test_bump_numeric_counter(self): |
| class ModuleThatLogs(torch.jit.ScriptModule): |
| @torch.jit.script_method |
| def forward(self, x): |
| for i in range(x.size(0)): |
| x += 1.0 |
| torch.jit._logging.add_stat_value('foo', 1) |
| |
| if bool(x.sum() > 0.0): |
| torch.jit._logging.add_stat_value('positive', 1) |
| else: |
| torch.jit._logging.add_stat_value('negative', 1) |
| return x |
| |
| logger = torch.jit._logging.LockingLogger() |
| old_logger = torch.jit._logging.set_logger(logger) |
| try: |
| |
| mtl = ModuleThatLogs() |
| for i in range(5): |
| mtl(torch.rand(3, 4, 5)) |
| |
| self.assertEqual(logger.get_counter_val('foo'), 15) |
| self.assertEqual(logger.get_counter_val('positive'), 5) |
| finally: |
| torch.jit._logging.set_logger(old_logger) |
| |
| def test_trace_numeric_counter(self): |
| def foo(x): |
| torch.jit._logging.add_stat_value('foo', 1) |
| return x + 1.0 |
| |
| traced = torch.jit.trace(foo, torch.rand(3, 4)) |
| logger = torch.jit._logging.LockingLogger() |
| old_logger = torch.jit._logging.set_logger(logger) |
| try: |
| traced(torch.rand(3, 4)) |
| |
| self.assertEqual(logger.get_counter_val('foo'), 1) |
| finally: |
| torch.jit._logging.set_logger(old_logger) |
| |
| def test_time_measurement_counter(self): |
| class ModuleThatTimes(torch.jit.ScriptModule): |
| def forward(self, x): |
| tp_start = torch.jit._logging.time_point() |
| for i in range(30): |
| x += 1.0 |
| tp_end = torch.jit._logging.time_point() |
| torch.jit._logging.add_stat_value('mytimer', tp_end - tp_start) |
| return x |
| |
| mtm = ModuleThatTimes() |
| logger = torch.jit._logging.LockingLogger() |
| old_logger = torch.jit._logging.set_logger(logger) |
| try: |
| mtm(torch.rand(3, 4)) |
| self.assertGreater(logger.get_counter_val('mytimer'), 0) |
| finally: |
| torch.jit._logging.set_logger(old_logger) |
| |
| def test_time_measurement_counter_script(self): |
| class ModuleThatTimes(torch.jit.ScriptModule): |
| @torch.jit.script_method |
| def forward(self, x): |
| tp_start = torch.jit._logging.time_point() |
| for i in range(30): |
| x += 1.0 |
| tp_end = torch.jit._logging.time_point() |
| torch.jit._logging.add_stat_value('mytimer', tp_end - tp_start) |
| return x |
| |
| mtm = ModuleThatTimes() |
| logger = torch.jit._logging.LockingLogger() |
| old_logger = torch.jit._logging.set_logger(logger) |
| try: |
| mtm(torch.rand(3, 4)) |
| self.assertGreater(logger.get_counter_val('mytimer'), 0) |
| finally: |
| torch.jit._logging.set_logger(old_logger) |
| |
| def test_counter_aggregation(self): |
| def foo(x): |
| for i in range(3): |
| torch.jit._logging.add_stat_value('foo', 1) |
| return x + 1.0 |
| |
| traced = torch.jit.trace(foo, torch.rand(3, 4)) |
| logger = torch.jit._logging.LockingLogger() |
| logger.set_aggregation_type('foo', torch.jit._logging.AggregationType.AVG) |
| old_logger = torch.jit._logging.set_logger(logger) |
| try: |
| traced(torch.rand(3, 4)) |
| |
| self.assertEqual(logger.get_counter_val('foo'), 1) |
| finally: |
| torch.jit._logging.set_logger(old_logger) |
| |
| def test_logging_levels_set(self): |
| torch._C._jit_set_logging_option('foo') |
| self.assertEqual('foo', torch._C._jit_get_logging_option()) |