| """Example of Timer and Compare APIs: |
| |
| $ python -m examples.compare |
| """ |
| |
| import pickle |
| import sys |
| import time |
| |
| import torch |
| |
| import torch.utils.benchmark as benchmark_utils |
| |
| |
| class FauxTorch: |
| """Emulate different versions of pytorch. |
| |
| In normal circumstances this would be done with multiple processes |
| writing serialized measurements, but this simplifies that model to |
| make the example clearer. |
| """ |
| def __init__(self, real_torch, extra_ns_per_element): |
| self._real_torch = real_torch |
| self._extra_ns_per_element = extra_ns_per_element |
| |
| def extra_overhead(self, result): |
| # time.sleep has a ~65 us overhead, so only fake a |
| # per-element overhead if numel is large enough. |
| numel = int(result.numel()) |
| if numel > 5000: |
| time.sleep(numel * self._extra_ns_per_element * 1e-9) |
| return result |
| |
| def add(self, *args, **kwargs): |
| return self.extra_overhead(self._real_torch.add(*args, **kwargs)) |
| |
| def mul(self, *args, **kwargs): |
| return self.extra_overhead(self._real_torch.mul(*args, **kwargs)) |
| |
| def cat(self, *args, **kwargs): |
| return self.extra_overhead(self._real_torch.cat(*args, **kwargs)) |
| |
| def matmul(self, *args, **kwargs): |
| return self.extra_overhead(self._real_torch.matmul(*args, **kwargs)) |
| |
| |
| def main(): |
| tasks = [ |
| ("add", "add", "torch.add(x, y)"), |
| ("add", "add (extra +0)", "torch.add(x, y + zero)"), |
| ] |
| |
| serialized_results = [] |
| repeats = 2 |
| timers = [ |
| benchmark_utils.Timer( |
| stmt=stmt, |
| globals={ |
| "torch": torch if branch == "master" else FauxTorch(torch, overhead_ns), |
| "x": torch.ones((size, 4)), |
| "y": torch.ones((1, 4)), |
| "zero": torch.zeros(()), |
| }, |
| label=label, |
| sub_label=sub_label, |
| description=f"size: {size}", |
| env=branch, |
| num_threads=num_threads, |
| ) |
| for branch, overhead_ns in [("master", None), ("my_branch", 1), ("severe_regression", 5)] |
| for label, sub_label, stmt in tasks |
| for size in [1, 10, 100, 1000, 10000, 50000] |
| for num_threads in [1, 4] |
| ] |
| |
| for i, timer in enumerate(timers * repeats): |
| serialized_results.append(pickle.dumps( |
| timer.blocked_autorange(min_run_time=0.05) |
| )) |
| print(f"\r{i + 1} / {len(timers) * repeats}", end="") |
| sys.stdout.flush() |
| print() |
| |
| comparison = benchmark_utils.Compare([ |
| pickle.loads(i) for i in serialized_results |
| ]) |
| |
| print("== Unformatted " + "=" * 80 + "\n" + "/" * 95 + "\n") |
| comparison.print() |
| |
| print("== Formatted " + "=" * 80 + "\n" + "/" * 93 + "\n") |
| comparison.trim_significant_figures() |
| comparison.colorize() |
| comparison.print() |
| |
| |
| if __name__ == "__main__": |
| main() |