| """File invoked through subprocess to actually carry out measurements. |
| |
| `worker/main.py` is deliberately isolated from the rest of the benchmark |
| infrastructure. Other parts of the benchmark rely on this file, but |
| `worker/` has only one Python file and does not import ANYTHING from the rest |
| of the benchmark suite. The reason that this is important is that we can't |
| rely on paths to access the other files (namely `core.api`) since a source |
| command might change the CWD. It also helps keep startup time down by limiting |
| spurious definition work. |
| |
| The life of a worker is very simple: |
| It receives a file containing a `WorkerTimerArgs` telling it what to run, |
| and writes a `WorkerOutput` result back to the same file. |
| |
| Because this file only expects to run in a child context, error handling means |
| plumbing failures up to the caller, not raising in this process. |
| """ |
| import argparse |
| import dataclasses |
| import io |
| import os |
| import pickle |
| import sys |
| import timeit |
| import traceback |
| from typing import Any, Tuple, TYPE_CHECKING, Union |
| |
| |
| if TYPE_CHECKING: |
| # Benchmark utils are only partially strict compliant, so MyPy won't follow |
| # imports using the public namespace. (Due to an exclusion rule in |
| # mypy-strict.ini) |
| from torch.utils.benchmark.utils.timer import Language, Timer |
| from torch.utils.benchmark.utils.valgrind_wrapper.timer_interface import ( |
| CallgrindStats, |
| ) |
| |
| else: |
| from torch.utils.benchmark import CallgrindStats, Language, Timer |
| |
| |
| WORKER_PATH = os.path.abspath(__file__) |
| |
| |
| # ============================================================================= |
| # == Interface ================================================================ |
| # ============================================================================= |
| |
| # While the point of this is mainly to collect instruction counts, we're going |
| # to have to compile C++ timers anyway (as they're used as a check before |
| # calling Valgrind), so we may as well grab wall times for reference. They |
| # are comparatively inexpensive. |
| MIN_RUN_TIME = 5 |
| |
| # Repeats are inexpensive as long as they are all run in the same process. This |
| # also lets us filter outliers (e.g. malloc arena reorganization), so we don't |
| # need a high CALLGRIND_NUMBER to get good data. |
| CALLGRIND_NUMBER = 100 |
| CALLGRIND_REPEATS = 5 |
| |
| |
| @dataclasses.dataclass(frozen=True) |
| class WorkerTimerArgs: |
| """Container for Timer constructor arguments. |
| |
| This dataclass serves two roles. First, it is a simple interface for |
| defining benchmarks. (See core.api.GroupedStmts and core.api.GroupedModules |
| for the advanced interfaces.) Second, it provides serialization for |
| controlling workers. `Timer` is not pickleable, so instead the main process |
| will pass `WorkerTimerArgs` instances to workers for processing. |
| """ |
| |
| stmt: str |
| setup: str = "pass" |
| global_setup: str = "" |
| num_threads: int = 1 |
| language: Language = Language.PYTHON |
| |
| |
| @dataclasses.dataclass(frozen=True) |
| class WorkerOutput: |
| # Only return values to reduce communication between main process and workers. |
| wall_times: Tuple[float, ...] |
| instructions: Tuple[int, ...] |
| |
| |
| @dataclasses.dataclass(frozen=True) |
| class WorkerFailure: |
| # If a worker fails, we attach the string contents of the Exception |
| # rather than the Exception object itself. This is done for two reasons: |
| # 1) Depending on the type thrown, `e` may or may not be pickleable |
| # 2) If we re-throw in the main process, we lose the true stack trace. |
| failure_trace: str |
| |
| |
| class WorkerUnpickler(pickle.Unpickler): |
| def find_class(self, module: str, name: str) -> Any: |
| """Resolve import for pickle. |
| |
| When the main runner uses a symbol `foo` from this file, it sees it as |
| `worker.main.foo`. However the worker (called as a standalone file) |
| sees the same symbol as `__main__.foo`. We have to help pickle |
| understand that they refer to the same symbols. |
| """ |
| symbol_map = { |
| # Only blessed interface Enums and dataclasses need to be mapped. |
| "WorkerTimerArgs": WorkerTimerArgs, |
| "WorkerOutput": WorkerOutput, |
| "WorkerFailure": WorkerFailure, |
| } |
| |
| if name in symbol_map: |
| return symbol_map[name] |
| |
| return super().find_class(module, name) |
| |
| def load_input(self) -> WorkerTimerArgs: |
| result = self.load() |
| assert isinstance(result, WorkerTimerArgs) |
| return result |
| |
| def load_output(self) -> Union[WorkerTimerArgs, WorkerOutput, WorkerFailure]: |
| """Convenience method for type safe loading.""" |
| result = self.load() |
| assert isinstance(result, (WorkerTimerArgs, WorkerOutput, WorkerFailure)) |
| return result |
| |
| |
| # ============================================================================= |
| # == Execution ================================================================ |
| # ============================================================================= |
| |
| |
| def _run(timer_args: WorkerTimerArgs) -> WorkerOutput: |
| timer = Timer( |
| stmt=timer_args.stmt, |
| setup=timer_args.setup or "pass", |
| global_setup=timer_args.global_setup, |
| # Prevent NotImplementedError on GPU builds and C++ snippets. |
| timer=timeit.default_timer, |
| num_threads=timer_args.num_threads, |
| language=timer_args.language, |
| ) |
| |
| m = timer.blocked_autorange(min_run_time=MIN_RUN_TIME) |
| |
| stats: Tuple[CallgrindStats, ...] = timer.collect_callgrind( |
| number=CALLGRIND_NUMBER, |
| collect_baseline=False, |
| repeats=CALLGRIND_REPEATS, |
| retain_out_file=False, |
| ) |
| |
| return WorkerOutput( |
| wall_times=tuple(m.times), |
| instructions=tuple(s.counts(denoise=True) for s in stats), |
| ) |
| |
| |
| def main(communication_file: str) -> None: |
| result: Union[WorkerOutput, WorkerFailure] |
| try: |
| with open(communication_file, "rb") as f: |
| timer_args: WorkerTimerArgs = WorkerUnpickler(f).load_input() |
| assert isinstance(timer_args, WorkerTimerArgs) |
| result = _run(timer_args) |
| |
| except KeyboardInterrupt: |
| # Runner process sent SIGINT. |
| sys.exit() |
| |
| except BaseException: |
| trace_f = io.StringIO() |
| traceback.print_exc(file=trace_f) |
| result = WorkerFailure(failure_trace=trace_f.getvalue()) |
| |
| if not os.path.exists(os.path.split(communication_file)[0]): |
| # This worker is an orphan, and the parent has already cleaned up the |
| # working directory. In that case we can simply exit. |
| print(f"Orphaned worker {os.getpid()} exiting.") |
| return |
| |
| with open(communication_file, "wb") as f: |
| pickle.dump(result, f) |
| |
| |
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--communication-file", "--communication_file", type=str) |
| communication_file = parser.parse_args().communication_file |
| main(communication_file) |