| # Owner(s): ["module: dynamo"] |
| import torch |
| import torch._dynamo.test_case |
| import torch._dynamo.testing |
| from torch._dynamo.testing import same |
| |
| |
| try: |
| from . import utils |
| except ImportError: |
| import utils |
| |
| |
| class Pair: # noqa: B903 |
| def __init__(self, x, y): |
| self.x = x |
| self.y = y |
| |
| |
| def Foo(): |
| return Pair(1, 1) |
| |
| |
| g_counter = 1 |
| g_list = [0, 1, 2] |
| g_dict = {"a": 0, "b": 1} |
| g_object = Foo() |
| g_tensor = torch.zeros(10) |
| |
| |
| _name: int = 0 |
| |
| |
| def fresh_name() -> str: |
| """create a new unique name for a variable: v0, v1, v2""" |
| global _name |
| r = f"v{_name}" |
| _name += 1 |
| return r |
| |
| |
| def reset_name(): |
| global _name |
| _name = 0 |
| |
| |
| class TestGlobals(torch._dynamo.test_case.TestCase): |
| def test_store_global_1(self): |
| def fn(x): |
| global g_counter |
| val = x + g_counter |
| g_counter += 1 |
| return val |
| |
| x = torch.randn(10) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| res1 = opt_fn(x) |
| res2 = fn(x) |
| self.assertTrue(same(res2 - res1, torch.ones(10))) |
| |
| def test_store_global_2(self): |
| def fn(x): |
| global g_counter |
| val = x + g_counter |
| g_counter += 1 |
| g_counter += 1 |
| return val |
| |
| x = torch.randn(10) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| res1 = opt_fn(x) |
| """Wrap the second call with torch._dynamo as well""" |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| res2 = opt_fn(x) |
| self.assertTrue(same(res2 - res1, 2 * torch.ones(10))) |
| |
| def test_store_global_new(self): |
| def fn(x): |
| # Test create a new global |
| global g_counter_new |
| g_counter_new = x + 1 |
| return x + g_counter_new |
| |
| x = torch.randn(10) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| res1 = opt_fn(x) |
| self.assertTrue(same(res1, x + x + 1)) |
| |
| def test_store_global_list(self): |
| def fn(x): |
| global g_list |
| val = x + g_list[1] |
| """ |
| Strictly speaking, we are not testing STORE_GLOBAL |
| here, since STORE_SUBSCR is actually used to store. |
| """ |
| g_list[1] += 1 |
| return val |
| |
| x = torch.randn(10) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| res1 = opt_fn(x) |
| res2 = fn(x) |
| self.assertTrue(same(res2 - res1, torch.ones(10))) |
| |
| def test_store_global_list_2(self): |
| def fn(x): |
| global g_list |
| val = x + g_list[1] |
| g_list = [x + 1 for x in g_list] |
| return val |
| |
| x = torch.randn(10) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| res1 = opt_fn(x) |
| res2 = fn(x) |
| self.assertTrue(same(res2 - res1, torch.ones(10))) |
| |
| def test_store_global_dict(self): |
| def fn(x): |
| global g_dict |
| val = x + g_dict["b"] |
| """ |
| Strictly speaking, we are not testing STORE_GLOBAL |
| here, since STORE_SUBSCR is actually used to store. |
| """ |
| g_dict["b"] += 1 |
| return val |
| |
| x = torch.randn(10) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| res1 = opt_fn(x) |
| res2 = fn(x) |
| self.assertTrue(same(res2 - res1, torch.ones(10))) |
| |
| def test_store_global_dict_2(self): |
| def fn(x): |
| global g_dict |
| g_dict = {key: value + 1 for key, value in g_dict.items()} |
| val = x + g_dict["b"] |
| return val |
| |
| x = torch.randn(10) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| res1 = opt_fn(x) |
| res2 = fn(x) |
| self.assertTrue(same(res2 - res1, torch.ones(10))) |
| |
| def test_store_global_object(self): |
| def fn(x): |
| global g_object |
| val = x + g_object.y |
| g_object.y += 1 |
| return val |
| |
| x = torch.randn(10) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| res1 = opt_fn(x) |
| res2 = fn(x) |
| self.assertTrue(same(res2 - res1, torch.ones(10))) |
| |
| def test_store_global_cross_file(self): |
| def fn(x): |
| val = x + utils.g_tensor_export |
| utils.g_tensor_export = utils.g_tensor_export + 1 |
| return val |
| |
| x = torch.randn(10) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| res1 = opt_fn(x) |
| res2 = fn(x) |
| self.assertTrue(same(res2 - res1, torch.ones(10))) |
| |
| def test_store_global_inline_1(self): |
| # Borrowed from test_python_autograd.py |
| class Variable: |
| def __init__(self, value: torch.Tensor, name: str = None): |
| self.value = value |
| self.name = name or fresh_name() |
| |
| def fn(a, b): |
| a = Variable(a) |
| b = Variable(b) |
| return a.value + b.value, a.name + b.name |
| |
| a = torch.randn(10) |
| b = torch.randn(10) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| v0, s0 = opt_fn(a, b) |
| self.assertEqual(s0, "v0v1") |
| reset_name() |
| |
| def test_store_global_inline_2(self): |
| # Borrowed from test_python_autograd.py |
| class Variable: |
| def __init__(self, value: torch.Tensor, name: str = None): |
| self.value = value |
| self.name = name or fresh_name() |
| |
| @staticmethod |
| def constant(value: torch.Tensor, name: str = None): |
| return Variable(value, name) |
| |
| def fn(a, b): |
| a = Variable.constant(a) |
| b = Variable.constant(b) |
| return a.value + b.value, a.name + b.name |
| |
| a = torch.randn(10) |
| b = torch.randn(10) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize(cnts)(fn) |
| v0, s0 = opt_fn(a, b) |
| self.assertEqual(s0, "v0v1") |
| reset_name() |
| |
| def test_store_global_crossfile_inline(self): |
| try: |
| from . import mock_store_global_crossfile_inline |
| except ImportError: |
| import mock_store_global_crossfile_inline |
| |
| @torch.compile() |
| def fn(x): |
| mock_store_global_crossfile_inline.set_flag_true() |
| mock_store_global_crossfile_inline.set_flag_false() |
| return x + 1 |
| |
| @torch.compile() |
| def fn_set_true(x): |
| mock_store_global_crossfile_inline.set_flag_true() |
| return x + 1 |
| |
| fn_set_true(torch.ones(2, 2)) |
| self.assertTrue(mock_store_global_crossfile_inline.global_flag) |
| fn(torch.ones(2, 2)) |
| self.assertFalse(mock_store_global_crossfile_inline.global_flag) |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| run_tests() |