blob: 98cc274a4c669018fa48a7756c21638f2e5081ae [file] [log] [blame]
import argparse
import functools
import importlib
import os
import torch
import torch.distributed as dist
import torch.nn as nn
from torch._dynamo.testing import reduce_to_scalar_loss
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
try:
from .torchbench import setup_torchbench_cwd
except ImportError:
from torchbench import setup_torchbench_cwd
from transformers.models.bert.modeling_bert import BertLayer, BertLMPredictionHead
from transformers.models.t5.modeling_t5 import T5Block
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = os.getenv("MASTER_ADDR", "localhost")
os.environ["MASTER_PORT"] = os.getenv("MASTER_PORT", "12355")
os.environ["RANK"] = os.getenv("RANK", "0")
os.environ["WORLD_SIZE"] = os.getenv("WORLD_SIZE", "1")
dist.init_process_group("nccl")
def cleanup():
dist.destroy_process_group()
class CustomLinear(torch.nn.Module):
def __init__(self, a, b):
super().__init__()
self.weight = nn.Parameter(torch.randn(a, b))
def forward(self, x):
return torch.mm(x, self.weight)
class MyModule(torch.nn.Module):
def __init__(self, a, b):
super().__init__()
self.net = nn.Sequential(
nn.Linear(a, b),
nn.ReLU(),
)
def forward(self, x):
return self.net(x)
class ToyModel(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
*[nn.Linear(10, 10000), nn.ReLU()]
+ [nn.Linear(10000, 10000), nn.ReLU()]
+ [MyModule(10000, 10000)]
+ [MyModule(10000, 1000)]
+ [MyModule(1000, 1000)]
+ [MyModule(1000, 1000)]
+ [MyModule(1000, 1000)]
+ [MyModule(1000, 1000)]
+ [MyModule(1000, 1000)]
+ [MyModule(1000, 1000)]
+ [MyModule(1000, 1000)]
+ [nn.Linear(1000, 5)]
)
def forward(self, x):
return self.net(x)
def model_iter_fn(model, example_inputs, collect_outputs=False):
outputs = model(*example_inputs)
loss = reduce_to_scalar_loss(outputs)
loss.backward()
if collect_outputs:
return outputs
def get_model(args):
if args.torchbench_model:
old_cwd = setup_torchbench_cwd()
module = importlib.import_module(
f"torchbenchmark.models.{args.torchbench_model}"
)
benchmark_cls = getattr(module, "Model", None)
bm = benchmark_cls(test="train", device=args.device, batch_size=args.batch_size)
model, inputs = bm.get_module()
elif args.toy_model:
model = ToyModel()
inputs = (torch.randn(20, 10),)
else:
raise argparse.ArgumentError(
args.torchbench_model, message="Must specify a model"
)
return model, inputs
def fsdp_checkpointing_base(model, blocks):
"""apply activation checkpointing to model
returns None as model is updated directly
"""
non_reentrant_wrapper = functools.partial(
checkpoint_wrapper,
offload_to_cpu=False,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
def check_fn(submodule):
return isinstance(submodule, blocks)
apply_activation_checkpointing(
model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
)
MODEL_FSDP_WRAP = {
"toy_model": (MyModule,),
"hf_Bert": (BertLayer, BertLMPredictionHead),
"hf_T5": (T5Block,),
}
def apply_fsdp(args, model, use_checkpointing=False, use_wrap_policy=True):
wrap_policy = None
blocks = MODEL_FSDP_WRAP[
"toy_model" if model.__class__ is ToyModel else args.torchbench_model
]
if use_wrap_policy:
wrap_policy = ModuleWrapPolicy(blocks)
model = FSDP(model, auto_wrap_policy=wrap_policy, use_orig_params=True)
if use_checkpointing:
fsdp_checkpointing_base(model, blocks)
return model