blob: 7f99d95af24153de7c9c047c83102dc0e88e10ae [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
import contextlib
import sys
from copy import deepcopy
from functools import partial
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
offload_wrapper,
)
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import (
CPUOffload,
FullyShardedDataParallel as FSDP,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import _maybe_wrap_fsdp, FSDPTest
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TEST_WITH_DEV_DBG_ASAN,
)
from torch.utils.checkpoint import checkpoint
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
_save_on_cpu_called = False
def get_patched_save_on_cpu():
orig_save_on_cpu = (
torch.distributed.algorithms._checkpoint.checkpoint_wrapper.save_on_cpu
)
def patched_save_on_cpu(*args, **kwargs):
global _save_on_cpu_called
_save_on_cpu_called = True
return orig_save_on_cpu(*args, **kwargs)
return patched_save_on_cpu
@contextlib.contextmanager
def patch_save_on_cpu(new_save_on_cpu):
orig_save_on_cpu = (
torch.distributed.algorithms._checkpoint.checkpoint_wrapper.save_on_cpu
)
torch.distributed.algorithms._checkpoint.checkpoint_wrapper.save_on_cpu = (
new_save_on_cpu
)
try:
yield
finally:
torch.distributed.algorithms._checkpoint.checkpoint_wrapper.save_on_cpu = (
orig_save_on_cpu
)
class TestFSDPCheckpoint(FSDPTest):
class SequentialModule(nn.Module):
def __init__(
self,
checkpoint_layer=False,
offload_activations=False,
wrap_fsdp=False,
*fsdp_args,
**fsdp_kwargs,
):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
super().__init__()
l1 = nn.Linear(3, 3).cuda()
l2 = nn.Linear(3, 3).cuda()
l3 = nn.Linear(3, 3).cuda()
if checkpoint_layer:
if offload_activations:
ckpt_wrapper = offload_wrapper
else:
ckpt_wrapper = checkpoint_wrapper
l1 = ckpt_wrapper(l1)
l2 = ckpt_wrapper(l2)
l3 = ckpt_wrapper(l3)
fsdp_wrapper = partial(
_maybe_wrap_fsdp, wrap_fsdp=wrap_fsdp, *fsdp_args, **fsdp_kwargs
)
self.ffn = nn.Sequential(
fsdp_wrapper(l1),
fsdp_wrapper(l2),
fsdp_wrapper(l3),
)
def forward(self, x):
return self.ffn(x)
def _verify_parity(self, losses, outputs, models):
assert losses
assert outputs
assert models
for l, o in zip(losses[1:], outputs[1:]):
self.assertEqual(losses[0], l)
self.assertEqual(outputs[0], o)
# Verify grads
ref_model = models[0]
ref_grads = [p.grad for p in ref_model.parameters()]
for m in models[1:]:
grads = [p.grad for p in m.parameters()]
for ref_g, g in zip(ref_grads, grads):
self.assertEqual(ref_g, g)
@skip_if_lt_x_gpu(2)
@parametrize(
"cpu_offload",
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)],
)
@parametrize("offload_activations", [True, False])
@parametrize("use_orig_params", [False, True])
def test_checkpoint_fsdp_wrapping(
self,
cpu_offload: CPUOffload,
offload_activations: bool,
use_orig_params: bool,
):
# Test checkpoint(FSDP(layer1), FSDP(layer2), ....)
if offload_activations:
wrapper_to_use = offload_wrapper
else:
wrapper_to_use = checkpoint_wrapper
fsdp_kwargs = {"cpu_offload": cpu_offload, "use_orig_params": use_orig_params}
ckpt_sequential_wrapped_fsdp = wrapper_to_use(
TestFSDPCheckpoint.SequentialModule(
wrap_fsdp=True,
**fsdp_kwargs,
),
)
# Test FSDP(checkpoint(layer1)), FSDP(checkpoint(layer2)), ....
inner_ckpt = TestFSDPCheckpoint.SequentialModule(
checkpoint_layer=True,
offload_activations=offload_activations,
wrap_fsdp=True,
**fsdp_kwargs,
)
baseline = TestFSDPCheckpoint.SequentialModule(
wrap_fsdp=True,
**fsdp_kwargs,
)
# note that reentrant-based checkpointing requires inputs to have grad
# flag set.
inp = torch.randn(10, 3, device=torch.cuda.current_device(), requires_grad=True)
global _save_on_cpu_called
models = [ckpt_sequential_wrapped_fsdp, inner_ckpt, baseline]
with patch_save_on_cpu(get_patched_save_on_cpu()):
for i in range(2):
losses = []
outputs = []
for m in models:
check_offload = m != baseline and i == 0 and offload_activations
if check_offload:
self.assertFalse(_save_on_cpu_called)
out = m(inp)
if check_offload:
self.assertTrue(_save_on_cpu_called)
_save_on_cpu_called = False
loss = out.sum()
loss.backward()
losses.append(loss)
outputs.append(out)
self._verify_parity(losses, outputs, models)
dist.barrier()
@skip_if_lt_x_gpu(2)
@parametrize(
"cpu_offload",
[CPUOffload(offload_params=True), CPUOffload(offload_params=False)],
)
@parametrize("offload_activations", [True, False])
@parametrize("use_orig_params", [False, True])
def test_basic_checkpoint_end_to_end(
self,
cpu_offload: CPUOffload,
offload_activations: bool,
use_orig_params: bool,
):
fsdp_kwargs = {"cpu_offload": cpu_offload, "use_orig_params": use_orig_params}
global _save_on_cpu_called
with patch_save_on_cpu(get_patched_save_on_cpu()):
seq = TestFSDPCheckpoint.SequentialModule().to(torch.cuda.current_device())
# Runs FSDP with no checkpointing
fsdp_only_seq = FSDP(deepcopy(seq), **fsdp_kwargs)
# Runs checkpoint-wrapped FSDP
if offload_activations:
wrapper_to_use = offload_wrapper
else:
wrapper_to_use = checkpoint_wrapper
checkpointed_fsdp = wrapper_to_use(
FSDP(deepcopy(seq), **fsdp_kwargs),
)
# Runs FSDP-wrapped checkpointed module
fsdp_wrapped_checkpoint = FSDP(
wrapper_to_use(deepcopy(seq)),
**fsdp_kwargs,
)
# Runs FSDP with manual calls to checkpoint.
fsdp_call_checkpoint = FSDP(deepcopy(seq), **fsdp_kwargs)
# note that reentrant-based checkpointing requires inputs to have grad
# flag set.
inp = torch.randn(
10, 3, device=torch.cuda.current_device(), requires_grad=True
)
models = [
fsdp_only_seq,
checkpointed_fsdp,
fsdp_wrapped_checkpoint,
fsdp_call_checkpoint,
]
# Ensure _save_on_cpu is not yet called
self.assertFalse(_save_on_cpu_called)
for i in range(6):
losses = []
outputs = []
for m in models:
check_offload = (
m != fsdp_only_seq and i == 0 and offload_activations
)
if m == fsdp_call_checkpoint:
# _save_on_cpu should not be called yet
self.assertFalse(_save_on_cpu_called)
offload_ctx = (
get_patched_save_on_cpu()(pin_memory=True)
if offload_activations
else contextlib.nullcontext()
)
with offload_ctx:
out = checkpoint(m, inp, use_reentrant=True)
else:
# _save_on_cpu should not be called yet
self.assertFalse(_save_on_cpu_called)
out = m(inp)
if check_offload:
self.assertTrue(_save_on_cpu_called)
loss = out.sum()
loss.backward()
losses.append(loss)
outputs.append(out)
_save_on_cpu_called = False
self._verify_parity(losses, outputs, models)
dist.barrier()
instantiate_parametrized_tests(TestFSDPCheckpoint)
class CheckpointModule(nn.Module):
def __init__(self, checkpoint: bool = False, use_reentrant: bool = True):
super().__init__()
self.seq = nn.Sequential(*[nn.Linear(100, 100) for _ in range(4)])
self.checkpoint = checkpoint
self.use_reentrant = use_reentrant
def forward(self, x):
return (
checkpoint(self.seq, x, use_reentrant=self.use_reentrant)
if self.checkpoint
else self.seq(x)
)
class ModelWithCheckpointSubmodule(nn.Module):
def __init__(self, checkpoint: bool = False, use_reentrant: bool = True):
super().__init__()
self.l1 = nn.Linear(100, 100)
self.s1 = CheckpointModule(checkpoint, use_reentrant)
self.s2 = CheckpointModule(checkpoint, use_reentrant)
self.relu = nn.ReLU()
self.l2 = nn.Linear(100, 100)
def forward(self, x):
return self.l2(self.relu(self.s2(self.s1(self.l1(x)))))
class TestModel(nn.Module):
def __init__(self, checkpoint: bool = False, use_reentrant: bool = True):
super().__init__()
self.l1 = nn.Linear(100, 100)
self.relu = nn.ReLU()
self.checkpoint1 = ModelWithCheckpointSubmodule(checkpoint, use_reentrant)
self.checkpoint2 = ModelWithCheckpointSubmodule(checkpoint, use_reentrant)
self.l2 = nn.Linear(100, 100)
def forward(self, x):
return self.l2(self.relu(self.checkpoint2(self.checkpoint1(self.l1(x)))))
class TestFSDPCheckpointSubmodule(FSDPTest):
# TODO: grad value checks occasionally fails when use_reentrant = True
@skip_if_lt_x_gpu(2)
@parametrize("use_reentrant", [False])
def test_checkpoint_submodule(self, use_reentrant: bool):
model = TestModel(use_reentrant=use_reentrant).cuda()
model_ac = deepcopy(model)
for _, m in model_ac.named_modules():
if isinstance(m, CheckpointModule):
m.checkpoint = True
self.assertTrue(model_ac.checkpoint1.s1.checkpoint)
self.assertTrue(model_ac.checkpoint2.s2.checkpoint)
fsdp_kwargs = {
"device_id": torch.cuda.current_device(),
"sharding_strategy": ShardingStrategy.NO_SHARD,
}
# Wrap no checkpointing model submodules with FSDP
model.checkpoint1 = FSDP(module=model.checkpoint1, **fsdp_kwargs)
model.checkpoint2 = FSDP(module=model.checkpoint2, **fsdp_kwargs)
# Wrap checkpointing model submodules with FSDP
model_ac.checkpoint1 = FSDP(module=model_ac.checkpoint1, **fsdp_kwargs)
model_ac.checkpoint2 = FSDP(module=model_ac.checkpoint2, **fsdp_kwargs)
x = torch.randn(2, 100, device="cuda")
model(x).sum().backward()
model_ac(x).sum().backward()
for (n1, p1), (n2, p2) in zip(
model.named_parameters(), model_ac.named_parameters()
):
self.assertEqual(n1, n2)
self.assertTrue(p1.grad.allclose(p2.grad))
instantiate_parametrized_tests(TestFSDPCheckpointSubmodule)
if __name__ == "__main__":
run_tests()