blob: 43f8de2150f92d3322608e84427c28170db794f9 [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
import torch
from torch.distributed.fsdp._trace_utils import _ExecOrderTracer
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
run_tests,
TestCase,
)
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight1 = torch.nn.Parameter(torch.randn(6, 6))
self.weight2 = torch.nn.Parameter(torch.randn(6, 6))
self.weight_unused = torch.nn.Parameter(torch.randn(2, 2))
self.layer0 = torch.nn.Linear(6, 6)
self.layer1 = torch.nn.Linear(6, 6, bias=False)
self.layer2 = torch.nn.Sequential(
torch.nn.Linear(6, 3, bias=False),
torch.nn.ReLU(),
torch.nn.Linear(3, 6, bias=False),
)
self.relu = torch.nn.ReLU()
def forward(self, x: torch.Tensor, run_all_layers: bool) -> torch.Tensor:
z = self.relu(self.layer0(x))
z = self.relu(self.layer2(z))
z = z @ self.weight1
if run_all_layers:
z = self.relu(self.layer1(z))
z = z @ self.weight2
# Use `layer0` twice to check the handling of multiplicity in the
# saved data structures
z = self.relu(self.layer0(x))
return z
class TestSymbolicTracing(TestCase):
def test_symbolic_tracing_outputs(self):
"""
Tests running ``tracer.trace()`` inside ``patch_tracer()`` by checking
the saved data structures.
"""
model = Model()
tracer = torch.fx.Tracer()
orig_call_module = tracer.call_module
orig_create_proxy = tracer.create_proxy
exec_order_tracer = _ExecOrderTracer()
with exec_order_tracer.patch_tracer(tracer=tracer, root_module=model):
concrete_args = {"run_all_layers": True}
tracer.trace(model, concrete_args)
# Check that the tracer methods are unchanged after exiting the context
self.assertEqual(orig_call_module, tracer.call_module)
self.assertEqual(orig_create_proxy, tracer.create_proxy)
# Check `module_forward_order`
correct_module_forward_order = [
model,
model.layer0,
model.relu,
model.layer2,
model.layer2[0],
model.layer2[1],
model.layer2[2],
model.relu,
model.layer1,
model.relu,
model.layer0,
model.relu,
]
exec_info = exec_order_tracer.exec_info
self.assertEqual(exec_info.module_forward_order, correct_module_forward_order)
# Check `module_to_param_usage_infos`
self.assertEqual(
exec_info.module_to_param_usage_infos[model],
[
(model.layer0, list(model.layer0.named_parameters())),
(model.layer2, list(model.layer2.named_parameters())),
(model, [("weight1", model.weight1)]),
(model.layer1, list(model.layer1.named_parameters())),
(model, [("weight2", model.weight2)]),
(model.layer0, list(model.layer0.named_parameters())),
],
)
self.assertEqual(
exec_info.module_to_param_usage_infos[model.layer0],
[(model.layer0, list(model.layer0.named_parameters()))],
)
self.assertEqual(
exec_info.module_to_param_usage_infos[model.layer1],
[(model.layer1, list(model.layer1.named_parameters()))],
)
self.assertEqual(
exec_info.module_to_param_usage_infos[model.layer2],
[
(model.layer2[0], list(model.layer2[0].named_parameters())),
(model.layer2[2], list(model.layer2[2].named_parameters())),
],
)
self.assertEqual(exec_info.module_to_param_usage_infos[model.relu], [])
# Check `param_forward_order`
correct_param_order = [
model.layer0.weight,
model.layer0.bias,
model.layer2[0].weight,
model.layer2[2].weight,
model.weight1,
model.layer1.weight,
model.weight2,
]
self.assertEqual(exec_info.param_forward_order, correct_param_order)
# Check `visited_params`
self.assertEqual(
len(exec_info.visited_params), len(exec_info.param_forward_order)
)
self.assertEqual(exec_info.visited_params, set(exec_info.param_forward_order))
instantiate_parametrized_tests(TestSymbolicTracing)
if __name__ == "__main__":
run_tests()