blob: 86db213974db942f770dc28dfaf509b320c4635d [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
import sys
import time
from statistics import mean
from unittest.mock import patch
import torch
import torch.nn as nn
from torch import distributed as dist
from torch.cuda import Event
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import (
get_cycles_per_ms,
run_tests,
TEST_WITH_DEV_DBG_ASAN,
)
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class Layer(nn.Module):
def __init__(self, compute_cycles, has_params: bool):
super().__init__()
self.sleep_cycles = compute_cycles
self.optional_param = None
if has_params:
self.optional_param = nn.Parameter(torch.rand(1))
def forward(self, x):
# Get 2 events.
self.e1 = Event(enable_timing=True)
self.e2 = Event(enable_timing=True)
# Record the fake forward compute time.
self.e1.record()
if self.sleep_cycles > 0:
torch.cuda._sleep(self.sleep_cycles)
if self.optional_param is not None:
x = x + self.optional_param # force the param to be part of the graph
self.e2.record()
return x
def get_time(self):
# return the recorded duration.
return self.e1.elapsed_time(self.e2)
def _create_model(compute_cycles, has_params: bool):
model = FSDP(
nn.Sequential(
FSDP(Layer(compute_cycles, has_params)),
FSDP(Layer(compute_cycles, has_params)),
FSDP(Layer(compute_cycles, has_params)),
FSDP(Layer(compute_cycles, has_params)),
)
).cuda()
return model
class Min10:
def __init__(self):
self.data = []
def add(self, new_data):
if len(self.data) < 10:
self.data.append(new_data)
else:
self.data = sorted(self.data)
if new_data < self.data[-1]:
self.data[-1] = new_data
def avg(self):
return mean(self.data)
class TestForwardOverlapWorldSizeOne(FSDPTest):
@property
def world_size(self):
return 1
def _dist_train(self):
rank = self.rank
world_size = self.world_size
# Save the original torch.distributed.all_gather_into_tensor function since we will
# patch it to include an artificial delay.
orig_all_gather = torch.distributed.all_gather_into_tensor
def run(compute_cycles, all_gather_cycles):
has_params = all_gather_cycles > 0
model = _create_model(compute_cycles, has_params)
# Get the input and sets the input's requires_grad to True because
# we have a fake compute in the forward pass.
batch = torch.rand(1).cuda()
batch.requires_grad = True
# Run one dummy iteration to trigger the execution order validation
# all-gathers
out = model(batch)
out.backward()
model.zero_grad(set_to_none=True)
# We run 20 iterations but only collect timing data from the minimal 10
# data points because nondeterministic system events can disturb the timing.
cpu_iter = Min10()
cpu_wait = Min10()
gpu_compute = Min10()
gpu_total = Min10()
for _ in range(20):
# Get two events for measuring the overall time.
e1 = Event(enable_timing=True)
e2 = Event(enable_timing=True)
cpu_start = time.process_time()
all_gather_called = False
def _delayed_all_gather(*args, **kwargs):
nonlocal all_gather_called
all_gather_called = True
torch.cuda._sleep(all_gather_cycles)
assert orig_all_gather
return orig_all_gather(*args, **kwargs)
# forward pass
#
# Even though both e1 & e2 are on the compute stream, since
# compute depends on all_gather, e2-e1 includes all_gather time.
e1.record()
with patch(
"torch.distributed.all_gather_into_tensor", _delayed_all_gather
):
out = model(batch)
if has_params and world_size > 1:
self.assertTrue(all_gather_called)
else:
self.assertFalse(all_gather_called)
e2.record()
# backward pass
out.backward()
model.zero_grad(set_to_none=True)
cpu_iter_time = time.process_time() - cpu_start
# wait for gpu
out.item()
cpu_wait_for_gpu_time = time.process_time() - cpu_start - cpu_iter_time
# get sum of the compute time
times = []
for mod in model.modules():
if not isinstance(mod, Layer):
continue
times.append(mod.get_time())
# get gpu compute + all_gather time
overall_gpu_time = e1.elapsed_time(e2)
cpu_iter.add(cpu_iter_time)
cpu_wait.add(cpu_wait_for_gpu_time)
gpu_compute.add(sum(times))
gpu_total.add(overall_gpu_time)
del model
return {
"cpu_iter": cpu_iter.avg(),
"cpu_wait": cpu_wait.avg(),
"gpu_compute": gpu_compute.avg(),
"gpu_total": gpu_total.avg(),
}
sleep_cycles = int(100 * get_cycles_per_ms())
e1 = run(0, 0) # no compute, no all-gather
e2 = run(0, sleep_cycles) # no compute, only all-gather
e3 = run(sleep_cycles, 0) # only compute, no all-gather
e4 = run(sleep_cycles, sleep_cycles) # both compute and all-gather
debug_string = f"\nrank{rank}:\n e1: {e1}\n e2: {e2}\n e3: {e3}\n e4: {e4}"
print(debug_string)
# Check the cpu/gpu timing. CPU should run ahead of GPU. Therefore, cpu-gpu
# wait should be long, except when there is no real work on GPU.
#
# If the assertions fail below, we likely have a cpu-gpu wait in the forward/backward pass.
# e4["cpu_iter"] may not be short as cpu may take some time to queue both compute and all-gather.
short = [
e1["cpu_iter"],
e2["cpu_iter"],
e3["cpu_iter"],
e1["cpu_wait"],
]
long = [e3["cpu_wait"], e4["cpu_wait"]]
if world_size == 1:
short.append(e2["cpu_wait"]) # all gather should not be happening.
else:
long.append(
e2["cpu_wait"]
) # all gather should happen and prolong the cpu-gpu wait.
for s in short:
for l in long:
# 10X longer is a safe margin, since the GPU work timing is around 100X more
# of that of the CPU.
self.assertTrue(s * 10 < l)
# Check the GPU timing.
short = [e1["gpu_compute"], e1["gpu_total"], e2["gpu_compute"]]
long = [e3["gpu_compute"], e3["gpu_total"], e4["gpu_compute"], e4["gpu_total"]]
if world_size == 1:
short.append(e2["gpu_total"]) # all gather should not be happening.
else:
long.append(
e2["gpu_total"]
) # all gather should happen and prolong the cpu-gpu wait.
for s in short:
for l in long:
# 10X longer is a safe margin, since the time is around 100X longer
# when there is work on GPU vs. no work.
self.assertTrue(s * 10 < l)
# Check the GPU overlapping when there is all-gather.
if world_size > 1:
compute_only = e3["gpu_compute"]
all_gather_only = e2["gpu_total"]
both = e4["gpu_total"]
self.assertTrue(compute_only + all_gather_only > 1.1 * both)
@skip_if_lt_x_gpu(2)
def test_forward_overlap(self):
self._dist_train()
class TestForwardOverlapWorldSizeTwo(TestForwardOverlapWorldSizeOne):
@property
def world_size(self):
return 2
if __name__ == "__main__":
run_tests()