blob: d44239a329344edf923011849f2eda2743c417da [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
import sys
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
CUDAInitMode,
FSDPInitMode,
FSDPTest,
NestedWrappedModule,
TransformerWithSharedParams,
)
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class TestApply(FSDPTest):
@property
def world_size(self):
return 2
@torch.no_grad()
def _init_linear_weights(self, m):
if type(m) == nn.Linear:
m.weight.fill_(1.0)
m.bias.fill_(1.0)
def check_weights(self, fsdp, expected_tensor_fn, check):
with FSDP.summon_full_params(fsdp, recurse=True):
linear_modules = [
module for module in fsdp.modules() if type(module) == nn.Linear
]
for module in linear_modules:
for param in module.parameters():
expected = expected_tensor_fn(param)
check(param, expected, f"Got {param} but expected {expected}")
def _check_apply(self, fsdp):
# Assert linear weights are not all 1.0
self.check_weights(
fsdp, lambda param: torch.empty_like(param).fill_(1.0), self.assertNotEqual
)
fsdp.apply(self._init_linear_weights)
# Ensure all weights are 1.0
self.check_weights(
fsdp, lambda param: torch.empty_like(param).fill_(1.0), self.assertEqual
)
@skip_if_lt_x_gpu(2)
def test_nested_module_apply(self):
"""Tests that ``apply()`` modifies parameter values in-place on a
non-FSDP-root nested FSDP-wrapped model."""
nested_wrapped_module = NestedWrappedModule.init(
self.process_group,
FSDPInitMode.RECURSIVE,
CUDAInitMode.CUDA_AFTER,
)
self._check_apply(nested_wrapped_module)
@skip_if_lt_x_gpu(2)
def test_transformer_module_apply(self):
"""Tests that ``apply()`` modifies parameter values in-place on an
FSDP-wrapped transformer model with shared parameters."""
transformer = TransformerWithSharedParams.init(
self.process_group,
FSDPInitMode.RECURSIVE,
CUDAInitMode.CUDA_AFTER,
)
self._check_apply(transformer)
@skip_if_lt_x_gpu(2)
def test_apply_in_summon_raises_error(self):
"""Tests that calling ``apply()`` on an FSDP instance inside the
``summon_full_params()`` context raises an error."""
transformer = TransformerWithSharedParams.init(
self.process_group,
FSDPInitMode.RECURSIVE,
CUDAInitMode.CUDA_AFTER,
)
with transformer.summon_full_params(transformer):
with self.assertRaisesRegex(ValueError, "expected to be in states"):
transformer.apply(self._init_linear_weights)
if __name__ == "__main__":
run_tests()