| import argparse |
| import logging |
| import os |
| from functools import partial |
| |
| import torch |
| import torch._dynamo as dynamo |
| import torch.utils._pytree as pytree |
| from torch._dynamo.testing import reduce_to_scalar_loss |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from torch.profiler import profile, ProfilerActivity, record_function |
| |
| try: |
| from .common import timed |
| from .dist_util import apply_fsdp, cleanup, get_model, model_iter_fn, setup |
| except ImportError: |
| from common import timed |
| from dist_util import apply_fsdp, cleanup, get_model, model_iter_fn, setup |
| |
| log = logging.getLogger(__name__) |
| |
| |
| def torchviz_model(args, model, inputs, rank): |
| from torchviz import make_dot |
| |
| outputs = model(*inputs) |
| loss = reduce_to_scalar_loss(outputs) |
| parameter_names = dict(model.named_parameters()) |
| dot = make_dot(loss, params=parameter_names, show_attrs=True, show_saved=True) |
| if rank == 0: |
| dot.render("torchviz.dot") |
| |
| |
| def profile_model(args, model, inputs, rank): |
| with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: |
| for i in range(args.repeat): |
| with record_function("Forward"): |
| outputs = model(*inputs) |
| loss = reduce_to_scalar_loss(outputs) |
| with record_function("Backward"): |
| loss.backward() |
| if rank == 0: |
| prof.export_chrome_trace(args.trace_file) |
| |
| |
| def run_model(args, model, inputs, key): |
| rank = int(os.getenv("RANK", 0)) |
| world_size = int(os.getenv("WORLD_SIZE", 1)) |
| # result_q = [] |
| |
| setup(rank, world_size) |
| if args.device == "cuda": |
| # needed for FSDP |
| torch.cuda.set_device(rank) |
| |
| dev_rank = f"{args.device}:{rank}" |
| model = model.to(dev_rank) |
| |
| def move_tensor(maybe_tensor): |
| if torch.is_tensor(maybe_tensor): |
| return maybe_tensor.to(dev_rank) |
| return maybe_tensor |
| |
| inputs = pytree.tree_map(move_tensor, inputs) |
| |
| if args.fsdp: |
| model = apply_fsdp( |
| args, |
| model, |
| use_checkpointing=args.fsdp_checkpoint, |
| use_wrap_policy=args.fsdp_wrap, |
| ) |
| elif args.ddp: |
| model = DDP(model) |
| |
| if args.verbose: |
| print(model) |
| |
| if args.dynamo: |
| dynamo.reset() |
| if args.verbose: |
| dynamo.config.verbose = True |
| dynamo.config.log_level = logging.DEBUG |
| if args.dynamo_no_optimize_ddp: |
| dynamo.config.optimize_ddp = False |
| if args.dynamo == "inductor" and args.fsdp: |
| torch._inductor.config.triton.cudagraphs = False |
| log.warning("disabling inductor cudagraphs for compatibility with FSDP") |
| |
| def print_compile(gm, ex): |
| print( |
| f"print_compile:\n{str(gm.graph)}\n-----------------------------------------" |
| ) |
| return gm |
| |
| dynamo_ctx = dynamo.optimize( |
| print_compile if args.dynamo == "print" else args.dynamo |
| ) |
| model = dynamo_ctx(model) |
| |
| # warmup |
| _ = timed(model, model_iter_fn, inputs, times=3, return_result=False) |
| t_total = timed( |
| model, model_iter_fn, inputs, times=args.repeat, return_result=False |
| ) |
| if args.torchviz: |
| torchviz_model(args, model, inputs, rank) |
| if args.profile: |
| profile_model(args, model, inputs, rank) |
| |
| cleanup() |
| return t_total |
| |
| |
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--device", default="cuda") |
| parser.add_argument( |
| "--dynamo", |
| default=None, |
| help="if set to a str, uses dynamo[str] backend. else, eager", |
| ) |
| parser.add_argument("--verbose", action="store_true") |
| parser.add_argument("--batch-size", "--batch_size", default=None) |
| parser.add_argument( |
| "--torchviz", action="store_true", help="Dump autograd graph with torchviz" |
| ) |
| parser.add_argument("--profile", action="store_true", help="Run the profiler") |
| parser.add_argument( |
| "--trace-file", "--trace_file", default="profile.json", help="Run the profiler" |
| ) |
| parser.add_argument("--repeat", default=10, help="Repeats for timing run") |
| parser.add_argument( |
| "--dynamo-no-optimize-ddp", |
| "--dynamo_no_optimize_ddp", |
| action="store_true", |
| help="Disable dynamo's ddp optimizer (enabled by default)", |
| ) |
| parser.add_argument( |
| "--fsdp-checkpoint", |
| "--fsdp_checkpoint", |
| action="store_true", |
| help="Use gradient checkpointing via model-specific policy", |
| ) |
| parser.add_argument( |
| "--fsdp-wrap", |
| "--fsdp_wrap", |
| action="store_true", |
| help="Apply fsdp to submodules via model-specific policy", |
| ) |
| |
| dist_arg = parser.add_mutually_exclusive_group() |
| dist_arg.add_argument("--ddp", action="store_true") |
| dist_arg.add_argument("--fsdp", action="store_true") |
| |
| model_arg = parser.add_mutually_exclusive_group(required=True) |
| model_arg.add_argument( |
| "--torchbench-model", |
| "--torchbench_model", |
| help="name of torchbench model, e.g. hf_Bert", |
| ) |
| model_arg.add_argument( |
| "--toy-model", "--toy_model", action="store_true", help="use toy model instead" |
| ) |
| args = parser.parse_args() |
| |
| model_name = args.torchbench_model |
| if args.toy_model: |
| model_name = "ToyModel" |
| model, inputs = get_model(args) |
| |
| fn = partial(run_model, args, model, inputs) |
| |
| world_size = os.getenv("WORLD_SIZE", 1) |
| t_total = fn(f"{model_name}_{world_size}") |
| print(f"mean latency {t_total / args.repeat} across {args.repeat} runs") |