blob: 457b9648d73c95a97df43bcb6e571e809ac7bc49 [file] [log] [blame] [edit]
# Owner(s): ["module: unknown"]
from copy import copy
import torch
from torch import nn
from torch.testing._internal.common_utils import run_tests, TestCase, xfailIfTorchDynamo
from torch.utils.checkpoint import checkpoint
from torch.utils.module_tracker import ModuleTracker
class TestModuleTracker(TestCase):
# "https://github.com/pytorch/pytorch/issues/127112
@xfailIfTorchDynamo
def test_module_hierarchy(self):
seen_fw = []
seen_bw = []
class Foo(nn.Module):
def forward(self, x):
x = x["a"].relu_()
seen_fw.append((copy(tracker.parents), tracker.is_bw))
x.register_hook(
lambda grad: seen_bw.append((copy(tracker.parents), tracker.is_bw))
)
return {"a": torch.mm(x, x)}
class Mod(nn.Module):
def __init__(self) -> None:
super().__init__()
self.a = Foo()
self.b = nn.ModuleDict({"nest": Foo()})
self.c = nn.ModuleList([Foo()])
def forward(self, x):
x = self.c[0](x)
return self.b["nest"](self.a(x))
mod = Mod()
with ModuleTracker() as tracker:
mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[
"a"
].sum().backward()
mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[
"a"
].sum().backward()
self.assertEqual(
seen_fw,
[
({"Global", "Mod", "Mod.c.0"}, False),
({"Global", "Mod", "Mod.a"}, False),
({"Global", "Mod", "Mod.b.nest"}, False),
({"Global", "Mod", "Mod.c.0"}, False),
({"Global", "Mod", "Mod.a"}, False),
({"Global", "Mod", "Mod.b.nest"}, False),
],
)
self.assertEqual(
seen_bw,
[
({"Global", "Mod", "Mod.b.nest"}, True),
({"Global", "Mod", "Mod.a"}, True),
({"Global", "Mod", "Mod.c.0"}, True),
({"Global", "Mod", "Mod.b.nest"}, True),
({"Global", "Mod", "Mod.a"}, True),
({"Global", "Mod", "Mod.c.0"}, True),
],
)
def test_confused_hierarchy(self):
class MyMod(nn.Module):
def __init__(self):
super().__init__()
self.inner = nn.Linear(2, 2)
self.ran = False
def forward(self, inp):
if not self.ran:
self.ran = True
return self(inp)
else:
self.ran = False
return self.inner(inp)
mod = MyMod()
inp = torch.rand(1, 2, requires_grad=True)
# Should not fail
with ModuleTracker() as tracker:
res = mod(inp)
res.sum().backward()
# Should not fail
with ModuleTracker() as tracker:
res = checkpoint(lambda inp: mod(inp), inp)
res.sum().backward()
def test_bw_detection(self):
mod = nn.Linear(2, 2)
with ModuleTracker() as tracker:
mod(torch.rand(2, requires_grad=True)).sum().backward()
self.assertFalse(tracker.is_bw)
self.assertEqual(tracker.parents, {"Global"})
if __name__ == "__main__":
run_tests()