blob: adaa5bf43547d4038687757ba4ac0644b2e6cf0e [file] [log] [blame]
# Owner(s): ["oncall: export"]
import torch
import torch.utils._pytree as pytree
from torch._dynamo.testing import EagerAndRecordGraphs
from torch._functorch.aot_autograd import aot_export_module
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
from torch._higher_order_ops.wrap import wrap
from torch._library.fake_class_registry import FakeScriptObject
from torch.export import export
from torch.export._trace import _export
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
skipIfTorchDynamo,
TestCase,
)
from torch.testing._internal.torchbind_impls import (
_empty_tensor_queue,
init_torchbind_implementations,
)
def _assertEqualSkipScriptObject(test_case, exp, actual):
flat_exp = pytree.tree_leaves(exp)
flat_actual = pytree.tree_leaves(actual)
test_case.assertEqual(len(flat_exp), len(flat_actual))
for a, b in zip(flat_exp, flat_actual):
if isinstance(a, torch.ScriptObject) and isinstance(b, torch.ScriptObject):
continue
test_case.assertEqual(a, b)
def _check_script_obj_equal(test_case, a: torch.ScriptObject, b: torch.ScriptObject):
return test_case.assertEqual(
a._type().qualified_name(), b._type().qualified_name()
) and test_case.assertEqual(a.__obj_flatten__(), b.__obj_flatten__())
def _assertEqualScriptObject(
test_case, exp, actual, check_obj_eq=_check_script_obj_equal
):
flat_exp = pytree.tree_leaves(exp)
flat_actual = pytree.tree_leaves(actual)
test_case.assertEqual(len(flat_exp), len(flat_actual))
for a, b in zip(flat_exp, flat_actual):
if isinstance(a, torch.ScriptObject) and isinstance(b, torch.ScriptObject):
check_obj_eq(test_case, a, b)
else:
test_case.assertEqual(a, b)
@skipIfTorchDynamo("torchbind not supported with dynamo yet")
class TestExportTorchbind(TestCase):
def setUp(self):
init_torchbind_implementations()
test = self
test.tq_push_counter = 0
test.tq_pop_counter = 0
test.tq_size_counter = 0
test.foo_add_tensor_counter = 0
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
class FakeFoo:
def __init__(self, x: int, y: int):
self.x = x
self.y = y
@classmethod
def __obj_unflatten__(cls, flattend_foo):
return cls(**dict(flattend_foo))
def add_tensor(self, z):
test.foo_add_tensor_counter += 1
return (self.x + self.y) * z
@torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
class FakeTensorQueue:
def __init__(self, queue):
self.queue = queue
@classmethod
def __obj_unflatten__(cls, flattened_ctx):
return cls(**dict(flattened_ctx))
def push(self, x):
test.tq_push_counter += 1
self.queue.append(x)
def pop(self):
test.tq_pop_counter += 1
return self.queue.pop(0)
def size(self):
test.tq_size_counter += 1
return len(self.queue)
self.torch_bind_ops = [
torch.ops._TorchScriptTesting.takes_foo,
torch.ops._TorchScriptTesting.takes_foo_python_meta,
torch.ops._TorchScriptTesting.takes_foo_list_return,
torch.ops._TorchScriptTesting.takes_foo_tuple_return,
torch.ops._TorchScriptTesting.take_an_instance,
torch.ops._TorchScriptTesting.take_an_instance_inferred,
torch.ops._TorchScriptTesting.takes_foo_cia,
torch.ops._TorchScriptTesting.queue_pop,
torch.ops._TorchScriptTesting.queue_push,
torch.ops._TorchScriptTesting.queue_size,
]
def tearDown(self):
torch._library.fake_class_registry.deregister_fake_class(
"_TorchScriptTesting::_Foo"
)
torch._library.fake_class_registry.deregister_fake_class(
"_TorchScriptTesting::_TensorQueue"
)
def _test_export_same_as_eager(
self, f, args, kwargs=None, strict=True, pre_dispatch=False
):
kwargs = kwargs or {}
def export_wrapper(f, args, kwargs, strcit, pre_dispatch):
with enable_torchbind_tracing():
if pre_dispatch:
exported_program = _export(
f, args, kwargs, strict=strict, pre_dispatch=True
)
else:
exported_program = export(f, args, kwargs, strict=strict)
return exported_program
exported_program = export_wrapper(f, args, kwargs, strict, pre_dispatch)
reversed_kwargs = {key: kwargs[key] for key in reversed(kwargs)}
unlifted = exported_program.module()
exp = f(*args, **kwargs)
self.assertEqual(unlifted(*args, **kwargs), exp)
self.assertEqual(
unlifted(*args, **reversed_kwargs),
exp,
)
# check re-tracing
retraced_ep = export_wrapper(unlifted, args, kwargs, strict, pre_dispatch)
self.assertEqual(retraced_ep.module()(*args, **kwargs), exp)
return exported_program
@parametrize("pre_dispatch", [True, False])
def test_none(self, pre_dispatch):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
def forward(self, x, n):
return x + self.attr.add_tensor(x)
ep = self._test_export_same_as_eager(
MyModule(),
(torch.ones(2, 3), None),
strict=False,
pre_dispatch=pre_dispatch,
)
self.assertExpectedInline(
ep.module().code.strip(),
"""\
def forward(self, x, n):
x, n, = fx_pytree.tree_flatten_spec(([x, n], {}), self._in_spec)
attr = self.attr
call_torchbind = torch.ops.higher_order.call_torchbind(attr, 'add_tensor', x); attr = None
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
return pytree.tree_unflatten((add,), self._out_spec)""",
)
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, obj_attr, x, n):
call_torchbind = torch.ops.higher_order.call_torchbind(obj_attr, 'add_tensor', x); obj_attr = None
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
return (add,)""",
)
@parametrize("pre_dispatch", [True, False])
def test_attribute(self, pre_dispatch):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
def forward(self, x):
return x + self.attr.add_tensor(x)
ep = self._test_export_same_as_eager(
MyModule(), (torch.ones(2, 3),), strict=False, pre_dispatch=pre_dispatch
)
self.assertExpectedInline(
ep.module().code.strip(),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
attr = self.attr
call_torchbind = torch.ops.higher_order.call_torchbind(attr, 'add_tensor', x); attr = None
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
return pytree.tree_unflatten((add,), self._out_spec)""",
)
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, obj_attr, x):
call_torchbind = torch.ops.higher_order.call_torchbind(obj_attr, 'add_tensor', x); obj_attr = None
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
return (add,)""",
)
@parametrize("pre_dispatch", [True, False])
def test_attribute_as_custom_op_argument(self, pre_dispatch):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
def forward(self, x):
return x + torch.ops._TorchScriptTesting.takes_foo(self.attr, x)
ep = self._test_export_same_as_eager(
MyModule(), (torch.ones(2, 3),), strict=False, pre_dispatch=pre_dispatch
)
self.assertExpectedInline(
ep.module().code.strip(),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
attr = self.attr
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, x); attr = None
add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
return pytree.tree_unflatten((add,), self._out_spec)""",
)
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, token, obj_attr, x):
with_effects = torch._higher_order_ops.effects.with_effects(token, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, x); token = obj_attr = None
getitem = with_effects[0]
getitem_1 = with_effects[1]; with_effects = None
add = torch.ops.aten.add.Tensor(x, getitem_1); x = getitem_1 = None
return (getitem, add)""", # noqa: B950
)
@parametrize("pre_dispatch", [True, False])
@parametrize("fakify_script_obj", [True, False])
def test_input(self, pre_dispatch, fakify_script_obj):
cc = torch.classes._TorchScriptTesting._Foo(10, 20)
if not fakify_script_obj:
qual_name = cc._type().qualified_name() # type: ignore[att-defined]
if torch._library.fake_class_registry.has_fake_class(qual_name):
torch._library.fake_class_registry.deregister_fake_class(
"_TorchScriptTesting::_Foo"
)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, cc):
return x + cc.add_tensor(x)
ep = self._test_export_same_as_eager(
MyModule(), (torch.ones(2, 3), cc), strict=False, pre_dispatch=pre_dispatch
)
self.assertExpectedInline(
ep.module().code.strip(),
"""\
def forward(self, x, cc):
x, cc, = fx_pytree.tree_flatten_spec(([x, cc], {}), self._in_spec)
call_torchbind = torch.ops.higher_order.call_torchbind(cc, 'add_tensor', x); cc = None
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
return pytree.tree_unflatten((add,), self._out_spec)""",
)
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, x, cc):
call_torchbind = torch.ops.higher_order.call_torchbind(cc, 'add_tensor', x); cc = None
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
return (add,)""",
)
# aot_export_function runs the program twice
# in run_functionalized_fw_and_collect_metadata and create_aot_dispatcher_function
# We also have a re-tracing test, which doubles the count.
if fakify_script_obj:
self.assertEqual(self.foo_add_tensor_counter, 4)
@parametrize("pre_dispatch", [True, False])
@parametrize("fakify_script_obj", [True, False])
def test_input_as_custom_op_argument(self, pre_dispatch, fakify_script_obj):
cc = torch.classes._TorchScriptTesting._Foo(10, 20)
if not fakify_script_obj:
qual_name = cc._type().qualified_name() # type: ignore[att-defined]
if torch._library.fake_class_registry.has_fake_class(qual_name):
torch._library.fake_class_registry.deregister_fake_class(
"_TorchScriptTesting::_Foo"
)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, cc):
return x + torch.ops._TorchScriptTesting.takes_foo(cc, x)
del torch.ops._TorchScriptTesting.takes_foo.default.py_kernels[
torch._C.DispatchKey.Meta
]
torch.ops._TorchScriptTesting.takes_foo.default._dispatch_cache.clear()
# Even though a C++ implementation for takes_foo.default is registered,
# we still need the python implementation for takes_foo.default to trace with FakeFoo.
if fakify_script_obj:
with self.assertRaisesRegex(
RuntimeError, "no python implementation is found"
):
self._test_export_same_as_eager(
MyModule(),
(torch.ones(2, 3), cc),
strict=False,
pre_dispatch=pre_dispatch,
)
torch.ops._TorchScriptTesting.takes_foo.default.py_impl(
torch._C.DispatchKey.Meta
)(lambda cc, x: cc.add_tensor(x))
ep = self._test_export_same_as_eager(
MyModule(),
(torch.ones(2, 3), cc),
strict=False,
pre_dispatch=pre_dispatch,
)
self.assertExpectedInline(
ep.module().code.strip(),
"""\
def forward(self, x, cc):
x, cc, = fx_pytree.tree_flatten_spec(([x, cc], {}), self._in_spec)
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(cc, x); cc = None
add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
return pytree.tree_unflatten((add,), self._out_spec)""",
)
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, token, x, cc):
with_effects = torch._higher_order_ops.effects.with_effects(token, torch.ops._TorchScriptTesting.takes_foo.default, cc, x); token = cc = None
getitem = with_effects[0]
getitem_1 = with_effects[1]; with_effects = None
add = torch.ops.aten.add.Tensor(x, getitem_1); x = getitem_1 = None
return (getitem, add)""", # noqa: B950
)
@parametrize("pre_dispatch", [True, False])
def test_unlift_custom_obj(self, pre_dispatch):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
def forward(self, x):
a = torch.ops._TorchScriptTesting.takes_foo(self.attr, x)
b = torch.ops._TorchScriptTesting.takes_foo(self.attr, a)
return x + b
input = torch.ones(2, 3)
ep = self._test_export_same_as_eager(
MyModule(), (input,), strict=False, pre_dispatch=pre_dispatch
)
self.assertExpectedInline(
ep.module().code.strip(),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
attr = self.attr
takes_foo_default_1 = torch.ops._TorchScriptTesting.takes_foo.default(attr, x)
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, takes_foo_default_1); attr = takes_foo_default_1 = None
add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
return pytree.tree_unflatten((add,), self._out_spec)""", # noqa: B950
)
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, token, obj_attr, x):
with_effects = torch._higher_order_ops.effects.with_effects(token, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, x); token = None
getitem = with_effects[0]
getitem_1 = with_effects[1]; with_effects = None
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, getitem_1); getitem = obj_attr = getitem_1 = None
getitem_2 = with_effects_1[0]
getitem_3 = with_effects_1[1]; with_effects_1 = None
add = torch.ops.aten.add.Tensor(x, getitem_3); x = getitem_3 = None
return (getitem_2, add)""", # noqa: B950
)
@parametrize("pre_dispatch", [True, False])
def test_custom_obj_list_out(self, pre_dispatch):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
def forward(self, x):
a = torch.ops._TorchScriptTesting.takes_foo_list_return(self.attr, x)
y = a[0] + a[1] + a[2]
b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
return x + b
input = torch.ones(2, 3)
ep = self._test_export_same_as_eager(
MyModule(), (input,), strict=False, pre_dispatch=pre_dispatch
)
self.assertExpectedInline(
ep.module().code.strip(),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
attr = self.attr
takes_foo_list_return_default = torch.ops._TorchScriptTesting.takes_foo_list_return.default(attr, x)
getitem_2 = takes_foo_list_return_default[0]
getitem_3 = takes_foo_list_return_default[1]
getitem_4 = takes_foo_list_return_default[2]; takes_foo_list_return_default = None
add = torch.ops.aten.add.Tensor(getitem_2, getitem_3); getitem_2 = getitem_3 = None
add_1 = torch.ops.aten.add.Tensor(add, getitem_4); add = getitem_4 = None
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, add_1); attr = add_1 = None
add_2 = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
return pytree.tree_unflatten((add_2,), self._out_spec)""",
)
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, token, obj_attr, x):
with_effects = torch._higher_order_ops.effects.with_effects(token, torch.ops._TorchScriptTesting.takes_foo_list_return.default, obj_attr, x); token = None
getitem = with_effects[0]
getitem_1 = with_effects[1]; with_effects = None
getitem_2 = getitem_1[0]
getitem_3 = getitem_1[1]
getitem_4 = getitem_1[2]; getitem_1 = None
add = torch.ops.aten.add.Tensor(getitem_2, getitem_3); getitem_2 = getitem_3 = None
add_1 = torch.ops.aten.add.Tensor(add, getitem_4); add = getitem_4 = None
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, add_1); getitem = obj_attr = add_1 = None
getitem_5 = with_effects_1[0]
getitem_6 = with_effects_1[1]; with_effects_1 = None
add_2 = torch.ops.aten.add.Tensor(x, getitem_6); x = getitem_6 = None
return (getitem_5, add_2)""", # noqa: B950
)
@parametrize("pre_dispatch", [True, False])
def test_custom_obj_tuple_out(self, pre_dispatch):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
def forward(self, x):
a = torch.ops._TorchScriptTesting.takes_foo_tuple_return(self.attr, x)
y = a[0] + a[1]
b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
return x + b
input = torch.ones(2, 3)
ep = self._test_export_same_as_eager(
MyModule(), (input,), strict=False, pre_dispatch=pre_dispatch
)
self.assertExpectedInline(
ep.module().code.strip(),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
attr = self.attr
takes_foo_tuple_return_default = torch.ops._TorchScriptTesting.takes_foo_tuple_return.default(attr, x)
getitem_1 = takes_foo_tuple_return_default[0]
getitem_2 = takes_foo_tuple_return_default[1]; takes_foo_tuple_return_default = None
add = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = getitem_2 = None
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, add); attr = add = None
add_1 = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
return pytree.tree_unflatten((add_1,), self._out_spec)""",
)
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, token, obj_attr, x):
with_effects = torch._higher_order_ops.effects.with_effects(token, torch.ops._TorchScriptTesting.takes_foo_tuple_return.default, obj_attr, x); token = None
getitem = with_effects[0]
getitem_1 = with_effects[1]
getitem_2 = with_effects[2]; with_effects = None
add = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = getitem_2 = None
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, add); getitem = obj_attr = add = None
getitem_3 = with_effects_1[0]
getitem_4 = with_effects_1[1]; with_effects_1 = None
add_1 = torch.ops.aten.add.Tensor(x, getitem_4); x = getitem_4 = None
return (getitem_3, add_1)""", # noqa: B950
)
@parametrize("make_fx_tracing_mode", ["fake", "symbolic"])
def test_make_fx_tensor_queue_methods(self, make_fx_tracing_mode):
test = self
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 2)
self.check_tq_is_fake = True
def forward(self, tq, x):
if self.check_tq_is_fake:
test.assertTrue(isinstance(tq, FakeScriptObject))
tq.push(x.cos())
tq.push(x.sin())
x_cos = tq.pop() + tq.size()
x_sin = tq.pop() - tq.size()
return x_sin, x_cos, tq
mod = Model()
tq = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
tq1 = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
x = torch.ones(2, 3)
gm = make_fx(mod, tracing_mode=make_fx_tracing_mode)(tq, x)
self.assertEqual(self.tq_push_counter, 2)
self.assertEqual(self.tq_pop_counter, 2)
self.assertEqual(self.tq_size_counter, 2)
self.assertEqual(tq.size(), 0)
self.assertExpectedInline(
gm.code.strip("\n"),
"""\
def forward(self, arg0_1, arg1_1):
cos = torch.ops.aten.cos.default(arg1_1)
call_torchbind = torch.ops.higher_order.call_torchbind(arg0_1, 'push', cos); cos = None
sin = torch.ops.aten.sin.default(arg1_1); arg1_1 = None
call_torchbind_1 = torch.ops.higher_order.call_torchbind(arg0_1, 'push', sin); sin = None
call_torchbind_2 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
call_torchbind_3 = torch.ops.higher_order.call_torchbind(arg0_1, 'size')
add = torch.ops.aten.add.Tensor(call_torchbind_2, 1); call_torchbind_2 = None
call_torchbind_4 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
call_torchbind_5 = torch.ops.higher_order.call_torchbind(arg0_1, 'size')
sub = torch.ops.aten.sub.Tensor(call_torchbind_4, 0); call_torchbind_4 = None
return (sub, add, arg0_1)
""",
)
mod.check_tq_is_fake = False
_assertEqualSkipScriptObject(self, gm(tq, x), mod(tq1, x))
@parametrize("make_fx_tracing_mode", ["fake", "symbolic"])
def test_make_fx_tensor_queue_methods_fakify_internal_states(
self, make_fx_tracing_mode
):
test = self
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 2)
self.check_tq_is_fake = True
self.current_test = test
def forward(self, tq, x):
if self.check_tq_is_fake:
self.current_test.assertTrue(isinstance(tq, FakeScriptObject))
x_cos = tq.pop() + tq.size() + x
x_sin = tq.pop() - tq.size() + x
return x_sin, x_cos, tq
mod = Model()
tq = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
tq1 = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
for _ in range(2):
tq.push(torch.ones(2, 3))
tq1.push(torch.ones(2, 3))
x = torch.ones(2, 3)
prev_size = tq.size()
gm = make_fx(mod, tracing_mode=make_fx_tracing_mode)(tq, x)
self.assertEqual(self.tq_push_counter, 0)
self.assertEqual(self.tq_pop_counter, 2)
self.assertEqual(self.tq_size_counter, 2)
self.assertEqual(tq.size(), prev_size)
self.assertExpectedInline(
gm.code.strip("\n"),
"""\
def forward(self, arg0_1, arg1_1):
call_torchbind = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
call_torchbind_1 = torch.ops.higher_order.call_torchbind(arg0_1, 'size')
add = torch.ops.aten.add.Tensor(call_torchbind, 1); call_torchbind = None
add_1 = torch.ops.aten.add.Tensor(add, arg1_1); add = None
call_torchbind_2 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
call_torchbind_3 = torch.ops.higher_order.call_torchbind(arg0_1, 'size')
sub = torch.ops.aten.sub.Tensor(call_torchbind_2, 0); call_torchbind_2 = None
add_2 = torch.ops.aten.add.Tensor(sub, arg1_1); sub = arg1_1 = None
return (add_2, add_1, arg0_1)
""",
)
# turn off tq type checking in eager execution
mod.check_tq_is_fake = False
_assertEqualSkipScriptObject(self, gm(tq, x), mod(tq1, x))
self.assertEqual(tq.size(), 0)
self.assertEqual(tq1.size(), 0)
def test_identifying_torchbind_ops(self):
for op in self.torch_bind_ops:
self.assertTrue(op._has_torchbind_op_overload)
for op in [
torch.ops.aten.add,
torch.ops.aten.cos,
]:
self.assertFalse(op._has_torchbind_op_overload)
def test_torchbind_op_register_fallthrough(self):
TEST_DISPATCH_KEY = torch._C.DispatchKey.AutocastCPU
TEST_DISPATCH_KEY_STR = "AutocastCPU"
for op_packet in self.torch_bind_ops:
op = op_packet.default
ns, _ = torch._library.utils.parse_namespace(op_packet._qualified_op_name)
with torch.library._scoped_library(ns, "FRAGMENT") as lib:
lib.impl(
op.name(), torch.library.fallthrough_kernel, TEST_DISPATCH_KEY_STR
)
self.assertTrue(
torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough(
op.name(), TEST_DISPATCH_KEY
)
)
def test_torchbind_op_fallthrough_keys_respects_lib_impl(self):
TEST_DISPATCH_KEY = torch._C.DispatchKey.AutogradCPU
TEST_DISPATCH_KEY_STR = "AutogradCPU"
tested = 0
for op_packet in self.torch_bind_ops:
op = op_packet.default
ns, _ = torch._library.utils.parse_namespace(op_packet._qualified_op_name)
if (
not torch._C._dispatch_has_kernel_for_dispatch_key(
op.name(), TEST_DISPATCH_KEY
)
and TEST_DISPATCH_KEY not in op.py_kernels
):
tested += 1
with torch.library._scoped_library(ns, "FRAGMENT") as lib:
lib.impl(
op.name(), lambda *args, **kwargs: args, TEST_DISPATCH_KEY_STR
)
self.assertTrue(TEST_DISPATCH_KEY not in op._fallthrough_keys())
with torch.library._scoped_library(ns, "FRAGMENT") as lib:
lib.impl(
op.name(),
torch.library.fallthrough_kernel,
TEST_DISPATCH_KEY_STR,
)
self.assertTrue(TEST_DISPATCH_KEY in op._fallthrough_keys())
self.assertTrue(tested > 0)
def test_make_fx_schema_checking_script_object(self):
class Model(torch.nn.Module):
def forward(self, tq, x, foo):
torch.ops._TorchScriptTesting.queue_push(foo, x.cos())
return tq
class ModelCallByKW(torch.nn.Module):
def forward(self, tq, x, foo):
torch.ops._TorchScriptTesting.queue_push(x=x.cos(), foo=foo)
return tq
mod = Model()
modkw = ModelCallByKW()
foo = torch.classes._TorchScriptTesting._Foo(10, 20)
x = torch.ones(3, 3)
tq = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
ns = "_TorchScriptTesting"
with torch.library._scoped_library(ns, "FRAGMENT") as lib:
op = torch.ops._TorchScriptTesting.queue_push
lib.impl(op.__name__, torch.library.fallthrough_kernel, "AutogradCPU")
lib.impl(op.__name__, torch.library.fallthrough_kernel, "ADInplaceOrView")
lib.impl(
op.__name__,
torch.library.fallthrough_kernel,
"PythonTLSSnapshot",
)
with self.assertRaisesRegex(
RuntimeError, "is expected to be a FakeScriptObject"
):
_ = make_fx(mod, tracing_mode="fake")(tq, x, foo)
with self.assertRaisesRegex(
RuntimeError, "is expected to be a FakeScriptObject"
):
_ = make_fx(modkw, tracing_mode="fake")(tq, x, foo)
@parametrize("fallthrough_via", ["lib_impl", "py_impl"])
def test_make_fx_tensor_queue_operators(self, fallthrough_via):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, tq, x):
with torch.autocast("cuda", dtype=torch.bfloat16):
torch.ops._TorchScriptTesting.queue_push(tq, x.cos())
torch.ops._TorchScriptTesting.queue_push(tq, x.sin())
x_sin = torch.ops._TorchScriptTesting.queue_pop(
tq
) - torch.ops._TorchScriptTesting.queue_size(tq)
x_cos = torch.ops._TorchScriptTesting.queue_pop(
tq
) + torch.ops._TorchScriptTesting.queue_size(tq)
return x_sin, x_cos, tq
mod = Model()
tq1 = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
tq2 = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
x = torch.ones(2, 3)
mod(tq1, x)
ops = [
torch.ops._TorchScriptTesting.queue_push,
torch.ops._TorchScriptTesting.queue_pop,
torch.ops._TorchScriptTesting.queue_size,
]
if fallthrough_via == "lib_impl":
ns = "_TorchScriptTesting"
with torch.library._scoped_library(ns, "FRAGMENT") as lib:
for op in ops:
lib.impl(
op.__name__, torch.library.fallthrough_kernel, "AutocastCUDA"
)
gm = make_fx(mod, tracing_mode="fake")(tq1, x)
else:
for op in ops:
op.default.py_impl(torch._C.DispatchKey.AutocastCUDA)(
torch.library.fallthrough_kernel
)
gm = make_fx(mod, tracing_mode="fake")(tq1, x)
for op in ops:
op.default._dispatch_cache.clear()
del op.default.py_kernels[torch._C.DispatchKey.AutocastCUDA]
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, arg0_1, arg1_1):
cos = torch.ops.aten.cos.default(arg1_1)
queue_push = torch.ops._TorchScriptTesting.queue_push.default(arg0_1, cos); cos = None
sin = torch.ops.aten.sin.default(arg1_1); arg1_1 = None
queue_push_1 = torch.ops._TorchScriptTesting.queue_push.default(arg0_1, sin); sin = None
queue_pop = torch.ops._TorchScriptTesting.queue_pop.default(arg0_1)
queue_size = torch.ops._TorchScriptTesting.queue_size.default(arg0_1)
sub = torch.ops.aten.sub.Tensor(queue_pop, 1); queue_pop = None
queue_pop_1 = torch.ops._TorchScriptTesting.queue_pop.default(arg0_1)
queue_size_1 = torch.ops._TorchScriptTesting.queue_size.default(arg0_1)
add = torch.ops.aten.add.Tensor(queue_pop_1, 0); queue_pop_1 = None
return (sub, add, arg0_1)""",
)
_assertEqualSkipScriptObject(self, gm(tq1, x), mod(tq2, x))
def test_aot_export_tensor_queue_operators(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, tq, x):
torch.ops._TorchScriptTesting.queue_push(tq, x.cos())
torch.ops._TorchScriptTesting.queue_push(tq, x.sin())
x_sin = torch.ops._TorchScriptTesting.queue_pop(
tq
) - torch.ops._TorchScriptTesting.queue_size(tq)
x_cos = torch.ops._TorchScriptTesting.queue_pop(
tq
) + torch.ops._TorchScriptTesting.queue_size(tq)
return x_sin, x_cos, tq
mod = Model()
tq1 = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
x = torch.ones(2, 3)
fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
fake_tq1 = torch._library.fake_class_registry.to_fake_obj(fake_mode, tq1)
fake_x = fake_mode.from_tensor(x)
gm = aot_export_module(mod, (fake_tq1, fake_x), trace_joint=False)[0]
# inputs: token, tq, x
# return: token, x_sin, x_cos, tq
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1):
cos = torch.ops.aten.cos.default(arg2_1)
with_effects = torch._higher_order_ops.effects.with_effects(arg0_1, torch.ops._TorchScriptTesting.queue_push.default, arg1_1, cos); arg0_1 = cos = None
getitem = with_effects[0]; with_effects = None
sin = torch.ops.aten.sin.default(arg2_1); arg2_1 = None
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops._TorchScriptTesting.queue_push.default, arg1_1, sin); getitem = sin = None
getitem_2 = with_effects_1[0]; with_effects_1 = None
with_effects_2 = torch._higher_order_ops.effects.with_effects(getitem_2, torch.ops._TorchScriptTesting.queue_pop.default, arg1_1); getitem_2 = None
getitem_4 = with_effects_2[0]
getitem_5 = with_effects_2[1]; with_effects_2 = None
with_effects_3 = torch._higher_order_ops.effects.with_effects(getitem_4, torch.ops._TorchScriptTesting.queue_size.default, arg1_1); getitem_4 = None
getitem_6 = with_effects_3[0]; with_effects_3 = None
sub = torch.ops.aten.sub.Tensor(getitem_5, 1); getitem_5 = None
with_effects_4 = torch._higher_order_ops.effects.with_effects(getitem_6, torch.ops._TorchScriptTesting.queue_pop.default, arg1_1); getitem_6 = None
getitem_8 = with_effects_4[0]
getitem_9 = with_effects_4[1]; with_effects_4 = None
with_effects_5 = torch._higher_order_ops.effects.with_effects(getitem_8, torch.ops._TorchScriptTesting.queue_size.default, arg1_1); getitem_8 = None
getitem_10 = with_effects_5[0]; with_effects_5 = None
add = torch.ops.aten.add.Tensor(getitem_9, 0); getitem_9 = None
return (getitem_10, sub, add, arg1_1)""", # noqa: B950
)
class TestCompileTorchbind(TestCase):
def setUp(self):
init_torchbind_implementations()
@torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
class FakeTensorQueue:
def __init__(self, queue):
self.queue = queue
@classmethod
def __obj_unflatten__(cls, flattened_ctx):
return cls(**dict(flattened_ctx))
def push(self, x):
self.queue.append(x)
def pop(self):
return self.queue.pop(0)
def size(self):
return len(self.queue)
torch._dynamo.reset()
def tearDown(self):
torch._dynamo.reset()
def test_compile_script_object_input(self):
backend = EagerAndRecordGraphs()
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.check_tq_is_fake = True
def forward(self, tq, x):
tq.push(x.cos())
tq.push(x.sin())
x_sin = tq.pop() - tq.size()
return x_sin, tq
mod = Model()
tq1 = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
tq2 = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
tq3 = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
tq4 = torch.classes._TorchScriptTesting._TensorQueue(
torch.empty(
0,
).fill_(-1)
)
x = torch.randn(2, 3)
ret = torch.compile(mod, backend=backend)(tq1, x)
eager_ret = mod(tq2, x)
_assertEqualSkipScriptObject(self, ret, eager_ret)
self.assertEqual(ret[1].size(), eager_ret[1].size())
self.assertEqual(ret[1].pop(), eager_ret[1].pop())
# Note that dynamo captured graph
# does not return L_tq_ as output. This is because it's able
# to detect that L_tq_ is an input therefore don't return
# it as graph output. Related logic is in dynamo/codegen.py
self.assertExpectedInline(
backend.graphs[0].code.strip(),
"""\
def forward(self, L_tq_ : torch.ScriptObject, L_x_ : torch.Tensor):
l_tq_ = L_tq_
l_x_ = L_x_
cos = l_x_.cos()
call_torchbind = torch.ops.higher_order.call_torchbind(l_tq_, 'push', cos); cos = None
sin = l_x_.sin(); l_x_ = None
call_torchbind_1 = torch.ops.higher_order.call_torchbind(l_tq_, 'push', sin); sin = None
call_torchbind_2 = torch.ops.higher_order.call_torchbind(l_tq_, 'pop')
call_torchbind_3 = torch.ops.higher_order.call_torchbind(l_tq_, 'size'); l_tq_ = None
x_sin = call_torchbind_2 - 1; call_torchbind_2 = None
return (x_sin,)""",
)
def test_compile_script_object_input_guards(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.check_tq_is_fake = True
def forward(self, tq, x):
tq.push(x.cos())
tq.push(x.sin())
x_sin = tq.pop() - tq.size()
return x_sin, tq
mod = Model()
cnt = torch._dynamo.testing.CompileCounter()
x = torch.randn(2, 3)
tq1 = _empty_tensor_queue()
torch.compile(mod, backend=cnt)(tq1, x)
self.assertEqual(cnt.frame_count, 1)
tq2 = _empty_tensor_queue()
for _ in range(10):
tq2.push(torch.randn(4, 5, requires_grad=False))
torch.compile(mod, backend=cnt)(tq2, x)
# Queue length change causes re-compile
self.assertEqual(cnt.frame_count, 2)
tq3 = _empty_tensor_queue()
tq3.push(torch.randn(2, 3, requires_grad=False))
torch.compile(mod, backend=cnt)(tq3, x)
# Tensor in queue changes shape causes re-compile
self.assertEqual(cnt.frame_count, 3)
tq4 = _empty_tensor_queue()
tq4.push(torch.randn(2, 3, requires_grad=False))
torch.compile(mod, backend=cnt)(tq4, x)
# No recompile
self.assertEqual(cnt.frame_count, 3)
tq5 = _empty_tensor_queue()
tq5.push(torch.randn(2, 3, requires_grad=True))
torch.compile(mod, backend=cnt)(tq5, x)
# Tensor in queue changes dispatch key causes re-compile
self.assertEqual(cnt.frame_count, 4)
tq6 = _empty_tensor_queue()
tq6.push(torch.randn(2, 3, requires_grad=True, dtype=torch.float64))
torch.compile(mod, backend=cnt)(tq6, x)
# Tensor in queue changes dtype causes re-compile
self.assertEqual(cnt.frame_count, 5)
def test_compile_script_object_input_automatic_dynamic_shape(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.check_tq_is_fake = True
def forward(self, tq, x):
tq.push(x.cos())
tq.push(x.sin())
x_sin = tq.pop() - tq.size()
return x_sin, tq
mod = Model()
cnt = torch._dynamo.testing.CompileCounter()
x = torch.randn(2, 3)
tq1 = _empty_tensor_queue()
tq1.push(torch.randn(2, 3, requires_grad=False))
torch.compile(mod, backend=cnt)(tq1, x)
self.assertEqual(cnt.frame_count, 1)
tq2 = _empty_tensor_queue()
# make first tensor's secon dim dynamic
tq2.push(torch.randn(2, 4, requires_grad=False))
torch.compile(mod, backend=cnt)(tq2, x)
self.assertEqual(cnt.frame_count, 2)
tq3 = _empty_tensor_queue()
tq3.push(torch.randn(2, 5, requires_grad=False))
# should have no-recompilation
torch.compile(mod, backend=cnt)(tq3, x)
self.assertEqual(cnt.frame_count, 2)
def test_compile_error_on_input_aliasing_contents(self):
backend = EagerAndRecordGraphs()
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.check_tq_is_fake = True
def forward(self, tq, x):
return x.sin(), tq.pop().cos()
x = torch.randn(2, 3)
mod = Model()
tq1 = _empty_tensor_queue()
tq1.push(x)
with self.assertRaisesRegex(RuntimeError, "is alising"):
torch.compile(mod, backend=backend)(tq1, x)
def test_compile_error_on_script_obj_setattr(self):
def setattr_f(tq):
tq.a = 1
return tq
with self.assertRaisesRegex(
RuntimeError, "call method __setattr__ on script object is not safe"
):
torch.compile(setattr_f, backend="eager")(_empty_tensor_queue())
def test_compile_error_on_script_obj_missing_attr(self):
def setattr_f(tq):
return tq._not_defined_attr
with self.assertRaisesRegex(
RuntimeError, "doesn't define method _not_defined_attr"
):
torch.compile(setattr_f, backend="eager")(_empty_tensor_queue())
def test_compile_body_aliasing_contents(self):
backend = EagerAndRecordGraphs()
def f(tq, x):
x1 = x.view(-1)
x2 = x.permute(1, 0)
tq.push(x1)
tq.push(x2)
return x1 - tq.size(), x2 + tq.size(), tq
x = torch.randn(2, 3)
_assertEqualScriptObject(
self,
f(_empty_tensor_queue(), x),
torch.compile(f, backend=backend)(_empty_tensor_queue(), x),
)
if not torch._dynamo.is_compiling():
self.assertExpectedInline(
backend.graphs[0].code.strip(),
"""\
def forward(self, L_x_ : torch.Tensor, L_tq_ : torch.ScriptObject):
l_x_ = L_x_
l_tq_ = L_tq_
x1 = l_x_.view(-1)
x2 = l_x_.permute(1, 0); l_x_ = None
call_torchbind = torch.ops.higher_order.call_torchbind(l_tq_, 'push', x1)
call_torchbind_1 = torch.ops.higher_order.call_torchbind(l_tq_, 'push', x2)
call_torchbind_2 = torch.ops.higher_order.call_torchbind(l_tq_, 'size')
sub = x1 - 2; x1 = None
call_torchbind_3 = torch.ops.higher_order.call_torchbind(l_tq_, 'size'); l_tq_ = None
add = x2 + 2; x2 = None
return (sub, add)""",
)
def test_compile_error_on_non_fakified_method(self):
backend = EagerAndRecordGraphs()
def f(tq, x):
x1 = x.view(-1)
x2 = x.permute(1, 0)
tq.push(x1)
tq.push(x2)
# though real tensor queue implemented a method clone_queue,
# The fakified version doesn't.
flat_obj = tq.clone_queue()
return flat_obj
x = torch.randn(2, 3)
with self.assertRaisesRegex(
RuntimeError, "FakeScriptObject doesn't define method"
):
torch.compile(f, backend=backend)(_empty_tensor_queue(), x)
def test_compile_obj_as_hop_input(self):
def f(tq, x):
def fn(tq, x):
tq.push(x)
return x.sin()
return wrap(fn, tq, x)
x = torch.randn(2, 3)
_assertEqualScriptObject(
self,
f(_empty_tensor_queue(), x),
torch.compile(f, backend="eager")(_empty_tensor_queue(), x),
)
def test_compile_obj_closure(self):
def f(x):
def inner_f(x):
tq.push(x.sin())
inner_f(x)
return tq.pop(), tq
opt_f = torch.compile(f, backend="eager")
tq = _empty_tensor_queue()
x = torch.randn(3, 2)
_assertEqualScriptObject(self, f(x), opt_f(x))
def test_compile_global_obj(self):
global _TENSOR_QUEUE_GLOBAL_TEST
_TENSOR_QUEUE_GLOBAL_TEST = _empty_tensor_queue()
def f(x):
_TENSOR_QUEUE_GLOBAL_TEST.push(x.sin())
return _TENSOR_QUEUE_GLOBAL_TEST.pop(), _TENSOR_QUEUE_GLOBAL_TEST
opt_f = torch.compile(f, backend="eager")
x = torch.randn(3, 2)
eager_ret = f(x)
opt_ret = opt_f(x)
_assertEqualScriptObject(self, eager_ret, opt_ret)
def test_compile_obj_graph_breaks(self):
cnt = torch._dynamo.testing.CompileCounter()
def f(tq, x):
tq.push(x.sin())
tq.push(x.sin())
torch._dynamo.graph_break()
tq.pop()
torch._dynamo.graph_break()
tq.push(x.cos() + tq.size())
torch._dynamo.graph_break()
tq.push(x.cos() - tq.size())
return x, tq.pop(), tq
opt_f = torch.compile(f, backend=cnt)
x = torch.randn(3, 2)
_assertEqualScriptObject(
self, f(_empty_tensor_queue(), x), opt_f(_empty_tensor_queue(), x)
)
self.assertEqual(cnt.frame_count, 4)
def test_compile_obj_attributes(self):
backend = EagerAndRecordGraphs()
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.tq = _empty_tensor_queue()
def forward(self, x):
self.tq.push(x)
return self.tq.pop()
x = torch.randn(2, 3)
opt_f = torch.compile(Model(), backend=backend)
_assertEqualScriptObject(self, Model()(x), opt_f(x))
self.assertEqual(len(backend.graphs), 1)
# lifted as input. In the future, we would want to cosolidate this
# with non-strict behavior, where they're set as attributes.
self.assertExpectedInline(
backend.graphs[0].code.strip(),
"""\
def forward(self, L_self_tq : torch.ScriptObject, L_x_ : torch.Tensor):
l_self_tq = L_self_tq
l_x_ = L_x_
call_torchbind = torch.ops.higher_order.call_torchbind(l_self_tq, 'push', l_x_); l_x_ = None
call_torchbind_1 = torch.ops.higher_order.call_torchbind(l_self_tq, 'pop'); l_self_tq = None
return (call_torchbind_1,)""",
)
def test_compile_obj_torchbind_op(self):
def f(tq, x):
torch.ops._TorchScriptTesting.queue_push(tq, x.cos())
torch.ops._TorchScriptTesting.queue_push(tq, x.cos() + 1)
torch.ops._TorchScriptTesting.queue_pop(tq)
torch.ops._TorchScriptTesting.queue_push(tq, x.sin())
return tq.pop(), tq.pop() + tq.size(), tq
opt_f = torch.compile(f, backend="eager")
x = torch.randn(2)
_assertEqualScriptObject(
self, f(_empty_tensor_queue(), x), opt_f(_empty_tensor_queue(), x)
)
@skipIfTorchDynamo("torchbind not supported with dynamo yet")
class TestRegisterFakeClass(TestCase):
def setUp(self):
init_torchbind_implementations()
def tearDown(self):
torch._library.fake_class_registry.global_fake_class_registry.clear()
def test_register_fake_class_no_torch_bind_class(self):
with self.assertRaisesRegex(RuntimeError, "Tried to instantiate class"):
@torch._library.register_fake_class("_TorchScriptTesting::NOT_A_VALID_NAME")
class Invalid:
pass
def test_register_fake_class_no_from_real(self):
with self.assertRaisesRegex(
RuntimeError, "define a classmethod __obj_unflatten__"
):
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
class InvalidFakeFoo:
def __init__(self):
pass
def test_register_fake_class_from_real_not_classmethod(self):
with self.assertRaisesRegex(RuntimeError, "is not a classmethod"):
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
class FakeFoo:
def __init__(self, x, y):
self.x = x
self.y = y
def __obj_unflatten__(cls, flattend_foo): # noqa: B902
return cls(**dict(flattend_foo))
def test_register_fake_class_valid(self):
class FakeFoo:
def __init__(self, x, y):
self.x = x
self.y = y
@classmethod
def __obj_unflatten__(cls, flattend_foo):
return cls(**dict(flattend_foo))
torch._library.register_fake_class("_TorchScriptTesting::_Foo", FakeFoo)
instantiate_parametrized_tests(TestExportTorchbind)
if __name__ == "__main__":
run_tests()