blob: e936e66f23f8ab218da8d3d93a13ae18cb09bbb3 [file] [log] [blame] [edit]
# Owner(s): ["module: dynamo"]
# flake8: noqa: E731, C405, F811, C418, C417
import collections
import functools
import inspect
import itertools
import math
import operator
import random
import sys
import unittest
from dataclasses import dataclass, field
from typing import Any, Dict, List, NamedTuple
from unittest.mock import patch
import numpy as np
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch import sub
from torch._dynamo.testing import (
CompileCounterWithBackend,
EagerAndRecordGraphs,
normalize_gm,
)
from torch._dynamo.utils import ifdynstaticdefault, same
from torch._dynamo.variables import ConstantVariable
from torch._dynamo.variables.lists import RangeVariable
from torch.nn import functional as F
from torch.testing._internal.common_utils import (
disable_translation_validation_if_dynamic_shapes,
instantiate_parametrized_tests,
parametrize,
)
# Defines all the kernels for tests
from torch.testing._internal.triton_utils import * # noqa: F403
d = torch.ones(10, 10)
e = torch.nn.Linear(10, 10)
flag = True
class CustomDictSubclass(collections.OrderedDict):
pass
clip01 = functools.partial(torch.clip, min=0.0, max=1.0)
def constant3(a, b):
return a - b + (1.0 + 2)
_variable = 0
def update_global(x):
global _variable
_variable += 1
# Check that updated global variable value is picked up
return x * _variable
def func_with_default(a, b, some_default_arg=True):
if some_default_arg:
return a - b
def make_test(fn=None, expected_frame_count=1):
if fn is None:
return lambda fn: make_test(fn, expected_frame_count=expected_frame_count)
nargs = len(inspect.signature(fn).parameters)
def test_fn(self):
return torch._dynamo.testing.standard_test(
self,
fn=fn,
nargs=nargs,
expected_frame_count=expected_frame_count,
)
return test_fn
class MyCls:
a = 1
@torch.jit.script_if_tracing
def inline_script_if_tracing(x):
return x + 1.2
@torch.jit.ignore
def inline_ignore(x):
return x + 3.4
@torch.jit.unused
def inline_unused(x):
return x + 5.6
@functools.lru_cache
def inline_lru_cache_fn_with_default_args(x, y, _=None):
return torch.sin(x * y)
@torch.jit.script_if_tracing
def inline_script_if_tracing_fn_with_default_args(x, y, c=1.2):
return torch.cos(x * y) + c
class FunctionTests(torch._dynamo.test_case.TestCase):
@make_test
def test_inline_jit_annotations(x):
x = inline_script_if_tracing(x)
x = inline_ignore(x)
x = inline_unused(x)
return
@make_test
def test_inline_script_if_tracing_fn_with_default_args(a, b):
return inline_script_if_tracing_fn_with_default_args(a, b)
@make_test
def test_inline_lru_cache_fn_with_default_args(a, b):
return inline_lru_cache_fn_with_default_args(a, 2, b)
@make_test
def test_add(a, b):
return a + b
@make_test
def test_add_(a, b):
a_copy = torch.tensor(a)
return a_copy.add_(b, alpha=5.0)
@make_test
def test_addcdiv(a, b, c):
# dynamo decomposes this to avoid a graph break when
# the value kwarg is populated
return torch.addcdiv(a, b, c, value=5.0)
@make_test
def test_addcdiv_(a, b, c):
a_copy = torch.tensor(a)
return a_copy.addcdiv_(b, c, value=5.0)
@make_test
def test_is_not_null(a, b):
if a is not None and b is not None:
return a + b
def test_foreach_lerp_(self):
def fn(x, y, s):
return torch._foreach_lerp_(x, y, s)
cnt = torch._dynamo.testing.CompileCounter()
fn_opt = torch.compile(backend=cnt, fullgraph=True)(fn)
expected = fn(
[torch.ones(2, 2) * 4.26, torch.ones(2, 2) * 3.14],
[torch.ones(2, 2), torch.ones(2, 2)],
torch.tensor(0.5),
)
actual = fn_opt(
[torch.ones(2, 2) * 4.26, torch.ones(2, 2) * 3.14],
[torch.ones(2, 2), torch.ones(2, 2)],
torch.tensor(0.5),
)
self.assertTrue(same(expected, actual))
def test_broadcast_foreach_pow(self):
from torch._dynamo.utils import same
def fn(x, y):
return torch._foreach_pow(x, y)
cnt = torch._dynamo.testing.CompileCounter()
fn_opt = torch.compile(backend=cnt, fullgraph=True)(fn)
inps = (torch.tensor(0.80), [torch.tensor(3.4), torch.tensor(7.8)])
actual = fn_opt(*inps)
expected = fn(*inps)
self.assertTrue(same(actual, expected))
self.assertTrue(cnt.frame_count, 1)
def test_addcmul_(self):
from copy import deepcopy
from torch._dynamo.utils import same
def fn(x, y, z, s):
return x.addcmul_(y, z, value=s)
cnt = torch._dynamo.testing.CompileCounter()
fn_opt = torch.compile(backend=cnt, fullgraph=True)(fn)
inps = (
torch.ones(2, 2),
torch.ones(2, 2) + 1,
torch.rand(2, 2),
torch.tensor(0.3),
)
inps_2 = deepcopy(inps)
actual = fn_opt(*inps)
expected = fn(*inps_2)
self.assertTrue(same(actual, expected))
self.assertEqual(cnt.frame_count, 1)
@make_test
def test_functools_partial(a, b):
return clip01(a + b)
@make_test
def test_itertools_product(a, b):
v = a
for x, i in itertools.product([a, b], [1, 2]):
v = v + x * i
return v
@make_test
def test_itertools_chain(a, b):
v = a
for x in itertools.chain([a, b], [1, 2]):
v = v + x
return v
@make_test
def test_itertools_chain_from_iterable(a, b):
v = a
for x in itertools.chain.from_iterable([[a, b], [1, 2]]):
v = v + x
return v
def test_itertools_reconstruct(self):
def fn(a):
it1 = itertools.repeat(1)
it2 = itertools.count(2)
for _ in range(3):
a += next(it1)
a += next(it2)
return it1, it2, a
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
i1, i2, a = fn(torch.ones(3, 3))
it1, it2, b = opt_fn(torch.ones(3, 3))
self.assertEqual(next(i1), next(it1))
self.assertEqual(next(i2), next(it2))
self.assertEqual(a, b)
@make_test
def test_obj_eq(a, b):
v = a + b
if MyCls() == None: # noqa: E711
return -1
if MyCls() != None: # noqa: E711
v = v.sin()
if MyCls() == MyCls():
return -2
if MyCls() != MyCls():
return v + 1
return -3
@make_test
def test_cls_eq(a, b):
v = a + b
if MyCls == None: # noqa: E711
return -1
if MyCls != None: # noqa: E711
v = v.sin()
if MyCls != MyCls:
return -2
if MyCls == MyCls:
return v + 1
return -3
@make_test
def test_obj_is(a, b):
v = a + b
if MyCls() is None: # noqa: E711
return -1
if MyCls() is not None: # noqa: E711
v = v.sin()
if MyCls() is MyCls():
return -2
if MyCls() is not MyCls():
return v + 1
return -3
@make_test
def test_cls_is(a, b):
v = a + b
if MyCls is None: # noqa: E711
return -1
if MyCls is not None: # noqa: E711
v = v.sin()
if MyCls is not MyCls:
return -2
if MyCls is MyCls:
return v + 1
return -3
@make_test
def test_itertools_combinations(a, b):
combs = []
for size in itertools.combinations((1, 2, 3, 4), 2):
combs.append(torch.ones(size))
return combs
@make_test
def test_np_iinfo(a):
max_dim = np.iinfo(np.int16).max
return a + max_dim
@make_test
def test_np_finfo(a):
min_dim = np.finfo(np.float32).min
return a + min_dim
@make_test
def test_constant1(a, b, c):
return a - b * c + 1.0
@make_test
def test_constant2(a, b, c):
return a - b * c + 1
@make_test
def test_constant3(a):
b = 1
c = 2
d = 3
return b + c - d + a
@make_test
def test_constant4(a, b):
c = 2
d = 3
if c > d:
return a - b
return b - a
@make_test
def test_cls_hasattr(self, x):
if hasattr(MyCls, "a"):
x = x + 1
if hasattr(MyCls, "b"):
x = x + 2
return x
@make_test
def test_finfo(a, b):
if torch.iinfo(torch.int32).bits == 32:
return torch.finfo(a.dtype).min * b
@make_test
def test_globalfn(a, b):
return sub(a, b)
@make_test
def test_viatorch(a, b):
return torch.sub(a, b)
@make_test
def test_viamethod(a, b):
return a.sub(b)
@make_test
def test_indirect1(a, b):
t = a.sub
return t(b)
@make_test
def test_indirect2(a, b):
t = a.sub
args = (b,)
return t(*args)
@make_test
def test_indirect3(a, b):
t = a.sub
args = (b,)
kwargs = {}
return t(*args, **kwargs)
@make_test
def test_methodcall1(a, b, c):
return constant3(a, b) * c
@make_test
def test_methodcall2(a, b):
return constant3(a=b, b=a) + 1
@make_test
def test_methodcall3(a, b):
return constant3(a, b=1.0) + b
def test_is_integer(self):
@torch.compile(backend="eager", fullgraph=True)
def forward(t, m):
return 2 * t if m.is_integer() else t
t = torch.tensor([1])
self.assertEqual(forward(t, 1.0).item(), 2)
self.assertEqual(forward(t, 1.5).item(), 1)
@parametrize(
"method, num_type",
(
("as_integer_ratio", int),
("bit_length", int),
("conjugate", int),
("as_integer_ratio", float),
("conjugate", float),
("hex", float),
("is_integer", float),
),
)
def test_number_method(self, method, num_type):
def forward(t, m):
return 2 * t if getattr(m, method)() else t
wrapped = torch.compile(backend="eager", fullgraph=True)(forward)
for i in (0, 1, 2.5):
m = num_type(i)
t = torch.tensor([1])
actual = wrapped(t, m)
expected = forward(t, m)
self.assertEqual(actual, expected)
@make_test
def test_device_constant(a):
return a + torch.ones(1, device=torch.device("cpu"))
@make_test
def test_tuple1(a, b):
args = (a, b)
return sub(*args)
@make_test
def test_tuple2(a, b):
args = [a, b]
return sub(*args)
@make_test
def test_is_in_onnx_export(x, y):
if torch.onnx.is_in_onnx_export():
return x - 1
else:
return y + 1
@make_test
def test_is_fx_tracing(x, y):
if torch.fx._symbolic_trace.is_fx_tracing():
return x - 1
else:
return y + 1
@make_test
def test_listarg1(a, b):
return torch.cat([a, b])
@make_test
def test_listarg2(a, b):
return torch.cat((a, b), dim=0)
@make_test
def test_listarg3(a, b):
kwargs = {"tensors": (a, b), "dim": 0}
return torch.cat(**kwargs)
@make_test
def test_listarg4(a, b):
return torch.cat(tensors=[a, b], dim=0)
@make_test
def test_listarg5(a, b):
args = [(a, b)]
kwargs = {"dim": 0}
return torch.cat(*args, **kwargs)
def test_list_slice(self):
class Mock:
def __init__(self):
self.ets = []
self.counter = 0
@torch.compile(backend="eager")
def run(self, x):
self.ets = self.ets[-3:]
self.ets.append(x)
return torch.sin(x)
mock = Mock()
mock.run(torch.randn(4))
self.assertEqual(len(mock.ets), 1)
@make_test
def test_deque(a, b):
d = collections.deque([a, b])
d.append(a + 1)
d.extend([a, b])
d.insert(0, "foo")
tmp = d.pop()
another_deque = collections.deque([tmp])
d.extendleft(another_deque)
another_deque.clear()
d.extend(another_deque)
d[2] = "setitem"
d = d.copy()
d.append(d.popleft())
empty = collections.deque()
d.extend(empty)
return d
@make_test
def test_slice1(a):
return a[5]
@make_test
def test_slice2(a):
return a[:5]
@make_test
def test_slice3(a):
return a[5:]
@make_test
def test_slice4(a):
return a[2:5]
@make_test
def test_slice5(a):
return a[::2]
@make_test
def test_slice6(a):
return torch.unsqueeze(a, 0)[:, 2:]
@make_test
def test_range1(a):
return torch.tensor(range(a.size(0)))
@make_test
def test_range2(x, y):
r = x + y
for i in range(x.size(0) + 2):
r = r / y
return r
@make_test
def test_unpack1(a):
a, b = a[:5], a[5:]
return a - b
@make_test
def test_unpack2(a):
packed = [a[:5], a[5:]]
a, b = packed
return a - b
@make_test
def test_unpack3(a):
packed = (a[:5], a[5:])
a, b = packed
return a - b
@make_test
def test_fn_with_self_set(a, b):
# avg_pool2d is an odd one with __self__ set
return F.avg_pool2d(
torch.unsqueeze(a, 0) * torch.unsqueeze(b, 1), kernel_size=2, padding=1
)
@make_test
def test_return_tuple1(a, b):
return (a - b, b - a, a, b)
@make_test
def test_globalvar(a, b):
return a - b + d
@make_test
def test_globalmodule(x):
return e(x)
@make_test
def test_inline_with_default(a, b, c):
return func_with_default(a, b) * c
@make_test
def test_inner_function(x):
def fn(x):
return torch.add(x, x)
return fn(x)
@make_test
def test_transpose_for_scores(x):
new_x_shape = x.size()[:-1] + (2, 5)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1)
@make_test
def test_return_tuple2(x):
return (torch.add(x, x), x)
@make_test
def test_load_global_bool(x):
if flag:
return torch.add(x, x)
else:
return x
@make_test
def test_len_tensor(x):
z = len(x)
return torch.add(x, z)
@make_test
def test_len_constant_list(x):
z = len([1, 2, 3])
return torch.add(x, z)
@make_test
def test_len_constant_dict(x):
z = len({"foo": "bar"})
return torch.add(x, z)
@make_test
def test_dict_copy(x):
z = dict({"foo": x + 1})
return z
@make_test
def test_dict_keys(x):
d = {3: x}
keys = d.keys()
d[4] = x + 1
d2 = {3: 2, 4: "aa"}
return 3 in keys, 4 in keys, 5 in keys, d2.keys() == keys
@make_test
def test_dict_values(x):
d = {3: x}
values = d.values()
d[3] = x + 1
d[4] = x + 2
return len(values)
@make_test
def test_dict_setdefault1(x):
d = {"a": 1, "b": 2}
d.setdefault("a", 10)
if d["a"] == 1:
return x + 1
else:
return x - 1
@make_test
def test_dict_setdefault2(x):
d = {"a": 1, "b": 2}
d.setdefault("c", 10)
if d["c"] == 10:
return x + 1
else:
return x - 1
@make_test
def test_dict_setdefault3(x):
d = {"a": 1, "b": 2}
d.setdefault("c")
if d["c"] is None:
return x + 1
else:
return x - 1
@make_test
def test_defaultdict_setdefault1(x):
d = collections.defaultdict.fromkeys("a", "b")
d["a"] = 1
d["b"] = 2
d.setdefault("a", 10)
if d["a"] == 1:
return x + 1
else:
return x - 1
@make_test
def test_defaultdict_setdefault2(x):
d = collections.defaultdict.fromkeys("a", "b")
d["a"] = 1
d["b"] = 2
d.setdefault("c", 10)
if d["c"] == 10:
return x + 1
else:
return x - 1
@make_test
def test_defaultdict_setdefault3(x):
d = collections.defaultdict.fromkeys("a", "b")
d["a"] = 1
d["b"] = 2
d.setdefault("c")
if d["c"] is None:
return x + 1
else:
return x - 1
def test_dict_id_guard(self):
d1 = collections.OrderedDict({"a": 2})
d2 = d1
def fn(x):
# Iteration forces DictGuardManager
for k in d1:
x = x * d1[k] * d2[k]
return x
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
self.assertEqual(fn(x), opt_fn(x))
@make_test
def test_callable_lambda(x):
if callable(lambda x: True):
return x + 1
else:
return x - 1
@make_test
def test_callable_torch(x):
if callable(torch.abs):
return x + 1
else:
return x - 1
@make_test
def test_callable_builtin(x):
if callable(sum):
return x + 1
else:
return x - 1
def test_callable_class(self):
class CallableClass:
def __call__():
pass
class NotCallableClass:
pass
@torch.compile(backend="eager", fullgraph=True)
def fn1(x, arg):
if callable(arg):
return x
return x + 1
@torch.compile(backend="eager", fullgraph=True)
def fn2(x, arg):
if callable(arg):
return x * 2
return x + 1
input = torch.randn(4)
for f in [fn1, fn2]:
self.assertEqual(f(input, NotCallableClass()), input + 1)
self.assertEqual(
f(input, CallableClass()), input if f is fn1 else input * 2
)
# passing tensor and scalars
self.assertEqual(f(input, 1), input + 1)
self.assertEqual(f(input, 1.1), input + 1)
self.assertEqual(f(input, True), input + 1)
self.assertEqual(f(input, input), input + 1)
def test_callable_list(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(x, arg):
if callable(arg):
return x
return x + 1
input = torch.randn(4)
self.assertEqual(fn(input, [1, 2, 3]), input + 1)
self.assertEqual(fn(input, (1, 2, 3)), input + 1)
@make_test
def test_len_constant_misc_iterables(x):
a = len((1, 2, 3))
b = len("test str")
c = a + b
return torch.add(x, c)
@make_test
def test_dict_kwargs(x):
z = dict(text_embed=x + 1, other=x + 2)
return z
@make_test
def test_ordered_dict_kwargs(x):
z = collections.OrderedDict(sample=torch.ones(10))
return z
@make_test
def test_custom_dict_kwargs(x):
z = CustomDictSubclass(sample=torch.ones(10))
return z
@make_test
def test_float(x):
y = float(1.2) # noqa: UP018
y += float("1.2")
return torch.add(x, y)
@make_test
def test_is_floating_point(x):
y = x + 1
return torch.is_floating_point(y), torch.is_floating_point(input=y)
@make_test
def test_dtype(x):
if x.dtype == torch.float32:
return x + 1
@make_test
def test_get_default_dtype(x):
if x.dtype == torch.get_default_dtype():
return x + 1
else:
return x - 1
@make_test
def test_get_autocast_gpu_dtype(x):
dtype = torch.get_autocast_gpu_dtype()
return x.type(dtype)
@make_test
def test_is_any_autocast_enabled(x):
if torch._C._is_any_autocast_enabled():
return x + 1
else:
return x - 1
@make_test
def test_is_checkpoint_valid(x):
if torch.autograd._is_checkpoint_valid():
return x + 1
else:
return x - 1
@make_test
def test_list_compare_polyfill(x):
for a, b, c in [
[(1, 2, 3), (1, 2, 3), 7.77],
[(1, 4, 3), (1, 2, 3), 3.33],
[(1, 2), (1, 2, 3), 5.55],
[(1, 2, 3), (1, 2), 11.11],
[(1, -1, 3), (1, 2, 3), 13.33],
]:
if a != b:
x += 1 * c
if a == b:
x += 2 * c
if a < b:
x += 4 * c
if a > b:
x += 8 * c
if a <= b:
x += 16 * c
if a >= b:
x += 32 * c
return x
@make_test
def test_promote_types(x):
if x.dtype == torch.promote_types(torch.int32, torch.float32):
return x + 1
else:
return x - 1
@make_test
def test_cublas_allow_tf32(x):
if torch.backends.cuda.matmul.allow_tf32:
return x.sin() + 1
return x.cos() - 1
@make_test
def test_get_calculate_correct_fan(x):
fan_in = torch.nn.init._calculate_correct_fan(x, "fan_in")
return x + fan_in
@make_test
def test_is_complex(x):
if torch.is_complex(x):
return x + 1
else:
return x - 1
@make_test
def test_tensor_is_complex(x):
if x.is_complex():
return x + 1
else:
return x - 1
@make_test
def test_get_privateuse1_name(x):
if torch._C._get_privateuse1_backend_name() == "privateuseone":
return x + 1
else:
return x - 1
@make_test
def test_device(x):
if not x.is_cuda:
return x + 1
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
@make_test
def test_get_device_properties_tensor_device(a):
x = a.to("cuda")
prop = torch.cuda.get_device_properties(x.device)
if prop.major == 8:
return x + prop.multi_processor_count
return x + prop.max_threads_per_multi_processor
@make_test
def test_tensor_type(a, b):
m = a.to(torch.float16)
return b.type(m.type())
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
@make_test
def test_tensor_type2(a, b):
m = a.to("cuda")
return m + b.type(m.type())
@make_test
def test_tensor_type3(a, b):
m = a.type(torch.HalfTensor)
return b.type(m.type())
@make_test
def test_tensor_type4(a, b):
m = a.type("torch.HalfTensor")
return b.type(m.type())
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
@make_test
def test_tensor_type5(a, b):
m = a.type(torch.cuda.HalfTensor)
return b.type(m.type())
@make_test
def test_tensor_element_size(a):
if a.element_size() > 1:
return (a + a.element_size(), a - a.element_size())
return (a - a.element_size(), a + a.element_size())
@make_test
def test_ndim(x):
if x.ndim == 2 and x.ndimension() == 2 and x.dim() == 2:
return x + 1
@make_test
def test_T(x):
return torch.ones_like(x.T)
@make_test
def test_mT(x):
return torch.ones_like(x.mT)
@make_test
def test_is_sparse(x):
if not x.is_sparse:
return x + 1
@make_test
def test_shape1(x):
if x.shape[0] == 10:
return x + 1
@make_test
def test_shape2(x):
if x.size(1) == 10:
return x + 1
@make_test
def test_del(a, b):
c = a + 1
d = c + 2
del c, a
return b + d
@make_test
def test_chunks1(x):
chunk_size = 5
assert x.shape[0] % chunk_size == 0
assert x.shape[0] // chunk_size == 2
return x[:chunk_size] - x[chunk_size:]
@make_test
def test_import1(x, y):
import torch
from torch import sub
return sub(torch.add(x, y), y)
@make_test
def test_return_dict(x, y):
z = [x + y, y, False]
return {"x": x, "z": z, "a": x, "b": z, "c": x}
@make_test
def test_return_dict2(x, y):
tmp = {"x": x}
tmp["z"] = [x + y, y]
tmp["y"] = y
tmp["z"].append(False)
return tmp
@make_test
def test_funcdef_closure(x, y):
x = x + y + 1.0
def inner(z):
nonlocal x, y
y = x + z + 20.0
x = y + z + 10.0
inner(2.0)
inner(3.0)
return x, y
@make_test
def test_module_constant(x, y):
r = x + y
for i in range(torch._dynamo.testing.three):
r = r / y
return r
@make_test
def test_inline_softmax(x, y):
# This is common in sme huggingface models
return torch.nn.Softmax(dim=-1)(x + y * 2)
@make_test
def test_dtype_compare(a, b):
if a.dtype == torch.float16:
return a + 10
if a.dtype == torch.float32:
return a - b * 32
@make_test
def test_build_list_unpack(a, b):
it1 = (x + 1 for x in (a, b))
it2 = (x - 1 for x in (a, b))
return torch.cat([*it1, *it2], dim=-1)
@make_test
def test_tensor_len(a, b):
return a + b + len(a) + b.__len__()
@make_test
def test_pop(a, b):
ll = [a, b]
ll.append(a + 1)
ll.extend(
[
b + 2,
a + b,
]
)
ll.pop(-1)
ll.pop(0)
ll.pop()
v1, v2 = ll
return v1 - v2
@make_test
def test_list_convert(a, b):
ll = [a + 2, b]
ll = tuple(ll)
tmp = b + 3
ll = list(ll)
v1, v2 = ll
return v1 - v2 + tmp
@make_test
def test_list_add(a, b):
l1 = (a, b)
l2 = () # being a LOAD_CONST in the bytecode
l3 = l1 + l2
return l3[0] + l3[1]
@make_test
def test_list_index_with_constant_tensor(a, b):
l1 = [a, b, a + 1, b + 1]
return l1[torch.as_tensor(2)]
@make_test
def test_startswith(a, b):
x = a + b
if "foobar".startswith("foo") and "test" in constant3.__module__:
x = x + 1
return x
@make_test
def test_dict_ops(a, b):
tmp = {"a": a + 1, "b": b + 2}
assert tmp.get("zzz") is None
v = tmp.pop("b") + tmp.get("a") + tmp.get("missing", 3) + tmp.pop("missing", 4)
tmp.update({"d": 3})
tmp["c"] = v + tmp["d"]
if "c" in tmp and "missing" not in tmp:
return tmp["c"] - tmp["a"] + len(tmp)
@make_test
def test_inline_jit__unwrap_optional(x):
if torch.jit._unwrap_optional(x) is None:
return torch.ones(2, 2)
return x.sin()
@make_test
def test_zip_longest(x):
list1 = [1, 2, 3]
list2 = ["a", "b"]
list3 = [True, False, True, False]
return torch.sin(x + 1), list(
itertools.zip_longest(list1, list2, list3, fillvalue=None)
)
def test_torch_size_as_dict_key(self):
def fn(x, cached):
if x.shape not in cached:
cached[x.shape] = x
return x + cached[x.shape]
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x1 = torch.randn(2, 3)
x2 = torch.randn(2, 3)
cached = {}
ref1 = fn(x1, cached)
ref2 = fn(x2, cached)
cached = {}
res1 = opt_fn(x1, cached)
res2 = opt_fn(x2, cached)
self.assertEqual(ref1, res1)
self.assertEqual(ref2, res2)
def test_dict_param_keys(self):
a_param = torch.nn.Parameter(torch.ones([4, 4]))
def fn(a):
tmp = {"a": a, a_param: 3}
return tmp["a"] + tmp[a_param]
test = make_test(fn)
test(self)
def test_dict_mutable_map(self):
from collections.abc import MutableMapping
class TensorDict(MutableMapping):
def __init__(self) -> None:
self._dict = {}
def add(self, key, value):
self._dict[key] = value
def items(self):
return self._dict.items()
def __delitem__(self, key):
del self._dict[key]
def __getitem__(self, key):
return self._dict[key]
def __iter__(self):
return iter(self._dict)
def __len__(self):
return len(self._dict)
def __setitem__(self, key, value):
self._dict[key] = value
tensor_dict = TensorDict()
tensor_dict.add("a", torch.ones(4) * 2)
def fn(x):
copy_tensordict = dict(tensor_dict)
return x * copy_tensordict["a"]
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_unpack_mutable_map(self):
from collections.abc import MutableMapping
class TensorDict(MutableMapping):
def __init__(self) -> None:
self._dict = {}
def add(self, key, value):
self._dict[key] = value
def items(self):
return self._dict.items()
def __delitem__(self, key):
del self._dict[key]
def __getitem__(self, key):
return self._dict[key]
def __iter__(self):
return iter(self._dict)
def __len__(self):
return len(self._dict)
def __setitem__(self, key, value):
self._dict[key] = value
tensor_dict = TensorDict()
tensor_dict.add("a", torch.ones(4) * 2)
def gn(x, a=1):
return x * a
def fn(x):
return gn(x, **tensor_dict)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
def _test_default_dict_helper(self, factory):
dd = collections.defaultdict(factory)
param = torch.nn.Parameter(torch.ones([2, 2]))
def fn(x):
dd["a"] = x + 1
dd[param] = 123
dd["c"] = x * 2
return dd["b"], dd
x = torch.randn(10, 10)
ref = fn(x)
opt_fn = torch._dynamo.optimize_assert("eager")(fn)
res = opt_fn(x)
self.assertTrue(same(ref[0], res[0]))
self.assertTrue(same(ref[1]["a"], res[1]["a"]))
self.assertTrue(same(ref[1]["c"], res[1]["c"]))
self.assertTrue(same(ref[1][param], res[1][param]))
def test_default_dict_dict(self):
self._test_default_dict_helper(dict)
def test_default_dict_list(self):
self._test_default_dict_helper(list)
def test_default_dict_tuple(self):
self._test_default_dict_helper(tuple)
def test_default_dict_set(self):
self._test_default_dict_helper(set)
def test_default_dict_lambda(self):
self._test_default_dict_helper(lambda: dict()) # noqa: C408
def test_default_dict_closure(self):
def factory():
return dict() # noqa: C408
self._test_default_dict_helper(factory)
def test_class_dict(self):
class A:
x = 4
y = 5
def __init__(self) -> None:
self.a = 6
a = A()
def fn(x):
if "x" in type(a).__dict__:
return x + 1
return x + 2
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
self.assertEqual(fn(x), opt_fn(x))
def test_default_dict_constr(self):
param = torch.nn.Parameter(torch.ones([2, 2]))
def fn(x):
dd = collections.defaultdict(lambda: dict()) # noqa: C408
dd["a"] = x + 1
dd[param] = 123
dd["c"] = x * 2
dd.update({"b": x * 3})
dd.update([["d", x - 2], ("e", x + 2)])
dd.update(zip("ab", [x + 3, x + 4]))
return dd["b"], dd
x = torch.randn(10, 10)
ref = fn(x)
opt_fn = torch._dynamo.optimize_assert("eager")(fn)
res = opt_fn(x)
self.assertTrue(same(ref[0], res[0]))
self.assertTrue(same(ref[1]["a"], res[1]["a"]))
self.assertTrue(same(ref[1]["b"], res[1]["b"]))
self.assertTrue(same(ref[1]["c"], res[1]["c"]))
self.assertTrue(same(ref[1]["d"], res[1]["d"]))
self.assertTrue(same(ref[1]["e"], res[1]["e"]))
self.assertTrue(same(ref[1][param], res[1][param]))
def test_dict_tuple_lazy_guard(self):
@torch.compile(backend="eager")
def fn(x, y):
return torch.sin(x) * y[1]
fn(torch.randn(3), {1: 1, 2: 2})
# Changing the value of other key should not causing recompilation
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
fn(torch.randn(3), {1: 1, 2: 3})
fn(torch.randn(3), (1, 2, 3))
# Changing the value of index 0, 2 (not 1) should not cause recompilation
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
fn(torch.randn(3), (11, 2, 13))
@make_test
def test_call_dict1(x):
d1 = dict() # noqa: C408
d1["x"] = x + 1
d2 = collections.OrderedDict()
d2["x"] = x + 2
return d1["x"] + d2["x"] + 1
@make_test
def test_call_dict2(x):
d1 = dict() # noqa: C408
d1["x"] = x
d2 = collections.OrderedDict(d1)
if isinstance(d2, collections.OrderedDict):
return x + 1
else:
return x - 1
@make_test
def test_call_dict3(x):
my_list = [("a", x), ("b", x + 1), ("c", x + 2)]
d1 = dict(my_list)
d1["a"] = x + 10
d2 = collections.OrderedDict(my_list)
d2["c"] = x + 20
return d1["a"] + d2["c"] + 1
@make_test
def test_call_dict4(x):
my_list = (("a", x), ("b", x + 1), ("c", x + 2))
d1 = dict(my_list)
d1["a"] = x + 10
d2 = collections.OrderedDict(my_list)
d2["c"] = x + 20
return d1["a"] + d2["c"] + 1
@make_test
def test_call_dict5(x):
my_list = iter([("a", x), ("b", x + 1), ("c", x + 2)])
d1 = dict(my_list)
d1["a"] = x + 10
d2 = collections.OrderedDict(my_list)
d2["c"] = x + 20
return d1["a"] + d2["c"] + 1
@make_test
def test_dict_fromkeys(x, y):
lst = ["a", "b"]
d = dict.fromkeys(lst)
d1 = dict.fromkeys(d, x + 1)
d2 = collections.defaultdict.fromkeys(iter(d1), x - 2)
d3 = collections.OrderedDict.fromkeys(tuple(lst), value=y)
return d1["a"] * d2["b"] + d2["a"] + d1["b"] + d3["a"] + d3["b"] + 1
@make_test
def test_dict_copy(x):
my_list = [("a", x), ("b", x + 1), ("c", x + 2)]
d1 = dict(my_list)
d1["a"] = x + 10
d2 = d1.copy()
d2["a"] = x - 5
d2["b"] = x + 3
d3 = collections.OrderedDict(my_list)
d3["c"] = x + 20
d4 = d3.copy()
d4["c"] = x - 10
return d1["a"] * d2["a"] + d2["b"] + d3["c"] * d4["c"] + 1
@make_test
def test_dict_update(x, y, z):
d = {"a": x, "b": y}
d.update({"a": y - 1})
d.update([("b", z + 1), ["c", z]])
d.update(zip("ab", [z + 3, y + 2]))
od = collections.OrderedDict(a=x * 3, b=y + 2)
od.update({"a": y + 5})
od.update([["b", z + 6], ("c", z - 7)])
od.update(zip("ab", [z - 3, x + 2]))
return d["a"] * od["a"] + od["c"] + d["b"] + od["b"] * d["c"]
@make_test
def test_min_max(a, b):
c = a + b
a = a.sum()
b = b.sum()
a = min(max(a, 0), 1)
b = max(0, min(1, b))
return max(a, b) - min(a, b) + c
@make_test
def test_symbool_to_int(x):
# this is roughly the pattern found in einops.unpack()
if sum(s == -1 for s in x.size()) == 0:
return x + 1
else:
return x - 1
@make_test
def test_map_sum(a, b, c, d):
return sum(map(lambda x: x + 1, [a, b, c, d]))
@make_test
def test_sum(a, b, c, d):
return sum([a, b, c, d])
@make_test
def test_sum_with_start_arg(a, b, c, d):
return sum([b, c, d], a)
@make_test
def test_sum_with_start_kwarg(a, b, c, d):
return sum([b, c, d], start=a)
@make_test(expected_frame_count=0)
def test_sum_shortcut():
return sum([0, 1.0, 2, 3.0])
@make_test(expected_frame_count=0)
def test_sum_shortcut_with_start_arg():
return sum([0, 1.0, 2, 3.0], -10)
@make_test(expected_frame_count=0)
def test_sum_shortcut_with_start_kwarg():
return sum([0, 1.0, 2, 3.0], start=-10)
@make_test
def test_reduce(a, b, c, d):
return functools.reduce(operator.add, [a, b, c, d])
@make_test
def test_reduce_with_initial(a, b, c, d):
return functools.reduce(operator.add, [b, c, d], a)
@make_test(expected_frame_count=0)
def test_reduce_with_single(x):
return functools.reduce(lambda a, b: (a, b), [x])
@make_test(expected_frame_count=0)
def test_reduce_with_single_with_initial(x, y):
return functools.reduce(lambda a, b: (a, b), [y], x)
@make_test(expected_frame_count=0)
def test_reduce_with_none_initial(x):
return functools.reduce(lambda a, b: (a, b), [x], None)
@make_test
def test_tuple_contains(a, b):
v1 = "a"
v2 = "b"
v3 = "c"
vals1 = (v1, v2, v3)
vals2 = ("d", "e", "f")
if "a" in vals1 and "b" not in vals2:
return a + b
return a - b
@unittest.skipIf(
sys.version_info < (3, 9),
"SET_UPDATE was added at Python 3.9",
)
@make_test
def test_set_update_bytecode(x):
# This produces bytecode SET_UPDATE since python 3.9
var = {"apple", "banana", "cherry"}
if isinstance(var, set):
return x + 1
else:
return x - 1
@unittest.skipIf(
sys.version_info < (3, 9),
"SET_UPDATE was added at Python 3.9",
)
@make_test
def test_set_update_list_with_duplicated_items(x):
list1 = ["apple", "banana", "apple"]
list2 = ["orange", "banana"]
if len({*list1, *list2}) == 3:
return x + 1
else:
return x - 1
@make_test
def test_set_contains(a, b):
vals = set(["a", "b", "c"])
if "a" in vals:
x = a + b
else:
x = a - b
if "d" in vals:
y = a + b
else:
y = a - b
return x, y
def test_set_isdisjoint(self):
x = {"apple", "banana", "cherry"}
y = {"google", "microsoft", "apple"}
def fn(a):
if x.isdisjoint(y):
return a + 1
else:
return a - 1
test = make_test(fn)
test(self)
@make_test
def test_set_intersection(a, b):
set1 = {"apple", "banana", "cherry"}
set2 = {"google", "microsoft", "apple"}
intersection_set = set1.intersection(set2)
if "apple" in intersection_set:
x = a + b
else:
x = a - b
if "banana" in intersection_set:
y = a + b
else:
y = a - b
return x, y
@make_test
def test_set_union(a, b):
set1 = {"apple", "banana", "cherry"}
set2 = {"google", "microsoft", "apple"}
union_set = set1.union(set2)
if "apple" in union_set:
x = a + b
else:
x = a - b
if "banana" in union_set:
y = a + b
else:
y = a - b
return x, y
@make_test
def test_set_difference(a, b):
set1 = {"apple", "banana", "cherry"}
set2 = {"google", "microsoft", "apple"}
difference_set = set1.difference(set2)
if "apple" in difference_set:
x = a + b
else:
x = a - b
if "banana" in difference_set:
y = a + b
else:
y = a - b
return x, y
def test_set_keys_view(self):
from collections.abc import KeysView
class StringKeys(KeysView):
def __init__(self, keys):
self.keys = keys
def __getitem__(self, key):
return self.keys.__getitem__(key)
def __iter__(self):
yield from self.keys
def __repr__(self):
return f"{type(self).__name__}({self.keys})"
def __len__(self):
return len(self.keys)
def __contains__(self, item):
return self.keys.__contains__(item)
a = StringKeys([1, 2, 3, 3])
def fn(x):
set_a = set(a)
return len(set_a) * x
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.rand(4)
self.assertEqual(fn(x), opt_fn(x))
def test_constant_set(self):
s = set([1, 2])
def fn(x):
return torch.cos(x) * len(s)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.rand(4)
self.assertEqual(fn(x), opt_fn(x))
# This should cause recompilation
s.add(3)
self.assertEqual(fn(x), opt_fn(x))
def test_set_add(self):
s = set([1, 2])
def fn(x):
s.add(3)
return torch.cos(x) * len(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.rand(4)
self.assertEqual(fn(x), opt_fn(x))
self.assertEqual(len(s), 3)
@make_test
def test_tuple_iadd(a, b):
output = (a, b)
output += (a + b, a - b)
return output
@make_test
def test_unpack_ex1(x):
output = (x, x + 1, x + 2, x + 3)
a, b, *cd = output
return a - b / cd[0]
@make_test
def test_unpack_ex2(x):
output = (x, x + 1, x + 2, x + 3)
*ab, c, d = output
return c - d / ab[0]
@make_test
def test_unpack_ex3(x):
output = (x, x + 1, x + 2, x + 3)
a, *bc, d = output
return a - d / bc[0]
@make_test
def test_const_tuple_add1(x):
output = (x, x + 1, x + 2, x + 3)
output = () + output + ()
return output[2] + output[3]
@make_test
def test_const_tuple_add2(x):
output = (x, x + 1, x + 2, x + 3)
output = (None,) + output + (None,)
return output[2] + output[3]
@make_test
def test_list_truth(a, b):
tmp = [1, 2, 3]
if tmp:
return a + b
else:
return a - b
@make_test
def test_list_reversed(a, b):
tmp = [a + 1, a + 2, a + 3]
return a + b + next(iter(reversed(tmp)))
@make_test
def test_list_sorted1(x):
tmp = [1, 10, 3, 0]
return x + 1, sorted(tmp), sorted(tmp, reverse=True)
@make_test
def test_list_sorted2(x):
y = [
("john", "A", 8),
("jane", "B", 5),
("dave", "B", 10),
]
return (
x + 1,
sorted(y),
sorted(y, key=lambda student: student[2]),
sorted(y, key=lambda student: student[2], reverse=True),
)
@make_test
def test_tuple_sorted(x):
tmp = (1, 10, 3, 0)
return x + 1, sorted(tmp), sorted(tmp, reverse=True)
@make_test
def test_dict_sorted(x):
tmp = {1: "D", 10: "B", 3: "E", 0: "F"}
return x + 1, sorted(tmp), sorted(tmp, reverse=True)
def test_dict_hasattr(self):
def fn(x):
if hasattr(x, "to"):
return x.to("cpu")
if hasattr(x, "items"):
return torch.cos(x["a"])
return x
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = dict(a=torch.randn(3))
self.assertEqual(fn(x), opt_fn(x))
x = torch.randn(4)
self.assertEqual(fn(x), opt_fn(x))
@make_test
def test_list_clear(a, b):
tmp = [a + 1, a + 2]
tmp.clear()
tmp.append(a + b)
return tmp
@make_test
def test_not_list(a):
return not [a + 1]
@make_test
def test_islice_chain(a, b):
tmp1 = [a + 1, a + 2]
tmp2 = [a + 3, a + 4]
a, b = list(itertools.islice(itertools.chain(tmp1, tmp2), 1, 3))
c = next(itertools.islice(tmp1, 1, None))
return a - b / c
@make_test
def test_namedtuple(a, b):
mytuple = collections.namedtuple("mytuple", ["x", "y", "xy"])
tmp = mytuple(a, b, a + b)
return mytuple(tmp.x, tmp[1], tmp.xy + b)
@make_test
def test_namedtuple_defaults(a, b):
mytuple = collections.namedtuple(
"mytuple", ["x", "y", "xy"], defaults=(None, 1, None)
)
tmp = mytuple(a, xy=b)
return mytuple(tmp.x, tmp[1], tmp.xy + b)
class MyNamedTuple(NamedTuple):
first: torch.Tensor
second: torch.Tensor
def add(self) -> torch.Tensor:
return self.first + self.second
@staticmethod
def static_method() -> int:
return 1
@classmethod
def class_method(cls) -> str:
return cls.__name__
@make_test
def test_namedtuple_user_methods(a, b):
mytuple = FunctionTests.MyNamedTuple(a, b)
return mytuple.add(), mytuple.static_method(), mytuple.class_method()
@make_test
def test_namedtuple_hasattr(a, b):
mytuple = FunctionTests.MyNamedTuple(a, b)
def isinstance_namedtuple(obj) -> bool:
return (
isinstance(obj, tuple)
and hasattr(obj, "_asdict")
and hasattr(obj, "_fields")
)
if isinstance_namedtuple(mytuple):
return a + b
else:
return a - b
@make_test
def test_torch_size_hasattr(x):
if hasattr(x.shape, "_fields"):
return x + 1
else:
return x - 1
@make_test
def test_is_quantized(a, b):
if not a.is_quantized:
return a + b
@make_test
def test_fstrings1(a, b):
x = 1.229
tmp = f"{x:.2f} bar"
if tmp.startswith("1.23"):
return a + b
@make_test
def test_fstrings2(x):
tmp = f"{x.shape[0]} bar"
if tmp.startswith("10"):
return x + 1
@make_test
def test_fstrings3(x):
tmp = f"{x.__class__.__name__} foo"
if tmp.startswith("Tensor"):
return x + 1
@make_test
def test_fstrings4(x):
tmp = f"{x.shape[0]} bar"
if "10" in tmp:
return x + 1
@make_test
def test_fstrings5(x):
tmp = f"{x.shape[0]} bar"
if "10" in (tmp + "haha"):
return x + 1
@make_test
def test_fstrings6(x):
tmp = f"{x.shape[0] + x.shape[1]}"
if "20" in tmp:
return x + 1
@make_test
def test_tensor_new_with_size(x):
y = torch.rand(5, 8)
z = x.new(y.size())
assert z.size() == y.size()
@make_test
def test_tensor_new_with_shape(x):
y = torch.rand(5, 8)
z = x.new(y.shape)
assert z.size() == y.size()
@make_test
def test_jit_annotate(x):
y = torch.jit.annotate(Any, x + 1)
return y + 2
@make_test
def test_is_contiguous_memory_format(tensor):
if torch.jit.is_scripting():
return None
elif tensor.is_contiguous(memory_format=torch.contiguous_format):
return tensor + 1
def test_is_contiguous_frame_counts(self):
data = [
torch.rand(10),
torch.rand(2, 3, 32, 32),
torch.rand(2, 3, 32, 32).contiguous(memory_format=torch.channels_last),
torch.rand(10)[::2],
torch.rand(12),
torch.rand(2, 3, 24, 24).contiguous(memory_format=torch.channels_last),
torch.rand(50)[::2],
torch.rand(2, 3, 32, 32)[:, :, 2:-2, 3:-3],
]
# dynamo should recompile for all inputs in static shapes mode
expected_frame_counts_static = [1, 2, 3, 4, 5, 6, 7, 8]
# dynamo should recompile for items 0, 1, 2, 6 in dynamic shapes mode
expected_frame_counts_dynamic = [1, 2, 3, 4, 4, 4, 4, 5]
expected_frame_counts = ifdynstaticdefault(
expected_frame_counts_static, expected_frame_counts_dynamic
)
dynamic = ifdynstaticdefault(False, True)
def func(x):
if x.is_contiguous():
return x + 1
elif x.is_contiguous(memory_format=torch.channels_last):
return x + 2
else:
return x + 3
cnt = torch._dynamo.testing.CompileCounter()
cfunc = torch._dynamo.optimize_assert(cnt, dynamic=dynamic)(func)
assert cnt.frame_count == 0
for i, x in enumerate(data):
expected = func(x)
output = cfunc(x)
self.assertTrue(same(output, expected))
assert cnt.frame_count == expected_frame_counts[i]
@make_test
def test_list_slice_assignment(x):
m = [1, 2, 3, 4]
m[1:] = [6] * (len(m) - 1)
return x + 1
@make_test
def test_distributed_is_available(x):
if torch.distributed.is_available():
return x + 1
else:
return x - 1
@unittest.skipIf(
not torch.distributed.is_available(), "requires distributed package"
)
@make_test
def test_distributed_is_initialized(x):
if torch.distributed.is_initialized():
return x + 1
else:
return x - 1
@disable_translation_validation_if_dynamic_shapes
@make_test
def test_torch_distributions_functions(x):
normal = torch.distributions.Normal(x, torch.tensor(1))
independent = torch.distributions.Independent(normal, 1)
return independent.log_prob(x)
@make_test
def test_context_wrapping_nested_functions_no_closure(x):
@torch.no_grad()
def augment(x: torch.Tensor) -> torch.Tensor:
return (x + 1) * 2
return augment(x)
# # This is to test the new syntax for pattern matching
# # ("match ... case ...") added on python 3.10.
# # Uncomment these test cases if you run on 3.10+
# @make_test
# def test_match_sequence(a):
# point = (5, 8)
# match point:
# case (0, 0):
# return a
# case (0, y):
# return a - y
# case (x, 0):
# return a + x
# case (x, y):
# return a + x - y
# @make_test
# def test_match_mapping_and_match_keys(x):
# param = {"a": 0.5}
# match param:
# case {"a": param}:
# return x * param
# case {"b": param}:
# return x / param
def test_math_radians(self):
def func(x, a):
return x + math.radians(a)
cnt = torch._dynamo.testing.CompileCounter()
cfunc = torch._dynamo.optimize_assert(cnt)(func)
assert cnt.frame_count == 0
x = torch.rand(10)
expected = func(x, 12)
output = cfunc(x, 12)
self.assertTrue(same(output, expected))
assert cnt.frame_count == 1
@make_test
def test_numpy_meshgrid(x, y):
r1, r2 = np.meshgrid(x.numpy(), y.numpy())
return torch.from_numpy(r1), torch.from_numpy(r2)
@make_test
def test_torch_from_numpy(x):
a = x.numpy()
b = torch.from_numpy(a)
if b.size(0) == 1:
return torch.tensor(True)
else:
return torch.tensor(False)
@make_test
def test_numpy_size(x):
a = x.numpy()
return a.size
@make_test
def test_numpy_attributes(x):
a = x.numpy()
return (
a.itemsize,
a.strides,
a.shape,
a.ndim,
a.size,
torch.from_numpy(a.T),
torch.from_numpy(a.real),
torch.from_numpy(a.imag),
)
@make_test
def test_mean_sum_np(x: torch.Tensor):
x_mean = np.mean(x.numpy(), 1)
x_sum = np.sum(x_mean)
x_sum_array = np.asarray(x_sum)
return torch.from_numpy(x_sum_array)
@make_test
def test_return_numpy_ndarray(x):
a = x.numpy()
return a.T
@make_test
def test_return_multiple_numpy_ndarray(x):
a = x.numpy()
return a.T, a.imag, a.real
@make_test
def test_ndarray_method(x):
a = x.numpy()
return a.copy()
@make_test
def test_ndarray_transpose(x):
a = x.numpy()
return a.transpose(0, 1)
@make_test
def test_ndarray_reshape(x):
a = x.numpy()
return a.reshape([1, a.size])
@make_test
def test_ndarray_methods_returning_scalar(x):
a = x.numpy()
return a.max(axis=0), a.all(axis=0)
@make_test
def test_ndarray_builtin_functions(x):
a = x.numpy()
return a + a, a - a
@make_test
def test_numpy_dtype_argument_to_function(x):
return np.ones_like(x, dtype=np.float64)
@make_test
def test_numpy_dtype_call_in_function(x):
dt = np.dtype("float")
return np.full_like(x, 2.4, dtype=dt)
@make_test
def test_numpy_linalg(x):
return np.linalg.norm(x.numpy(), axis=0)
@make_test
def test_numpy_fft(x):
return np.fft.fftshift(x.numpy())
@make_test
def test_numpy_random():
x = np.random.randn(2, 2)
return x - x
@make_test
def test_partials_torch_op_kwarg(x):
par_mul = functools.partial(torch.mul, other=torch.ones(10, 10))
return par_mul(x)
@make_test
def test_partials_torch_op_arg(x):
par_mul = functools.partial(torch.mul, torch.ones(10, 10))
return par_mul(x)
@make_test
def test_partials_udf_arg(x):
par_mul = functools.partial(udf_mul, torch.ones(10, 10))
return par_mul(x)
@make_test
def test_list_add_then_mutate(x):
my_list = [1, x]
y = x / 4.0
my_list = my_list + [x / 2.0, 4]
my_list.append(y)
return sum(my_list)
@make_test
def test_list_expand_lhs(x):
return sum(4 * [x])
@make_test
def test_in_not_in(x):
mylist = [1, 2, 3, 4, 5, x]
myotherlist = [1, 2, 3, 4, 5]
assert 3 in mylist
assert 6 not in myotherlist
return sum(mylist)
@make_test
def test_are_functorch_transforms_active(x):
if torch._C._are_functorch_transforms_active():
return x + 1
else:
return x - 1
@make_test
def test_partials_udf_kwarg(x):
par_mul = functools.partial(udf_mul, y=torch.ones(10, 10))
return par_mul(x)
@make_test
def test_partials_udf_kwarg_module(x, y):
par_mod = functools.partial(udf_module, mod=SmallNN())
return par_mod(x=x, y=y)
@make_test
def test_partials_udf_kwarg_method(x, y):
par_mod = functools.partial(udf_module, mod=SmallNN().forward)
return par_mod(x=x, y=y)
@make_test
def test_partials_lambda(x):
multiply = lambda x, y: x * y
triple = functools.partial(multiply, y=3)
return triple(x)
@unittest.skipUnless(torch.distributed.is_available(), "requires torch.distributed")
@make_test
def test_flat_param_same_storage_size(x, y):
import torch.distributed.fsdp._flat_param as flat_param
if flat_param._same_storage_size(x, 100):
x = x + 1
else:
x = x - 1
if flat_param._same_storage_size(y, 123):
y = y + 1
else:
y = y - 1
return x, y
@parametrize(
"attr",
(
# True
"__subclasshook__",
"__lt__",
"__hash__",
"__ge__",
"__le__",
"__gt__",
"__dict__",
"__getattribute__",
"__setattr__",
"__doc__",
"__repr__",
"__dir__",
"__init__",
"__new__",
"__class__",
"__eq__",
"__delattr__",
"__reduce__",
"__module__",
"__format__",
"__str__",
"__sizeof__",
"__ne__",
"__call__",
"__reduce_ex__",
"__init_subclass__",
"args",
"keywords",
"func",
# False
"__code__",
"__kwdefaults__",
"__defaults__",
"__name__",
"__annotations__",
"__get__",
"__builtins__",
"__qualname__",
"__globals__",
"__closure__",
),
)
def test_partials_hasattr(self, attr):
def fn(t):
f = lambda x, y: torch.sin(x) + torch.cos(y)
p = functools.partial(f, y=t)
if hasattr(p, attr):
return p(t)
else:
return torch.zeros_like(t)
t = torch.randn(3, 4)
counter = torch._dynamo.testing.CompileCounter()
opt_fn = torch.compile(fullgraph=True, backend=counter)(fn)
self.assertEqual(opt_fn(t), fn(t))
self.assertGreater(counter.frame_count, 0)
@unittest.expectedFailure
def test_partials_hasattr_set_attr(self):
def fn(t):
f = lambda x, y: torch.sin(x) + torch.cos(y)
p = functools.partial(f, y=t)
p.__name__ = "test"
if hasattr(p, "__name__"):
return p(t)
else:
return torch.zeros_like(t)
t = torch.randn(3, 4)
counter = torch._dynamo.testing.CompileCounter()
opt_fn = torch.compile(fullgraph=True, backend=counter)(fn)
self.assertEqual(opt_fn(t), fn(t))
def test_filter(self):
def fn(inputs):
out = inputs[0]
for inp in filter(lambda x: (x.requires_grad), inputs):
out = out * inp
return out
input1 = torch.arange(2, dtype=torch.bfloat16)
input2 = torch.arange(2, dtype=torch.bfloat16).requires_grad_(True)
inputs = [input1, input2]
opt_fn = torch.compile(fullgraph=True)(fn)
self.assertEqual(opt_fn(inputs), fn(inputs))
def test_filter_fallback(self):
def fn(inputs):
out = inputs[0]
for inp in filter(lambda x: x[0] == 1, inputs):
out = out * inp
return out
input1 = torch.ones(2, dtype=torch.bfloat16)
input2 = torch.arange(2, dtype=torch.bfloat16)
inputs = [input1, input2]
opt_fn = torch.compile()(fn)
self.assertEqual(opt_fn(inputs), fn(inputs))
torch._dynamo.reset()
with self.assertRaises(torch._dynamo.exc.Unsupported):
opt_fn = torch.compile(fullgraph=True)(fn)
opt_fn(inputs)
def test_pow_int(self):
def fn(a, b):
return torch.pow(a, b)
x = torch.ones(2, 2)
opt_fn = torch.compile(fullgraph=True, backend="eager", dynamic=True)(fn)
self.assertEqual(opt_fn(x, 2), fn(x, 2))
def test_tensor_size_indexed_by_symint(self):
def fn(x, y):
index = x.shape[-1]
return x + y.shape[index]
x = torch.rand(10, 2)
y = torch.rand(10, 8, 6)
opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
self.assertEqual(opt_fn(x, y), fn(x, y))
def test_partials_as_input_partials_lambda(self):
def fn(f0, f1, x):
return f0(x) * f1(x)
multiply = lambda x, y: x * y
lambda0 = functools.partial(multiply, y=3)
lambda1 = functools.partial(multiply, y=2)
cnts = torch._dynamo.testing.CompileCounter()
torch._dynamo.optimize(cnts, nopython=True)(fn)(
lambda0, lambda1, torch.randn(2, 2)
)
self.assertEqual(cnts.frame_count, 1)
def test_partials_as_input_partials_mod(self):
def fn(f0, f1, x):
return f0(x) * f1(x)
lambda0 = functools.partial(SmallNN(), y=torch.randn(2, 2))
lambda1 = functools.partial(SmallNN(), y=torch.randn(2, 2))
cnts = torch._dynamo.testing.CompileCounter()
x = torch.randn(2, 2)
dynamo_result = torch._dynamo.optimize(cnts, nopython=True)(fn)(
lambda0, lambda1, x
)
self.assertEqual(cnts.frame_count, 1)
eager_result = fn(lambda0, lambda1, x)
self.assertEqual(eager_result, dynamo_result)
def test_partials_as_input_UDF(self):
def fn(f0, f1, x):
return f0(x) * f1(x)
lambda0 = functools.partial(udf_mul, y=torch.randn(2, 2))
lambda1 = functools.partial(udf_mul, y=torch.randn(2, 2))
cnts = torch._dynamo.testing.CompileCounter()
x = torch.randn(2, 2)
dynamo_result = torch._dynamo.optimize(cnts, nopython=True)(fn)(
lambda0, lambda1, x
)
self.assertEqual(cnts.frame_count, 1)
eager_result = fn(lambda0, lambda1, x)
self.assertEqual(eager_result, dynamo_result)
def test_partials_graph_break_reconstruct(self):
def fn(udf_mul_0, udf_mul_1, x):
lambda0 = functools.partial(udf_mul_0, y=x)
lambda1 = functools.partial(udf_mul_1, y=x)
print("break")
return torch.mul(lambda0(x), lambda1(x))
backend = EagerAndRecordGraphs()
cnts = CompileCounterWithBackend(backend)
x = torch.randn(2, 2)
dynamo_result = torch._dynamo.optimize(cnts)(fn)(udf_mul, udf_mul, x)
eager_result = fn(udf_mul, udf_mul, x)
gm = backend.graphs[0]
self.assertEqual(eager_result, dynamo_result)
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_lambda0_keywords_y_: "f32[2, 2]"):
l_lambda0_keywords_y_ = L_lambda0_keywords_y_
mul: "f32[2, 2]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
mul_1: "f32[2, 2]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
mul_2: "f32[2, 2]" = torch.mul(mul, mul_1); mul = mul_1 = None
return (mul_2,)
""",
)
else:
self.assertExpectedInline(
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"):
l_lambda0_keywords_y_ = L_lambda0_keywords_y_
mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
mul_1: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
mul_2: "f32[s0, s0]" = torch.mul(mul, mul_1); mul = mul_1 = None
return (mul_2,)
""",
)
def test_partials_graph_break_reconstruct_mix(self):
def fn(udf_mul_0, udf_add_1, x):
lambda0 = functools.partial(udf_mul_0, y=x)
lambda1 = functools.partial(udf_add_1, x)
print("break")
return torch.mul(lambda0(x), lambda1(x))
backend = EagerAndRecordGraphs()
cnts = CompileCounterWithBackend(backend)
x = torch.randn(2, 2)
dynamo_result = torch._dynamo.optimize(cnts)(fn)(udf_mul, udf_add, x)
eager_result = fn(udf_mul, udf_add, x)
gm = backend.graphs[0]
self.assertEqual(eager_result, dynamo_result)
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_lambda0_keywords_y_: "f32[2, 2]"):
l_lambda0_keywords_y_ = L_lambda0_keywords_y_
mul: "f32[2, 2]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
add: "f32[2, 2]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
mul_1: "f32[2, 2]" = torch.mul(mul, add); mul = add = None
return (mul_1,)
""",
)
else:
self.assertExpectedInline(
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"):
l_lambda0_keywords_y_ = L_lambda0_keywords_y_
mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
add: "f32[s0, s0]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
mul_1: "f32[s0, s0]" = torch.mul(mul, add); mul = add = None
return (mul_1,)
""",
)
def test_partials_graph_break_reconstruct_mix_no_source(self):
def fn(udf_mul_0, x):
udf_add_1 = lambda x, y: x + y
lambda0 = functools.partial(udf_mul_0, y=x)
lambda1 = functools.partial(udf_add_1, x)
print("break")
return torch.mul(lambda0(x), lambda1(x))
backend = EagerAndRecordGraphs()
cnts = CompileCounterWithBackend(backend)
x = torch.randn(2, 2)
dynamo_result = torch._dynamo.optimize(cnts)(fn)(udf_mul, x)
eager_result = fn(udf_mul, x)
gm = backend.graphs[0]
self.assertEqual(eager_result, dynamo_result)
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_lambda0_keywords_y_: "f32[2, 2]"):
l_lambda0_keywords_y_ = L_lambda0_keywords_y_
mul: "f32[2, 2]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
add: "f32[2, 2]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
mul_1: "f32[2, 2]" = torch.mul(mul, add); mul = add = None
return (mul_1,)
""",
)
else:
self.assertExpectedInline(
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"):
l_lambda0_keywords_y_ = L_lambda0_keywords_y_
mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
add: "f32[s0, s0]" = l_lambda0_keywords_y_ + l_lambda0_keywords_y_; l_lambda0_keywords_y_ = None
mul_1: "f32[s0, s0]" = torch.mul(mul, add); mul = add = None
return (mul_1,)
""",
)
def test_partials_graph_break_reconstruct_args_and_kwargs(self):
def fn(udf_mul_0, x):
lambda0 = functools.partial(udf_mul_0, x, 4, z=x)
lambda1 = functools.partial(udf_mul_0, 4, z=x)
return torch.mul(lambda0(), lambda1(5))
backend = EagerAndRecordGraphs()
cnts = CompileCounterWithBackend(backend)
x = torch.randn(2, 2)
dynamo_result = torch._dynamo.optimize(cnts)(fn)(udf_mul2, x)
eager_result = fn(udf_mul2, x)
gm = backend.graphs[0]
self.assertEqual(eager_result, dynamo_result)
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[2, 2]"):
l_x_ = L_x_
mul: "f32[2, 2]" = l_x_ * 4
mul_1: "f32[2, 2]" = mul * l_x_; mul = None
mul_2: "f32[2, 2]" = 20 * l_x_; l_x_ = None
mul_3: "f32[2, 2]" = torch.mul(mul_1, mul_2); mul_1 = mul_2 = None
return (mul_3,)
""",
)
else:
self.assertExpectedInline(
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0, s0]"):
l_x_ = L_x_
mul: "f32[s0, s0]" = l_x_ * 4
mul_1: "f32[s0, s0]" = mul * l_x_; mul = None
mul_2: "f32[s0, s0]" = 20 * l_x_; l_x_ = None
mul_3: "f32[s0, s0]" = torch.mul(mul_1, mul_2); mul_1 = mul_2 = None
return (mul_3,)
""",
)
def test_partials_recompilation(self):
def fn(f0, f1, x):
return f0(x) * f1(x)
lambda0 = functools.partial(udf_mul, y=torch.randn(2, 2))
lambda1 = functools.partial(udf_mul, y=torch.randn(2, 2))
cnts = torch._dynamo.testing.CompileCounter()
x = torch.randn(2, 2)
fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
dynamo_result = fn(lambda0, lambda1, x)
self.assertEqual(cnts.frame_count, 1)
fn(lambda1, lambda0, x)
self.assertEqual(
cnts.frame_count, 1
) # No recompile! Tensor and udf_mul guarded
lambda2 = functools.partial(udf_mul, y=torch.randn(3, 3))
x = torch.randn(3, 3)
fn(lambda2, lambda2, x)
self.assertEqual(cnts.frame_count, 2) # Recompile! Tensor size changed
multiply = lambda x, y: x * y
lambda3 = functools.partial(multiply, y=torch.randn(3, 3))
x = torch.randn(3, 3)
fn(lambda3, lambda3, x)
self.assertEqual(cnts.frame_count, 3) # Recompile! func id changed
def fn2(f0, f1, args):
return f0(*args) * f1(*args)
cnts = torch._dynamo.testing.CompileCounter()
x = torch.randn(2, 2)
fn2 = torch._dynamo.optimize(cnts, nopython=True)(fn2)
dynamo_result = fn2(lambda0, lambda1, [x])
self.assertEqual(cnts.frame_count, 1) # start over
lambda4 = functools.partial(multiply, y=3, x=torch.randn(3, 3))
fn2(lambda4, lambda4, [])
self.assertEqual(cnts.frame_count, 2) # Recompile! Different kwarg keys
lambda5 = functools.partial(multiply, 1)
x = torch.randn(3, 3)
fn2(lambda5, lambda5, [x])
self.assertEqual(cnts.frame_count, 3) # Recompile! Different arg keys
lambda6 = lambda x: x + x
fn2(lambda6, lambda6, [x])
self.assertEqual(
cnts.frame_count, 4
) # Recompile! input is no longer a functools partial
def test_manual_seed(self):
@torch.compile
def foo():
torch.manual_seed(3)
return torch.randint(0, 5, (5,))
self.assertEqual(foo(), foo())
self.assertEqual(foo(), foo())
def test_partial_across_graph_break_uninvoked(self):
from functools import partial
def bar(x, **kwargs):
return x + x
@torch.compile(backend="eager", dynamic=True)
def foo(x, i):
def inner():
print("this is a graph_break")
return op(x)
op = partial(bar, dim=10)
x = inner()
op = partial(bar, other=10)
return inner() + x
foo(torch.rand(1), 10)
def test_no_recompile_inner_function(self):
def forward(inp):
def g(y):
return inp + y
print("graph break")
return g(torch.rand([1]))
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(forward)
input = torch.rand([2])
_ = opt_fn(input)
_ = opt_fn(input)
_ = opt_fn(input)
# Should not have recompiled
self.assertEqual(cnts.frame_count, 1)
def test_no_recompile_inner_lambda(self):
def forward(inp):
g = lambda y: inp + y
print("graph break")
return g(torch.rand([1]))
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(forward)
input = torch.rand([2])
_ = opt_fn(input)
_ = opt_fn(input)
_ = opt_fn(input)
# Should not have recompiled
self.assertEqual(cnts.frame_count, 1)
def test_complex_closure(self):
@torch.compile
def forward(y):
def a():
def x(z):
return y + z
return x
return a()
input1 = torch.rand([2])
input2 = torch.rand([2])
res = forward(input1)(input2)
self.assertTrue(same(res, input1 + input2))
def test_non_inlined_closure(self):
@torch.compile()
def program(x, y):
one = lambda x, y: x + y
def inner():
# Force no inlining
torch._dynamo.graph_break()
return one(x, y)
res = inner()
one = lambda x, y: x - y
res += inner()
return res
input1 = torch.randn(1)
input2 = torch.randn(1)
self.assertTrue(same(program(input1, input2), input1 + input1))
@parametrize("int_or_float", ("int", "float"))
def test_np_constant_collections_as_input(self, int_or_float):
info_func = getattr(np, f"{int_or_float[0]}info")
dt_string_arg = f"{int_or_float}16"
np_dt_attr = getattr(np, dt_string_arg)
dt_args = [dt_string_arg, np_dt_attr]
arg_variants_iter = itertools.chain(
dt_args, map(np.dtype, dt_args), map(info_func, dt_args)
)
def func(a, b, info_or_dt):
return a + info_func(info_or_dt).max
opt_fn = torch.compile(func)
a = torch.randn(2)
b = torch.randn(2)
eager_result = func(a, b, dt_args[0])
for arg in arg_variants_iter:
opt_result = opt_fn(a, b, arg)
self.assertTrue(same(opt_result, eager_result))
@parametrize(
"typ, info_func",
[
(int, np.iinfo),
(float, np.finfo),
],
name_fn=lambda t, _: t.__name__,
)
def test_np_constant_collections_guards(self, typ, info_func):
def func_info(a, info):
return a + info.max
def func_dtype(a, dt):
return a + info_func(dt).max
dt_args = [
np.dtype(typ),
np.ones((1,), dtype=typ).dtype,
np.dtype(np.dtype(typ).name),
np.dtype(typ.__name__),
]
cnts_1 = torch._dynamo.testing.CompileCounter()
opt_fn_dtype = torch._dynamo.optimize(cnts_1)(func_dtype)
a = torch.zeros(3, dtype=typ)
for arg in dt_args:
r = opt_fn_dtype(a, arg)
# each should produce an identical arg
self.assertEqual(cnts_1.frame_count, 1)
cnts_2 = torch._dynamo.testing.CompileCounter()
opt_fn_info = torch._dynamo.optimize(cnts_2)(func_info)
info_args = [info_func(dt) for dt in dt_args]
for arg in info_args:
r = opt_fn_info(a, arg)
# each should produce an identical arg
self.assertEqual(cnts_2.frame_count, 1)
if typ is float:
dt_extra = np.dtype(np.float16)
else:
dt_extra = np.dtype(np.int16)
info_extra = info_func(dt_extra)
eager_result_dtype = func_dtype(a, dt_extra)
compile_result_dtype = opt_fn_dtype(a, dt_extra)
self.assertEqual(cnts_1.frame_count, 2)
self.assertEqual(eager_result_dtype, compile_result_dtype)
eager_result_info = func_info(a, info_extra)
compile_result_info = opt_fn_info(a, info_extra)
self.assertEqual(cnts_2.frame_count, 2)
self.assertEqual(eager_result_info, compile_result_info)
def test_compare_constant_and_tensor(self):
for op in [
operator.lt,
operator.le,
operator.gt,
operator.ge,
operator.ne,
operator.eq,
operator.is_,
operator.is_not,
]:
with self.subTest(op=op):
def fn(x):
return op(-10, x)
opt_fn = torch.compile(fullgraph=True)(fn)
x = torch.randn(10)
self.assertEqual(opt_fn(x), fn(x))
def test_pos(self):
def fn(x, y):
return operator.pos(x) * +y
opt_fn = torch.compile(fullgraph=True, dynamic=True)(fn)
def test(x, y):
self.assertEqual(opt_fn(x, y), fn(x, y))
test(torch.ones(4), 1)
test(1, torch.ones(4))
test(-1, -1)
test(-1.1, 1.1)
test(True, False)
test(torch.ones(4, dtype=torch.float32), 1.1)
def test_index(self):
def fn(x, t):
v = operator.index(x)
torch.mul(t, v)
def test(a, b):
self.assertEqual(opt_fn(a, b), fn(a, b))
for dynamic in [True, False]:
torch._dynamo.reset()
opt_fn = torch._dynamo.optimize(dynamic=dynamic)(fn)
t = torch.ones(1)
test(10, t)
test(-100, t)
test(10, t)
test(False, t)
test(True, t)
def test_truth(self):
def fn(x, y):
return operator.truth(x) and bool(y)
opt_fn = torch.compile(fullgraph=True, dynamic=False)(fn)
def test(x, y):
self.assertEqual(opt_fn(x, y), fn(x, y))
test(1, 100)
test(-1.1, True)
test(-1.1, 1.1)
test(True, False)
test(torch.ones(1), 1)
test(torch.zeros(1), 1)
test(torch.ones(1), torch.ones(1))
def test_unary_fold_op(self):
for op in (operator.abs, abs, operator.neg, operator.pos, operator.truth):
with self.subTest(op=op):
def fn():
a = range(-10, 10)
return list(map(op, a))
opt_fn = torch._dynamo.optimize(nopython=True)(fn)
self.assertEqual(opt_fn(), fn())
def test_unary_fold_op_seq(self):
for op in (operator.length_hint,):
with self.subTest(op=op):
def fn():
a = [tuple(range(-10, i)) for i in range(10)]
return tuple(map(op, a))
opt_fn = torch._dynamo.optimize(nopython=True)(fn)
self.assertEqual(opt_fn(), fn())
def gen_random_range_args(self):
args_count = random.randint(1, 3)
args = [random.randint(-10, 10) for _ in range(args_count)]
if args_count == 3 and args[2] == 0:
args[2] = 1
return args
def test_range_length(self):
def test(*args, expected=None):
r = range(*args)
range_variable = RangeVariable([ConstantVariable.create(v) for v in args])
self.assertEqual(len(r), range_variable.range_length())
if expected is not None:
self.assertEqual(len(r), expected)
test(1, 1, 1, expected=0)
test(1, 0, expected=0)
test(-10, expected=0)
test(4, expected=4)
test(10, expected=10)
# step >1
test(1, 10, 2, expected=5)
# negative step
test(10, 1, -1, expected=9)
test(10, 1, -3)
# Fuzz testing
for i in range(100):
args = self.gen_random_range_args()
print("testing :", args)
test(*args)
def test_indexed_range(self):
def test(range, index, expected=None):
range_variable = RangeVariable(
[
ConstantVariable.create(v)
for v in [range.start, range.stop, range.step]
]
)
self.assertEqual(
range[index],
range_variable.apply_index(index).as_python_constant(),
)
if expected is not None:
self.assertEqual(range[index], expected)
test(range(10), 1, expected=1)
test(range(10, 20, 2), 1, expected=12)
# Fuzz testing
for i in range(100):
range_args = self.gen_random_range_args()
r = range(*range_args)
if len(r) == 0:
continue
index = random.randint(0, len(r) - 1)
print("testing:", r, index)
test(r, index)
def test_sliced_range(self):
def test(range, slice, expected=None):
range_variable = RangeVariable(
[
ConstantVariable.create(v)
for v in [range.start, range.stop, range.step]
]
)
self.assertEqual(
range[slice],
range_variable.apply_slice(slice).as_python_constant(),
)
if expected is not None:
self.assertEqual(
range[slice],
expected,
)
test(range(10), slice(1, 10, 2), expected=range(1, 10, 2))
test(range(10), slice(None, 10, None), expected=range(0, 10))
test(range(10), slice(-1, 7, None), expected=range(9, 7))
test(range(10), slice(-1, 7, 2), expected=range(9, 7, 2))
test(range(1, 10, 2), slice(3, 7, 2), expected=range(7, 11, 4))
test(range(1, 10, 2), slice(-3, 7, 2), expected=range(5, 11, 4))
test(range(-1, -5, -3), slice(5, None, -3), expected=range(-4, 2, 9))
def rand_slice():
def flip_coin():
# 1 out of 10
return random.randint(1, 10) == 5
def r_item(allow_zero=True):
i = random.randint(-10, 10)
if not allow_zero and i == 0:
i = 1
if flip_coin():
i = None
return i
arg_count = random.randint(1, 3)
if arg_count == 1:
return slice(r_item())
elif arg_count == 2:
return slice(r_item(), r_item())
else:
return slice(r_item(), r_item(), r_item(False))
# Fuzz testing
for i in range(100):
range_args = self.gen_random_range_args()
r = range(*range_args)
# generate random slice
s = rand_slice()
print("testing:", r, s)
test(r, s)
def test_range_with_slice_index(self):
def fn(x):
acc = 1
for k in range(2)[1::2]:
acc *= acc * k
return x * acc
opt_fn = torch.compile(fullgraph=True)(fn)
x = torch.ones(1)
self.assertEqual(opt_fn(x), fn(x))
def test_range_with_index(self):
def fn(x):
acc = 1
acc *= acc * range(10, 20, 2)[2]
return x * acc
opt_fn = torch.compile(fullgraph=True)(fn)
x = torch.ones(1)
self.assertEqual(opt_fn(x), fn(x))
def test_rand_inlined(self):
@torch.compile(backend="eager", dynamic=True)
def fn():
idx_size = [10]
idx_size[random.randint(0, 0)] = random.randint(1, 8)
t = tuple(idx_size)
src_size = [random.randint(1, 5) + s for s in idx_size]
idx = torch.empty(t)
fn()
def test_rand_tensor_partial(self):
from collections import namedtuple
from functools import partial
SdpaShape = namedtuple(
"Sdpa_Shape", ["batch", "num_heads", "seq_len", "head_dim"]
)
@torch.compile(backend="eager")
def func():
make_tensor = partial(
torch.rand, device="cpu", dtype=torch.float16, requires_grad=True
)
bsz, num_heads, seq_len_q, seq_len_kv, head_dim = (16, 16, 128, 128, 16)
make_q_tensor = partial(
make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim)
)
make_kv_tensor = partial(
make_tensor, SdpaShape(bsz, num_heads, seq_len_kv, head_dim)
)
t1 = make_q_tensor()
t2 = make_kv_tensor()
t3 = t1 + t2
func()
def test_to(self):
@torch.compile(backend="eager")
def fn():
t = torch.ones(2)
y = t.to("meta")
fn()
def test_elipsis(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(a, ind, val):
a[ind] = val
return a
arr = np.zeros(4)
self.assertEqual(fn(arr, np.s_[...], np.ones(4)), np.ones(4))
arr = np.array([[1, 1], [2, 2]])
self.assertEqual(
fn(arr, np.s_[0, ...], np.zeros(2)), np.array([[0, 0], [2, 2]])
)
arr = np.array([[1, 1], [2, 2]])
self.assertEqual(
fn(arr, np.s_[1, ...], np.zeros(2)), np.array([[1, 1], [0, 0]])
)
arr = np.array([[1, 1], [2, 2]])
self.assertEqual(
fn(arr, np.s_[..., 0], np.array([3, 3])), np.array([[3, 1], [3, 2]])
)
arr = np.array([[1, 1], [2, 2]])
self.assertEqual(
fn(arr, np.s_[..., 1], np.array([3, 3])), np.array([[1, 3], [2, 3]])
)
def test_map_return(self):
def fn(a, b):
return map(lambda x: x + 1, [a, b])
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
m = opt_fn(torch.randn(3, 3), torch.randn(3, 3))
self.assertIsInstance(m, map)
@make_test
def test_map_max(a, b):
return max(map(lambda x: x.sum(), [a, b]))
# max(map(...)) graph breaks
@unittest.expectedFailure
@make_test
def test_map_max_const(a):
return max(map(lambda x: x, [1, 2, 3])), a + 1
@make_test
def test_map_list(a, b):
return list(map(lambda x: x + 1, [a, b]))
@make_test
def test_map_tuple(a, b):
return tuple(map(lambda x: x + 1, [a, b]))
@make_test
def test_map_iter(a, b):
it = iter(map(lambda x: x + 1, [a, b]))
return next(it)
@make_test
def test_map_zip_dict(a):
d = dict(
zip(
map(lambda x: x + 1, [0, 1, 2]),
[map(lambda x: x - 1, [y]) for y in [3, 4, 5]],
)
)
return list(d[3])[0], a + 1 # noqa: RUF015
@make_test
def test_map_dict_fromkeys(a):
return dict.fromkeys(map(lambda x: x + 1, [0, 1])), a + 1
@make_test
def test_map_set(a):
return set(map(lambda x: x + 1, [0, 1])), a + 1
# test_map_sum defined earlier
@make_test
def test_map_reduce(a, b):
return functools.reduce(lambda x, y: x + y, map(lambda x: x + 1, [a, b]))
@make_test
def test_map_sorted(a):
return sorted(map(lambda x: x + 1, [0, 4, 3, 1, 2])), a + 1
@make_test
def test_map_list_extend(a, b, c):
l = [a]
l.extend(map(lambda x: x + 1, [b, c]))
return l
@make_test
def test_map_list_slice_assign(a, b, c, d, e):
l = [a, b, c]
l[1:2] = map(lambda x: x + 1, [d, e])
return l
@make_test
def test_map_deque_extendleft(a, b, c):
d = collections.deque([a])
d.extendleft(map(lambda x: x + 1, [b, c]))
return d
@make_test
def test_map_str_join(a):
return "".join(map(lambda x: x, ["a", "b", "c"])), a + 1
def test_map_with_graph_break(self):
def f(a):
a += 1
def g(x):
nonlocal a
a += 1
return x + 1
m = map(g, [1, 2, 3, 4, 5])
a += next(m) # won't graph break
torch._dynamo.graph_break()
a += next(m) # will graph break
return a
cnts = torch._dynamo.testing.CompileCounter()
opt_f = torch.compile(f, backend=cnts)
self.assertEqual(f(torch.ones(3, 3)), opt_f(torch.ones(3, 3)))
self.assertEqual(cnts.frame_count, 3)
def test_map_reconstruct(self):
def fn(a):
return map(lambda x: x[0] + x[1], zip([1, 2, 3], [1, 2, 3])), a + 1
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
m = opt_fn(torch.ones(3, 3))[0]
self.assertIsInstance(m, map)
self.assertEqual(list(m), list(fn(torch.ones(3, 3))[0]))
def test_zip_reconstruct(self):
def fn(a):
return zip([1, 2, 3], map(lambda x: x + 1, [1, 2, 3])), a + 1
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
m = opt_fn(torch.ones(3, 3))[0]
self.assertIsInstance(m, zip)
self.assertEqual(list(m), list(fn(torch.ones(3, 3))[0]))
@make_test
def test_map_partial_unpack(a, b):
y = 1
def f(x):
nonlocal y
y += 1
return x
l = list(zip([a, b], map(f, [1, 2, 3, 4])))
return a + y
@make_test
def test_map_call_function_ex(a, b):
def f(x, y):
return x + y
return f(*map(lambda x: x + 1, [a, b]))
@make_test
def test_map_unpack_twice(a, b):
m = map(lambda x: x + 1, [a, b])
l1 = list(m)
l2 = list(m)
return l1, l2
@make_test
def test_enumerate(a, b):
return list(enumerate([a, b], start=1)), a + 1
@make_test
def test_map_enumerate(a, b):
return list(enumerate(map(lambda x: x + 1, [a, b]), start=1)), a + 1
@make_test
def test_map_infinite(a, b):
return list(map(lambda x, y: x + y, [a, b], itertools.count(3)))
@make_test
def test_map_unpack_vars(a, b):
x, y = map(lambda x: x + 1, [a, b])
return x + y
def test_enumerate_custom(self):
class MyClass:
def __iter__(self):
self.a = 1
return self
def __next__(self):
if self.a > 3:
raise StopIteration
self.a += 1
return self.a
def fn(x):
for i, it in enumerate(MyClass()):
x += i + it
return x
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
self.assertEqual(fn(torch.ones(3, 3)), opt_fn(torch.ones(3, 3)))
def test_enumerate_reconstruct(self):
def fn(a, b):
return enumerate([a, b], start=1)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
inps = (torch.randn(3, 3), torch.randn(3, 3))
it1 = fn(*inps)
it2 = opt_fn(*inps)
self.assertIsInstance(it2, enumerate)
self.assertEqual(list(it1), list(it2))
def udf_mul(x, y):
return x * y
def udf_mul2(x, y, z):
return x * y * z
def udf_add(x, y):
return x + y
class SmallNN(torch.nn.Module):
def forward(self, x, y):
combined = torch.cat((x, y), dim=1)
out = torch.nn.ReLU()(combined)
out = torch.nn.ReLU()(out)
return out
def udf_module(mod, x, y):
return mod(x, y)
def global_func_with_default_tensor_args(
x=torch.zeros((2, 2)), *, kw_x=torch.zeros((1, 2))
):
x.add_(1)
kw_x.add_(1)
return x, kw_x
class ModuleWithDefaultTensorArgsMethod(torch.nn.Module):
def forward(self, x=torch.zeros((2, 2)), *, kw_x=torch.zeros((1, 2))):
x.add_(1)
kw_x.add_(1)
return x, kw_x
class WrapperModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.m = ModuleWithDefaultTensorArgsMethod()
def forward(self):
return self.m()
class DefaultsTests(torch._dynamo.test_case.TestCase):
def test_func_default_tensor_args(self):
"""
Tests that we indeed reference (and mutate) "the one" default tensor arg
stored on the globally allocated function object, both from the orig and
compiled function
"""
def func():
return global_func_with_default_tensor_args()
cnts = torch._dynamo.testing.CompileCounter()
compiled_func = torch.compile(func, backend=cnts)
for i in range(4):
if i % 2 == 0:
x, kw_x = func()
else:
x, kw_x = compiled_func()
# the inner func mutates += 1 each call
self.assertTrue(same(x, torch.ones_like(x) + i))
self.assertTrue(same(kw_x, torch.ones_like(kw_x) + i))
# Calling compiled_func twice does not recompile
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 2)
# But with a change to the guarded default tensor, we do recompile
with patch.object(
global_func_with_default_tensor_args,
"__defaults__",
(torch.ones((3, 4, 5)),),
):
x, kw_x = compiled_func()
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 4)
with patch.object(
global_func_with_default_tensor_args,
"__kwdefaults__",
{"kw_x": torch.ones((3, 4, 5))},
):
x, kw_x = compiled_func()
self.assertEqual(cnts.frame_count, 3)
self.assertEqual(cnts.op_count, 6)
def test_meth_default_tensor_args(self):
"""
Tests that we indeed reference (and mutate) "the one" default tensor arg
stored on the globally allocated function object, both from the orig and
compiled function
"""
mod = WrapperModule()
cnts = torch._dynamo.testing.CompileCounter()
compiled_mod = torch.compile(mod, backend=cnts)
for i in range(4):
if i % 2 == 0:
x, kw_x = mod()
else:
x, kw_x = compiled_mod()
# the inner func mutates += 1 each call
self.assertTrue(same(x, torch.ones_like(x) + i))
self.assertTrue(same(kw_x, torch.ones_like(kw_x) + i))
# Calling compiled_func twice does not recompile
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 2)
# But with a change to the guarded default tensor, we do recompile
with patch.object(
ModuleWithDefaultTensorArgsMethod.forward,
"__defaults__",
(torch.ones((3, 4, 5)),),
):
x, kw_x = compiled_mod()
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 4)
with patch.object(
ModuleWithDefaultTensorArgsMethod.forward,
"__kwdefaults__",
{"kw_x": torch.ones((3, 4, 5))},
):
x, kw_x = compiled_mod()
self.assertEqual(cnts.frame_count, 3)
self.assertEqual(cnts.op_count, 6)
def test_func_default_torch_args(self):
"""
Tests other types of torch types as function default (size, dtype, device)
"""
def func_with_default_torch_args(
dt=torch.float16, ds=torch.Size((1, 2, 3)), dd=torch.device("cpu")
):
return torch.ones(ds, dtype=dt, device=dd)
def func():
return func_with_default_torch_args()
cnts = torch._dynamo.testing.CompileCounter()
compiled_func = torch.compile(func, backend=cnts)
out = func()
compiled_out = compiled_func()
self.assertEqual(out.dtype, compiled_out.dtype)
self.assertEqual(out.device, compiled_out.device)
self.assertEqual(out.size(), compiled_out.size())
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 1)
def test_dataclass_factory(self):
@dataclass
class Output:
scalar: int = 2
named_tensors: Dict[str, torch.Tensor] = field(default_factory=dict)
lists: List[torch.Tensor] = field(default_factory=list)
def scale(self):
return self.scalar * 2
def fn(x):
# Check default dict assignment
a = Output(1)
# Check that dataclass methods can be inlined
scaled_value = a.scale()
# Check that normal assignment works
b = Output(5, named_tensors={"x": x})
# Check default int assignment
c = Output()
# Check that the default members are properly initialized
if isinstance(a.named_tensors, dict):
x = torch.sin(x)
# Change dataclass
c.scalar = 6
c.named_tensors["x"] = x
# Return dataclaass as well to check reconstruction
return c, torch.cos(x) * scaled_value + b.named_tensors["x"] + c.scalar
cnts = torch._dynamo.testing.CompileCounter()
compiled_fn = torch.compile(fn, backend=cnts, fullgraph=True)
x = torch.randn(4)
eager_dataclass, out = fn(x)
compiled_dataclass, compiled_out = compiled_fn(x)
self.assertEqual(eager_dataclass.scalar, compiled_dataclass.scalar)
self.assertEqual(
eager_dataclass.named_tensors["x"], compiled_dataclass.named_tensors["x"]
)
self.assertTrue(same(out, compiled_out))
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 5)
def test_dataclass_nested(self):
@dataclass
class Base:
outer_a: int
outer_b: int
@dataclass
class Derived(Base):
inner_a: Any = field(default_factory=list)
def fn(x):
l = Derived(1, 2)
return l.outer_a * x
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
res = fn(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
def test_listlike_of_tensors_contains_constant(self):
for listlike in [set, list]:
def fn(x):
x.add_(1)
s = listlike([x])
res = 1 in s
return res
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(1)
ref = opt_fn(x)
res = fn(x)
self.assertEqual(ref, res)
def test_cast_tensor_single_elem(self):
with torch._dynamo.config.patch({"capture_scalar_outputs": True}):
for t, val in [
(float, 1.0),
(float, 1),
(float, True),
(int, 1),
(int, False),
# (int, 1.0), # fails due to a >= 0 comparison in sym_int
]: # , bool, complex]: no casting for sym_bool, no sym_complex
def fn(x):
x = x + 1
return t(x)
opt_fn = torch.compile(
fn, backend="eager", fullgraph=True, dynamic=False
)
x = torch.tensor([val])
res = fn(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
# Cannot handle non single-elem
with self.assertRaises(ValueError):
fn(torch.tensor([val] * 2))
with self.assertRaises(torch._dynamo.exc.TorchRuntimeError):
opt_fn(torch.tensor([val] * 2))
def test_set_construction(self):
def fn(x):
y = x.add_(1)
s = set({x})
s.add(y)
return len(s)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
res = fn(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
def test_frozenset_construction(self):
def fn(x):
s = frozenset({x})
t = frozenset(s)
return len(t)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
res = fn(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
def test_frozenset_reconstruction(self):
d = {}
f = frozenset()
d[f] = torch.randn(4)
def fn(x):
k = frozenset()
torch._dynamo.graph_break()
return d[k] * x
opt_fn = torch.compile(fn, backend="eager")
x = torch.randn(4)
res = fn(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
def test_frozenset_illegal_call_method(self):
def fn_add():
s = frozenset((1, 2, 3))
s.add({2})
return len(s)
def fn_pop():
s = frozenset((1, 2, 3))
s.pop()
return len(s)
def fn_update():
s = frozenset((1, 2, 3))
s.update({4, 5, 6})
return len(s)
def fn_remove():
s = frozenset((1, 2, 3))
s.remove(2)
return len(s)
def fn_discard():
s = frozenset((1, 2, 3))
s.discard(2)
return len(s)
def fn_clear():
s = frozenset((1, 2, 3))
s.clear()
return len(s)
for fn in [fn_add, fn_pop, fn_update, fn_remove, fn_discard, fn_clear]:
torch._dynamo.reset()
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
with self.assertRaises(torch._dynamo.exc.InternalTorchDynamoError):
opt_fn()
def test_is_tensor_tensor(self):
def fn(x, y):
if x is y:
return x * 2
else:
return x + y
fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn)
x = torch.zeros(2)
y = torch.ones(2)
self.assertEqual(fn(x, y), fn_opt(x, y))
self.assertEqual(fn(x, x), fn_opt(x, x))
def test_is_not_tensor_tensor(self):
def fn(x, y):
if x is not y:
return x * 2
else:
return x + y
fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn)
x = torch.zeros(2)
y = torch.ones(2)
self.assertEqual(fn(x, y), fn_opt(x, y))
self.assertEqual(fn(x, x), fn_opt(x, x))
def test_is_mutated_tensor_tensor(self):
def fn(x):
y = x.add_(1)
return x is y
fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn)
z = torch.ones(4)
self.assertEqual(fn(z), fn_opt(z))
def test_is_mutated_tensor_tensor_across_graph_break(self):
def fn(x):
y = x.add_(1)
cond = x is y
x.add_(1)
# The real tensor values are recovered when graph breaking.
# Hence we recover the invariant.
torch._dynamo.graph_break()
x.add_(1)
return x is y, cond
fn_opt = torch.compile(backend="eager", dynamic=True)(fn)
z = torch.ones(4)
self.assertEqual(fn(z), fn_opt(z))
def test_is_mutated_tensor_tensor(self):
def fn(x):
y = x.add_(1)
return y is x
fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn)
z = torch.ones(4, 1)
self.assertEqual(fn(z), fn_opt(z))
def test_is_init_in_compile_mutated_tensor_tensor(self):
def fn(x):
z = x.clone()
y = z.add_(1)
return y is z
fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn)
z = torch.ones(4, 1)
self.assertEqual(fn(z), fn_opt(z))
def test_is_init_in_compile_vmapped_mutated_tensor_tensor(self):
def fn(z):
x = z.clone()
y = torch.vmap(torch.Tensor.acos_)(x)
_ = y is z
return y is x
fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn)
z = torch.ones(4, 1)
self.assertEqual(fn(z), fn_opt(z))
def test_is_vmapped_mutated_tensor_tensor(self):
def fn(x):
y = torch.vmap(torch.Tensor.acos_)(x)
return y is x
fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn)
z = torch.ones(4, 1)
self.assertEqual(fn(z), fn_opt(z))
def test_is_init_in_compile_vmapped_mutated_tensor_tensor_multi_arg(self):
def fn(y, z):
a = y.clone()
b = z.clone()
def g(a, b):
return a.acos_(), b.acos_()
c, d = torch.vmap(g)(a, b)
return a is c is b is d
fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn)
y = torch.ones(4, 2)
z = torch.ones(4, 10)
self.assertEqual(fn(y, z), fn_opt(y, z))
self.assertEqual(fn(y, y), fn_opt(y, y))
def test_in_set_would_fail_broadcast(self):
param = torch.zeros(5)
param2 = torch.zeros(5, 10)
tensor_list = set()
tensor_list.add(param2)
assert param not in tensor_list
def fn(param, param2):
param.add_(1)
tensor_list = set([param2])
return param in tensor_list
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
self.assertEqual(opt_fn(param, param2), fn(param, param2))
self.assertEqual(cnts.frame_count, 1)
# Test aliased
self.assertEqual(opt_fn(param, param), fn(param, param))
self.assertEqual(cnts.frame_count, 2) # Recompiles
def test_in_set_inplace(self):
param = torch.zeros(5)
param2 = torch.zeros(5, 10)
tensor_list = set()
tensor_list.add(param2)
assert param not in tensor_list
def fn(param, param2):
y = param.add_(1) # Tensor method
z = torch.Tensor.add_(y, 1) # torch function
tensor_list = set([param2])
return y in tensor_list and z in tensor_list
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
self.assertEqual(opt_fn(param, param2), fn(param, param2))
self.assertEqual(cnts.frame_count, 1)
# Test aliased
self.assertEqual(opt_fn(param, param), fn(param, param))
self.assertEqual(cnts.frame_count, 2) # Recompiles
def test_reconstructed_name(self):
lst = []
@torch._dynamo.disable
def disallowed(g):
lst.append(g.__name__)
def f():
def g():
return ()
disallowed(g)
f_opt = torch._dynamo
opt_f = torch._dynamo.optimize(backend="eager")(f)
opt_f()
f()
self.assertEqual(len(lst), 2)
self.assertEqual(lst[0], lst[1])
@unittest.skipIf(
sys.version_info < (3, 10),
"zip strict kwargs not implemented for Python < 3.10",
)
def test_zip_strict(self):
def fn(x, ys, zs):
x = x.clone()
for y, z in zip(ys, zs, strict=True):
x += y * z
return x
opt_fn = torch._dynamo.optimize(backend="eager")(fn)
nopython_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
x = torch.ones(3)
ys = [1.0, 2.0, 3.0]
zs = [2.0, 5.0, 8.0]
self.assertEqual(opt_fn(x, ys, zs), fn(x, ys, zs))
# If nopython, should raise UserError
with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"):
nopython_fn(x, ys[:1], zs)
with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"):
nopython_fn(x, ys, zs[:1])
# Should cause fallback if allow graph break
with self.assertRaisesRegex(ValueError, "zip()"):
opt_fn(x, ys[:1], zs)
with self.assertRaisesRegex(ValueError, "zip()"):
opt_fn(x, ys, zs[:1])
def test_fn_with_attr(self):
def fn(x):
if fn.pred:
return torch.relu(x * 2)
else:
return torch.abs(x + 3)
t = torch.ones(3)
counter = torch._dynamo.testing.CompileCounter()
fn.pred = True
opt_fn_0 = torch.compile(fullgraph=True, backend=counter)(fn)
self.assertEqual(opt_fn_0(t), fn(t))
self.assertEqual(counter.frame_count, 1)
fn.pred = False
opt_fn_1 = torch.compile(fullgraph=True, backend=counter)(fn)
self.assertEqual(opt_fn_1(t), fn(t))
self.assertEqual(counter.frame_count, 2)
def test_str_handler_for_user_defined_object(self):
"""
Confirms handler behaviour for `str` is the same between eager and dynamo.
Compares a user defined object with custom `__str__` method and without.
"""
class CustomStr:
def __str__(self):
return "ok"
def foo_custom_str(x):
a = CustomStr()
return x, str(a)
eager_custom_str = foo_custom_str(torch.ones(4))
dynamo_custom_str = torch.compile(foo_custom_str, fullgraph=True)(torch.ones(4))
self.assertEqual(eager_custom_str[1], dynamo_custom_str[1])
self.assertEqual(eager_custom_str[1], "ok")
class DefaultStr:
pass
def foo_default_str(x):
a = DefaultStr()
return x, str(a)
eager_default_str = foo_default_str(torch.ones(4))
dynamo_default_str = torch.compile(foo_default_str, fullgraph=True)(
torch.ones(4)
)
# Check that the tensor output from eager and dynamo modes are the same
self.assertEqual(eager_default_str[0], dynamo_default_str[0])
# Check that the class name (without memory address) is the same in both modes
eager_class_name = eager_default_str[1].split(" object at")[0]
dynamo_class_name = dynamo_default_str[1].split(" object at")[0]
self.assertEqual(eager_class_name, dynamo_class_name)
def test_pybind_object(self):
def fn(x, pybind_obj):
if pybind_obj.result:
return torch.cos(x)
return torch.sin(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
pybind_obj = torch._C._dynamo.guards.GuardDebugInfo(True, ["a==1"], 0)
x = torch.randn(4)
self.assertEqual(opt_fn(x, pybind_obj), fn(x, pybind_obj))
pybind_obj = torch._C._dynamo.guards.GuardDebugInfo(False, ["a==1"], 1)
x = torch.randn(4)
self.assertEqual(opt_fn(x, pybind_obj), fn(x, pybind_obj))
instantiate_parametrized_tests(FunctionTests)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()