blob: a7dd07175996d9c7bbf38a17acb632d3eb0e6a70 [file] [log] [blame] [edit]
# Owner(s): ["module: dynamo"]
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.testing import same
try:
from . import utils
except ImportError:
import utils
class Pair: # noqa: B903
def __init__(self, x, y):
self.x = x
self.y = y
def Foo():
return Pair(1, 1)
g_counter = 1
g_list = [0, 1, 2]
g_dict = {"a": 0, "b": 1}
g_object = Foo()
g_tensor = torch.zeros(10)
_name: int = 0
def fresh_name() -> str:
"""create a new unique name for a variable: v0, v1, v2"""
global _name
r = f"v{_name}"
_name += 1
return r
def reset_name():
global _name
_name = 0
class TestGlobals(torch._dynamo.test_case.TestCase):
def test_store_global_1(self):
def fn(x):
global g_counter
val = x + g_counter
g_counter += 1
return val
x = torch.randn(10)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
res1 = opt_fn(x)
res2 = fn(x)
self.assertTrue(same(res2 - res1, torch.ones(10)))
def test_store_global_2(self):
def fn(x):
global g_counter
val = x + g_counter
g_counter += 1
g_counter += 1
return val
x = torch.randn(10)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
res1 = opt_fn(x)
"""Wrap the second call with torch._dynamo as well"""
opt_fn = torch._dynamo.optimize(cnts)(fn)
res2 = opt_fn(x)
self.assertTrue(same(res2 - res1, 2 * torch.ones(10)))
def test_store_global_new(self):
def fn(x):
# Test create a new global
global g_counter_new
g_counter_new = x + 1
return x + g_counter_new
x = torch.randn(10)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
res1 = opt_fn(x)
self.assertTrue(same(res1, x + x + 1))
def test_store_global_list(self):
def fn(x):
global g_list
val = x + g_list[1]
"""
Strictly speaking, we are not testing STORE_GLOBAL
here, since STORE_SUBSCR is actually used to store.
"""
g_list[1] += 1
return val
x = torch.randn(10)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
res1 = opt_fn(x)
res2 = fn(x)
self.assertTrue(same(res2 - res1, torch.ones(10)))
def test_store_global_list_2(self):
def fn(x):
global g_list
val = x + g_list[1]
g_list = [x + 1 for x in g_list]
return val
x = torch.randn(10)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
res1 = opt_fn(x)
res2 = fn(x)
self.assertTrue(same(res2 - res1, torch.ones(10)))
def test_store_global_dict(self):
def fn(x):
global g_dict
val = x + g_dict["b"]
"""
Strictly speaking, we are not testing STORE_GLOBAL
here, since STORE_SUBSCR is actually used to store.
"""
g_dict["b"] += 1
return val
x = torch.randn(10)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
res1 = opt_fn(x)
res2 = fn(x)
self.assertTrue(same(res2 - res1, torch.ones(10)))
def test_store_global_dict_2(self):
def fn(x):
global g_dict
g_dict = {key: value + 1 for key, value in g_dict.items()}
val = x + g_dict["b"]
return val
x = torch.randn(10)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
res1 = opt_fn(x)
res2 = fn(x)
self.assertTrue(same(res2 - res1, torch.ones(10)))
def test_store_global_object(self):
def fn(x):
global g_object
val = x + g_object.y
g_object.y += 1
return val
x = torch.randn(10)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
res1 = opt_fn(x)
res2 = fn(x)
self.assertTrue(same(res2 - res1, torch.ones(10)))
def test_store_global_cross_file(self):
def fn(x):
val = x + utils.g_tensor_export
utils.g_tensor_export = utils.g_tensor_export + 1
return val
x = torch.randn(10)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
res1 = opt_fn(x)
res2 = fn(x)
self.assertTrue(same(res2 - res1, torch.ones(10)))
def test_store_global_inline_1(self):
# Borrowed from test_python_autograd.py
class Variable:
def __init__(self, value: torch.Tensor, name: str = None):
self.value = value
self.name = name or fresh_name()
def fn(a, b):
a = Variable(a)
b = Variable(b)
return a.value + b.value, a.name + b.name
a = torch.randn(10)
b = torch.randn(10)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
v0, s0 = opt_fn(a, b)
self.assertEqual(s0, "v0v1")
reset_name()
def test_store_global_inline_2(self):
# Borrowed from test_python_autograd.py
class Variable:
def __init__(self, value: torch.Tensor, name: str = None):
self.value = value
self.name = name or fresh_name()
@staticmethod
def constant(value: torch.Tensor, name: str = None):
return Variable(value, name)
def fn(a, b):
a = Variable.constant(a)
b = Variable.constant(b)
return a.value + b.value, a.name + b.name
a = torch.randn(10)
b = torch.randn(10)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
v0, s0 = opt_fn(a, b)
self.assertEqual(s0, "v0v1")
reset_name()
def test_store_global_crossfile_inline(self):
try:
from . import mock_store_global_crossfile_inline
except ImportError:
import mock_store_global_crossfile_inline
@torch.compile()
def fn(x):
mock_store_global_crossfile_inline.set_flag_true()
mock_store_global_crossfile_inline.set_flag_false()
return x + 1
@torch.compile()
def fn_set_true(x):
mock_store_global_crossfile_inline.set_flag_true()
return x + 1
fn_set_true(torch.ones(2, 2))
self.assertTrue(mock_store_global_crossfile_inline.global_flag)
fn(torch.ones(2, 2))
self.assertFalse(mock_store_global_crossfile_inline.global_flag)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()