blob: caa7217ed81c3b06f592a4cde7aac16f1b5cd66f [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import functools
import weakref
import torch
import torch._dynamo
import torch._dynamo.test_case
from torch._C._dynamo import guards
from torch._dynamo.convert_frame import GlobalStateGuard
from torch.testing._internal.common_utils import set_default_dtype
RootGuardManager = guards.RootGuardManager
DictGuardManager = guards.DictGuardManager
DictSubclassGuardManager = guards.DictSubclassGuardManager
GetAttrGuardAccessor = guards.GetAttrGuardAccessor
GetItemGuardAccessor = guards.GetItemGuardAccessor
TypeGuardAccessor = guards.TypeGuardAccessor
TENSOR_ALIASING = guards.TENSOR_ALIASING
install_tensor_aliasing_guard = guards.install_tensor_aliasing_guard
NO_TENSOR_ALIASING = guards.NO_TENSOR_ALIASING
install_no_tensor_aliasing_guard = guards.install_no_tensor_aliasing_guard
x = torch.tensor(4)
weakref_x = weakref.ref(x)
default_mgr_enum = torch._dynamo.guards.GuardManagerType.GUARD_MANAGER
class Pair:
def __init__(self, x, y):
self.x = x
self.y = y
global_pair = Pair(torch.randn(4), 1)
def id_type(x):
return id(type(x))
def equals_match(x, expected):
return x == expected
def equals_match_verbose_code_parts(expected):
return [f"x == {expected}"]
def ge_match(x, expected):
return x >= expected
def ge_match_verbose_code_parts(expected):
return f"expected >= {expected}"
def less_match(x, expected):
return x < expected
def less_match_verbose_code_parts(expected):
return [f"expected < {expected}"]
class GuardManagerTests(torch._dynamo.test_case.TestCase):
def test_global_state_guard(self):
guard = guards.GLOBAL_STATE(["global_state_check"])
self.assertTrue(guard(None))
with set_default_dtype(torch.double):
self.assertFalse(guard(None))
self.assertExpectedInline(
str(guard.check_verbose(None)),
"""\
GuardDebugInfo(
result=0,
verbose_code_parts=['GLOBAL_STATE changed: default_dtype '],
num_guards_executed=0)
""",
)
self.assertTrue(guard(None))
self.assertTrue(guard.check_verbose(None).result)
_orig = torch.are_deterministic_algorithms_enabled()
try:
torch.use_deterministic_algorithms(not _orig)
self.assertFalse(guard(None))
self.assertExpectedInline(
str(guard.check_verbose(None)),
"""\
GuardDebugInfo(
result=0,
verbose_code_parts=['GLOBAL_STATE changed: deterministic_algorithms '],
num_guards_executed=0)
""",
)
finally:
torch.use_deterministic_algorithms(_orig)
self.assertTrue(guard(None))
self.assertTrue(guard.check_verbose(None).result)
def test_global_state_reason(self):
with torch.enable_grad():
guards = GlobalStateGuard()
with torch.no_grad():
self.assertIs(guards.check(), False)
self.assertEqual(guards.reason(), "grad_mode ")
def test_python_lambda_leaf_guard(self):
const_guard = guards.LAMBDA_GUARD(
functools.partial(equals_match, expected=5),
equals_match_verbose_code_parts(5),
)
self.assertTrue(const_guard(5))
self.assertFalse(const_guard(4))
self.assertFalse(const_guard("foo"))
def test_type_guard(self):
foo = 4
guard = guards.TYPE_MATCH(id_type(foo), ["type(x) == int"])
self.assertTrue(guard(5))
self.assertTrue(guard(4))
self.assertFalse(guard("foo"))
foo = {"a": 1}
guard = guards.TYPE_MATCH(id_type(foo), ["type(x) == dict"])
self.assertTrue(guard(foo))
self.assertTrue(guard({}))
self.assertFalse(guard(5))
self.assertFalse(guard("foo"))
class Foo:
def __init__(self, x, y):
self.x = x
self.y = y
foo = Foo(1, 2)
guard = guards.TYPE_MATCH(id_type(foo), ["type(x) == Foo"])
self.assertTrue(guard(foo))
self.assertFalse(guard({}))
self.assertFalse(guard(5))
self.assertFalse(guard("foo"))
def test_id_guard(self):
foo = 4
guard = guards.ID_MATCH(id(foo), ["id(x) == id(foo)"])
self.assertTrue(guard(foo))
self.assertFalse(guard(5))
self.assertFalse(guard("foo"))
foo = {"a": 1}
guard = guards.ID_MATCH(id(foo), ["id(x) == id(foo)"])
self.assertTrue(guard(foo))
self.assertFalse(guard({"a": 1}))
self.assertFalse(guard({}))
self.assertFalse(guard(5))
def test_equals_guard(self):
foo = 4
guard = guards.EQUALS_MATCH(foo, ["x == 4"])
self.assertTrue(guard(4))
self.assertFalse(guard(5))
self.assertFalse(guard("foo"))
# tuple
foo = (1, 2, 3)
guard = guards.EQUALS_MATCH(foo, ["x == foo"])
self.assertTrue(guard(foo))
self.assertTrue(guard((1, 2, 3)))
self.assertFalse(guard((1, 2, 3, 4)))
self.assertFalse(guard({}))
# list
foo = [1, 2, 3]
guard = guards.EQUALS_MATCH(foo, ["x == foo"])
self.assertTrue(guard(foo))
self.assertTrue(guard([1, 2, 3]))
self.assertFalse(guard([1, 2, 3, 4]))
# type
foo = int
guard = guards.EQUALS_MATCH(foo, ["x == foo"])
self.assertTrue(guard(foo))
self.assertTrue(guard(int))
self.assertFalse(guard(float))
def test_default_device_guard(self):
foo = 1
guard = guards.DEFAULT_DEVICE(["cpu device"])
self.assertTrue(guard(foo))
try:
torch.set_default_device("cuda")
self.assertFalse(guard(foo))
finally:
torch.set_default_device(None)
def test_data_ptr_match_guard(self):
foo = torch.tensor([1, 2, 3])
guard = guards.DATA_PTR_MATCH(foo, ["x.data_ptr() == foo.data_ptr()"])
self.assertTrue(guard(foo))
self.assertFalse(guard(torch.tensor([1, 2, 3])))
def test_length_check_guard(self):
foo = [1, 2, 3]
guard = guards.LENGTH_CHECK(len(foo), ["len(x) == len(foo)"])
self.assertTrue(guard(foo))
self.assertFalse(guard([]))
def test_no_hasattr_guard(self):
class Bar:
def __init__(self):
self.bar = 2
bar = Bar()
class Foo:
def __init__(self):
self.foo = 2
foo = Foo()
guard = guards.NO_HASATTR("foo", ["hasattr(x, 'foo') == False"])
self.assertTrue(guard(bar))
self.assertFalse(guard(foo))
def test_tensor_aliasing_guard(self):
guard_manager = RootGuardManager()
a = torch.randn(3, 4)
class Foo:
def __init__(self, x, y):
self.x = x
self.y = y
f_locals = Foo(a, a)
x_guard_mgr = guard_manager.getattr_manager("x", "", a, default_mgr_enum)
y_guard_mgr = guard_manager.getattr_manager("y", "", a, default_mgr_enum)
install_tensor_aliasing_guard(x_guard_mgr, y_guard_mgr, ["x is y"])
# Check structure
x_guards = x_guard_mgr.get_leaf_guards()
y_guards = y_guard_mgr.get_leaf_guards()
self.assertEqual(len(x_guards), 1)
self.assertEqual(len(y_guards), 1)
self.assertTrue(isinstance(x_guards[0], TENSOR_ALIASING))
self.assertTrue(isinstance(y_guards[0], TENSOR_ALIASING))
# Check that the two guards are the same object
self.assertTrue(x_guards[0] is y_guards[0])
f_locals_unaliased = Foo(torch.randn(3, 4), torch.randn(3, 4))
self.assertEqual(len(x_guard_mgr.get_leaf_guards()), 1)
self.assertEqual(len(y_guard_mgr.get_leaf_guards()), 1)
self.assertTrue(guard_manager.check(f_locals))
self.assertFalse(guard_manager.check(f_locals_unaliased))
def test_dict_version_guard(self):
foo = {"a": 1, "b": 2}
guard = guards.DICT_VERSION(foo, ["x.version == foo.version"])
self.assertTrue(guard(foo))
self.assertFalse(guard(dict(foo)))
foo["a"] = 2
self.assertFalse(guard(foo))
self.assertFalse(guard({"a": 1, "b": 2}))
self.assertFalse(guard({}))
def test_dynamic_indices_guard(self):
guard1 = guards.DYNAMIC_INDICES(set(), ["x.size(0) == y.size(0)"])
guard2 = guards.DYNAMIC_INDICES(set({0, 1}), ["x.size(0) == y.size(0)"])
x = torch.randn(4)
self.assertTrue(guard1(x))
self.assertTrue(guard2(x))
x._dynamo_dynamic_indices = set({0})
self.assertFalse(guard1(x))
self.assertTrue(guard2(x))
x._dynamo_dynamic_indices = set({2})
self.assertFalse(guard1(x))
self.assertFalse(guard2(x))
def test_tensor_match_guard(self):
guard_manager = RootGuardManager()
x = torch.randn(4, 4)
size = list(x.size())
stride = list(x.stride())
guard_manager.add_tensor_match_guard(x, size, stride, "x", ["check_tensor(x)"])
self.assertTrue(guard_manager.check(x))
self.assertTrue(guard_manager.check_verbose(x).result)
self.assertTrue(guard_manager.check(torch.randn(4, 4)))
self.assertTrue(guard_manager.check_verbose(torch.randn(4, 4)).result)
self.assertFalse(guard_manager.check(x.t_()))
x = torch.randn(4, 4)
x.t_()
debug_info = guard_manager.check_verbose(x)
print(debug_info.verbose_code_parts[0])
self.assertTrue(
"tensor 'x' stride mismatch" in debug_info.verbose_code_parts[0]
)
def test_no_tensor_aliasing_guard(self):
guard_manager = RootGuardManager()
a = torch.randn(3, 4)
class Foo:
def __init__(self, x, y, z):
self.x = x
self.y = y
self.z = z
f_locals = Foo(a, a, a)
x_guard_mgr = guard_manager.getattr_manager("x", "", a, default_mgr_enum)
y_guard_mgr = guard_manager.getattr_manager("y", "", a, default_mgr_enum)
z_guard_mgr = guard_manager.getattr_manager("z", "", a, default_mgr_enum)
install_no_tensor_aliasing_guard(
[x_guard_mgr, y_guard_mgr, z_guard_mgr],
["x", "y", "z"],
["no_aliasing(x, y, z)"],
)
# Check structure
x_guards = x_guard_mgr.get_leaf_guards()
y_guards = y_guard_mgr.get_leaf_guards()
z_guards = z_guard_mgr.get_leaf_guards()
self.assertEqual(len(x_guards), 1)
self.assertEqual(len(y_guards), 1)
self.assertEqual(len(z_guards), 1)
self.assertTrue(isinstance(x_guards[0], NO_TENSOR_ALIASING))
self.assertTrue(isinstance(y_guards[0], NO_TENSOR_ALIASING))
self.assertTrue(isinstance(z_guards[0], NO_TENSOR_ALIASING))
# Check that the two guards are the same object
self.assertTrue(x_guards[0] is y_guards[0] is z_guards[0])
self.assertFalse(guard_manager.check(f_locals))
self.assertFalse(guard_manager.check_verbose(f_locals).result)
f_locals_unaliased = Foo(
torch.randn(3, 4),
torch.randn(3, 4),
torch.randn(3, 4),
)
self.assertTrue(guard_manager.check(f_locals_unaliased))
self.assertTrue(guard_manager.check_verbose(f_locals_unaliased).result)
# Check that hash map is cleared.
self.assertTrue(guard_manager.check(f_locals_unaliased))
f_locals_unaliased = Foo(
a,
torch.randn(3, 4),
a,
)
self.assertFalse(guard_manager.check(f_locals_unaliased))
self.assertFalse(guard_manager.check_verbose(f_locals_unaliased).result)
def test_weakref_alive_guard(self):
x = torch.rand(3, 4)
weakref_x = weakref.ref(x)
guard = guards.NOT_NONE(["weakref_x is not None"])
self.assertTrue(guard(weakref_x()))
del x
self.assertFalse(guard(weakref_x()))
def test_guard_manager_leaf_guard(self):
guard_manager = RootGuardManager()
guard_manager.add_type_match_guard(id_type(5), ["type(x) == int"])
guard_manager.add_lambda_guard(
functools.partial(ge_match, expected=5),
ge_match_verbose_code_parts(expected=5),
)
guard_manager.add_lambda_guard(
functools.partial(less_match, expected=10),
less_match_verbose_code_parts(expected=10),
)
self.assertEqual(len(guard_manager.get_leaf_guards()), 3)
self.assertEqual(len(guard_manager.get_accessors()), 0)
self.assertTrue(guard_manager.check(6))
self.assertFalse(guard_manager.check(4))
self.assertFalse(guard_manager.check("foo"))
def test_attr_guard_manager(self):
class Foo:
def __init__(self, x, y):
self.x = x
self.y = y
foo = Foo(1, 2)
guard_manager = RootGuardManager()
guard_manager.add_type_match_guard(id_type(foo), ["type(x) == Foo"])
guard_manager.getattr_manager("x", "x", 1, default_mgr_enum).add_lambda_guard(
functools.partial(equals_match, expected=foo.x),
equals_match_verbose_code_parts(foo.x),
)
guard_manager.getattr_manager("y", "y", 2, default_mgr_enum).add_lambda_guard(
functools.partial(equals_match, expected=foo.y),
equals_match_verbose_code_parts(foo.y),
)
self.assertEqual(len(guard_manager.get_leaf_guards()), 1)
# 2 child managers, one for x and one for y
self.assertEqual(len(guard_manager.get_accessors()), 2)
self.assertTrue(
isinstance(guard_manager.get_accessors()[0], GetAttrGuardAccessor)
)
self.assertTrue(
isinstance(guard_manager.get_accessors()[1], GetAttrGuardAccessor)
)
# Check leaf guards on child managers
self.assertEqual(
len(
guard_manager.getattr_manager(
attr="x",
source="x",
example_value=None,
guard_manager_enum=default_mgr_enum,
).get_leaf_guards()
),
1,
)
self.assertEqual(
len(
guard_manager.getattr_manager(
"y", "y", None, default_mgr_enum
).get_leaf_guards()
),
1,
)
self.assertTrue(guard_manager.check(foo))
self.assertFalse(guard_manager.check(Foo(3, 4)))
self.assertFalse(guard_manager.check("foo"))
def test_item_guard_manager(self):
foo = [1, 2]
guard_manager = RootGuardManager()
guard_manager.add_type_match_guard(id_type(foo), ["type(x) == Foo"])
guard_manager.getitem_manager(0, "", 1, default_mgr_enum).add_lambda_guard(
functools.partial(equals_match, expected=foo[0]),
equals_match_verbose_code_parts(foo[0]),
)
guard_manager.getitem_manager(1, "", 2, default_mgr_enum).add_lambda_guard(
functools.partial(equals_match, expected=foo[1]),
equals_match_verbose_code_parts(foo[1]),
)
self.assertEqual(len(guard_manager.get_leaf_guards()), 1)
# 2 child managers, one for x and one for y
self.assertEqual(len(guard_manager.get_accessors()), 2)
self.assertTrue(
isinstance(guard_manager.get_accessors()[0], GetItemGuardAccessor)
)
self.assertTrue(
isinstance(guard_manager.get_accessors()[1], GetItemGuardAccessor)
)
# Check leaf guards on child managers
self.assertEqual(
len(
guard_manager.getitem_manager(
0, "", None, default_mgr_enum
).get_leaf_guards()
),
1,
)
self.assertEqual(
len(
guard_manager.getitem_manager(
1, "", None, default_mgr_enum
).get_leaf_guards()
),
1,
)
self.assertTrue(guard_manager.check(foo))
self.assertFalse(guard_manager.check([3, 4]))
self.assertFalse(guard_manager.check("foo"))
def test_dict_getitem_accessor(self):
foo = {
"a": 1,
"b": 2,
}
guards_manager = RootGuardManager()
guards_manager.add_type_match_guard(id_type(foo), ["type(x) == Foo"])
guards_manager.dict_getitem_manager(
"a", "", 1, default_mgr_enum
).add_equals_match_guard(1, ["a == 1"])
guards_manager.dict_getitem_manager(
"b", "", 2, default_mgr_enum
).add_equals_match_guard(2, ["b == 2"])
self.assertTrue(guards_manager.check(foo))
self.assertFalse(guards_manager.check({"a": 1, "b": 3}))
def test_globals(self):
global global_pair, Pair
guard_manager = RootGuardManager()
gpair_mgr = guard_manager.globals_dict_manager(
globals(), "", None, default_mgr_enum
).getitem_manager("global_pair", "", global_pair, default_mgr_enum)
gpair_mgr.add_lambda_guard(
lambda x: isinstance(x, Pair)
and isinstance(x.x, torch.Tensor)
and isinstance(x.y, int),
"global guard fail",
)
self.assertTrue(guard_manager.check(global_pair))
global_pair.y = "foo"
self.assertFalse(guard_manager.check(global_pair))
def test_type_manager(self):
guard_manager = RootGuardManager()
class A:
a = 4
class B(A):
def mul(self, x):
super().mul(x)
foo = B()
f_locals = {"foo": foo}
# len(type(foo).__mro__) == 2
foo_mgr = guard_manager.getitem_manager("foo", "", foo, default_mgr_enum)
type_manager = foo_mgr.type_manager("", type(foo), default_mgr_enum)
self.assertTrue(isinstance(foo_mgr.get_accessors()[0], TypeGuardAccessor))
mro_manager = type_manager.getattr_manager(
"__mro__", "", type(foo).__mro__, default_mgr_enum
)
self.assertTrue(
isinstance(type_manager.get_accessors()[0], GetAttrGuardAccessor)
)
mro_manager.add_length_check_guard(
3,
"Expected len(type(foo).__mro__) == 3",
)
# type(foo).__mro__[0].a = 4
item_manager = mro_manager.getitem_manager(
1, "", type(foo).__mro__[1], default_mgr_enum
)
self.assertTrue(
isinstance(mro_manager.get_accessors()[0], GetItemGuardAccessor)
)
attr_manager = item_manager.getattr_manager(
"a", "", type(foo).__mro__[0].a, default_mgr_enum
)
self.assertTrue(
isinstance(item_manager.get_accessors()[0], GetAttrGuardAccessor)
)
attr_manager.add_lambda_guard(
lambda x: x == 4,
"Expected value 4",
)
self.assertTrue(guard_manager.check(f_locals))
def test_tuple_iterator_getitem(self):
a = (1, 2, 3, 4, 5, 6)
foo = iter(a)
next(foo) # foo points at index=1
guard_manager = RootGuardManager()
# Check a[3] which is tuple_iterator_getitem(foo, 2)
guard_manager.add_tuple_iterator_length_guard(
5, id_type(iter(())), ["len == 5"]
)
guard_manager.tuple_iterator_getitem_manager(
2, "", foo, default_mgr_enum
).add_equals_match_guard(a[3], ["x==4"])
# Check that type match works
self.assertFalse(guard_manager.check(False))
self.assertTrue(guard_manager.check(foo))
# Check that index error fails gracefully
b = (1, 2)
b_foo = iter(b)
self.assertFalse(guard_manager.check(b_foo))
def test_global_weakref(self):
guard_manager = RootGuardManager()
globals_manager = guard_manager.globals_dict_manager(
globals(), "", None, default_mgr_enum
)
weakref_manager = globals_manager.global_weakref_manager(
"weakref_x", "", None, default_mgr_enum
)
weakref_manager.add_lambda_guard(
lambda x: isinstance(x, torch.Tensor),
"global weakref fail",
)
self.assertTrue(guard_manager.check(None))
global x
del x
self.assertFalse(guard_manager.check(None))
def test_lambda_manager(self):
a = (1, 1, 3, 4, 5, 6)
guard_manager = RootGuardManager()
# Check that we can use the same accessor
foo_mgr = guard_manager.lambda_manager(
lambda x: x[2], "", None, default_mgr_enum
)
foo_mgr.add_lambda_guard(
lambda x: x == 3,
"Expected value 3",
)
self.assertTrue(guard_manager.check(a))
# test that exception works
guard_manager = RootGuardManager()
def fn(x):
raise AssertionError("Test")
return x
foo_mgr = guard_manager.lambda_manager(fn, "", None, default_mgr_enum)
self.assertFalse(guard_manager.check(None))
debug_info = guard_manager.check_verbose(None)
self.assertFalse(debug_info.result)
self.assertTrue("Test" in debug_info.verbose_code_parts[0])
def test_dict_contains_guard(self):
foo = {"a": 1, "b": 2}
guard = guards.DICT_CONTAINS(True, "a", ["has a"])
self.assertTrue(guard(foo))
self.assertTrue(guard({"a": 1, "b": 2}))
self.assertFalse(guard({"b": 2, "c": 3}))
self.assertFalse(guard({}))
guard = guards.DICT_CONTAINS(False, "c", ["not has c"])
self.assertTrue(guard(foo))
self.assertTrue(guard({"a": 1, "b": 2}))
self.assertFalse(guard({"b": 2, "c": 3}))
self.assertTrue(guard({}))
def test_dict_guard_manager(self):
root = RootGuardManager()
def nothing():
pass
f_locals = {
"d": {"a": 1, nothing: {"z": 3}, 100: torch.randn(4)},
}
# its a getitem_manager just for f_locals. But the child guard manager
# should be a DictGuardManager.
dict_mgr = root.getitem_manager(
"d",
"",
f_locals["d"],
torch._dynamo.guards.GuardManagerType.DICT_GUARD_MANAGER,
)
self.assertTrue(isinstance(dict_mgr, DictGuardManager))
self.assertTrue(root.check(f_locals))
# Check that no one can add a leaf guard
with self.assertRaises(RuntimeError):
dict_mgr.add_id_match_guard(id_type(f_locals), "id match")
# Check that no one can add an arbitrary accessor
with self.assertRaises(RuntimeError):
dict_mgr.getitem_manager("a", "", f_locals["d"]["a"])
# Check that it fails with different length dict
f_locals_prime = {
"d": {"a": 1, "b": 2},
}
self.assertFalse(root.check(f_locals_prime))
# Add key-value manager ("a" : 1)
self.assertTrue(root.check(f_locals))
dict_mgr.get_key_manager(0, "", "a", default_mgr_enum).add_equals_match_guard(
"a",
["dict.keys()[0] == a"],
)
self.assertTrue(root.check(f_locals))
dict_mgr.get_value_manager(0, "", 1, default_mgr_enum).add_equals_match_guard(
1, ["d[0] == 1"]
)
self.assertTrue(root.check(f_locals))
# Add key-value manager (nothing : {"z" : 3})
self.assertTrue(root.check(f_locals))
dict_mgr.get_key_manager(1, "", nothing, default_mgr_enum).add_lambda_guard(
lambda x: x is nothing, ["x is nothing"]
)
self.assertTrue(root.check(f_locals))
value_mgr = dict_mgr.get_value_manager(
1,
"",
f_locals["d"][nothing],
torch._dynamo.guards.GuardManagerType.DICT_GUARD_MANAGER,
)
self.assertTrue(isinstance(value_mgr, DictGuardManager))
self.assertTrue(root.check(f_locals))
# Check structure
# Check that we are only guarding on two keys. This is common in
# LazyVariableTracker.
self.assertEqual(len(dict_mgr.get_key_value_managers()), 2)
f_locals["d"]["a"] = 2
self.assertFalse(root.check(f_locals))
self.assertFalse(root.check_verbose(f_locals).result)
f_locals["d"]["a"] = 1
self.assertTrue(root.check(f_locals))
f_locals["d"].pop(100)
# fails because of len check
self.assertFalse(root.check(f_locals))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()