| import argparse |
| import functools |
| import traceback |
| from typing import Callable, List, Optional, Tuple |
| |
| from torch.utils.jit.log_extract import ( |
| extract_ir, |
| load_graph_and_inputs, |
| run_baseline_no_fusion, |
| run_nnc, |
| run_nvfuser, |
| ) |
| |
| |
| """ |
| Usage: |
| 1. Run your script and pipe into a log file |
| PYTORCH_JIT_LOG_LEVEL=">>graph_fuser" python3 my_test.py &> log.txt |
| 2. Run log_extract: |
| log_extract.py log.txt --nvfuser --nnc-dynamic --nnc-static |
| |
| You can also extract the list of extracted IR: |
| log_extract.py log.txt --output |
| |
| Passing in --graphs 0 2 will only run graphs 0 and 2 |
| """ |
| |
| |
| def test_runners( |
| graphs: List[str], |
| runners: List[Tuple[str, Callable]], |
| graph_set: Optional[List[int]], |
| ): |
| for i, ir in enumerate(graphs): |
| _, inputs = load_graph_and_inputs(ir) |
| if graph_set and i not in graph_set: |
| continue |
| |
| print(f"Running Graph {i}") |
| prev_result = None |
| prev_runner_name = None |
| for runner in runners: |
| runner_name, runner_fn = runner |
| try: |
| result = runner_fn(ir, inputs) |
| if prev_result: |
| improvement = (prev_result / result - 1) * 100 |
| print( |
| f"{runner_name} : {result:.6f} ms improvement over {prev_runner_name}: improvement: {improvement:.2f}%" |
| ) |
| else: |
| print(f"{runner_name} : {result:.6f} ms") |
| prev_result = result |
| prev_runner_name = runner_name |
| except RuntimeError: |
| print(f" Graph {i} failed for {runner_name} :", traceback.format_exc()) |
| |
| |
| def run(): |
| parser = argparse.ArgumentParser( |
| description="Extracts torchscript IR from log files and, optionally, benchmarks it or outputs the IR" |
| ) |
| parser.add_argument("filename", help="Filename of log file") |
| parser.add_argument( |
| "--nvfuser", dest="nvfuser", action="store_true", help="benchmark nvfuser" |
| ) |
| parser.add_argument( |
| "--no-nvfuser", |
| dest="nvfuser", |
| action="store_false", |
| help="DON'T benchmark nvfuser", |
| ) |
| parser.set_defaults(nvfuser=False) |
| parser.add_argument( |
| "--nnc-static", |
| dest="nnc_static", |
| action="store_true", |
| help="benchmark nnc static", |
| ) |
| parser.add_argument( |
| "--no-nnc-static", |
| dest="nnc_static", |
| action="store_false", |
| help="DON'T benchmark nnc static", |
| ) |
| parser.set_defaults(nnc_static=False) |
| |
| parser.add_argument( |
| "--nnc-dynamic", |
| dest="nnc_dynamic", |
| action="store_true", |
| help="nnc with dynamic shapes", |
| ) |
| parser.add_argument( |
| "--no-nnc-dynamic", |
| dest="nnc_dynamic", |
| action="store_false", |
| help="DONT't benchmark nnc with dynamic shapes", |
| ) |
| parser.set_defaults(nnc_dynamic=False) |
| |
| parser.add_argument( |
| "--baseline", dest="baseline", action="store_true", help="benchmark baseline" |
| ) |
| parser.add_argument( |
| "--no-baseline", |
| dest="baseline", |
| action="store_false", |
| help="DON'T benchmark baseline", |
| ) |
| parser.set_defaults(baseline=False) |
| |
| parser.add_argument( |
| "--output", dest="output", action="store_true", help="Output graph IR" |
| ) |
| parser.add_argument( |
| "--no-output", dest="output", action="store_false", help="DON'T output graph IR" |
| ) |
| parser.set_defaults(output=False) |
| |
| parser.add_argument( |
| "--graphs", nargs="+", type=int, help="Run only specified graph indices" |
| ) |
| |
| args = parser.parse_args() |
| graphs = extract_ir(args.filename) |
| |
| graph_set = args.graphs |
| graph_set = graph_set if graph_set else None |
| |
| options = [] |
| if args.baseline: |
| options.append(("Baseline no fusion", run_baseline_no_fusion)) |
| if args.nnc_dynamic: |
| options.append(("NNC Dynamic", functools.partial(run_nnc, dynamic=True))) |
| if args.nnc_static: |
| options.append(("NNC Static", functools.partial(run_nnc, dynamic=False))) |
| if args.nvfuser: |
| options.append(("NVFuser", run_nvfuser)) |
| |
| test_runners(graphs, options, graph_set) |
| |
| if args.output: |
| quoted = [] |
| for i, ir in enumerate(graphs): |
| if graph_set and i not in graph_set: |
| continue |
| quoted.append('"""' + ir + '"""') |
| print("[" + ", ".join(quoted) + "]") |
| |
| |
| if __name__ == "__main__": |
| run() |