blob: 01ce9a350356f3163741684284e2439948b7959b [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import dataclasses
import unittest
from contextlib import contextmanager
from dataclasses import dataclass
import torch
import torch._dynamo as torchdynamo
from functorch.experimental.control_flow import map, cond
from torch import Tensor
from torch.export import Constraint
from torch._export import DEFAULT_EXPORT_DYNAMO_CONFIG, dynamic_dim, export, capture_pre_autograd_graph
from torch._export.constraints import constrain_as_size, constrain_as_value
from torch._export.utils import (
get_buffer,
get_param,
is_buffer,
is_param,
register_dataclass_as_pytree_node,
)
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing import FileCheck
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils._pytree import (
LeafSpec,
tree_flatten,
tree_unflatten,
TreeSpec,
treespec_loads,
treespec_dumps
)
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestDynamismExpression(TestCase):
def test_export_inline_constraints(self):
def f(x):
b = x.item()
constrain_as_size(b)
return torch.full((b, 1), 1)
inp = (torch.tensor([3]),)
ref = f(*inp)
gm = export(f, inp)
res = gm(*inp)
self.assertTrue(torchdynamo.utils.same(ref, res))
gm = make_fx(f, tracing_mode="symbolic")(*inp)
res = gm(*inp)
self.assertTrue(torchdynamo.utils.same(ref, res))
def test_export_constraints_error(self):
def invalid_input_conflict_with_input_constraints(x):
return x + 1
inp = torch.zeros([3])
inp_constraints = [
dynamic_dim(inp, 0) > 5,
]
with self.assertRaisesRegex(torchdynamo.exc.UserError, "not in range"):
export(
invalid_input_conflict_with_input_constraints,
(inp,),
constraints=inp_constraints,
)
def conflicting_constraints(x):
b = x.item()
constrain_as_size(b)
constrain_as_value(b, min=4, max=5)
return torch.full((b, 1), 1)
inp = (torch.tensor([3]),)
ep = export(conflicting_constraints, inp)
with self.assertRaisesRegex(RuntimeError, r"is outside of inline constraint \[4, 5\]"):
ep(torch.tensor([3]))
def test_export_assume_static_by_default(self):
def branch_on_shape(x: torch.Tensor):
if x.shape[0] == 4:
return x + 1
else:
return x
inp = (torch.rand(4, 5),)
# Being able to export means shape is preserved as static
export(branch_on_shape, inp)
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestExport(TestCase):
def _test_export_same_as_eager(self, f, args, kwargs=None):
kwargs = kwargs or {}
exported_program = export(f, args, kwargs)
reversed_kwargs = {key: kwargs[key] for key in reversed(kwargs)}
self.assertEqual(exported_program(*args, **kwargs), f(*args, **kwargs))
self.assertEqual(exported_program(*args, **reversed_kwargs), f(*args, **reversed_kwargs))
def test_basic(self):
def f(x, y):
return x[0] + y
inp = ([torch.ones(1, 3)], torch.ones(1, 3))
self._test_export_same_as_eager(f, inp)
def test_export_preserve_signature(self):
class NestedChild(torch.nn.Module):
def forward(self, zx, y):
return {"x": y["key"] + zx[1], "w": y["key"] * zx[1]}
class Child1(torch.nn.Module):
def __init__(self):
super().__init__()
self.nested = NestedChild()
def forward(self, x, y):
z = torch.ones_like(x)
xw = self.nested((z, x), y={"key": y})
return xw["w"] + z - xw["x"]
class Child2(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x - 1
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.foo = Child1()
self.bar = Child2()
def forward(self, x, y):
x = self.foo(x, y)
x = self.bar(x)
return x
orig_eager = MyModule()
inps = torch.rand(2, 3), torch.rand(2, 3)
ep = export(
orig_eager,
inps,
{},
preserve_module_call_signature=("foo.nested", "foo"),
)
ep._validate()
self.assertEqual(len(ep.module_call_graph), 3)
# TODO(zhxchen17) unflattener
# unflattened = unflatten(export_module)
# self.compare_outputs(export_module, unflattened, inps)
# unflattened.foo.nested = NestedChild()
# self.compare_outputs(export_module, unflattened, inps)
def test_raise_user_error_when_guard_on_data_dependent_operation(self):
def fn_ddo(x):
y = x.nonzero()
z = y.shape[0]
if z > 2:
return x.cos()
else:
return x.sin()
with self.assertRaisesRegex(
torchdynamo.exc.UserError,
"trying to get a value out of symbolic int"
):
_ = export(fn_ddo, (torch.tensor([2, 3, 5]),), constraints=None)
def test_if_functional(self):
def foo(x):
z = x + 4
z.add_(4)
y = z.view(x.shape)
return x.cos() + y.cos()
gm = export(foo, (torch.tensor([2, 3, 5]),), constraints=None)
view_count = 0
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.add_.Tensor:
# No more inplace mutation
self.assertNotEqual(
node.target,
torch.ops.aten.add_.Tensor,
"There shouldn't be any inplace mutation node in the graph."
)
if node.op == "call_function" and node.target == torch.ops.aten.view.default:
view_count += 1
# There should be nonzero view nodes in the graph
self.assertTrue(view_count > 0)
def test_export_mod_constraints(self):
class BasicDynamiShapeModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.view(x.shape[0] - 1, -1)
m = BasicDynamiShapeModel()
a = torch.randn(3, 4)
constraints = [3 <= dynamic_dim(a, 0), dynamic_dim(a, 1)]
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
(
"Some dynamic dimensions need to be specialized because "
"the constraints inferred for them are too complex to specify"
".*\n.*\\[0\\], which was marked dynamic, must be specialized to 3"
".*\n.*\\[1\\], which was marked dynamic, must be specialized to 4"
),
):
torch._export.export(m, (a,), constraints=constraints)
em = torch._export.export(m, (a,))
x = torch.randn(3, 5)
with self.assertRaisesRegex(RuntimeError, "\\[1\\] is specialized at 4"):
em(x)
def test_not_correct_dim(self):
def f(x):
return x.cos()
def g(x):
return x + 4
inp_for_f = torch.tensor(5)
with self.assertRaisesRegex(torchdynamo.exc.UserError, "Cannot mark 0-dimension tensors to be dynamic"):
constraints = [dynamic_dim(inp_for_f, 0)]
inp_for_f_mul_dim = torch.ones(5, 5)
with self.assertRaisesRegex(
torchdynamo.exc.UserError,
"Expected the dimension passed to dynamic_dim to be in the range \\[0:1\\]"
):
constraints = [dynamic_dim(inp_for_f_mul_dim, 2)]
inp_for_g = 4
with self.assertRaisesRegex(torchdynamo.exc.UserError, "Expected tensor as input to dynamic_dim"):
constraints = [dynamic_dim(inp_for_g, 0)]
def test_map(self):
def list_tensor_map(xs, y, z):
def body(x, y, z):
return x + y + z
return map(body, xs, y, z)
inps = (torch.ones(6, 4), torch.tensor(5), torch.tensor(4))
self._test_export_same_as_eager(list_tensor_map, inps)
def test_export_func_with_kwargs(self):
def kw_func(arg1, arg2, kw1, kw2):
return arg1 + arg2, kw1 + kw2
args = (torch.ones(6, 4), torch.ones(1, 1))
kwargs = {"kw1": torch.ones(1, 1), "kw2": torch.ones(6, 4)}
self._test_export_same_as_eager(kw_func, args, kwargs)
def test_export_func_with_pytree_kwargs(self):
def kw_func(arg1, arg2, a, b):
return arg1 + a["kw1"] + b[0], arg2 + a["kw2"] + b[1]
args = (torch.ones(2, 3), torch.ones(3, 4))
kwargs = {"a": {"kw1": torch.ones(2, 3), "kw2": torch.ones(3, 4)}, "b": [torch.ones(2, 3), torch.ones(3, 4)]}
self._test_export_same_as_eager(kw_func, args, kwargs)
def test_export_func_with_default_kwargs(self):
def kw_func(arg1, arg2, a, b=1):
return arg1 + arg2, a["kw1"] + a["kw2"] + b
def kw_func2(arg1, arg2, a=1, b=2):
return arg1 + a, arg2 + b
args = (torch.ones(6, 4), torch.ones(1, 1))
kwargs1 = {"a": {"kw1": torch.ones(1, 1), "kw2": torch.ones(6, 4)}}
kwargs2 = {"a": {"kw1": torch.ones(1, 1), "kw2": torch.ones(6, 4)}, "b": 2}
self._test_export_same_as_eager(kw_func, args, kwargs1)
self._test_export_same_as_eager(kw_func, args, kwargs2)
kwargs3 = {"b": 1}
self._test_export_same_as_eager(kw_func2, args, kwargs3)
def test_export_func_with_var_postional_args(self):
def kw_func(arg1, arg2, *args):
return arg1 + args[0], arg2 + args[1]
args = (torch.ones(2, 3), torch.ones(3, 4), torch.ones(2, 3), torch.ones(3, 4))
self._test_export_same_as_eager(kw_func, args)
def test_export_func_with_keyword_only_args(self):
def kw_func(arg1, arg2, *args, kw1, kw2):
return arg1 + args[0] + kw1, arg2 + args[1] + kw2
args = (torch.ones(2, 3), torch.ones(3, 4), torch.ones(2, 3), torch.ones(3, 4))
kwargs = {"kw1": torch.ones(2, 3), "kw2": torch.ones(3, 4)}
self._test_export_same_as_eager(kw_func, args, kwargs)
def test_export_func_with_var_keyword_args(self):
def kw_func(arg1, arg2, *args, kw1, kw2, **kwargs):
return arg1 + args[0] + kw1 + kwargs["kw3"], arg2 + args[1] + kw2 + kwargs["kw4"]
args = (torch.ones(2, 3), torch.ones(3, 4), torch.ones(2, 3), torch.ones(3, 4))
kwargs = {"kw1": torch.ones(2, 3), "kw2": torch.ones(3, 4), "kw3": torch.ones(2, 3), "kw4": torch.ones(3, 4)}
self._test_export_same_as_eager(kw_func, args, kwargs)
def test_export_func_with_var_keyword_pytree_args(self):
def kw_func(arg1, arg2, *args, kw1, kw2, **kwargs):
return arg1 + arg2[0][0] + args[0] + kw1[0] + kwargs["kw3"][0], arg2[1] + args[1] + kw2 + kwargs["kw4"]
args = (torch.ones(2, 3), [(torch.ones(2, 3), ), torch.ones(3, 4)], torch.ones(2, 3), torch.ones(3, 4))
kwargs = {"kw1": (torch.ones(2, 3), ), "kw2": torch.ones(3, 4),
"kw3": (torch.ones(2, 3), torch.ones(3, 4)), "kw4": torch.ones(3, 4)}
self._test_export_same_as_eager(kw_func, args, kwargs)
def test_linear_conv(self):
class MyLinear(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.randn(20, 98)
self.bias = torch.randn(20)
def forward(self, x):
return torch.nn.functional.linear(x, self.weight, self.bias)
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(16, 33, 3)
self.linear = MyLinear()
def forward(self, x):
x_conv = self.conv(x)
x_linear = self.linear(x_conv)
return x_linear.cos()
ep = export(Foo(), (torch.randn(20, 16, 50, 100),))
for node in ep.graph.nodes:
if (
node.op == "placeholder" and
node.name in ep.graph_signature.inputs_to_buffers or
node.name in ep.graph_signature.inputs_to_parameters
):
self.assertTrue("source_fn" in node.meta)
self.assertTrue("nn_module_stack" in node.meta)
def test_error_does_not_reference_eager_fallback(self):
def fn_ddo(x):
y = x.nonzero()
z = y.shape[0]
if z > 2:
return x.cos()
else:
return x.sin()
with self.assertRaisesRegex(
torchdynamo.exc.UserError,
r"^(?!.*fall back to eager).*"
):
_ = export(fn_ddo, (torch.tensor([2, 3, 5]),), constraints=None)
def test_pytree_regster_data_class(self):
@dataclass
class MyDataClass:
x: int
y: int
z: int = None
dt = MyDataClass(x=3, y=4)
flat, spec = tree_flatten(dt)
self.assertTrue(spec, LeafSpec())
self.assertTrue(len(flat) == 1)
register_dataclass_as_pytree_node(MyDataClass)
flat, spec = tree_flatten(dt)
self.assertEqual(
spec,
TreeSpec(
MyDataClass,
(
MyDataClass,
['x', 'y'],
['z']
),
[LeafSpec(), LeafSpec()]
)
)
self.assertEqual(flat, [3, 4])
orig_dt = tree_unflatten(flat, spec)
self.assertTrue(isinstance(orig_dt, MyDataClass))
self.assertEqual(orig_dt.x, 3)
self.assertEqual(orig_dt.y, 4)
self.assertEqual(orig_dt.z, None)
roundtrip_spec = treespec_loads(treespec_dumps(spec))
self.assertEqual(roundtrip_spec, spec)
# Override the registration with keep none fields
register_dataclass_as_pytree_node(MyDataClass, return_none_fields=True)
flat, spec = tree_flatten(dt)
self.assertEqual(
spec,
TreeSpec(
MyDataClass,
(
MyDataClass,
['x', 'y', 'z'],
[],
),
[LeafSpec(), LeafSpec(), LeafSpec()]
)
)
self.assertEqual(flat, [3, 4, None])
orig_dt = tree_unflatten(flat, spec)
self.assertTrue(isinstance(orig_dt, MyDataClass))
self.assertEqual(orig_dt.x, 3)
self.assertEqual(orig_dt.y, 4)
self.assertEqual(orig_dt.z, None)
roundtrip_spec = treespec_loads(treespec_dumps(spec))
self.assertEqual(roundtrip_spec, spec)
def test_pytree_regster_nested_data_class(self):
@dataclass
class Inner:
x: int
y: int
@dataclass
class Outer:
xy: Inner
ab: Inner
xy = Inner(1, 2)
ab = Inner(3, 4)
dt = Outer(xy, ab)
inp = {"dt1": (dt, ({},)), "dt2": ((torch.ones(1),), dt)}
register_dataclass_as_pytree_node(Inner)
register_dataclass_as_pytree_node(Outer)
flat, spec = tree_flatten(inp)
self.assertEqual(flat, [1, 2, 3, 4, torch.ones(1), 1, 2, 3, 4])
unflat = tree_unflatten(flat, spec)
self.assertEqual(unflat, inp)
roundtrip_spec = treespec_loads(treespec_dumps(spec))
self.assertEqual(roundtrip_spec, spec)
def test_param_util(self):
class Basic(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(10, 1)
def forward(self, x):
return self.lin(x)
ep = export(Basic(), (torch.randn(5, 10),))
num_params = 0
params = []
for node in ep.graph.nodes:
if is_param(ep, node):
num_params += 1
params.append(get_param(ep, node))
self.assertEqual(num_params, 2)
self.assertEqual(params[0].shape, [1, 10]) # weight
self.assertEqual(params[1].shape, [1]) # bias
def test_buffer_util(self):
ep = export(torch.nn.BatchNorm2d(100, affine=False), (torch.ones(20, 100, 35, 45), ))
num_buffer = 0
buffer = []
for node in ep.graph.nodes:
if is_buffer(ep, node):
num_buffer += 1
buffer.append(get_buffer(ep, node))
self.assertEqual(num_buffer, 3)
self.assertEqual(buffer[0].shape, torch.Size([100])) # running_mean
self.assertEqual(buffer[1].shape, torch.Size([100])) # running_var
self.assertEqual(buffer[2].shape, torch.Size([])) # num_batches_tracked
def test_export_dynamo_config(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.lstm = torch.nn.LSTM(input_size=4, hidden_size=5, num_layers=1)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return self.lstm(inputs)
config = DEFAULT_EXPORT_DYNAMO_CONFIG
mod = MyModule()
@contextmanager
def _patch_config(kwargs):
orig_config_dict = dataclasses.asdict(config)
try:
for k, v in kwargs.items():
setattr(config, k, v)
yield
finally:
for k, v in orig_config_dict.items():
setattr(config, k, v)
inp = (torch.rand(5, 4), )
exported_program = export(mod, inp)
with _patch_config({"allow_rnn": False}):
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported,
"TorchDynamo purposely graph breaks on RNN, GRU, LSTMs"
):
_ = export(mod, inp)
def test_module(self):
class MyLinear(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.randn(20, 98)
self.bias = torch.randn(20)
def forward(self, x):
return torch.nn.functional.linear(x, self.weight, self.bias)
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(16, 33, 3)
self.linear = MyLinear()
def forward(self, x):
a, b = x
a_conv = self.conv(a)
a_linear = self.linear(a_conv)
b_conv = self.conv(b)
b_linear = self.linear(b_conv)
return (a_linear.cos() + b_linear.sin(), a_linear.sin() + b_linear.cos())
inp_container = ((torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)),)
ep = export(Foo(), inp_container)
ep_rexported = export(ep.module(), inp_container)
inp_test = ((torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)),)
self.assertTrue(torch.allclose(ep(*inp_test)[0], ep_rexported(*inp_test)[0]))
self.assertTrue(torch.allclose(ep(*inp_test)[1], ep_rexported(*inp_test)[1]))
def test_module_with_dict_container_inp_out(self):
class MyLinear(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.randn(20, 98)
self.bias = torch.randn(20)
def forward(self, x):
return torch.nn.functional.linear(x, self.weight, self.bias)
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(16, 33, 3)
self.linear = MyLinear()
def forward(self, x):
a1, a2 = x["a"]
b = x["b"]
a1_conv = self.conv(a1)
a1_linear = self.linear(a1_conv)
a2_conv = self.conv(a2)
a2_linear = self.linear(a2_conv)
b_conv = self.conv(b)
b_linear = self.linear(b_conv)
return {"a": a1_linear.cos() + b_linear.sin(), "b": a2_linear.sin() + b_linear.cos()}
inp_container = ({"a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)), "b": torch.randn(20, 16, 50, 100)},)
ep = export(Foo(), inp_container)
ep_rexported = export(ep.module(), inp_container)
inp_test = ({"a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)), "b": torch.randn(20, 16, 50, 100)},)
self.assertTrue(torch.allclose(ep(*inp_test)["a"], ep_rexported(*inp_test)["a"]))
self.assertTrue(torch.allclose(ep(*inp_test)["b"], ep_rexported(*inp_test)["b"]))
def test_args_type_checked(self):
def fn(x):
return x + 1
inp = torch.rand(2, 2)
with self.assertRaisesRegex(torch._dynamo.exc.UserError, "to be a tuple"):
# Intentionally not wrapping `inp` in a tuple to trigger the error
_ = export(fn, inp)
def test_constrain_value_with_no_default(self):
def fn(x, y):
n = x.max().item()
constrain_as_value(n)
return y + n
ep = export(fn, (torch.randint(3, 5, (2, 2)), torch.randint(3, 5, (2, 3))))
test_inp = (torch.randint(3, 5, (2, 2)), torch.randint(3, 5, (2, 3)))
self.assertTrue(torch.allclose(ep(*test_inp), fn(*test_inp)))
def test_constrain_value_with_symfloat(self):
def fn(x, y):
n = x.max().item()
constrain_as_value(n)
return y + n
with self.assertRaisesRegex(torch._dynamo.exc.TorchRuntimeError, "Constraining SymFloat or Symbool is nyi"):
_ = export(fn, (torch.rand(2, 2), torch.rand(2, 3)))
def test_constrain_size_in_eager(self):
def fn(x, y):
n = x.max().item()
constrain_as_size(n)
return y + n
ep = export(fn, (torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3))))
test_inp = (torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3)))
self.assertTrue(torch.allclose(ep(*test_inp), fn(*test_inp)))
def test_constrain_size_with_constrain_value(self):
def fn(x, y):
n = x.max().item()
constrain_as_value(n, 2, 10)
constrain_as_size(n)
return y + n
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 1 between \[2, 10\]."):
_ = fn(torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3)))
ep = export(fn, (torch.randint(3, 4, (2, 2)), torch.randint(3, 5, (2, 3))))
with self.assertRaisesRegex(RuntimeError, "is outside of inline constraint"):
test_inp = (torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3)))
_ = ep(*test_inp)
def test_constrain_size_with_various_cases(self):
def case_1(x, y):
n = x.item()
constrain_as_size(n, min=0)
return y.sum() + torch.ones(n, 5).sum()
def case_2(x, y):
n = x.item()
constrain_as_size(n, min=0, max=6)
return y.sum() + torch.ones(n, 5).sum()
def case_3(x, y):
n = x.item()
constrain_as_size(n, min=0, max=1)
return y.sum() + torch.ones(n, 5).sum()
def case_4(x, y):
n = x.item()
constrain_as_size(n, min=2)
return y.sum() + torch.ones(n, 5).sum()
def case_5(x, y):
n = x.item()
constrain_as_size(n, min=1)
return y.sum() + torch.ones(n, 5).sum()
ep = export(case_1, (torch.tensor(1), torch.ones(4, 5)))
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for -1 between"):
_ = case_1(torch.tensor(-1), torch.randn(4, 5))
self.assertTrue(
torch.allclose(
ep(torch.tensor(1), torch.ones(4, 5)),
case_1(torch.tensor(1), torch.ones(4, 5)),
)
)
ep = export(case_2, (torch.tensor(5), torch.randn(4, 5)))
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 7 between"):
_ = case_2(torch.tensor(7), torch.randn(4, 5))
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 9 between"):
_ = case_2(torch.tensor(9), torch.randn(4, 5))
self.assertTrue(
torch.allclose(
ep(torch.tensor(5), torch.ones(4, 5)),
case_2(torch.tensor(5), torch.ones(4, 5)),
)
)
with self.assertRaisesRegex(RuntimeError, "Max value to constrain_range_for_size must be greater than 2. got: 1"):
_ = case_3(torch.tensor(1), torch.randn(4, 5))
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 1 between \[2, 9223372036854775807\]."):
_ = case_4(torch.tensor(1), torch.randn(4, 5))
ep = export(case_4, (torch.tensor(5), torch.randn(4, 5)))
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 1"):
_ = case_4(torch.tensor(1), torch.randn(4, 5))
self.assertTrue(
torch.allclose(
ep(torch.tensor(5), torch.ones(4, 5)),
case_4(torch.tensor(5), torch.ones(4, 5)),
)
)
ep = export(case_5, (torch.tensor(5), torch.randn(4, 5)))
with self.assertRaisesRegex(RuntimeError, r"Invalid value range for 0"):
_ = case_5(torch.tensor(0), torch.randn(4, 5))
self.assertTrue(
torch.allclose(
ep(torch.tensor(5), torch.ones(4, 5)),
case_5(torch.tensor(5), torch.ones(4, 5)),
)
)
def test_mixed_input(self):
def func(a, b, alpha: int):
return torch.add(a, b, alpha=alpha)
a = torch.rand(1, 2)
b = torch.rand(1, 2)
alpha = 10
exported = torch._export.export(func, (a, b, alpha))
for node in exported.graph_module.graph.nodes:
if node.op == "placeholder":
self.assertTrue(isinstance(node.meta["val"], (Tensor, int)))
def test_export_with_inline_constraints(self):
def f(x):
a = x.item()
constrain_as_value(a, 4, 7)
return torch.empty((a, 4))
ep = export(f, (torch.tensor([5]),))
self.assertEqual(ep(torch.tensor([6])).shape, (6, 4))
FileCheck().check_count(
"torch.ops.aten.sym_constrain_range.default", 1, exactly=True
).run(ep.graph_module.code)
with self.assertRaisesRegex(
RuntimeError,
r"_local_scalar_dense is outside of inline constraint \[4, 7\]",
) as cm:
ep(torch.tensor([30]))
def test_export_with_inline_constraints_complex(self):
def f(x):
a = x.item()
constrain_as_value(a, 4, 7)
empty = torch.empty((a, 4))
return torch.cat((empty.transpose(0, 1), torch.zeros(6, a)), 0)
ep = export(f, (torch.tensor([6]),))
self.assertEqual(ep(torch.tensor([5])).shape, (10, 5))
FileCheck().check_count(
"torch.ops.aten.sym_constrain_range.default", 1, exactly=True
).run(ep.graph_module.code)
def test_to_module_with_mutated_buffer(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buf", torch.zeros(1))
def forward(self, x):
self.buf.add_(1)
return x.sum() + self.buf.sum()
exported = torch._export.export(Foo(), (torch.ones(5, 5),))
stateful_gm = exported.module()
export_return_val = stateful_gm(torch.ones(5, 5))
eager = Foo()
eager_return_val = eager(torch.ones(5, 5))
self.assertTrue(torch.allclose(eager_return_val, export_return_val))
for name, buffer in stateful_gm.named_buffers():
self.assertTrue(torch.allclose(torch.ones(1), buffer))
changed = stateful_gm.graph.eliminate_dead_code()
self.assertFalse(changed)
self.assertTrue(torch.allclose(stateful_gm(torch.ones(5, 5)), eager(torch.ones(5, 5))))
for name, buffer in stateful_gm.named_buffers():
self.assertTrue(torch.allclose(torch.tensor(2, dtype=torch.float), buffer))
def test_to_module_with_mutated_buffer_multiple(self):
class Bar(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buf", torch.ones(1))
def forward(self, x):
self.buf.add_(1)
return x.sum() + self.buf.sum()
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buf", torch.zeros(1))
self.bar = Bar()
def forward(self, x):
self.buf.add_(1)
self.bar.buf.add_(2)
bar = self.bar(x)
return bar.sum() + self.buf.sum()
exported = torch._export.export(Foo(), (torch.ones(5, 5),))
stateful_gm = exported.module()
export_return_val = stateful_gm(torch.ones(5, 5))
eager = Foo()
eager_return_val = eager(torch.ones(5, 5))
self.assertTrue(torch.allclose(eager_return_val, export_return_val))
for name, buffer in stateful_gm.named_buffers():
if name == "L__self___buf":
self.assertTrue(torch.allclose(torch.ones(1), buffer))
if name == "L__self___bar_buf":
self.assertTrue(torch.allclose(torch.tensor(4, dtype=torch.float), buffer))
changed = stateful_gm.graph.eliminate_dead_code()
self.assertFalse(changed)
self.assertTrue(torch.allclose(stateful_gm(torch.ones(5, 5)), eager(torch.ones(5, 5))))
for name, buffer in stateful_gm.named_buffers():
if name == "L__self___buf":
self.assertTrue(torch.allclose(torch.tensor(2, dtype=torch.float), buffer))
if name == "L__self___bar_buf":
self.assertTrue(torch.allclose(torch.tensor(7, dtype=torch.float), buffer))
def test_runtime_assert_for_prim(self):
def f(x, y):
return x + y
tensor_inp = torch.ones(7, 5)
exported = torch._export.export(f, (tensor_inp, 5), constraints=[dynamic_dim(tensor_inp, 0) > 5])
self.assertTrue(torch.allclose(exported(torch.ones(8, 5), 5), f(torch.ones(8, 5), 5)))
with self.assertRaisesRegex(RuntimeError, "Input arg1_1 is specialized to be 5 at tracing time"):
_ = exported(torch.ones(8, 5), 6)
exported = torch._export.export(f, (tensor_inp, 5.0), constraints=[dynamic_dim(tensor_inp, 0) > 5])
with self.assertRaisesRegex(RuntimeError, "Input arg1_1 is specialized to be 5.0 at tracing time"):
_ = exported(torch.ones(7, 5), 6.0)
def test_runtime_assert_for_prm_str(self):
def g(a, b, mode):
return torch.div(a, b, rounding_mode=mode)
inps = (torch.randn(4, 4), torch.randn(4), "trunc")
exported = torch._export.export(g, inps)
with self.assertRaisesRegex(RuntimeError, "Input arg2_1 is specialized to be trunc at"):
_ = exported(torch.randn(4, 4), torch.randn(4), "floor")
self.assertTrue(torch.allclose(exported(*inps), g(*inps)))
def test_to_module_with_mutated_buffer_multiple_update_sub_later(self):
class Bar(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buf", torch.ones(1))
def forward(self, x):
self.buf.add_(1)
return x.sum() + self.buf.sum()
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buf", torch.zeros(1))
self.bar = Bar()
def forward(self, x):
self.buf.add_(1)
bar = self.bar(x)
self.bar.buf.add_(2)
return bar.sum() + self.buf.sum()
exported = torch._export.export(Foo(), (torch.ones(5, 5),))
stateful_gm = exported.module()
export_return_val = stateful_gm(torch.ones(5, 5))
eager = Foo()
eager_return_val = eager(torch.ones(5, 5))
self.assertTrue(torch.allclose(eager_return_val, export_return_val))
for name, buffer in stateful_gm.named_buffers():
if name == "L__self___buf":
self.assertTrue(torch.allclose(torch.ones(1), buffer))
if name == "L__self___bar_buf":
self.assertTrue(torch.allclose(torch.tensor(4, dtype=torch.float), buffer))
changed = stateful_gm.graph.eliminate_dead_code()
self.assertFalse(changed)
self.assertTrue(torch.allclose(stateful_gm(torch.ones(5, 5)), eager(torch.ones(5, 5))))
for name, buffer in stateful_gm.named_buffers():
if name == "L__self___buf":
self.assertTrue(torch.allclose(torch.tensor(2, dtype=torch.float), buffer))
if name == "L__self___bar_buf":
self.assertTrue(torch.allclose(torch.tensor(7, dtype=torch.float), buffer))
def test_retracable_ep(self):
class Bar(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buf", torch.ones(1))
def forward(self, x):
self.buf.add_(1)
return x.sum() + self.buf.sum()
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buf", torch.zeros(1))
self.bar = Bar()
def forward(self, x):
self.buf.add_(1)
bar = self.bar(x)
self.bar.buf.add_(2)
return bar.sum() + self.buf.sum()
inp = torch.ones(5, 5)
exported = torch._export.export(Foo(), (inp,))
reexported = torch._export.export(exported, (inp,))
self.assertTrue(torch.allclose(exported(inp), reexported(inp)))
inp = torch.ones(5, 5)
exported = torch._export.export(Foo(), (inp,), constraints=[dynamic_dim(inp, 0)])
reexported = torch._export.export(exported, (inp,))
self.assertTrue(torch.allclose(exported(torch.ones(7, 5)), reexported(torch.ones(7, 5))))
exported = torch._export.export(Foo(), (inp,), constraints=[dynamic_dim(inp, 0)])
# This seems fine because the exported program is generalized to work for dynamic shapes.
reexported = torch._export.export(exported, (inp,))
self.assertTrue(torch.allclose(exported(torch.ones(7, 5)), reexported(torch.ones(7, 5))))
exported = torch._export.export(Foo(), (inp,), constraints=[dynamic_dim(inp, 0)])
with self.assertRaisesRegex(torch._dynamo.exc.UserError, 'Cannot provide constraints for already exported program.'):
_ = torch._export.export(exported, (inp,), constraints=[dynamic_dim(inp, 0)])
# Reexported program should still work for dynamic shapes.
reexported = torch._export.export(exported, (inp,))
self.assertTrue(reexported(torch.ones(7, 5)), Foo()(torch.ones(7, 5)))
def test_retrace_graph_level_meta_preservation(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
if x.shape[0] > 4:
return x.cos()
return x.sin()
inp = torch.ones(7, 5)
exported = torch._export.export(Foo(), (inp,), constraints=[dynamic_dim(inp, 0) > 5])
stateful_module = exported.module()
self.assertTrue(len(stateful_module.meta["input_shape_constraints"]), 1)
re_exported = torch._export.export(stateful_module, (inp,))
self.assertTrue(len(re_exported.graph_module.meta["input_shape_constraints"]), 1)
self.assertTrue(torch.allclose(exported(torch.ones(7, 5)), re_exported(torch.ones(7, 5))))
re_exported_v2 = torch._export.export(exported, (inp,))
self.assertTrue(len(re_exported_v2.graph_module.meta["input_shape_constraints"]), 1)
self.assertTrue(torch.allclose(exported(torch.ones(7, 5)), re_exported_v2(torch.ones(7, 5))))
def test_constrain_as_size_error(self):
def f(x):
a = x.item()
return torch.full((a, 4), 0)
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
"Tried to use data-dependent value in the subsequent computation"
):
_ = export(f, (torch.tensor(6),))
def test_constraint_directly_construct(self):
with self.assertRaisesRegex(
TypeError,
"torch.export.Constraint has no public constructor. Please use torch.export.dynamic_dim"
):
_ = Constraint()
def test_train_eval_on_exported_preautograd_module(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
if x.shape[0] > 4:
return x.cos()
return x.sin()
graph_module = capture_pre_autograd_graph(Foo(), (torch.ones(7, 5),))
with self.assertRaisesRegex(NotImplementedError, r"Calling train\(\) is not supported yet."):
graph_module.train()
with self.assertRaisesRegex(NotImplementedError, r"Calling eval\(\) is not supported yet."):
graph_module.eval()
def test_export_cond_preserve_stack_trace_for_subgraphs(self):
class MySubModule(torch.nn.Module):
def foo(self, x):
return x.cos()
def forward(self, x):
return self.foo(x)
class CondBranchClassMethod(torch.nn.Module):
def __init__(self):
super().__init__()
self.subm = MySubModule()
def bar(self, x):
return x.sin()
def forward(self, x):
return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])
from torch._export import capture_pre_autograd_graph
example_inputs = (torch.randn(1, 3, 3, 3),)
m = CondBranchClassMethod()
m.eval()
gm = capture_pre_autograd_graph(m, example_inputs)
actual_source_fns = []
for mod in gm.modules():
for node in mod.graph.nodes:
if node.name in {"sin", "cos"}:
actual_source_fns.append(node.meta.get("source_fn", None))
exp_source_fns = [("cos", "cos"), ("sin", "sin")]
self.assertEqual(actual_source_fns, exp_source_fns)
def test_lift_constants(self) -> None:
from torch._export.passes.lift_constant_tensor_pass import lift_constant_tensor_pass
def f(x):
return x + torch.tensor(3)
ep = export(f, (torch.tensor(1),))
ep = lift_constant_tensor_pass(ep)
for node in ep.graph.nodes:
self.assertTrue(node.op != "get_attr")
self.assertEqual(len(ep.graph_signature.buffers), 1)
self.assertEqual(len(ep.state_dict), 1)
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.tensor(3)
def forward(self, x):
list_tensor = [torch.tensor(3), torch.tensor(4)]
return x + self.a + list_tensor[0] + list_tensor[1]
ep = export(Foo(), (torch.tensor(1),))
ep = lift_constant_tensor_pass(ep)
nodes = list(ep.graph.nodes)
for node in nodes:
self.assertTrue(node.op != "get_attr")
self.assertEqual(len(ep.graph_signature.buffers), 3)
self.assertEqual(len(ep.state_dict), 3)
# These constants should be placed after the param/buffers
self.assertTrue(
nodes[1].name in ep.graph_signature.inputs_to_buffers and
nodes[2].name in ep.graph_signature.inputs_to_buffers
)
if __name__ == '__main__':
run_tests()