blob: 79ed6da6240fabf1c17b86536eb79ab6241ce4fd [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
import functools
import sys
from collections import namedtuple
from contextlib import suppress
from copy import deepcopy
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp import (
CPUOffload,
FlatParameter,
FullyShardedDataParallel as FSDP,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import always_wrap_policy, transformer_auto_wrap_policy
from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
_assert_module_states,
CUDAInitMode,
FSDPInitMode,
FSDPTest,
NestedWrappedModule,
TransformerWithSharedParams,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TEST_WITH_DEV_DBG_ASAN,
)
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class TestFSDPMisc(FSDPTest):
@property
def world_size(self):
return 2
@property
def process_group(self):
return dist.distributed_c10d._get_default_group()
@skip_if_lt_x_gpu(2)
def test_fsdp_namedtuple(self):
# Ensure namedtuple support, preventing issues such as
# https://github.com/pytorch/pytorch/issues/83053
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.Linear(100, 100)
def forward(self, x):
return x
m = MyModule().cuda()
m = FSDP(m)
t = torch.ones(1, device="cuda", requires_grad=True)
MyOutputType = namedtuple(
"MyOutputType", ["a", "b", "c", "d"], defaults=(t, t, t, t)
)
inp = MyOutputType()
out = m(inp)
# Ensure hooks are registered
for x in out:
self.assertNotEqual([], list(x._backward_hooks.values()))
# TODO: we should check backward() and param is resharded
# as well, but this is blocked by
# https://github.com/pytorch/pytorch/issues/83107 and
# https://github.com/pytorch/pytorch/issues/83129
@skip_if_lt_x_gpu(2)
def test_fsdp_not_all_outputs_used_in_loss(self):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.lin1 = nn.Linear(4, 4)
self.lin2 = nn.Linear(4, 4)
def forward(self, x):
a = self.lin1(x)
b = self.lin2(x)
return (a, b)
def _check_resharded(fsdp_module):
for handle in fsdp_module._handles:
param = handle.flat_param
if handle.uses_sharded_strategy:
full_param = param._full_param_padded
self.assertEqual(full_param.storage().size(), 0)
self.assertEqual(param.data_ptr(), param._local_shard.data_ptr())
def _check_equal(local, fsdp):
with FSDP.summon_full_params(fsdp):
for p1, p2 in zip(fsdp.parameters(), local.parameters()):
torch.testing.assert_close(p1, p2)
for sharding_strategy in [
ShardingStrategy.FULL_SHARD,
ShardingStrategy.SHARD_GRAD_OP,
ShardingStrategy.NO_SHARD,
]:
with self.subTest(sharding_strategy=sharding_strategy):
fsdp_ctor = functools.partial(FSDP, sharding_strategy=sharding_strategy)
m = MyModule().cuda()
m_local = deepcopy(m)
local_m = m_local
prev_params = [p.clone() for p in m_local.parameters()]
m.lin1 = fsdp_ctor(m.lin1)
m = fsdp_ctor(m)
_check_equal(m_local, m)
opt = torch.optim.SGD(m.parameters(), lr=1e-3)
opt_local = torch.optim.SGD(local_m.parameters(), lr=1e-3)
for i in range(6):
t = torch.ones(4, device="cuda")
a, b = m(t)
local_a, local_b = local_m(t)
if i < 2:
# use both params in loss computation. Later,
# b will go unused and we check grads are the
# same as local training.
loss = (a @ b).sum()
loss_local = (local_a @ local_b).sum()
else:
loss = a.sum()
loss_local = local_a.sum()
loss.backward()
loss_local.backward()
_check_resharded(m)
opt.step()
opt_local.step()
_check_equal(m_local, m)
# Ensure at least some change from previous params, otherwise
# above check would be vacuously true.
self.assertTrue(
any(
not torch.equal(p1, p2)
for p1, p2 in zip(prev_params, m_local.parameters())
)
)
prev_params = [p.clone() for p in local_m.parameters()]
opt.zero_grad()
opt_local.zero_grad()
dist.barrier()
@skip_if_lt_x_gpu(2)
@parametrize("use_second_layer", [True, False])
@parametrize("sharding_strategy", [ShardingStrategy.NO_SHARD, None])
def test_fsdp_module_no_compute_grad(self, use_second_layer, sharding_strategy):
# When use_second_layer=True, b is involved in forward computation but does
# not receive grad in backward. Otherwise, b is not involved in forward
# computation.
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Linear(10, 10)
self.b = nn.Linear(10, 10)
def forward(self, x, y):
out1 = self.a(x)
if use_second_layer:
out2 = self.b(y)
return out1, out2
else:
return out1
fsdp = FSDP(
MyModel().cuda(),
sharding_strategy=sharding_strategy,
auto_wrap_policy=always_wrap_policy,
)
x = torch.randn(10, 10, device="cuda")
y = torch.randn(10, 10, device="cuda")
for i in range(4):
if use_second_layer:
a, b = fsdp(x, y)
else:
a = fsdp(x, y)
loss = a.sum()
loss.backward()
# self.a receives grad, self.b does not
a_grad = fsdp.module.a._handles[0].flat_param.grad
b_grad = fsdp.module.b._handles[0].flat_param.grad
self.assertIsNotNone(a_grad)
self.assertIsNone(b_grad)
@skip_if_lt_x_gpu(2)
def test_device_id_auto_wrap(self):
"""Tests that ``auto_wrap_policy`` propagates ``device_id`` to all
nested FSDP instances."""
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer},
)
fsdp_kwargs = {
"auto_wrap_policy": auto_wrap_policy,
"device_id": torch.cuda.current_device(),
}
fsdp_model = TransformerWithSharedParams.init(
self.process_group,
FSDPInitMode.RECURSIVE,
CUDAInitMode.CUDA_BEFORE,
fsdp_kwargs,
)
for fsdp_module in FSDP.fsdp_modules(fsdp_model):
self.assertEqual(
fsdp_module.compute_device,
torch.device("cuda", torch.cuda.current_device()),
)
@skip_if_lt_x_gpu(2)
def test_fsdp_device_id_cpu_offload(self):
"""
Ensures that even if device_id is specified but we have
CPU offload, module is on CPU after init.
"""
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Linear(10, 10)
self.b = nn.Linear(10, 10)
def forward(self, x):
return self.b(self.a(x))
model = MyModel()
fsdp = FSDP(
model,
auto_wrap_policy=always_wrap_policy,
cpu_offload=CPUOffload(offload_params=True),
device_id=torch.cuda.current_device(),
)
cpu_device = torch.device("cpu")
for fsdp_unit in FSDP.fsdp_modules(fsdp):
# This FSDP unit may not directly manage
# any parameters.
if len(fsdp_unit.params) > 0:
fsdp_param = fsdp_unit.params[0]
self.assertEqual(fsdp_param.device, cpu_device)
@skip_if_lt_x_gpu(2)
@parametrize("use_index", [True, False])
def test_fsdp_device_id(self, use_index):
"""
Tests the FSDP ``device_id`` argument:
- Wrapping a CPU module should move the module to the GPU matching
``device_id``
- Wrapping a GPU module already on the GPU matching ``device_id``
should not raise an error
- Wrapping a GPU module already on GPU and passing a GPU device
without specifying a device ID (i.e. ``torch.device("cuda")``) warns
"""
dev_id = (
torch.cuda.current_device()
if use_index
else torch.device("cuda", torch.cuda.current_device())
)
def _check_device_matches(module, device_id):
"""Checks that the ``FlatParameter``s in ``module`` have device
matching ``device_id``."""
devices = {
p.device for p in module.parameters() if isinstance(p, FlatParameter)
}
assert len(devices) > 0
self.assertEqual(1, len(devices))
found_device = devices.pop()
if use_index and not isinstance(device_id, torch.device):
device = torch.device("cuda", device_id)
else:
device = device_id
self.assertEqual(found_device, device)
# Check that FSDP parameters are moved to `device_id` for a CPU module
nested_wrapped_module = NestedWrappedModule.init(
self.process_group,
FSDPInitMode.RECURSIVE,
CUDAInitMode.CUDA_NEVER,
fsdp_kwargs={"device_id": dev_id},
)
_check_device_matches(nested_wrapped_module, dev_id)
# Check that specifying `device_id` for a GPU module already on that
# device does not raise an error
nested_wrapped_module = NestedWrappedModule.init(
self.process_group,
FSDPInitMode.RECURSIVE,
CUDAInitMode.CUDA_BEFORE,
fsdp_kwargs={"device_id": dev_id},
)
_check_device_matches(nested_wrapped_module, dev_id)
# Check that passing in `torch.device("cuda")` for a GPU module warns
regex = "does not have an explicit index"
context = self.assertWarnsRegex(
expected_warning=UserWarning, expected_regex=regex
)
with context:
nested_wrapped_module = NestedWrappedModule.init(
self.process_group,
FSDPInitMode.RECURSIVE,
CUDAInitMode.CUDA_BEFORE,
fsdp_kwargs={"device_id": torch.device("cuda")},
)
_check_device_matches(
nested_wrapped_module, torch.device("cuda", torch.cuda.current_device())
)
@skip_if_lt_x_gpu(2)
def test_module_device_mismatches_device_id(self):
"""Tests that specifying a ``device_id`` argument to FSDP for a GPU
module that does not match the GPU device ID raises an error."""
context = (
self.assertRaisesRegex(ValueError, f"cuda:{self.rank} vs cuda:0")
if self.rank != 0
else suppress()
)
with context:
NestedWrappedModule.init(
self.process_group,
FSDPInitMode.RECURSIVE,
# Move wrapped modules to CUDA before wrapping with FSDP
cuda_init_mode=CUDAInitMode.CUDA_BEFORE,
# Should raise error since rank 1 is given `device_id=0` when
# the model is on cuda:1
fsdp_kwargs={"device_id": 0},
)
@skip_if_lt_x_gpu(2)
def test_multi_device_not_supported(self):
"""Tests that wrapping a multi-device module (i.e. with submodules on
both GPU and CPU) with FSDP raises an error."""
class MultiDeviceModule(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Linear(1, 1).cuda()
self.b = nn.Linear(1, 1)
with self.assertRaisesRegex(
RuntimeError, "FSDP only supports single device modules"
):
FSDP(MultiDeviceModule())
@skip_if_lt_x_gpu(2)
def test_no_params(self):
"""
Test that device_id and cpu init work if module has no params
(they are effective noops, but ensure FSDP does not assume module
has parameters during init)
"""
# Test CPU
no_params = nn.ReLU()
module = FSDP(no_params)
# Test CUDA
no_params = nn.ReLU().cuda()
module = FSDP(no_params)
# Test CPU + device_id
no_params = nn.ReLU()
module = FSDP(no_params, device_id=torch.cuda.current_device())
# For modules with no params, wrong device_id will raise error about
# inconsistency between compute_device and device_id, since compute_device
# is computed as torch.cuda.current_device when there are no params.
no_params = nn.ReLU().cuda()
context = (
(
self.assertRaisesRegex(
ValueError, f"Inconsistent.*cuda:{self.rank} vs cuda:0"
)
)
if self.rank != 0
else suppress()
)
with context:
module = FSDP(no_params, device_id=0)
@skip_if_lt_x_gpu(2)
def test_fsdp_cpu_init_stays_on_cpu(self):
"""Tests that passing a CPU module to FSDP preserves that the wrapped
module is on CPU after FSDP initialization, albeit after loging a
warning, and that FSDP moves CPU input to GPU before the forward."""
torch.cuda.set_device(self.rank)
regex = "passed-in `module` is on CPU"
context = self.assertWarnsRegex(
expected_warning=UserWarning, expected_regex=regex
)
with context:
nested_wrapped_module = NestedWrappedModule.init(
self.process_group,
FSDPInitMode.RECURSIVE,
CUDAInitMode.CUDA_NEVER,
)
fsdp_model = FSDP(nested_wrapped_module, self.process_group)
devices = {p.device for p in fsdp_model.parameters()}
self.assertEqual(1, len(devices))
self.assertEqual(torch.device("cpu"), devices.pop())
fsdp_model = fsdp_model.cuda()
# Ensure fwd + backward can be performed after moving to CUDA.
# CPU input also tests that input is correctly moved to appropriate
# CUDA device.
inp = fsdp_model.module.get_input(device=torch.device("cpu"))
fsdp_model(*inp).sum().backward()
@skip_if_lt_x_gpu(2)
def test_cpu_init_with_sync_module_states(self):
"""Tests that passing ``sync_module_states=True`` raises an error for
a CPU module since the synchronization requires GPU communication,
while additionally passing ``device_id`` does not raise an error."""
nested_wrapped_module = NestedWrappedModule.init(
self.process_group,
FSDPInitMode.RECURSIVE,
CUDAInitMode.CUDA_NEVER,
)
with self.assertRaisesRegex(
ValueError, "The module has CPU parameters when `sync_module_states=True`"
):
FSDP(nested_wrapped_module, self.process_group, sync_module_states=True)
# Specifying device_id with sync_module_states=True works.
FSDP(
nested_wrapped_module,
self.process_group,
device_id=torch.cuda.current_device(),
sync_module_states=True,
)
@skip_if_lt_x_gpu(2)
def test_fsdp_same_model_across_ranks(self):
"""
FSDP broadcasts model from rank 0 to ensure it starts off with the same
values.
"""
class MyModel(nn.Module):
def __init__(self, rank):
super().__init__()
# Seed via rank to make model different across ranks
torch.manual_seed(rank)
torch.cuda.manual_seed(rank)
self.lin = nn.Linear(10, 10, bias=False)
self.register_buffer("buffer", torch.ones(1) * rank)
m = MyModel(self.rank).cuda()
_assert_module_states(
m, process_group=self.process_group, assert_fn=self.assertNotEqual
)
# Passing sync_module_states into FSDP makes model the same during init.
fsdp = FSDP(m, sync_module_states=True)
with fsdp.summon_full_params(fsdp):
_assert_module_states(
fsdp, process_group=self.process_group, assert_fn=self.assertEqual
)
# sync_module_states also works with CPU module with device_id passed in
m = MyModel(self.rank)
_assert_module_states(
m, process_group=self.process_group, assert_fn=self.assertNotEqual
)
# Passing sync_module_states into FSDP makes model the same during init.
fsdp = FSDP(m, device_id=torch.cuda.current_device(), sync_module_states=True)
with fsdp.summon_full_params(fsdp):
_assert_module_states(
fsdp, process_group=self.process_group, assert_fn=self.assertEqual
)
instantiate_parametrized_tests(TestFSDPMisc)
if __name__ == "__main__":
run_tests()