blob: bbefffadeae561a78852558a15966f9ff71f8f22 [file] [log] [blame]
# Owner(s): ["oncall: jit"]
import os
import sys
import torch
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class TestLogging(JitTestCase):
def test_bump_numeric_counter(self):
class ModuleThatLogs(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
for i in range(x.size(0)):
x += 1.0
torch.jit._logging.add_stat_value('foo', 1)
if bool(x.sum() > 0.0):
torch.jit._logging.add_stat_value('positive', 1)
else:
torch.jit._logging.add_stat_value('negative', 1)
return x
logger = torch.jit._logging.LockingLogger()
old_logger = torch.jit._logging.set_logger(logger)
try:
mtl = ModuleThatLogs()
for i in range(5):
mtl(torch.rand(3, 4, 5))
self.assertEqual(logger.get_counter_val('foo'), 15)
self.assertEqual(logger.get_counter_val('positive'), 5)
finally:
torch.jit._logging.set_logger(old_logger)
def test_trace_numeric_counter(self):
def foo(x):
torch.jit._logging.add_stat_value('foo', 1)
return x + 1.0
traced = torch.jit.trace(foo, torch.rand(3, 4))
logger = torch.jit._logging.LockingLogger()
old_logger = torch.jit._logging.set_logger(logger)
try:
traced(torch.rand(3, 4))
self.assertEqual(logger.get_counter_val('foo'), 1)
finally:
torch.jit._logging.set_logger(old_logger)
def test_time_measurement_counter(self):
class ModuleThatTimes(torch.jit.ScriptModule):
def forward(self, x):
tp_start = torch.jit._logging.time_point()
for i in range(30):
x += 1.0
tp_end = torch.jit._logging.time_point()
torch.jit._logging.add_stat_value('mytimer', tp_end - tp_start)
return x
mtm = ModuleThatTimes()
logger = torch.jit._logging.LockingLogger()
old_logger = torch.jit._logging.set_logger(logger)
try:
mtm(torch.rand(3, 4))
self.assertGreater(logger.get_counter_val('mytimer'), 0)
finally:
torch.jit._logging.set_logger(old_logger)
def test_time_measurement_counter_script(self):
class ModuleThatTimes(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
tp_start = torch.jit._logging.time_point()
for i in range(30):
x += 1.0
tp_end = torch.jit._logging.time_point()
torch.jit._logging.add_stat_value('mytimer', tp_end - tp_start)
return x
mtm = ModuleThatTimes()
logger = torch.jit._logging.LockingLogger()
old_logger = torch.jit._logging.set_logger(logger)
try:
mtm(torch.rand(3, 4))
self.assertGreater(logger.get_counter_val('mytimer'), 0)
finally:
torch.jit._logging.set_logger(old_logger)
def test_counter_aggregation(self):
def foo(x):
for i in range(3):
torch.jit._logging.add_stat_value('foo', 1)
return x + 1.0
traced = torch.jit.trace(foo, torch.rand(3, 4))
logger = torch.jit._logging.LockingLogger()
logger.set_aggregation_type('foo', torch.jit._logging.AggregationType.AVG)
old_logger = torch.jit._logging.set_logger(logger)
try:
traced(torch.rand(3, 4))
self.assertEqual(logger.get_counter_val('foo'), 1)
finally:
torch.jit._logging.set_logger(old_logger)
def test_logging_levels_set(self):
torch._C._jit_set_logging_option('foo')
self.assertEqual('foo', torch._C._jit_get_logging_option())