| # Owner(s): ["module: ci"] |
| |
| import subprocess |
| import sys |
| import unittest |
| import pathlib |
| |
| from torch.testing._internal.common_utils import TestCase, run_tests, IS_LINUX, IS_IN_CI |
| |
| |
| REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent |
| |
| try: |
| # Just in case PyTorch was not built in 'develop' mode |
| sys.path.append(str(REPO_ROOT)) |
| from tools.stats.scribe import rds_write, register_rds_schema |
| except ImportError: |
| register_rds_schema = None |
| rds_write = None |
| |
| |
| # these tests could eventually be changed to fail if the import/init |
| # time is greater than a certain threshold, but for now we just use them |
| # as a way to track the duration of `import torch` in our ossci-metrics |
| # S3 bucket (see tools/stats/print_test_stats.py) |
| class TestImportTime(TestCase): |
| def test_time_import_torch(self): |
| TestCase.runWithPytorchAPIUsageStderr("import torch") |
| |
| def test_time_cuda_device_count(self): |
| TestCase.runWithPytorchAPIUsageStderr( |
| "import torch; torch.cuda.device_count()", |
| ) |
| |
| @unittest.skipIf(not IS_LINUX, "Memory test is only implemented for Linux") |
| @unittest.skipIf(not IS_IN_CI, "Memory test only runs in CI") |
| @unittest.skipIf(rds_write is None, "Cannot import rds_write from tools.stats.scribe") |
| def test_peak_memory(self): |
| def profile(module, name): |
| command = f"import {module}; import resource; print(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)" |
| result = subprocess.run( |
| [sys.executable, "-c", command], |
| stdout=subprocess.PIPE, |
| ) |
| max_rss = int(result.stdout.decode().strip()) |
| |
| return { |
| "test_name": name, |
| "peak_memory_bytes": max_rss, |
| } |
| |
| data = profile("torch", "pytorch") |
| baseline = profile("sys", "baseline") |
| try: |
| rds_write("import_stats", [data, baseline]) |
| except Exception as e: |
| raise unittest.SkipTest(f"Failed to record import_stats: {e}") |
| |
| |
| if __name__ == "__main__": |
| if register_rds_schema and IS_IN_CI: |
| try: |
| register_rds_schema( |
| "import_stats", |
| { |
| "test_name": "string", |
| "peak_memory_bytes": "int", |
| "time_ms": "int", |
| }, |
| ) |
| except Exception as e: |
| print(f"Failed to register RDS schema: {e}") |
| |
| run_tests() |