blob: bfb53072bcb827006b250c3c7225e49c750e60a5 [file] [log] [blame]
# Owner(s): ["module: codegen"]
import torch
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo, TEST_WITH_TORCHDYNAMO
from torch.testing._internal.logging_tensor import LoggingTensor, LoggingTensorReentrant, capture_logs
from torch.utils._pytree import tree_map
from torch.fx.experimental.proxy_tensor import make_fx
import unittest
import logging
def are_aliased(x, y):
if x._base is None and y._base is None:
return False
if x._base is not None and y._base is None:
return x._base is y
if x._base is None and y._base is not None:
return y._base is x
return x._base is y._base
# Just for testing: a logging tensor that also transforms out-of-place ops into inplace ops.
# That way even if the outer wrapper is functionalized, the inner wrapper will also need functionalization.
class InplaceLoggingTensor(LoggingTensorReentrant):
@staticmethod
def __new__(cls, e):
r = torch.Tensor._make_wrapper_subclass(cls, e.shape, dtype=e.dtype, requires_grad=False)
r.elem = e
return r
__torch_function__ = torch._C._disabled_torch_function_impl
def __str__(self):
return f'InplaceLoggingTensor({self.elem})'
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(e):
if isinstance(e, InplaceLoggingTensor):
return e.elem
else:
return e
def wrap(e):
if isinstance(e, torch.Tensor):
return InplaceLoggingTensor(e)
else:
return e
f = func
# this subclass converts all `add()` ops into `add_()` ops
if f is torch.ops.aten.add.Tensor:
f = torch.ops.aten.add_.Tensor
with cls.context():
rs = tree_map(wrap, f(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
# after running the (potentially transformed) op,
# log the original op that we saw.
logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs)
return rs
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "https://github.com/pytorch/pytorch/issues/81457")
class TestFunctionalization(TestCase):
# We can unify testing and use functionalize() here instead
# if/when functorch moves into core.
def _functionalize(self, f, *, reapply_views: bool):
def wrapped(a):
input_functional = torch._to_functional_tensor(a)
torch._enable_functionalization(reapply_views=reapply_views)
try:
out = f(input_functional)
finally:
torch._disable_functionalization()
torch._sync(input_functional)
tree_map(torch._sync, out)
out_unwrapped = tree_map(torch._from_functional_tensor, out)
return out_unwrapped
return wrapped
def get_logs(self, func, inpt, *, reapply_views=False):
traced_f = make_fx(self._functionalize(func, reapply_views=reapply_views))(inpt)
return traced_f.code
def assert_functionalization(self, func, inpt, *, reapply_views=False):
input_clone = inpt.clone()
input_clone2 = inpt.clone()
input_functional = torch._to_functional_tensor(input_clone2)
# Compare outputs (and mutated inputs), with and without functionalization.
out_ref = func(inpt)
torch._enable_functionalization(reapply_views=reapply_views)
try:
out_functional = func(input_functional)
finally:
torch._disable_functionalization()
# We need to sync the input tensors first, in case there are any queued mutations left.
torch._sync(input_functional)
self.assertEqual(inpt, torch._from_functional_tensor(input_functional)) # input mutations should still occur
# Handle tests with multi-tensor outputs
if isinstance(out_ref, tuple) and isinstance(out_functional, tuple):
out_refs, out_functionals = list(out_ref), list(out_functional)
else:
out_refs, out_functionals = [out_ref], [out_functional]
for out_ref_, out_functional_ in zip(out_refs, out_functionals):
self.assertEqual(out_ref_.size(), out_functional_.size())
torch._sync(out_functional_)
out_functional_unwrapped = torch._from_functional_tensor(out_functional_)
self.assertEqual(out_ref_, out_functional_unwrapped)
def test_save_for_backwards_segfault(self):
inp = torch._to_functional_tensor(LoggingTensor(torch.randn(2, 2))).requires_grad_(True)
inp.exp()
def test_multiple_views_of_same_base(self):
def f(x):
y = x.view(-1)
z = x.view(-1)
x.add_(1)
# y should have been updated.
y2 = y + 1
# z should have been updated too.
z2 = z + 1
return z2
self.assert_functionalization(f, torch.ones(4))
def test_simple(self):
def f(x):
# simple test: 1 view op, 1 inplace op
tmp = torch.ones(4, 2)
y = x.view(4, 2)
y.add_(tmp)
z = x * x
return y
self.assert_functionalization(f, torch.ones(4, 2))
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, fill_scalar); view_copy_default = fill_scalar = None
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2])
mul_tensor = torch.ops.aten.mul.Tensor(view_copy_default_1, view_copy_default_1); view_copy_default_1 = None
return add_tensor
""")
def test_simple_out(self):
def f(x):
tmp = torch.ones(4, 2)
y = x.view(4, 2)
# the out= tensor will get resized, since it has size=0 to start.
z = torch.empty(())
torch.add(y, tmp, out=z)
w = z * z
return w
self.assert_functionalization(f, torch.ones(4, 2))
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None
empty_1 = torch.ops.aten.empty.SymInt([], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, fill_scalar); view_copy_default = fill_scalar = None
mul_tensor = torch.ops.aten.mul.Tensor(add_tensor, add_tensor); add_tensor = None
return mul_tensor
""")
def test_multi_out(self):
def f(x):
# aminmax.out returns a tuple of tensors.
# functionalization should properly handle the tuple.
out_min = torch.empty(4)
out_max = torch.empty(4)
torch.aminmax(x, dim=0, out=(out_max, out_min))
return out_max
self.assert_functionalization(f, torch.arange(8, dtype=torch.float32))
logs = self.get_logs(f, torch.arange(8, dtype=torch.float32))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.SymInt([4], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
empty_1 = torch.ops.aten.empty.SymInt([4], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
aminmax_default = torch.ops.aten.aminmax.default(a_1, dim = 0); a_1 = None
getitem = aminmax_default[0]
getitem_1 = aminmax_default[1]; aminmax_default = None
return getitem
""")
def test_tensor_ctr(self):
def f(x):
y = torch.tensor((1, 2, 3))
z = y.view(-1)
z.add_(1)
return y
self.assert_functionalization(f, torch.arange(3, dtype=torch.float32))
def test_inplace_on_non_view(self):
def f(x):
# test for the case where we functionalize an inplace op on the other tensor - not a view.
# This is worth checking because the tensor will have an empty ViewMeta stack, which needs to be special cased.
tmp = torch.ones(4, 2)
y = x.view(4, 2)
x.add_(tmp)
return y
self.assert_functionalization(f, torch.ones(4, 2))
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2])
add_tensor = torch.ops.aten.add.Tensor(a_1, fill_scalar); a_1 = fill_scalar = None
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2]); add_tensor = None
return view_copy_default_1
""")
# Some ops that are mutable are neither inplace nor out= ops.
# They also need special handling.
def test_mutable_op_not_inplace_or_other(self):
def f(x):
return torch._fused_moving_avg_obs_fq_helper(x, x, x, x, x, x, x, 1.0, 0, 1, 0)
logs = self.get_logs(f, torch.ones(1))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
_fused_moving_avg_obs_fq_helper_functional_default = torch.ops.aten._fused_moving_avg_obs_fq_helper_functional.default(a_1, a_1, a_1, a_1, a_1, a_1, a_1, 1.0, 0, 1, 0); a_1 = None
getitem = _fused_moving_avg_obs_fq_helper_functional_default[0]
getitem_1 = _fused_moving_avg_obs_fq_helper_functional_default[1]
getitem_2 = _fused_moving_avg_obs_fq_helper_functional_default[2]
getitem_3 = _fused_moving_avg_obs_fq_helper_functional_default[3]
getitem_4 = _fused_moving_avg_obs_fq_helper_functional_default[4]
getitem_5 = _fused_moving_avg_obs_fq_helper_functional_default[5]; _fused_moving_avg_obs_fq_helper_functional_default = None
return (getitem, getitem_1)
""") # noqa: B950
def test_as_strided(self):
def f(x):
y = x.as_strided((2,), (2,), 1)
y.add_(1)
return x
self.assert_functionalization(f, torch.ones(9))
logs = self.get_logs(f, torch.ones(9))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
as_strided_copy_default = torch.ops.aten.as_strided_copy.default(a_1, [2], [2], 1)
add_tensor = torch.ops.aten.add.Tensor(as_strided_copy_default, 1); as_strided_copy_default = None
as_strided_scatter_default = torch.ops.aten.as_strided_scatter.default(a_1, add_tensor, [2], [2], 1); a_1 = add_tensor = None
return as_strided_scatter_default
""")
def test_tensor_list_composite(self):
def f(x):
# Test an op with TensorList input
y = torch.block_diag(x, x)
return y
self.assert_functionalization(f, torch.ones(2, 2))
logs = self.get_logs(f, torch.ones(2, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
block_diag_default = torch.ops.aten.block_diag.default([a_1, a_1]); a_1 = None
return block_diag_default
""")
def test_cat(self):
def f(x):
out = torch.empty(0)
torch.cat((x,), out=out)
return out
self.assert_functionalization(f, torch.ones(2, 2))
logs = self.get_logs(f, torch.ones(2, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.SymInt([0], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
cat_default = torch.ops.aten.cat.default([a_1]); a_1 = None
return cat_default
""")
def test_diagonal(self):
def f(x):
# test: view ops that take a subset of the original tensor (select/diagonal)
tmp = torch.ones(2)
y = x.diagonal()
y.add_(tmp)
z = x * x
return z
self.assert_functionalization(f, torch.ones(2, 2))
logs = self.get_logs(f, torch.ones(2, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(a_1)
add_tensor = torch.ops.aten.add.Tensor(diagonal_copy_default, fill_scalar); diagonal_copy_default = fill_scalar = None
diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(a_1, add_tensor); a_1 = add_tensor = None
mul_tensor = torch.ops.aten.mul.Tensor(diagonal_scatter_default, diagonal_scatter_default); diagonal_scatter_default = None
return mul_tensor
""")
def test_diagonal_mutated_input(self):
def f(x):
# simple test: there are pending updates afterwards, which the test syncs manually
tmp = torch.ones(2)
y = x.diagonal()
y.add_(tmp)
return x
x = torch.ones(2, 2)
self.assert_functionalization(f, x)
def test_split(self):
def f(x):
# test: view ops that return multiple tensors (split)
tmp = torch.ones(2)
y1, y2 = x.split(2)
y3 = y2.diagonal()
y3.add_(tmp)
z = x * x
return y3
self.assert_functionalization(f, torch.ones(4, 2))
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
split_copy_tensor = torch.ops.aten.split_copy.Tensor(a_1, 2)
getitem = split_copy_tensor[0]
getitem_1 = split_copy_tensor[1]; split_copy_tensor = None
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(getitem_1); getitem_1 = None
add_tensor = torch.ops.aten.add.Tensor(diagonal_copy_default, fill_scalar); diagonal_copy_default = fill_scalar = None
split_copy_tensor_1 = torch.ops.aten.split_copy.Tensor(a_1, 2)
getitem_2 = split_copy_tensor_1[0]
getitem_3 = split_copy_tensor_1[1]; split_copy_tensor_1 = None
diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(getitem_3, add_tensor); getitem_3 = None
slice_scatter_default = torch.ops.aten.slice_scatter.default(a_1, diagonal_scatter_default, 0, 2, 4); a_1 = diagonal_scatter_default = None
mul_tensor = torch.ops.aten.mul.Tensor(slice_scatter_default, slice_scatter_default); slice_scatter_default = None
return add_tensor
""") # noqa: B950
def test_view_inplace(self):
def f(x):
# test: view + inplace op (transpose_)
tmp = torch.ones(4)
x.transpose_(1, 0)
y = x[0]
y.add_(tmp)
return x
self.assert_functionalization(f, torch.ones(4, 2))
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([4], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
transpose_copy_int = torch.ops.aten.transpose_copy.int(a_1, 1, 0)
select_copy_int = torch.ops.aten.select_copy.int(transpose_copy_int, 0, 0); transpose_copy_int = None
add_tensor = torch.ops.aten.add.Tensor(select_copy_int, fill_scalar); select_copy_int = fill_scalar = None
transpose_copy_int_1 = torch.ops.aten.transpose_copy.int(a_1, 1, 0); a_1 = None
select_scatter_default = torch.ops.aten.select_scatter.default(transpose_copy_int_1, add_tensor, 0, 0); transpose_copy_int_1 = add_tensor = None
transpose_copy_int_2 = torch.ops.aten.transpose_copy.int(select_scatter_default, 1, 0); select_scatter_default = None
transpose_copy_int_3 = torch.ops.aten.transpose_copy.int(transpose_copy_int_2, 1, 0); transpose_copy_int_2 = None
return transpose_copy_int_3
""") # noqa: B950
def test_optional_tensor_list(self):
def f(x):
# test: an operator that takes in a List[Optional[Tensor]] argument
# (index_put)
y = x.view(8)
indices = torch.arange(4)
values = torch.arange(4, dtype=y.dtype)
y.index_put_((indices,), values, accumulate=False)
return y
self.assert_functionalization(f, torch.ones(4, 2))
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
view_copy_default = torch.ops.aten.view_copy.default(a_1, [8]); a_1 = None
empty = torch.ops.aten.empty.memory_format([0], dtype = torch.int64, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
arange = torch.ops.aten.arange.start_step(0, 4, 1, dtype = torch.int64, layout = torch.strided, device = device(type='cpu'))
empty_1 = torch.ops.aten.empty.memory_format([0], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
arange_1 = torch.ops.aten.arange.start_step(0, 4, 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'))
index_put_default = torch.ops.aten.index_put.default(view_copy_default, [arange], arange_1); view_copy_default = arange = arange_1 = None
view_copy_default_1 = torch.ops.aten.view_copy.default(index_put_default, [4, 2])
return index_put_default
""") # noqa: B950
def test_scalars(self):
def f(x):
# test: the pass can handle scalar inputs properly
tmp = torch.ones(4, 2)
y = x.view(4, 2)
y.add_(1)
z = 2 * y
z.div_(1)
return z
self.assert_functionalization(f, torch.ones(4, 2))
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None
add_tensor = torch.ops.aten.add.Tensor(view_copy_default, 1); view_copy_default = None
mul_tensor = torch.ops.aten.mul.Tensor(add_tensor, 2)
div_tensor = torch.ops.aten.div.Tensor(mul_tensor, 1); mul_tensor = None
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2]); add_tensor = None
return div_tensor
""")
@skipIfTorchDynamo("Test does not work with TorchDynamo")
def test_metadata_change(self):
def f(x):
# ops like ge_() are allowed to change the dtype of the input.
# functionalization should pick up on that.
return x.ge_(0)
self.assert_functionalization(f, torch.ones(4, 2))
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
ge_scalar = torch.ops.aten.ge.Scalar(a_1, 0); a_1 = None
_to_copy_default = torch.ops.aten._to_copy.default(ge_scalar, dtype = torch.float32, layout = torch.strided); ge_scalar = None
_tensor_constant0 = self._tensor_constant0
return _tensor_constant0
""")
def test_only_one_view(self):
def f(x):
# This tests that we don't have any unnecessary views in the trace.
# If the input wasn't mutated, we don't need to regenerate it,
# so there should be a total of 1 op in the output trace.
return x.view(4, 2)
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None
return view_copy_default
""")
def test_everything(self):
def f(x):
# test: everything
tmp = torch.ones(2, 2)
x2 = x + x
y = x2.view(8)
z0 = y.reshape(2, 4)
z1 = z0.transpose(1, 0)
z1.unsqueeze_(0)
z1.squeeze_()
z2, z3 = z1.split(2)
z2.add_(tmp)
z4 = z0[0] + z2.reshape(4)
return z2
self.assert_functionalization(f, torch.ones(4, 2))
logs = self.get_logs(f, torch.ones(4, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
view_copy_default = torch.ops.aten.view_copy.default(add_tensor, [8])
_reshape_alias_copy_default = torch.ops.aten._reshape_alias_copy.default(view_copy_default, [2, 4], [4, 1]); view_copy_default = None
transpose_copy_int = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_default, 1, 0)
unsqueeze_copy_default = torch.ops.aten.unsqueeze_copy.default(transpose_copy_int, 0); transpose_copy_int = None
squeeze_copy_default = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_default); unsqueeze_copy_default = None
split_copy_tensor = torch.ops.aten.split_copy.Tensor(squeeze_copy_default, 2); squeeze_copy_default = None
getitem = split_copy_tensor[0]
getitem_1 = split_copy_tensor[1]; split_copy_tensor = None
add_tensor_1 = torch.ops.aten.add.Tensor(getitem, fill_scalar); getitem = fill_scalar = None
select_copy_int = torch.ops.aten.select_copy.int(_reshape_alias_copy_default, 0, 0); _reshape_alias_copy_default = None
clone_default = torch.ops.aten.clone.default(add_tensor_1, memory_format = torch.contiguous_format)
_unsafe_view_default = torch.ops.aten._unsafe_view.default(clone_default, [4]); clone_default = None
view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [8]); add_tensor = None
_reshape_alias_copy_default_1 = torch.ops.aten._reshape_alias_copy.default(view_copy_default_1, [2, 4], [4, 1]); view_copy_default_1 = None
transpose_copy_int_1 = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_default_1, 1, 0); _reshape_alias_copy_default_1 = None
unsqueeze_copy_default_1 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_int_1, 0); transpose_copy_int_1 = None
squeeze_copy_default_1 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_default_1); unsqueeze_copy_default_1 = None
slice_scatter_default = torch.ops.aten.slice_scatter.default(squeeze_copy_default_1, add_tensor_1, 0, 0, 2); squeeze_copy_default_1 = None
unsqueeze_copy_default_2 = torch.ops.aten.unsqueeze_copy.default(slice_scatter_default, 0); slice_scatter_default = None
squeeze_copy_dim = torch.ops.aten.squeeze_copy.dim(unsqueeze_copy_default_2, 0); unsqueeze_copy_default_2 = None
transpose_copy_int_2 = torch.ops.aten.transpose_copy.int(squeeze_copy_dim, 1, 0); squeeze_copy_dim = None
_reshape_alias_copy_default_2 = torch.ops.aten._reshape_alias_copy.default(transpose_copy_int_2, [8], [1]); transpose_copy_int_2 = None
view_copy_default_2 = torch.ops.aten.view_copy.default(_reshape_alias_copy_default_2, [4, 2]); _reshape_alias_copy_default_2 = None
view_copy_default_3 = torch.ops.aten.view_copy.default(view_copy_default_2, [8]); view_copy_default_2 = None
_reshape_alias_copy_default_3 = torch.ops.aten._reshape_alias_copy.default(view_copy_default_3, [2, 4], [4, 1]); view_copy_default_3 = None
select_copy_int_1 = torch.ops.aten.select_copy.int(_reshape_alias_copy_default_3, 0, 0); _reshape_alias_copy_default_3 = None
add_tensor_2 = torch.ops.aten.add.Tensor(select_copy_int_1, _unsafe_view_default); select_copy_int_1 = _unsafe_view_default = None
return add_tensor_1
""") # noqa: B950
def test_reapply_views_simple(self):
def f(x):
tmp = torch.ones(4, 2)
y = x.view(4, 2)
y.add_(tmp)
z = x * x
return y
self.assert_functionalization(f, torch.ones(4, 2), reapply_views=True)
logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True)
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None
view_default = torch.ops.aten.view.default(a_1, [4, 2]); a_1 = None
add_tensor = torch.ops.aten.add.Tensor(view_default, fill_scalar); view_default = fill_scalar = None
view_default_1 = torch.ops.aten.view.default(add_tensor, [4, 2])
mul_tensor = torch.ops.aten.mul.Tensor(view_default_1, view_default_1); view_default_1 = None
return add_tensor
""")
def test_aliases_maintained_after_pass_when_reapplying_views(self):
def f(x):
tmp = torch.ones(4, 2)
y = x.view(4, 2)
z = x.view(4, 2)
y.add_(tmp)
return y, z
input_functional = torch._to_functional_tensor(torch.ones(4, 2))
torch._enable_functionalization(reapply_views=True)
try:
y, z = f(input_functional)
torch._sync(y)
torch._sync(z)
finally:
torch._disable_functionalization()
# y and z are aliases inside of the function, and that aliasing relationship should be maintained.
_y = torch._from_functional_tensor(y)
_z = torch._from_functional_tensor(z)
self.assertTrue(are_aliased(_y, _z))
# copy_() gets its own test, because it is special cased in functionalization.
# self.copy_(src) decomposes into src.to(self).expand_as(self).
def test_copy_(self):
def f(x):
tmp = torch.zeros(2, 2)
# NOTE: LoggingTensor isn't a mode, which means that the diagonal call
# will not be logged. This is fine for testing.
tmp_slice = tmp.diagonal()
y = tmp_slice.copy_(x)
z = y.add_(x)
return z
# Test 1: copy_() with same dtype and shape
# to() is a composite op that noops when the dtype/shape match, so nothing gets logged.
# self.assert_functionalization(f, torch.ones(2))
logs = self.get_logs(f, torch.ones(2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
zero_default = torch.ops.aten.zero.default(empty); empty = None
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zero_default)
diagonal_copy_default_1 = torch.ops.aten.diagonal_copy.default(zero_default); zero_default = None
copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
return add_tensor
""")
# Test 2: copy_() with same dtype, different shape
self.assert_functionalization(f, torch.ones(1))
logs = self.get_logs(f, torch.ones(1))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
zero_default = torch.ops.aten.zero.default(empty); empty = None
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zero_default)
diagonal_copy_default_1 = torch.ops.aten.diagonal_copy.default(zero_default); zero_default = None
copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
return add_tensor
""")
# Test 3: copy_() with different dtype, same shape
self.assert_functionalization(f, torch.ones(2, dtype=torch.long))
logs = self.get_logs(f, torch.ones(2, dtype=torch.long))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
zero_default = torch.ops.aten.zero.default(empty); empty = None
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zero_default)
diagonal_copy_default_1 = torch.ops.aten.diagonal_copy.default(zero_default); zero_default = None
copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
return add_tensor
""")
# Test 4: copy_() with different dtype, different shape
self.assert_functionalization(f, torch.ones(1, dtype=torch.long))
logs = self.get_logs(f, torch.ones(1, dtype=torch.long))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
zero_default = torch.ops.aten.zero.default(empty); empty = None
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zero_default)
diagonal_copy_default_1 = torch.ops.aten.diagonal_copy.default(zero_default); zero_default = None
copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None
add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None
return add_tensor
""")
def test_expand_symint(self):
# Once some existing SymInt bugs are ironed out, we should update
# this test to plumb FakeSymbolicTensors through it
def f(x):
return x.expand(x.size(0), x.size(1))
self.assert_functionalization(f, torch.ones(2, 2))
logs = self.get_logs(f, torch.ones(2, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
expand_copy_sym_int = torch.ops.aten.expand_copy.SymInt(a_1, [2, 2]); a_1 = None
return expand_copy_sym_int
""")
def test_fill_(self):
def f(x):
y = x + x
z = y.diagonal()
z.fill_(0)
return y
self.assert_functionalization(f, torch.ones(2, 2))
logs = self.get_logs(f, torch.ones(2, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None
diagonal_copy_default = torch.ops.aten.diagonal_copy.default(add_tensor)
fill_scalar = torch.ops.aten.fill.Scalar(diagonal_copy_default, 0); diagonal_copy_default = None
diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(add_tensor, fill_scalar); add_tensor = fill_scalar = None
return diagonal_scatter_default
""")
def test_resize_smaller(self):
def f(w):
# Resizing to a smaller size doesn't affect storage
x = w + 1
y = x.view(4, 4)
y.resize_(3, 3)
y2 = y.view(-1)
y2.add_(1)
z = y + 1
return z
self.assert_functionalization(f, torch.ones(8, 2))
logs = self.get_logs(f, torch.ones(8, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
add_tensor = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
view_copy_default = torch.ops.aten.view_copy.default(add_tensor, [4, 4])
resize_default = torch.ops.aten.resize.default(view_copy_default, [3, 3])
as_strided_copy_default = torch.ops.aten.as_strided_copy.default(view_copy_default, [3, 3], [3, 1]); view_copy_default = None
view_copy_default_1 = torch.ops.aten.view_copy.default(as_strided_copy_default, [-1]); as_strided_copy_default = None
add_tensor_1 = torch.ops.aten.add.Tensor(view_copy_default_1, 1); view_copy_default_1 = None
view_copy_default_2 = torch.ops.aten.view_copy.default(add_tensor, [4, 4]); add_tensor = None
as_strided_copy_default_1 = torch.ops.aten.as_strided_copy.default(view_copy_default_2, [3, 3], [3, 1])
view_copy_default_3 = torch.ops.aten.view_copy.default(add_tensor_1, [3, 3]); add_tensor_1 = None
as_strided_scatter_default = torch.ops.aten.as_strided_scatter.default(view_copy_default_2, view_copy_default_3, [3, 3], [3, 1]); view_copy_default_2 = view_copy_default_3 = None
view_copy_default_4 = torch.ops.aten.view_copy.default(as_strided_scatter_default, [8, 2]); as_strided_scatter_default = None
view_copy_default_5 = torch.ops.aten.view_copy.default(view_copy_default_4, [4, 4]); view_copy_default_4 = None
as_strided_copy_default_2 = torch.ops.aten.as_strided_copy.default(view_copy_default_5, [3, 3], [3, 1]); view_copy_default_5 = None
add_tensor_2 = torch.ops.aten.add.Tensor(as_strided_copy_default_2, 1); as_strided_copy_default_2 = None
return add_tensor_2
""") # noqa: B950
def test_resize_larger_valid(self):
def f(x):
y = x + 1
# resizing a tensor to a larger size is only currently allowed
# if the tensor-to-resize is not a view / has no outstanding views.
# See Note [resize_() in functionalization pass]
y.resize_(5, 5)
y2 = y.view(25)
# Do a mutation to ensure that aliases of the output of resize_()
# propagate mutations correctly.
# I'm using fill_ specifically because I want to guarantee that
# none of the output has uninitialized memory at the end
# (since these tests compare the data output against a reference impl)
y2.fill_(1)
out = y + 1
return y, out
self.assert_functionalization(f, torch.ones(8, 2))
logs = self.get_logs(f, torch.ones(8, 2))
self.assertExpectedInline(logs, """\
def forward(self, a_1):
add_tensor = torch.ops.aten.add.Tensor(a_1, 1); a_1 = None
resize_default = torch.ops.aten.resize.default(add_tensor, [5, 5]); add_tensor = None
view_copy_default = torch.ops.aten.view_copy.default(resize_default, [25]); resize_default = None
fill_scalar = torch.ops.aten.fill.Scalar(view_copy_default, 1); view_copy_default = None
view_copy_default_1 = torch.ops.aten.view_copy.default(fill_scalar, [5, 5]); fill_scalar = None
add_tensor_1 = torch.ops.aten.add.Tensor(view_copy_default_1, 1)
return (view_copy_default_1, add_tensor_1)
""")
def test_resize_larger_invalid(self):
def f(x):
y = x + 1
z = y.view(4, 4)
# resizing a tensor to a larger size is only currently allowed
# if the tensor-to-resize is not a view / has no outstanding views.
# See Note [resize_() in functionalization pass]
# This should fail
z.resize_(5, 5)
z2 = z.view(25)
z2.fill_(1)
out = z + 1
return y, out
with self.assertRaisesRegex(
RuntimeError,
r'Attempted to resize a view tensor to a larger size. This is not allowed in the functionalization pass'):
self.assert_functionalization(f, torch.ones(8, 2))
def test_nested_functions_propagate_updates(self):
def g(x):
# Create a view of x
y = x[0]
y.add_(1)
# The view, y, gets deallocated at the end of this function
def f(x):
# Calling g(x) should mutate x
g(x)
# We expect x to be synced here, even though the alias created in g() has been deallocated!
y = x + x
return y
self.assert_functionalization(f, torch.ones(2, 2))
def test_mixed_wrappers_valid(self):
def f(x, y):
z = x + y
z.add_(1)
return z
x1_not_functional = LoggingTensor(torch.ones(4))
x2_functional = torch._to_functional_tensor(LoggingTensor(torch.ones(4)))
with capture_logs() as logs:
y = f(x1_not_functional, x2_functional)
# Make sure that functionalization ran the "+" kernel
# with a functional + non-functional tensor, and wrapped the output appropriately.
self.assertExpectedInline('\n'.join(logs), """\
$2 = torch._ops.aten.add.Tensor($0, $1)
$3 = torch._ops.aten.add.Tensor($2, 1)""")
def test_mixed_wrappers_invalid(self):
x1_not_functional = torch.ones(4)
x2_functional = torch._to_functional_tensor(torch.ones(4))
# When dealing with mixed functional + non functional tensors,
# normal_tensor.add_(functional_tensor) is not valid
# because normal_tensor would need to be "promoted" to a functional tensor.
with self.assertRaises(RuntimeError):
x1_not_functional.add_(x2_functional)
if __name__ == '__main__':
run_tests()