blob: dc984626888a024db4218a53d368b0d8fe951e27 [file] [log] [blame]
#!/usr/bin/env python3
#
# Computes difference between measurements produced by ./benchmark.py.
#
import argparse
import json
import numpy as np
def load(path):
with open(path, 'r') as f:
return json.load(f)
def main():
parser = argparse.ArgumentParser(description='PyTorch distributed benchmark diff')
parser.add_argument("file", nargs=2)
args = parser.parse_args()
if len(args.file) != 2:
raise RuntimeError("Must specify 2 files to diff")
ja = load(args.file[0])
jb = load(args.file[1])
keys = (set(ja.keys()) | set(jb.keys())) - set(["benchmark_results"])
print("{:20s} {:>20s} {:>20s}".format("", "baseline", "test"))
print("{:20s} {:>20s} {:>20s}".format("", "-" * 20, "-" * 20))
for key in sorted(keys):
va = str(ja.get(key, "-"))
vb = str(jb.get(key, "-"))
print("{:20s} {:>20s} vs {:>20s}".format(key + ":", va, vb))
print("")
ba = ja["benchmark_results"]
bb = jb["benchmark_results"]
for ra, rb in zip(ba, bb):
if ra["model"] != rb["model"]:
continue
if ra["batch_size"] != rb["batch_size"]:
continue
model = ra["model"]
batch_size = int(ra["batch_size"])
name = "{} with batch size {}".format(model, batch_size)
print("Benchmark: {}".format(name))
# Print header
print("")
print("{:>10s}".format(""), end='') # noqa: E999
for _ in [75, 95]:
print("{:>16s}{:>10s}{:>10s}".format("sec/iter", "ex/sec", "diff"), end='') # noqa: E999
print("")
# Print measurements
for (i, (xa, xb)) in enumerate(zip(ra["result"], rb["result"])):
# Ignore round without ddp
if i == 0:
continue
# Sanity check: ignore if number of ranks is not equal
if len(xa["ranks"]) != len(xb["ranks"]):
continue
ngpus = len(xa["ranks"])
ma = sorted(xa["measurements"])
mb = sorted(xb["measurements"])
print("{:>4d} GPUs:".format(ngpus), end='') # noqa: E999
for p in [75, 95]:
va = np.percentile(ma, p)
vb = np.percentile(mb, p)
# We're measuring time, so lower is better (hence the negation)
delta = -100 * ((vb - va) / va)
print(" p{:02d}: {:8.3f}s {:7d}/s {:+8.1f}%".format(p, vb, int(batch_size / vb), delta), end='') # noqa: E999
print("")
print("")
if __name__ == '__main__':
main()