| # Owner(s): ["module: dynamo"] |
| import os |
| import unittest |
| from unittest.mock import patch |
| |
| import torch |
| import torch._dynamo |
| import torch._dynamo.test_case |
| import torch.distributed as dist |
| from torch import nn |
| from torch._dynamo import config |
| from torch._dynamo.utils import same |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| |
| |
| class ToyModel(nn.Module): |
| def __init__(self, in_feat=10, hidden_feat=5000, num_hidden=2, out_feat=5): |
| super().__init__() |
| self.net = nn.Sequential( |
| *[nn.Linear(in_feat, hidden_feat), nn.ReLU()] |
| + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()] * num_hidden |
| + [nn.Linear(hidden_feat, 5), nn.ReLU()] |
| ) |
| |
| def forward(self, inputs): |
| return self.net(inputs) |
| |
| |
| class CheckSplitsCompiler: |
| def __init__(self): |
| self.compiler_called = 0 |
| |
| def compile_fn(self, gm, example_inputs): |
| self.compiler_called += 1 |
| return gm |
| |
| |
| class TestDistributed(torch._dynamo.test_case.TestCase): |
| """ |
| Test harness initializes dist process group |
| """ |
| |
| @classmethod |
| def setUpClass(cls): |
| super().setUpClass() |
| # _exit_stack is set up in TestCase |
| cls._exit_stack.enter_context( |
| patch.dict( |
| os.environ, |
| { |
| "MASTER_ADDR": "localhost", |
| "MASTER_PORT": "12355", |
| }, |
| ) |
| ) |
| cls.rank = 0 |
| cls.device = f"cpu:{cls.rank}" |
| cls.device_ids = None if "cpu" in cls.device else [cls.rank] |
| dist.init_process_group("gloo", rank=cls.rank, world_size=1) |
| |
| @classmethod |
| def tearDownClass(cls): |
| dist.destroy_process_group() |
| super().tearDownClass() |
| |
| def get_model(self): |
| m = ToyModel().to(self.device) |
| inputs = torch.randn(20, 10).to(self.device) |
| outputs = m(inputs) |
| return m, inputs, outputs |
| |
| @patch.object(config, "optimize_ddp", False) |
| def test_ddp_baseline_aot_eager(self): |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| |
| m, inputs, correct_outputs = self.get_model() |
| ddp_m = DDP(m, device_ids=self.device_ids) |
| ddp_m = torch._dynamo.optimize("aot_eager")(ddp_m) |
| outputs = ddp_m(inputs) |
| self.assertTrue(same(correct_outputs, outputs)) |
| |
| @patch.object(config, "optimize_ddp", False) |
| def test_ddp_baseline_inductor(self): |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| |
| m, inputs, correct_outputs = self.get_model() |
| ddp_m = DDP(m, device_ids=self.device_ids) |
| ddp_m = torch._dynamo.optimize("inductor")(ddp_m) |
| outputs = ddp_m(inputs) |
| self.assertTrue(same(correct_outputs, outputs)) |
| |
| # TODO(whc) move these tests to 'distributed' shard to get nccl, or see if it's available already in pytorch CI? |
| @unittest.skip( |
| "can't run with gloo (no support for _allgather_base) and nccl not available in CI" |
| ) |
| @patch.object(config, "optimize_ddp", False) |
| def test_fsdp_baseline_aot_eager(self): |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| |
| m, inputs, correct_outputs = self.get_model() |
| fsdp_m = FSDP(m, device_id=self.device_ids[0] if self.device_ids else None) |
| fsdp_m = torch._dynamo.optimize("aot_eager")(fsdp_m) |
| outputs = fsdp_m(inputs) |
| self.assertTrue(same(correct_outputs, outputs)) |
| |
| @unittest.skip("hangs/crashes with inductor currently") |
| @patch.object(config, "optimize_ddp", False) |
| def test_fsdp_baseline_inductor(self): |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| |
| m, inputs, correct_outputs = self.get_model() |
| fsdp_m = FSDP(m, device_id=self.device_ids[0] if self.device_ids else None) |
| fsdp_m = torch._dynamo.optimize("inductor")(fsdp_m) |
| outputs = fsdp_m(inputs) |
| self.assertTrue(same(correct_outputs, outputs)) |
| |
| @patch.object(config, "optimize_ddp", True) |
| def test_graph_split(self): |
| """ |
| Just ensures that the appropriate number of splits happen (based on |
| bucket size and model parameters) - verifies the number of times |
| the user-provided compiler is called by the DDPOptimizer which is |
| doing the graph splitting |
| """ |
| |
| m, inputs, correct_outputs = self.get_model() |
| ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) |
| |
| check_splits_compiler = CheckSplitsCompiler() |
| |
| @torch._dynamo.optimize(check_splits_compiler.compile_fn) |
| def opt_fn(inputs): |
| return ddp_m(inputs) |
| |
| opt_outputs = opt_fn(inputs) |
| self.assertTrue(same(correct_outputs, opt_outputs)) |
| self.assertEqual(check_splits_compiler.compiler_called, 3) |
| |
| # hangs/crashes with inductor currently |
| @unittest.skip("hangs/crashes with inductor currently") |
| @patch.object(config, "optimize_ddp", True) |
| def test_graph_split_inductor(self): |
| """ |
| Same as above, but using inductor backend. |
| We observed issues with inductor/fx interface in the past. |
| """ |
| m, inputs, correct_outputs = self.get_model() |
| ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) |
| |
| @torch._dynamo.optimize("inductor") |
| def opt_fn(inputs): |
| return ddp_m(inputs) |
| |
| opt_outputs = opt_fn(inputs) |
| self.assertTrue(same(correct_outputs, opt_outputs)) |
| |
| @patch.object(config, "optimize_ddp", True) |
| def test_no_split(self): |
| """ |
| Ensures the DDPOptimizer returns a correct, compiled module without |
| introducing graph splits. (Based on model parmeters fitting in the bucket) |
| """ |
| # DDP will always do a 'first bucket' with a really small size; so only a tiny model will escape this |
| m = ToyModel(hidden_feat=5).to(self.device) |
| inputs = torch.randn(20, 10).to(self.device) |
| correct_outputs = m(inputs) |
| ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=250) |
| |
| check_splits_compiler = CheckSplitsCompiler() |
| |
| @torch._dynamo.optimize(check_splits_compiler.compile_fn) |
| def opt_fn(inputs): |
| return ddp_m(inputs) |
| |
| opt_outputs = opt_fn(inputs) |
| self.assertTrue(same(correct_outputs, opt_outputs)) |
| self.assertEqual(check_splits_compiler.compiler_called, 1) |
| |
| @patch.object(config, "optimize_ddp", True) |
| def test_aot_autograd(self): |
| """ |
| Explicitly check AotAutograd family of compilers work, |
| since they require example inputs propagated between graph splits. |
| """ |
| m, inputs, correct_outputs = self.get_model() |
| ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) |
| |
| @torch._dynamo.optimize("aot_eager") |
| def opt_fn(inputs): |
| return ddp_m(inputs) |
| |
| opt_outputs = opt_fn(inputs) |
| opt_outputs.sum().backward() |
| self.assertTrue(same(correct_outputs, opt_outputs)) |
| |
| @patch.object(config, "optimize_ddp", True) |
| def test_custom_layer(self): |
| """ |
| Just ensures that the appropriate number of splits happen (based on |
| bucket size and model parameters) - verifies the number of times |
| the user-provided compiler is called by the DDPOptimizer which is |
| doing the graph splitting |
| """ |
| |
| class MyCustomLinear(torch.nn.Module): |
| def __init__(self): |
| super(MyCustomLinear, self).__init__() |
| self.weight = nn.Parameter(torch.randn(512, 512)) |
| |
| def forward(self, x): |
| return torch.mm(x, self.weight.t()) |
| |
| class MyLinear(torch.nn.Module): |
| def __init__(self): |
| super(MyLinear, self).__init__() |
| self.linear = torch.nn.Linear(512, 512) |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| class MyModule(torch.nn.Module): |
| def __init__(self): |
| super(MyModule, self).__init__() |
| mods = [ |
| (MyLinear(), torch.nn.ReLU()), |
| # sandwitch the custom in the middle so it comes before and after |
| (MyCustomLinear(), torch.nn.ReLU()), |
| (MyLinear(), torch.nn.ReLU()), |
| ] |
| self.seq = torch.nn.Sequential(*[x for items in mods for x in items]) |
| |
| def forward(self, x): |
| return self.seq(x) |
| |
| m = MyModule().to(self.device) |
| inputs = torch.randn((512, 512)).to(self.device) |
| correct_outputs = m(inputs) |
| ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=1) |
| |
| check_splits_compiler = CheckSplitsCompiler() |
| |
| @torch._dynamo.optimize(check_splits_compiler.compile_fn) |
| def opt_fn(inputs): |
| return ddp_m(inputs) |
| |
| opt_outputs = opt_fn(inputs) |
| self.assertTrue(same(correct_outputs, opt_outputs)) |
| self.assertEqual(check_splits_compiler.compiler_called, 3) |
| |
| def test_empty_graph(self): |
| def fn(): |
| get_world_size = torch.distributed.distributed_c10d.get_world_size() |
| return (get_world_size,) |
| |
| opt_fn = torch._dynamo.optimize("inductor")(fn) |
| res = None |
| try: |
| res = opt_fn()[0] |
| except Exception: |
| pass |
| self.assertEqual(res, 1) |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| run_tests() |