blob: 43a4a23039175c5452a71cd0075d854aa09757d0 [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import os
import unittest
from unittest.mock import patch
import torch
import torch._dynamo
import torch._dynamo.test_case
import torch.distributed as dist
from torch import nn
from torch._dynamo import config
from torch._dynamo.utils import same
from torch.nn.parallel import DistributedDataParallel as DDP
class ToyModel(nn.Module):
def __init__(self, in_feat=10, hidden_feat=5000, num_hidden=2, out_feat=5):
super().__init__()
self.net = nn.Sequential(
*[nn.Linear(in_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()] * num_hidden
+ [nn.Linear(hidden_feat, 5), nn.ReLU()]
)
def forward(self, inputs):
return self.net(inputs)
class CheckSplitsCompiler:
def __init__(self):
self.compiler_called = 0
def compile_fn(self, gm, example_inputs):
self.compiler_called += 1
return gm
class TestDistributed(torch._dynamo.test_case.TestCase):
"""
Test harness initializes dist process group
"""
@classmethod
def setUpClass(cls):
super().setUpClass()
# _exit_stack is set up in TestCase
cls._exit_stack.enter_context(
patch.dict(
os.environ,
{
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12355",
},
)
)
cls.rank = 0
cls.device = f"cpu:{cls.rank}"
cls.device_ids = None if "cpu" in cls.device else [cls.rank]
dist.init_process_group("gloo", rank=cls.rank, world_size=1)
@classmethod
def tearDownClass(cls):
dist.destroy_process_group()
super().tearDownClass()
def get_model(self):
m = ToyModel().to(self.device)
inputs = torch.randn(20, 10).to(self.device)
outputs = m(inputs)
return m, inputs, outputs
@patch.object(config, "optimize_ddp", False)
def test_ddp_baseline_aot_eager(self):
from torch.nn.parallel import DistributedDataParallel as DDP
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=self.device_ids)
ddp_m = torch._dynamo.optimize("aot_eager")(ddp_m)
outputs = ddp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@patch.object(config, "optimize_ddp", False)
def test_ddp_baseline_inductor(self):
from torch.nn.parallel import DistributedDataParallel as DDP
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=self.device_ids)
ddp_m = torch._dynamo.optimize("inductor")(ddp_m)
outputs = ddp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
# TODO(whc) move these tests to 'distributed' shard to get nccl, or see if it's available already in pytorch CI?
@unittest.skip(
"can't run with gloo (no support for _allgather_base) and nccl not available in CI"
)
@patch.object(config, "optimize_ddp", False)
def test_fsdp_baseline_aot_eager(self):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
m, inputs, correct_outputs = self.get_model()
fsdp_m = FSDP(m, device_id=self.device_ids[0] if self.device_ids else None)
fsdp_m = torch._dynamo.optimize("aot_eager")(fsdp_m)
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@unittest.skip("hangs/crashes with inductor currently")
@patch.object(config, "optimize_ddp", False)
def test_fsdp_baseline_inductor(self):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
m, inputs, correct_outputs = self.get_model()
fsdp_m = FSDP(m, device_id=self.device_ids[0] if self.device_ids else None)
fsdp_m = torch._dynamo.optimize("inductor")(fsdp_m)
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@patch.object(config, "optimize_ddp", True)
def test_graph_split(self):
"""
Just ensures that the appropriate number of splits happen (based on
bucket size and model parameters) - verifies the number of times
the user-provided compiler is called by the DDPOptimizer which is
doing the graph splitting
"""
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
check_splits_compiler = CheckSplitsCompiler()
@torch._dynamo.optimize(check_splits_compiler.compile_fn)
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
self.assertEqual(check_splits_compiler.compiler_called, 3)
# hangs/crashes with inductor currently
@unittest.skip("hangs/crashes with inductor currently")
@patch.object(config, "optimize_ddp", True)
def test_graph_split_inductor(self):
"""
Same as above, but using inductor backend.
We observed issues with inductor/fx interface in the past.
"""
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
@torch._dynamo.optimize("inductor")
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
@patch.object(config, "optimize_ddp", True)
def test_no_split(self):
"""
Ensures the DDPOptimizer returns a correct, compiled module without
introducing graph splits. (Based on model parmeters fitting in the bucket)
"""
# DDP will always do a 'first bucket' with a really small size; so only a tiny model will escape this
m = ToyModel(hidden_feat=5).to(self.device)
inputs = torch.randn(20, 10).to(self.device)
correct_outputs = m(inputs)
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=250)
check_splits_compiler = CheckSplitsCompiler()
@torch._dynamo.optimize(check_splits_compiler.compile_fn)
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
self.assertEqual(check_splits_compiler.compiler_called, 1)
@patch.object(config, "optimize_ddp", True)
def test_aot_autograd(self):
"""
Explicitly check AotAutograd family of compilers work,
since they require example inputs propagated between graph splits.
"""
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
@torch._dynamo.optimize("aot_eager")
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
opt_outputs.sum().backward()
self.assertTrue(same(correct_outputs, opt_outputs))
@patch.object(config, "optimize_ddp", True)
def test_custom_layer(self):
"""
Just ensures that the appropriate number of splits happen (based on
bucket size and model parameters) - verifies the number of times
the user-provided compiler is called by the DDPOptimizer which is
doing the graph splitting
"""
class MyCustomLinear(torch.nn.Module):
def __init__(self):
super(MyCustomLinear, self).__init__()
self.weight = nn.Parameter(torch.randn(512, 512))
def forward(self, x):
return torch.mm(x, self.weight.t())
class MyLinear(torch.nn.Module):
def __init__(self):
super(MyLinear, self).__init__()
self.linear = torch.nn.Linear(512, 512)
def forward(self, x):
return self.linear(x)
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
mods = [
(MyLinear(), torch.nn.ReLU()),
# sandwitch the custom in the middle so it comes before and after
(MyCustomLinear(), torch.nn.ReLU()),
(MyLinear(), torch.nn.ReLU()),
]
self.seq = torch.nn.Sequential(*[x for items in mods for x in items])
def forward(self, x):
return self.seq(x)
m = MyModule().to(self.device)
inputs = torch.randn((512, 512)).to(self.device)
correct_outputs = m(inputs)
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=1)
check_splits_compiler = CheckSplitsCompiler()
@torch._dynamo.optimize(check_splits_compiler.compile_fn)
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
self.assertEqual(check_splits_compiler.compiler_called, 3)
def test_empty_graph(self):
def fn():
get_world_size = torch.distributed.distributed_c10d.get_world_size()
return (get_world_size,)
opt_fn = torch._dynamo.optimize("inductor")(fn)
res = None
try:
res = opt_fn()[0]
except Exception:
pass
self.assertEqual(res, 1)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()