blob: ec87cd8e13b36b44a1405eca8aad2ef0eb0d1caf [file] [log] [blame]
# Owner(s): ["module: meta tensors"]
from torch.testing._internal.common_utils import (
TestCase, run_tests, skipIfCrossRef, skipIfRocm, skipIfTorchDynamo, parametrize,
instantiate_parametrized_tests)
import torch
import torch._dynamo
import itertools
import numpy as np
from torch.testing._internal.jit_utils import RUN_CUDA
from torch._subclasses.fake_tensor import (
FakeTensor,
FakeTensorMode,
FakeTensorConverter,
DynamicOutputShapeException,
UnsupportedOperatorException,
)
from torch.testing._internal.custom_op_db import custom_op_db
from torch.testing._internal.common_device_type import ops
from torch.testing._internal.common_device_type import instantiate_device_type_tests, OpDTypes
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from torch._dynamo.testing import rand_strided
from torch.testing import FileCheck
from torch import nn
import unittest
import torch._prims as prims
import contextlib
import weakref
import copy
import torch._functorch.config
import torch.testing._internal.optests as optests
from unittest.mock import patch
from torch import distributed as dist
from torch.utils._mode_utils import no_dispatch
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_flatten
class FakeTensorTest(TestCase):
def checkType(self, t, device_str, size):
self.assertTrue(isinstance(t, FakeTensor))
self.assertEqual(t.device.type, device_str)
self.assertEqual(list(t.size()), size)
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_cuda_initialized(self):
# doesnt error
with FakeTensorMode():
p = torch.randn(4, 2, requires_grad=True, device='cuda')
x = torch.randn(8, 4, device='cuda')
y = torch.mm(x, p).square().sum()
y.backward()
def test_basic(self):
x = torch.empty(2, 2, device="cpu")
y = torch.empty(4, 2, 2, device="cpu")
with FakeTensorMode() as mode:
x = mode.from_tensor(x)
y = mode.from_tensor(y)
z = x + y
self.assertEqual(z.shape, (4, 2, 2))
self.assertEqual(z.device, torch.device("cpu"))
self.assertTrue(isinstance(z, FakeTensor))
def test_basic_forced_memo_only(self):
x = torch.empty(2, 2, device="cpu")
y = torch.empty(4, 2, 2, device="cpu")
with FakeTensorMode() as mode:
x_fake = mode.from_tensor(x)
x2 = mode.from_tensor(x, memoized_only=True)
self.assertTrue(x2 is not None)
y = mode.from_tensor(y, memoized_only=True)
self.assertIs(y, None)
def test_custom_op_fallback(self):
from torch.library import Library, impl
test_lib = Library("my_test_op", "DEF")
test_lib.define('foo(Tensor self) -> Tensor')
@impl(test_lib, 'foo', 'CPU')
def foo_impl(self):
return self.cos()
x = torch.empty(2, 2, device="cpu")
with self.assertRaisesRegex(UnsupportedOperatorException, "my_test_op.foo.default"):
with FakeTensorMode(allow_fallback_kernels=True) as mode:
x = mode.from_tensor(x)
torch.ops.my_test_op.foo(x)
def test_parameter_instantiation(self):
with FakeTensorMode():
x = torch.rand([4])
y = torch.nn.parameter.Parameter(x)
self.assertTrue(isinstance(y, torch.nn.Parameter))
@unittest.skipIf(not dist.is_available(), "requires distributed")
def test_fsdp_flat_param(self):
from torch.distributed.fsdp.flat_param import FlatParameter
with FakeTensorMode() as m:
data = torch.randn(2, 2)
param = FlatParameter(data, requires_grad=True)
self.assertIsInstance(param, FlatParameter)
self.assertIsInstance(param, torch.nn.Parameter)
self.assertIsInstance(param, FakeTensor)
def test_non_parameter_grad(self):
mode = FakeTensorMode()
t = torch.rand([4], requires_grad=True)
fake_t = mode.from_tensor(t)
self.assertEqual(fake_t.requires_grad, t.requires_grad)
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_index_cuda_with_cpu(self):
with FakeTensorMode():
x = torch.rand([2048], device='cuda')
out = x[torch.zeros([36], dtype=torch.int64)]
self.checkType(out, "cuda", [36])
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_shape_take_not_device(self):
with FakeTensorMode():
x = torch.empty(1, device="cpu")
y = torch.empty(8, 8, device="cuda")
out = x.resize_as_(y)
self.assertEqual(out.shape, (8, 8))
self.assertEqual(out.device.type, "cpu")
self.assertTrue(isinstance(out, FakeTensor))
def test_repr(self):
with FakeTensorMode():
x = torch.empty(2, 2, device="cpu")
self.assertEqual(repr(x), 'FakeTensor(..., size=(2, 2))')
x = torch.empty(2, 2, device="meta")
self.assertEqual(repr(x), "FakeTensor(..., device='meta', size=(2, 2))")
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_zero_dim(self):
with FakeTensorMode() as mode:
x = torch.tensor(0.)
y = torch.rand([4, 4], device="cuda")
out = x + y
self.assertEqual(out.shape, (4, 4))
self.assertEqual(out.device, y.device)
self.assertTrue(isinstance(out, FakeTensor))
def test_nan_to_num(self):
with FakeTensorMode():
for dtype in [torch.float16, torch.float32]:
x = torch.rand([4], dtype=dtype)
y = torch.nan_to_num(x, nan=None)
z = torch.nan_to_num(x, 0.0)
self.assertEqual(dtype, y.dtype)
self.assertEqual(dtype, z.dtype)
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_throw(self):
x = torch.tensor(0.) # TODO: tensor() errors
with FakeTensorMode() as mode:
x_conv = mode.from_tensor(x)
y = torch.rand([4, 4], device="cuda")
z = torch.rand([4, 4], device="cpu")
self.assertRaises(Exception, lambda: torch.lerp(x_conv, y, z))
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_type_as(self):
with FakeTensorMode():
x = torch.rand([16, 1], device="cpu")
y = torch.rand([4, 4], device="cuda")
out = x.type_as(y)
self.assertEqual(out.device.type, "cuda")
self.assertTrue(isinstance(out, FakeTensor))
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_setitem(self):
for device in ["cpu", "cuda"]:
with FakeTensorMode():
x = torch.rand([16, 1], device=device)
x[..., 0] = 0
def test_fake_dispatch_keys(self):
with FakeTensorMode():
x = torch.rand([4])
f = FileCheck().check("CPU").check("ADInplaceOrView").check("AutogradCPU").check("AutocastCPU")
f.run(torch._C._dispatch_key_set(x))
with torch.inference_mode():
x = torch.rand([4])
y = x + x
FileCheck().check("CPU").check("AutocastCPU").run(torch._C._dispatch_key_set(y))
FileCheck().check_not("ADInplaceOrView").check_not("Autograd").run(torch._C._dispatch_key_set(y))
def test_constructor(self):
with FakeTensorMode():
x = torch.rand([4, 4], device="cpu")
self.assertTrue(isinstance(x, FakeTensor))
self.assertTrue(x.device.type == "cpu")
def test_mode(self):
with FakeTensorMode():
y = torch.rand([4], device="cpu")
out = y + y
self.assertTrue(isinstance(out, FakeTensor))
def test_full(self):
# Test torch.full returns tensor with correct dtype
with torch._subclasses.CrossRefFakeMode():
y = torch.full((4, 4), 1)
def check_function_with_fake(self, fn):
out = fn()
with torch._subclasses.FakeTensorMode():
out_fake = fn()
for a, b in zip(tree_flatten(out), tree_flatten(out_fake)):
if not isinstance(a, FakeTensor):
self.assertTrue(not isinstance(b, FakeTensor))
continue
prims.utils.compare_tensor_meta(a, b, check_strides=True)
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_non_kwarg_device(self):
with FakeTensorMode():
x = torch.rand([16, 1], device="cpu")
y = x.to(torch.device("cpu"))
self.assertIs(x, y)
z = x.to(torch.device("cuda"))
self.assertEqual(z.device.type, "cuda")
def test_non_overlapping_stride_zero(self):
def foo():
x = torch.empty_strided([1, 3, 427, 640], (0, 1, 1920, 3))
return x.half()
self.check_function_with_fake(foo)
def test_fake_mode_error(self):
x = torch.rand([4, 4])
with self.assertRaisesRegex(Exception, "Please convert all Tensors"):
with FakeTensorMode():
y = x[0]
def test_fake_grad_copy(self):
x = torch.rand([4, 4], requires_grad=True)
x.grad = torch.rand([4, 4])
mode = FakeTensorMode()
fake_x = mode.from_tensor(x)
prims.utils.compare_tensor_meta(fake_x, x)
prims.utils.compare_tensor_meta(fake_x.grad, x.grad)
self.assertTrue(isinstance(fake_x.grad, FakeTensor))
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_like_constructor(self):
with FakeTensorMode():
x = torch.rand([4, 4])
y = torch.ones_like(x)
self.assertTrue(isinstance(y, FakeTensor))
self.assertEqual(y.device.type, "cpu")
z = torch.ones_like(x, device="cuda")
self.assertTrue(isinstance(z, FakeTensor))
self.assertEqual(z.device.type, "cuda")
def test_binary_op_type_promotion(self):
with FakeTensorMode():
x = torch.empty([2, 2], dtype=torch.float)
y = torch.empty([2, 2], dtype=torch.int64)
out = x / y
self.assertEqual(out.dtype, torch.float)
self.assertEqual(out.device.type, "cpu")
def test_from_numpy(self):
with FakeTensorMode():
x = torch.tensor(np.zeros([4, 4]))
self.checkType(x, "cpu", [4, 4])
def test_randperm(self):
x = torch.randperm(10)
y = torch.randperm(5, device="cpu")
with FakeTensorMode():
x1 = torch.randperm(10)
prims.utils.compare_tensor_meta(x, x1)
y1 = torch.randperm(5, device="cpu")
prims.utils.compare_tensor_meta(y, y1)
def test_print_in_fake_mode(self):
x = torch.zeros(2)
# does not fail
with FakeTensorMode():
out = str(x)
assert "FakeTensor" not in out
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_upsample_bilinear_small_channels(self):
out = []
mode = FakeTensorMode()
for i, context in enumerate([contextlib.nullcontext, lambda: mode]):
with context():
arg0_1 = torch.empty_strided((3, 427, 640), (1, 1920, 3), dtype=torch.float32, device='cuda')
unsqueeze = torch.ops.aten.unsqueeze.default(arg0_1, 0)
out.append(torch.ops.aten.upsample_bilinear2d.default(unsqueeze, [800, 1199], False))
self.assertTrue(out[1].is_contiguous())
self.checkMetaProps(out[0], out[1])
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_cpu_fallback(self):
with FakeTensorMode(allow_fallback_kernels=False):
filters = torch.randn(8, 4, 3, 3).cuda()
inputs = torch.randn(1, 4, 5, 5).cuda()
out = torch.nn.functional.conv2d(inputs, filters, padding=1)
self.assertEqual(out.device.type, "cuda")
self.assertEqual(list(out.size()), [1, 8, 5, 5])
with FakeTensorMode(allow_fallback_kernels=True):
# intentionally bad inputs
filters = torch.randn(8, 20, 3, 3).cuda()
inputs = torch.randn(1, 7, 10, 5).cuda()
with self.assertRaises(RuntimeError):
torch.nn.functional.conv2d(inputs, filters, padding=1)
with FakeTensorMode(allow_fallback_kernels=True):
filters = torch.randn(8, 4, 3, 3).cuda()
inputs = torch.randn(1, 4, 5, 5).cuda()
out = torch.nn.functional.conv2d(inputs, filters, padding=1)
self.assertEqual(out.device.type, "cuda")
self.assertEqual(list(out.size()), [1, 8, 5, 5])
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_out_multi_device(self):
with FakeTensorMode():
x = torch.rand([4])
y = torch.rand([4], device="cuda")
with self.assertRaisesRegex(Exception, "found two different devices"):
torch.sin(x, out=y)
with self.assertRaisesRegex(Exception, "found two different devices"):
x.add_(y)
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_normalize_device(self):
with FakeTensorMode():
x = torch.empty(1, device="cuda")
y = torch.empty(1, device=f"cuda:{torch.cuda.current_device()}")
out = x + y
self.checkType(out, "cuda", [1])
def test_recursive_invocation(self):
mode = FakeTensorMode()
with mode:
x = torch.tensor(2)
mode.in_kernel_invocation = True
y = x + x
self.assertTrue(mode.in_kernel_invocation)
@skipIfRocm
@parametrize("allow_fallback_kernels", [False, True],
lambda a: 'with_fallback' if a else 'without_fallback')
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_cudnn_rnn(self, allow_fallback_kernels):
def fn(
a0,
b0,
b1,
b2,
b3,
b4,
b5,
b6,
b7,
b8,
b9,
b10,
b11,
b12,
b13,
b14,
b15,
a3,
a4,
a5,
):
a1 = [
b0,
b1,
b2,
b3,
b4,
b5,
b6,
b7,
b8,
b9,
b10,
b11,
b12,
b13,
b14,
b15,
]
return torch.ops.aten._cudnn_rnn(
a0,
a1,
4,
a3,
a4,
a5,
2,
2048,
0,
2,
False,
0.0,
False,
True,
[],
None,
)
mode = FakeTensorMode(allow_fallback_kernels=allow_fallback_kernels)
for i, context in enumerate([contextlib.nullcontext, lambda: mode]):
with context():
inps1 = [
torch.randn([92, 8, 2048]).cuda(),
torch.randn([8192, 2048]).cuda(),
torch.randn([8192, 2048]).cuda(),
torch.randn([8192]).cuda(),
torch.randn([8192]).cuda(),
torch.randn([8192, 2048]).cuda(),
torch.randn([8192, 2048]).cuda(),
torch.randn([8192]).cuda(),
torch.randn([8192]).cuda(),
torch.randn([8192, 4096]).cuda(),
torch.randn([8192, 2048]).cuda(),
torch.randn([8192]).cuda(),
torch.randn([8192]).cuda(),
torch.randn([8192, 4096]).cuda(),
torch.randn([8192, 2048]).cuda(),
torch.randn([8192]).cuda(),
torch.randn([8192]).cuda(),
torch.randn([167837696]).cuda(),
torch.randn([4, 8, 2048]).cuda(),
torch.randn([4, 8, 2048]).cuda(),
]
inps2 = inps1
inps2[len(inps2) - 1] = None # argument `cx` can be None
for inps in [inps1, inps2]:
out = fn(*inps)
self.assertIs(out[4], inps[-3])
for ten in out:
if i == 1:
self.assertTrue(isinstance(ten, FakeTensor))
self.assertEqual(ten.device.type, 'cuda')
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_cuda_lstm(self):
# Ensure CUDA (non-cuDNN) impl succeeds with fake tensors.
with torch.backends.cudnn.flags(enabled=False):
fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False)
with fake_tensor_mode:
N = 5
L = 4
H_in = 2
hidden_size = 3
proj_size = 2
num_layers = 2
bidir = False
D = 2 if bidir else 1
H_out = proj_size if proj_size > 0 else hidden_size
lstm = torch.nn.LSTM(input_size=H_in, hidden_size=hidden_size,
num_layers=num_layers, proj_size=proj_size, batch_first=False,
bias=True, bidirectional=bidir, device='cuda')
h_0 = torch.randn((num_layers * D, N, H_out), device='cuda')
c_0 = torch.randn((num_layers * D, N, hidden_size), device='cuda')
inp = torch.randn((L, N, H_in), device='cuda')
(output, (h_n, c_n)) = lstm(inp, (h_0, c_0))
output.sum().backward()
self.assertEqual(output.shape, (L, N, D * H_out))
self.assertEqual(h_n.shape, (D * num_layers, N, H_out))
self.assertEqual(c_n.shape, (D * num_layers, N, hidden_size))
@skipIfRocm
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_fallback_memory_prop(self):
m = nn.Conv2d(16, 33, 3, stride=2, device="cuda", dtype=torch.half)
m = m.to(memory_format=torch.channels_last)
mode = FakeTensorMode()
# TODO: module.to() doesn't work because it assigns .data, which is ignored
with torch._subclasses.fake_tensor.FakeCopyMode(mode):
mod_copied = copy.deepcopy(m)
with mode:
input = torch.rand(20, 16, 50, 100, dtype=torch.half, device="cuda").to(memory_format=torch.channels_last)
out = mod_copied(input)
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
self.checkType(out, "cuda", [20, 33, 24, 49])
def test_data_dependent_operator(self):
with FakeTensorMode(allow_fallback_kernels=False):
x = torch.rand([10, 10])
self.assertRaises(DynamicOutputShapeException, lambda: torch.nonzero(x))
def checkMetaProps(self, t1, t2):
prims.utils.compare_tensor_meta(t1, t2, check_strides=True)
@skipIfCrossRef
def test_deepcopy(self):
with FakeTensorMode() as mode:
pass
mod = torch.nn.BatchNorm2d(10)
with torch._subclasses.fake_tensor.FakeCopyMode(mode):
mod_copied = copy.deepcopy(mod)
def check_copy(mod, mod_copied):
for name, param in itertools.chain(mod.named_parameters(), mod.named_buffers()):
param_copied = getattr(mod_copied, name)
self.checkMetaProps(param, param_copied)
self.assertTrue(isinstance(param_copied, FakeTensor))
self.assertEqual(isinstance(param, torch.nn.Parameter), isinstance(param_copied, torch.nn.Parameter))
self.assertEqual(param.requires_grad, param_copied.requires_grad)
check_copy(mod, mod_copied)
class ModuleNew(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.rand([10, 2])
self.b = self.a
self.c = self.a[0]
mod = ModuleNew()
with torch._subclasses.fake_tensor.FakeCopyMode(mode):
mod_copied = copy.deepcopy(mod)
self.assertIs(mod_copied.a, mod_copied.b)
self.assertEqual(mod_copied.b.storage()._cdata, mod_copied.a.storage()._cdata)
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_new(self):
with FakeTensorMode():
a = torch.rand([16, 1])
self.checkType(a.new(10, 10), "cpu", [10, 10])
self.checkType(a.new([1, 2, 3, 4]), "cpu", [4])
b = torch.rand([4, 4], device='cuda')
self.checkType(b.new(device='cuda'), "cuda", [0])
self.checkType(a.new(torch.rand([1])), "cpu", [1])
def test_scalar_inputs(self):
with FakeTensorMode():
self.checkType(torch.div(3, 2), "cpu", [])
ten = torch.zeros(2, dtype=torch.int32) * 2.0
self.assertEqual(ten.dtype, torch.float)
self.checkType(ten, "cpu", [2])
def test_allow_meta(self):
def run_meta():
with FakeTensorMode():
x = torch.rand([4], device="meta")
return x + x
self.checkType(run_meta(), "meta", [4])
with patch.object(torch._functorch.config, "fake_tensor_allow_meta", False):
self.assertRaises(Exception, run_meta)
def test_mixed_real_and_fake_inputs(self):
class _TestPattern(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)
self.bn = torch.nn.BatchNorm2d(1)
def forward(self, input):
running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
scale_factor = self.bn.weight / running_std
weight_shape = [1] * len(self.conv.weight.shape)
weight_shape[0] = -1
bias_shape = [1] * len(self.conv.weight.shape)
bias_shape[1] = -1
scaled_weight = self.conv.weight * scale_factor.reshape(weight_shape)
zero_bias = torch.zeros_like(self.conv.bias, dtype=input.dtype)
conv = self.conv._conv_forward(input, scaled_weight, zero_bias)
conv_orig = conv / scale_factor.reshape(bias_shape)
conv_orig = conv_orig + self.conv.bias.reshape(bias_shape)
conv = self.bn(conv_orig)
return conv
example_inputs = (torch.randn(1, 1, 3, 3),)
mod = _TestPattern()
with FakeTensorMode(allow_non_fake_inputs=True):
out = mod(torch.randn(1, 1, 3, 3))
self.checkType(out, "cpu", (1, 1, 3, 3))
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_aten_copy_multi_device(self):
with FakeTensorMode():
x1 = torch.rand(4, device="cpu")
x2 = torch.rand(4, device="cuda")
copy1 = torch.ops.aten.copy.default(x1, x2)
copy2 = torch.ops.aten.copy.default(x2, x1)
out = torch.empty(4, device="cpu")
torch.ops.aten.copy.out(x1, x2, out=out)
self.checkType(copy1, "cpu", (4,))
self.checkType(copy2, "cuda", (4,))
self.checkType(out, "cpu", (4,))
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_aten_index_multi_device(self):
with FakeTensorMode():
x1 = torch.rand(4, 4, device="cpu")
x2 = torch.rand(4, 4, device="cuda")
i1 = torch.tensor([0, 1], device="cuda")
i2 = torch.tensor([0, 1], device="cpu")
r1 = torch.ops.aten.index(x1, i1)
r2 = torch.ops.aten.index(x2, i2)
y1 = torch.rand(4, device="cpu")
y2 = torch.rand(4, device="cuda")
j1 = torch.tensor([2], device="cuda")
j2 = torch.tensor([2], device="cpu")
r3 = torch.ops.aten.index_put.default(x1, j1, y1)
r4 = torch.ops.aten.index_put.default(x2, j2, y2)
self.checkType(r1, "cpu", ())
self.checkType(r2, "cuda", ())
self.checkType(r3, "cpu", (4, 4))
self.checkType(r4, "cuda", (4, 4))
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_aten_slice_scatter_multi_device(self):
with FakeTensorMode():
x1 = torch.rand(4, 4, device="cpu")
y1 = torch.rand(2, 4, device="cuda")
x2 = torch.rand(4, 4, device="cuda")
y2 = torch.rand(2, 4, device="cpu")
out = torch.empty(4, 4, device="cpu")
r1 = torch.ops.aten.slice_scatter.default(x1, y1, start=2)
r2 = torch.ops.aten.slice_scatter.default(x2, y2, start=2)
r3 = torch.ops.aten.slice_scatter.out(x1, y1, out=out, start=2)
self.checkType(r1, "cpu", (4, 4))
self.checkType(r2, "cuda", (4, 4))
self.checkType(r3, "cpu", (4, 4))
self.checkType(out, "cpu", (4, 4))
def test__adaptive_avg_pool2d_backward(self):
with FakeTensorMode():
grad_out = torch.rand(2, 3, 4, 4)
inp = torch.rand(2, 3, 4, 4).to(memory_format=torch.channels_last)
grad_in = torch.ops.aten._adaptive_avg_pool2d_backward(grad_out, inp)
self.assertTrue(torch._prims_common.suggest_memory_format(grad_in) == torch.channels_last)
class FakeTensorConstHandling(TestCase):
def assertConst(self, *args):
for arg in args:
self.assertTrue(arg.constant is not None)
def assertNotConst(self, *args):
for arg in args:
self.assertTrue(arg.constant is None)
def test_simple(self):
with FakeTensorMode():
x = torch.tensor(4.)
self.assertEqual(x.item(), 4.)
def test_inplace_add(self):
with FakeTensorMode():
x = torch.tensor(4.)
y = x.add_(1)
self.assertEqual(x.item(), 5.)
self.assertEqual(y.item(), 5.)
self.assertConst(x, y)
def test_shared_storages(self):
with FakeTensorMode():
x = torch.tensor([4.])
y = x[:]
self.assertEqual(x.storage()._cdata, y.storage()._cdata)
self.assertEqual(x.constant.storage()._cdata, y.constant.storage()._cdata)
def test_constant_invalidation(self):
with FakeTensorMode():
x = torch.tensor([1.])
self.assertConst(x)
y = torch.rand([1])
x.add_(y)
self.assertNotConst(x)
def test_inplace_view_invalidation(self):
with FakeTensorMode():
x = torch.tensor([1])
self.assertConst(x)
x.resize_([2])
self.assertEqual(x.size(0), 2)
self.assertNotConst(x)
def test_fake_tensor_in_intlist_repro(self):
def fn(tensors):
max_size = torch.tensor([800, 1216], dtype=torch.int64)
batch_shape = [len(tensors)] + list(tensors[0].shape[:-2]) + list(max_size)
return tensors[0].new_full(batch_shape, 0.0)
with self.assertRaises(torch._subclasses.fake_tensor.DataDependentOutputException):
with torch._subclasses.fake_tensor.FakeTensorMode():
a = torch.randn(3, 800, 1199)
b = torch.randn(3, 800, 800)
inputs = [a, b]
ref = fn(inputs)
def test_fake_tensor_batch_norm_cpu(self):
with torch._subclasses.CrossRefFakeMode():
m = torch.nn.Sequential(
torch.nn.BatchNorm2d(10),
torch.nn.ReLU(),
)
m.eval()
out = m(torch.randn([2, 10, 8, 8]))
def test_shared_storage_invalidation(self):
with FakeTensorMode():
x = torch.tensor([1.])
y = x[:]
self.assertConst(x, y)
y.add_(torch.rand([1]))
self.assertNotConst(x, y)
def test_aliased_const_write(self):
with FakeTensorMode():
x = torch.tensor([1])
y = x.expand([4])
self.assertNotConst(y)
y[0] = 1
self.assertNotConst(x)
def test_constant_propagate_through_functions(self):
with FakeTensorMode():
y = torch.div(4, 4, rounding_mode='trunc')
self.assertConst(y)
def contains_type(type: torch._C.Type, maybe_contained_type: torch._C.Type):
return maybe_contained_type.isSubtypeOf(type) or any(
contains_type(e, maybe_contained_type) for e in type.containedTypes()
)
class FakeTensorOpInfoTest(TestCase):
@ops(custom_op_db, dtypes=OpDTypes.any_one)
def test_fake(self, device, dtype, op):
data_dependent_outputs = {
'NumpyNMSCustomOp',
'NumpyNonzeroCustomOp',
}
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
for sample_input in sample_inputs_itr:
args = (sample_input.input,) + sample_input.args
kwargs = sample_input.kwargs
optests.fake_check(op, args, kwargs, op.name in data_dependent_outputs)
class FakeTensorConverterTest(TestCase):
def test_memoized_conversion_to_meta(self):
x = torch.rand(2, 2, 2)
mode = FakeTensorMode()
self.assertTrue(mode.from_tensor(x) is mode.from_tensor(x))
def test_memoized_conversion_from_meta(self):
x = torch.rand(2, 2).to(device="meta")
mode = FakeTensorMode()
converter = mode.fake_tensor_converter
self.assertTrue(converter.from_meta_and_device(mode, x, "cpu") is converter.from_meta_and_device(mode, x, "cpu"))
def test_separate_tensor_storages_view(self):
x = torch.rand(2, 2, 2)
y = x[0]
mode = FakeTensorMode()
converter = mode.fake_tensor_converter
x_conv = converter(mode, x)
y_conv = converter(mode, y)
self.assertEqual(torch._C._storage_id(x_conv), torch._C._storage_id(y_conv))
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
def test_separate_tensor_storages_non_view(self):
x = torch.rand(2, 2, 2)
y = torch.rand(4, 2)
y.set_(x.storage())
mode = FakeTensorMode()
converter = mode.fake_tensor_converter
x_conv = converter(mode, x)
y_conv = converter(mode, y)
stor_id = torch._C._storage_id(x_conv)
self.assertEqual(stor_id, torch._C._storage_id(y_conv))
del x
self.assertEqual(len(converter.tensor_memo), 1)
converter.meta_converter.check_for_expired_weak_storages()
self.assertEqual(len(converter.meta_converter.storage_memo), 1)
del y
self.assertEqual(len(converter.tensor_memo), 0)
converter.meta_converter.check_for_expired_weak_storages()
self.assertEqual(len(converter.meta_converter.storage_memo), 0)
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
def test_dead_weak_ref(self):
x = torch.rand(2, 2, 2)
y = x[0]
mode = FakeTensorMode()
converter = FakeTensorConverter()
x_conv = converter(mode, x)
x_conv_storage = torch._C._storage_id(x_conv)
del x_conv
self.assertFalse(x in converter.tensor_memo)
y_conv = converter(mode, y)
self.assertEqual(x_conv_storage, torch._C._storage_id(y_conv))
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
def test_dead_key(self):
x = torch.rand(2, 2, 2)
mode = FakeTensorMode()
converter = FakeTensorConverter()
x_conv = converter(mode, x)
self.assertEqual(len(converter.tensor_memo), 1)
x_conv2 = converter(mode, x)
assert x_conv2 is x_conv
del x
self.assertEqual(len(converter.tensor_memo), 0)
def test_no_active_mode(self):
with FakeTensorMode() as mode:
x = torch.empty(2, 2, device="cpu")
y = torch.empty(2, 2, device="cpu")
out = x + y
self.assertEqual(mode, out.fake_mode)
self.assertTrue(isinstance(out, FakeTensor))
self.assertEqual(out.device.type, "cpu")
def test_multiple_modes(self):
t = torch.rand([4])
t2 = torch.rand([4])
with FakeTensorMode() as m:
with FakeTensorMode() as m2:
t_fake = m.from_tensor(t)
t2_fake = m2.from_tensor(t2)
with self.assertRaisesRegex(Exception, "Mixing fake modes"):
t_fake + t2_fake
def test_separate_mode_error(self):
with FakeTensorMode():
x = torch.empty(2, 2, device="cpu")
with FakeTensorMode():
y = torch.empty(2, 2, device="cpu")
self.assertRaises(Exception, lambda: x, y)
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
def test_no_ref_cycle(self):
x = torch.rand([4])
mode = FakeTensorMode()
y = mode.from_tensor(x)
self.assertEqual(len(mode.fake_tensor_converter.tensor_memo), 1)
mode_weak = weakref.ref(mode)
y_weak = weakref.ref(mode)
del mode
del y
assert mode_weak() is None
assert y_weak() is None
class FakeTensorOperatorInvariants(TestCase):
@staticmethod
def get_aten_op(schema):
namespace, name = schema.name.split("::")
overload = schema.overload_name if schema.overload_name else "default"
assert namespace == "aten"
return getattr(getattr(torch.ops.aten, name), overload)
@staticmethod
def get_all_aten_schemas():
for schema in torch._C._jit_get_all_schemas():
namespace = schema.name.split("::")[0]
if namespace != "aten":
continue
yield schema
def test_non_kwarg_only_device(self):
for schema in self.get_all_aten_schemas():
ten_type = torch._C.TensorType.get()
if not any(
contains_type(arg.type, ten_type)
for arg in itertools.chain(schema.arguments, schema.returns)
):
continue
opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get())
has_non_kwarg_device = any(
not arg.kwarg_only and arg.type.isSubtypeOf(opt_device)
for arg in schema.arguments
)
if has_non_kwarg_device:
self.assertTrue(
self.get_aten_op(schema) in torch._subclasses.fake_tensor._device_not_kwarg_ops
)
def test_tensor_constructors_all_have_kwarg_device(self):
for schema in self.get_all_aten_schemas():
op = self.get_aten_op(schema)
if not torch._subclasses.fake_tensor._is_tensor_constructor(op):
continue
opt_device = torch._C.OptionalType(torch._C.DeviceObjType.get())
has_kwarg_device = any(
arg.kwarg_only and arg.type.isSubtypeOf(opt_device)
for arg in schema.arguments
)
self.assertTrue(
has_kwarg_device or op == torch.ops.aten._list_to_tensor.default
)
@unittest.expectedFailure
def test_sparse_new(self):
with FakeTensorMode():
indices = torch.randn(1, 1, dtype=torch.int64)
values = torch.randn(1)
extra = (2,)
sparse = torch.randn(1).to_sparse()
# This used to segfault, now it does not, but it still raises an
# error
sparse2 = sparse.new(indices, values, extra)
def test_tensor_new(self):
with FakeTensorMode():
x = torch.Tensor([1, 2, 3])
self.assertIsInstance(x, FakeTensor)
def test_like_ops(self):
for schema in self.get_all_aten_schemas():
if "_like" == schema.name[-5:]:
op = self.get_aten_op(schema)
self.assertIn(op, torch._subclasses.fake_tensor._like_tensor_constructors)
# at::_embedding_bag has no op info,
# and returns extra tensors that at::embedding bag throws away
def test_embedding_bag_private(self):
args = [
torch.ones(6, 1),
torch.ones(6, dtype=torch.int64),
torch.arange(2, dtype=torch.int64),
False,
2, # mode = max
]
ref_out = torch.ops.aten._embedding_bag(*args)
with FakeTensorMode() as m:
meta_args = [m.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args]
meta_out = torch.ops.aten._embedding_bag(*meta_args)
self.assertEqual(len(ref_out), len(meta_out))
for ref_o, meta_o in zip(ref_out, meta_out):
self.assertEqual(ref_o.size(), meta_o.size())
def test_cross_entropy_loss(self):
inp = torch.randn(3, 5)
target = torch.randint(5, (3,), dtype=torch.long)
weight = torch.rand(5)
fn = torch.nn.functional.cross_entropy
for w in (weight, None):
args = (inp, target, w)
ref = fn(*args)
with FakeTensorMode() as m:
meta_args = [m.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args]
meta_out = torch.nn.functional.cross_entropy(*meta_args, label_smoothing=0.5)
self.assertEqual(ref.size(), meta_out.size())
@skipIfRocm
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_conv_c1_backward(self):
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, arg1, arg2, arg3):
torch.ops.aten.convolution_backward.default(
arg1,
arg2,
arg3,
[1],
[1, 1],
[1, 1],
[1, 1],
False,
[0, 0],
1,
[True, True, False],
)
args_new = [
((16, 1, 128, 128), (16384, 16384, 128, 1), torch.float16, "cuda"),
((16, 64, 128, 128), (1048576, 1, 8192, 64), torch.float16, "cuda"),
((1, 64, 3, 3), (576, 9, 3, 1), torch.float16, "cuda"),
]
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args_new]
with torch._subclasses.CrossRefFakeMode():
Repro()(*args)
def test_no_dispatch_with_like_function(self):
class CountingMode(TorchDispatchMode):
def __init__(self):
self.count = 0
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
self.count += 1
return func(*args, **kwargs)
with FakeTensorMode():
x = torch.randn(2)
with CountingMode() as mode:
with no_dispatch():
torch.zeros_like(x)
self.assertEqual(mode.count, 0)
class FakeTensorPropTest(TestCase):
def test_fake_tensor_prop_on_nn_module(self):
class ToyNnModuleWithParameters(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(4, 3)
self.layer2 = torch.nn.Linear(3, 2)
def forward(self, value):
value = self.layer1(value)
value = torch.relu(value)
value = self.layer2(value)
return value
model = ToyNnModuleWithParameters()
value = torch.randn(5, 4)
# Convert nn.Module to GraphModule so that FakeTensorProp runs.
graph_model = torch.fx.symbolic_trace(model, (value,))
# The following block runs FakeTensorProp on graph_module w/to the same FakeTensorMode
#
# TODO(wschin): there should be an API to run FakeTensorProp for GraphModule
# with parameters and buffers.
with FakeTensorMode() as fake_tensor_mode:
def to_fake_tensor(x):
if isinstance(x, torch.Tensor) and not isinstance(x, FakeTensor):
return fake_tensor_mode.from_tensor(x)
return x
fake_parameters_and_buffers = {
k: to_fake_tensor(v)
for k, v in itertools.chain(
graph_model.named_parameters(), graph_model.named_buffers()
)
}
with torch.nn.utils.stateless._reparametrize_module(
graph_model, fake_parameters_and_buffers
):
# This case uses the **same** fake tensor mode to
# 1. create fake parameters and fake buffers, and
# 2. run FakeTensorProp
# The result should be correct.
result = FakeTensorProp(graph_model, fake_tensor_mode).propagate(value)
self.assertTrue(isinstance(result, FakeTensor))
self.assertEqual(result.shape, (5, 2))
# This case uses the **different** fake tensor modes to
# 1. create fake parameters and fake buffers, and
# 2. run FakeTensorProp
# The following code should fail.
failed = False
try:
FakeTensorProp(graph_model).propagate(value)
except AssertionError:
# AssertionError: tensor's device must be `meta`, got cpu instead
failed = True
self.assertTrue(failed)
def test_fake_tensor_prop_on_nn_module_with_optional_args(self):
class OptionalArgumentInBetween(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(4, 3)
self.layer2 = torch.nn.Linear(3, 2)
def forward(self, value, another_value=None, another_optional_value=None):
# Mimic huggingface's `forward` methods which have several optional arguments.
# For example, GPT accepts forward(self, input_ids, None, attention_mask, ...).
# To apply FakeTensorProp, its from_real_tensor(...) needs to accept None.
if another_value is None:
another_value = torch.rand_like(value)
if another_optional_value is None:
another_optional_value = torch.rand_like(value)
value = value + another_value + another_optional_value
return value * value
fake_mode = FakeTensorMode(allow_non_fake_inputs=True, allow_fallback_kernels=False)
with fake_mode:
model = OptionalArgumentInBetween()
value = torch.randn(5, 4)
another_optional_value = torch.randn(5, 4)
graph_model = torch.fx.symbolic_trace(model, (value, None, another_optional_value))
FakeTensorProp(graph_model, fake_mode).propagate(value, None, another_optional_value)
instantiate_parametrized_tests(FakeTensorTest)
only_for = ("cpu", "cuda")
instantiate_device_type_tests(FakeTensorOpInfoTest, globals(), only_for=only_for)
if __name__ == "__main__":
run_tests()