| # Owner(s): ["oncall: r2p"] |
| |
| import tempfile |
| import time |
| |
| from datetime import datetime, timedelta |
| |
| from torch.monitor import ( |
| Aggregation, |
| Event, |
| log_event, |
| register_event_handler, |
| Stat, |
| TensorboardEventHandler, |
| unregister_event_handler, |
| _WaitCounter, |
| ) |
| from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase |
| |
| class TestMonitor(TestCase): |
| def test_interval_stat(self) -> None: |
| events = [] |
| |
| def handler(event): |
| events.append(event) |
| |
| handle = register_event_handler(handler) |
| s = Stat( |
| "asdf", |
| (Aggregation.SUM, Aggregation.COUNT), |
| timedelta(milliseconds=1), |
| ) |
| self.assertEqual(s.name, "asdf") |
| |
| s.add(2) |
| for _ in range(100): |
| # NOTE: different platforms sleep may be inaccurate so we loop |
| # instead (i.e. win) |
| time.sleep(1 / 1000) # ms |
| s.add(3) |
| if len(events) >= 1: |
| break |
| self.assertGreaterEqual(len(events), 1) |
| unregister_event_handler(handle) |
| |
| def test_fixed_count_stat(self) -> None: |
| s = Stat( |
| "asdf", |
| (Aggregation.SUM, Aggregation.COUNT), |
| timedelta(hours=100), |
| 3, |
| ) |
| s.add(1) |
| s.add(2) |
| name = s.name |
| self.assertEqual(name, "asdf") |
| self.assertEqual(s.count, 2) |
| s.add(3) |
| self.assertEqual(s.count, 0) |
| self.assertEqual(s.get(), {Aggregation.SUM: 6.0, Aggregation.COUNT: 3}) |
| |
| def test_log_event(self) -> None: |
| e = Event( |
| name="torch.monitor.TestEvent", |
| timestamp=datetime.now(), |
| data={ |
| "str": "a string", |
| "float": 1234.0, |
| "int": 1234, |
| }, |
| ) |
| self.assertEqual(e.name, "torch.monitor.TestEvent") |
| self.assertIsNotNone(e.timestamp) |
| self.assertIsNotNone(e.data) |
| log_event(e) |
| |
| @skipIfTorchDynamo("Really weird error") |
| def test_event_handler(self) -> None: |
| events = [] |
| |
| def handler(event: Event) -> None: |
| events.append(event) |
| |
| handle = register_event_handler(handler) |
| e = Event( |
| name="torch.monitor.TestEvent", |
| timestamp=datetime.now(), |
| data={}, |
| ) |
| log_event(e) |
| self.assertEqual(len(events), 1) |
| self.assertEqual(events[0], e) |
| log_event(e) |
| self.assertEqual(len(events), 2) |
| |
| unregister_event_handler(handle) |
| log_event(e) |
| self.assertEqual(len(events), 2) |
| |
| def test_wait_counter(self) -> None: |
| wait_counter = _WaitCounter( |
| "test_wait_counter", |
| ) |
| with wait_counter.guard() as wcg: |
| pass |
| |
| |
| @skipIfTorchDynamo("Really weird error") |
| class TestMonitorTensorboard(TestCase): |
| def setUp(self): |
| global SummaryWriter, event_multiplexer |
| try: |
| from torch.utils.tensorboard import SummaryWriter |
| from tensorboard.backend.event_processing import ( |
| plugin_event_multiplexer as event_multiplexer, |
| ) |
| except ImportError: |
| return self.skipTest("Skip the test since TensorBoard is not installed") |
| self.temp_dirs = [] |
| |
| def create_summary_writer(self): |
| temp_dir = tempfile.TemporaryDirectory() # noqa: P201 |
| self.temp_dirs.append(temp_dir) |
| return SummaryWriter(temp_dir.name) |
| |
| def tearDown(self): |
| # Remove directories created by SummaryWriter |
| for temp_dir in self.temp_dirs: |
| temp_dir.cleanup() |
| |
| def test_event_handler(self): |
| with self.create_summary_writer() as w: |
| handle = register_event_handler(TensorboardEventHandler(w)) |
| |
| s = Stat( |
| "asdf", |
| (Aggregation.SUM, Aggregation.COUNT), |
| timedelta(hours=1), |
| 5, |
| ) |
| for i in range(10): |
| s.add(i) |
| self.assertEqual(s.count, 0) |
| |
| unregister_event_handler(handle) |
| |
| mul = event_multiplexer.EventMultiplexer() |
| mul.AddRunsFromDirectory(self.temp_dirs[-1].name) |
| mul.Reload() |
| scalar_dict = mul.PluginRunToTagToContent("scalars") |
| raw_result = { |
| tag: mul.Tensors(run, tag) |
| for run, run_dict in scalar_dict.items() |
| for tag in run_dict |
| } |
| scalars = { |
| tag: [e.tensor_proto.float_val[0] for e in events] for tag, events in raw_result.items() |
| } |
| self.assertEqual(scalars, { |
| "asdf.sum": [10], |
| "asdf.count": [5], |
| }) |
| |
| |
| if __name__ == '__main__': |
| run_tests() |