| # Owner(s): ["module: dynamo"] |
| |
| import torch |
| import torch._dynamo |
| import torch._dynamo.test_case |
| import torch.nn as nn |
| from torch._dynamo.source import ( |
| AttrSource, |
| GlobalSource, |
| is_from_local_source, |
| LocalSource, |
| ) |
| |
| |
| class CausalLMOutputWithPast: |
| value = 5 |
| |
| |
| class SourceTests(torch._dynamo.test_case.TestCase): |
| def test_is_local(self): |
| x_src = LocalSource("x") |
| y_src = GlobalSource("y") |
| |
| attr_x_a = AttrSource(x_src, "a") |
| attr_y_b = AttrSource(y_src, "b") |
| |
| self.assertTrue(is_from_local_source(attr_x_a)) |
| self.assertEqual(is_from_local_source(attr_y_b), False) |
| |
| def test_property_closure(self): |
| def external_property(): |
| closed_value = 7 |
| |
| def internal_function(self): |
| return closed_value |
| |
| return internal_function |
| |
| class Elements: |
| myprop = property(external_property()) |
| |
| def func(elements): |
| if not elements.myprop: |
| return torch.tensor([1, 2, 3]) |
| else: |
| return torch.tensor([4, 5, 6]) |
| |
| e = Elements() |
| a = func(e) |
| b = torch.compile(func, backend="eager", fullgraph=True)(e) |
| self.assertEqual(a, b) |
| |
| def test_supported_nodes(self): |
| class Model(nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.x = torch.randn(10, 10) |
| |
| def forward(self): |
| if ( |
| torch.utils._pytree.SUPPORTED_NODES[CausalLMOutputWithPast].type |
| == int |
| ): |
| x = torch.sin(self.x) |
| else: |
| x = torch.cos(self.x) |
| return x |
| |
| torch.utils._pytree.register_pytree_node( |
| CausalLMOutputWithPast, |
| lambda x: ((), None), |
| lambda x, _: CausalLMOutputWithPast(), |
| ) |
| |
| torch.export.export(Model(), ()) |
| |
| |
| if __name__ == "__main__": |
| torch._dynamo.test_case.run_tests() |