| # Copyright (c) Meta Platforms, Inc. and affiliates |
| # Owner(s): ["oncall: distributed"] |
| # This file is a model zoo for testing torch.distributed.pipelining. |
| import torch |
| from torch.distributed.pipelining import pipe_split, SplitPoint |
| |
| |
| class ExampleCode(torch.nn.Module): |
| def __init__(self, d_hid): |
| super().__init__() |
| self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) |
| self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) |
| self.register_buffer("cval", torch.randn((d_hid,), requires_grad=False)) |
| self.lin0 = torch.nn.Linear(d_hid, d_hid) |
| self.lin1 = torch.nn.Linear(d_hid, d_hid) |
| |
| def forward(self, x): |
| x = torch.mm(x, self.mm_param0) |
| x = torch.relu(x) |
| # try passing a value that doesn't require_grad across skip boundaries |
| a_constant = self.cval.clone() |
| x = self.lin0(x) |
| pipe_split() |
| x = torch.relu(x) + a_constant |
| x = torch.mm(x, self.mm_param1) |
| x = self.lin1(x) |
| x = torch.relu(x) |
| return x |
| |
| |
| class ModelWithKwargs(torch.nn.Module): |
| DEFAULT_DHID = 512 |
| DEFAULT_BATCH_SIZE = 256 |
| |
| def __init__(self, d_hid: int = DEFAULT_DHID): |
| super().__init__() |
| self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) |
| self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) |
| self.lin0 = torch.nn.Linear(d_hid, d_hid) |
| self.lin1 = torch.nn.Linear(d_hid, d_hid) |
| |
| def forward(self, x, y=torch.zeros(DEFAULT_BATCH_SIZE, DEFAULT_DHID)): |
| x = torch.mm(x, self.mm_param0) |
| x = x + y |
| x = self.lin0(x) |
| x = torch.relu(x) |
| pipe_split() |
| x = torch.mm(x, self.mm_param1) |
| x = self.lin1(x) |
| x = torch.relu(x) |
| return x |
| |
| |
| class ModelWithParamAlias(torch.nn.Module): |
| default_dhid = 512 |
| default_batch_size = 256 |
| |
| def __init__(self, d_hid: int = default_dhid): |
| super().__init__() |
| self.mm_param1 = self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) |
| self.lin1 = self.lin0 = torch.nn.Linear(d_hid, d_hid) |
| |
| def forward(self, x, y): |
| x = torch.mm(x, self.mm_param0) |
| x = x + y |
| x = self.lin0(x) |
| x = torch.relu(x) |
| pipe_split() |
| x = torch.mm(x, self.mm_param1) |
| x = self.lin1(x) |
| x = torch.relu(x) |
| return x |
| |
| |
| # MLP Layer |
| class MLPModule(torch.nn.Module): |
| def __init__(self, d_hid: int): |
| super().__init__() |
| self.net1 = torch.nn.Linear(d_hid, d_hid) |
| self.relu = torch.nn.ReLU() |
| self.net2 = torch.nn.Linear(d_hid, d_hid) |
| |
| def forward(self, x): |
| x = self.net1(x) |
| x = self.relu(x) |
| x = self.net2(x) |
| return x |
| |
| |
| # Multi-MLP model |
| class MultiMLP(torch.nn.Module): |
| def __init__(self, d_hid: int, n_layers: int = 2): |
| super().__init__() |
| self.layers = torch.nn.ModuleList([MLPModule(d_hid) for _ in range(n_layers)]) |
| # For testing purpose only, this should be defined by user |
| self.split_spec = { |
| f"layers.{i}": SplitPoint.BEGINNING for i in range(1, n_layers) |
| } |
| |
| def forward(self, x): |
| for layer in self.layers: |
| x = layer(x) |
| return x |