blob: 9d99c4fcb6e12281b4ec6436f64d62be27e1cfc4 [file] [log] [blame]
import argparse
import logging
import os
from functools import partial
import torch
import torch._dynamo as dynamo
import torch.utils._pytree as pytree
from torch._dynamo.testing import reduce_to_scalar_loss
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.profiler import profile, ProfilerActivity, record_function
try:
from .common import timed
from .dist_util import apply_fsdp, cleanup, get_model, model_iter_fn, setup
except ImportError:
from common import timed
from dist_util import apply_fsdp, cleanup, get_model, model_iter_fn, setup
log = logging.getLogger(__name__)
def torchviz_model(args, model, inputs, rank):
from torchviz import make_dot
outputs = model(*inputs)
loss = reduce_to_scalar_loss(outputs)
parameter_names = dict(model.named_parameters())
dot = make_dot(loss, params=parameter_names, show_attrs=True, show_saved=True)
if rank == 0:
dot.render("torchviz.dot")
def profile_model(args, model, inputs, rank):
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
for i in range(args.repeat):
with record_function("Forward"):
outputs = model(*inputs)
loss = reduce_to_scalar_loss(outputs)
with record_function("Backward"):
loss.backward()
if rank == 0:
prof.export_chrome_trace(args.trace_file)
def run_model(args, model, inputs, key):
rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
# result_q = []
setup(rank, world_size)
if args.device == "cuda":
# needed for FSDP
torch.cuda.set_device(rank)
dev_rank = f"{args.device}:{rank}"
model = model.to(dev_rank)
def move_tensor(maybe_tensor):
if torch.is_tensor(maybe_tensor):
return maybe_tensor.to(dev_rank)
return maybe_tensor
inputs = pytree.tree_map(move_tensor, inputs)
if args.fsdp:
model = apply_fsdp(
args,
model,
use_checkpointing=args.fsdp_checkpoint,
use_wrap_policy=args.fsdp_wrap,
)
elif args.ddp:
model = DDP(model)
if args.verbose:
print(model)
if args.dynamo:
dynamo.reset()
if args.verbose:
dynamo.config.verbose = True
dynamo.config.log_level = logging.DEBUG
if args.dynamo_no_optimize_ddp:
dynamo.config.optimize_ddp = False
if args.dynamo == "inductor" and args.fsdp:
torch._inductor.config.triton.cudagraphs = False
log.warning("disabling inductor cudagraphs for compatibility with FSDP")
def print_compile(gm, ex):
print(
f"print_compile:\n{str(gm.graph)}\n-----------------------------------------"
)
return gm
dynamo_ctx = dynamo.optimize(
print_compile if args.dynamo == "print" else args.dynamo
)
model = dynamo_ctx(model)
# warmup
_ = timed(model, model_iter_fn, inputs, times=3, return_result=False)
t_total = timed(
model, model_iter_fn, inputs, times=args.repeat, return_result=False
)
if args.torchviz:
torchviz_model(args, model, inputs, rank)
if args.profile:
profile_model(args, model, inputs, rank)
cleanup()
return t_total
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--device", default="cuda")
parser.add_argument(
"--dynamo",
default=None,
help="if set to a str, uses dynamo[str] backend. else, eager",
)
parser.add_argument("--verbose", action="store_true")
parser.add_argument("--batch-size", "--batch_size", default=None)
parser.add_argument(
"--torchviz", action="store_true", help="Dump autograd graph with torchviz"
)
parser.add_argument("--profile", action="store_true", help="Run the profiler")
parser.add_argument(
"--trace-file", "--trace_file", default="profile.json", help="Run the profiler"
)
parser.add_argument("--repeat", default=10, help="Repeats for timing run")
parser.add_argument(
"--dynamo-no-optimize-ddp",
"--dynamo_no_optimize_ddp",
action="store_true",
help="Disable dynamo's ddp optimizer (enabled by default)",
)
parser.add_argument(
"--fsdp-checkpoint",
"--fsdp_checkpoint",
action="store_true",
help="Use gradient checkpointing via model-specific policy",
)
parser.add_argument(
"--fsdp-wrap",
"--fsdp_wrap",
action="store_true",
help="Apply fsdp to submodules via model-specific policy",
)
dist_arg = parser.add_mutually_exclusive_group()
dist_arg.add_argument("--ddp", action="store_true")
dist_arg.add_argument("--fsdp", action="store_true")
model_arg = parser.add_mutually_exclusive_group(required=True)
model_arg.add_argument(
"--torchbench-model",
"--torchbench_model",
help="name of torchbench model, e.g. hf_Bert",
)
model_arg.add_argument(
"--toy-model", "--toy_model", action="store_true", help="use toy model instead"
)
args = parser.parse_args()
model_name = args.torchbench_model
if args.toy_model:
model_name = "ToyModel"
model, inputs = get_model(args)
fn = partial(run_model, args, model, inputs)
world_size = os.getenv("WORLD_SIZE", 1)
t_total = fn(f"{model_name}_{world_size}")
print(f"mean latency {t_total / args.repeat} across {args.repeat} runs")