blob: 8b1c1c109ddba4f3cba6bb9083b855cae3bb5e17 [file] [log] [blame]
# Owner(s): ["oncall: jit"]
import contextlib
import copy
import itertools
import math
import operator
import unittest
import numpy as np
import sympy
import torch
import torch.fx
import torch.nn.functional as F
from torch import sym_int, SymBool, SymFloat, SymInt
from torch._C import _disabled_torch_function_impl
from torch.fx.experimental import sym_node
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.sym_node import method_to_operator, SymNode, to_node
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
DimConstraints,
DimDynamic,
expect_true,
guard_bool,
guard_float,
guard_int,
GuardOnDataDependentSymNode,
hint_int,
is_symbolic,
ShapeEnv,
StatelessSymbolicContext,
statically_known_true,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
skipIfTorchDynamo,
TestCase,
)
from torch.utils import _pytree as pytree
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._sympy.functions import (
FloorDiv,
IsNonOverlappingAndDenseIndicator,
Mod,
)
aten = torch.ops.aten
meta_funcs = {}
def register_meta(op):
def decorator(f):
def add_func(op):
meta_funcs[op] = f
pytree.tree_map_(add_func, op)
return f
return decorator
@register_meta([aten.add.Tensor, aten.sub.Tensor])
def binary_meta(a, b):
return a.new_empty(a.shape)
@register_meta(aten.cat.default)
def cat_meta(tensors, dim=0):
concat_length = 0
shape = tensors[0].shape
for tensor in tensors:
for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)):
if idx == dim:
concat_length = concat_length + length
else:
assert length == common_length
new_shape = list(shape)
new_shape[dim] = concat_length
return tensors[0].new_empty(new_shape)
@register_meta([aten.narrow_copy.default])
def narrow_copy_symint_meta(a, dim, start, length, **kwargs):
shape = []
for i, x in enumerate(a.shape):
if i == dim:
shape.append(length)
else:
shape.append(x)
return a.new_empty(tuple(shape))
@register_meta([aten.expand.default])
def expand_symint_meta(a, size, implicit=False):
return a.new_empty(size)
def create_contiguous(shape):
strides = [1]
for dim in reversed(shape[:-1]):
strides.append(dim * strides[-1])
return list(reversed(strides))
class FakeSymbolicTensor(torch.Tensor):
@staticmethod
def __new__(
cls,
sym_shape,
sym_strides,
dtype,
layout,
requires_grad,
device,
storage_offset=0,
):
# TODO: this is wrong in general
sym_stride = create_contiguous(sym_shape)
r = torch.Tensor._make_wrapper_subclass(
cls,
sym_shape,
sym_stride,
storage_offset,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
device=device,
)
return r
__torch_function__ = _disabled_torch_function_impl
def new_empty(self, shape):
return FakeSymbolicTensor(
shape, None, self.dtype, self.layout, self.requires_grad, self.device
)
@classmethod
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
if func_overload in meta_funcs:
return meta_funcs[func_overload](*args, **kwargs)
if func_overload == torch.ops.aten.new_empty.default:
self = args[0]
shape = args[1]
return FakeSymbolicTensor(
shape,
self.stride(),
self.dtype,
self.layout,
self.requires_grad,
self.device,
)
raise RuntimeError(f"operator {func_overload} not supported")
def create_symbolic_tensor(name, arg, shape_env, source=None, dynamic_dims=None):
from torch._dynamo.source import ConstantSource
if source is None:
source = ConstantSource(name)
constraint_dims = [None] * arg.dim()
if dynamic_dims is None:
dynamic_dims = [DimDynamic.DUCK] * arg.dim()
(
sym_shapes,
sym_strides,
sym_storage_offset,
) = shape_env.create_symbolic_sizes_strides_storage_offset(
arg,
source=source,
symbolic_context=StatelessSymbolicContext(
dynamic_sizes=dynamic_dims, constraint_sizes=constraint_dims
),
)
return FakeSymbolicTensor(
sym_shapes,
sym_strides,
arg.dtype,
arg.layout,
arg.requires_grad,
arg.device,
sym_storage_offset,
)
def create_symtype(cls, pytype, shape_env, val, duck=True, **kwargs):
from torch._dynamo.source import ConstantSource
symbol = shape_env.create_symbol(
val,
source=ConstantSource(f"__testing_only{len(shape_env.var_to_val)}"),
dynamic_dim=DimDynamic.DUCK if duck else DimDynamic.DYNAMIC,
constraint_dim=None,
**kwargs,
)
return cls(SymNode(symbol, shape_env, pytype, hint=val))
# TODO: default duck to False
def create_symint(shape_env, i: int, duck=True, **kwargs) -> SymInt:
return create_symtype(SymInt, int, shape_env, i, duck=duck, **kwargs)
def create_symbool(shape_env, b: bool) -> SymBool:
return create_symtype(SymBool, bool, shape_env, b)
def create_symfloat(shape_env, f: float) -> SymFloat:
return create_symtype(SymFloat, float, shape_env, f)
@skipIfTorchDynamo(
"Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)"
)
class TestPySymInt(TestCase):
def test_arith_ops(self):
shape_env = ShapeEnv()
symints = []
for i in range(2, 5):
symints.append((i, create_symint(shape_env, i)))
ops = [
operator.add,
operator.sub,
operator.floordiv,
operator.mul,
operator.mod,
]
for op in ops:
for args in itertools.permutations(symints, 2):
if not isinstance(args[0][1], int) and (
(op != operator.mod or op != operator.floordiv) and args[1][0] != 0
):
self.assertTrue(
op(args[0][1], args[1][1]) == op(args[0][0], args[1][0])
)
def test_reverse_arith_ops(self):
shape_env = ShapeEnv()
a = create_symint(shape_env, 2)
self.assertTrue(5 // a == 5 // 2)
a = create_symint(shape_env, 2)
self.assertTrue(5 * a == 5 * 2)
def test_sympify_symint(self):
shape_env = ShapeEnv()
a = create_symint(shape_env, 2)
self.assertIs(sympy.sympify(a), a.node.expr)
b = create_symfloat(shape_env, 3.0)
self.assertIs(sympy.sympify(b), b.node.expr)
c = create_symbool(shape_env, True)
self.assertIs(sympy.sympify(c), c.node.expr)
def test_roundtrip(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
self.assertTrue(not isinstance(x.shape[0], SymNode))
self.assertTrue(isinstance(x.shape[0], SymInt))
self.assertTrue(x.shape[0] == 5)
self.assertTrue(x.shape[1] == 4)
self.assertTrue(x.shape[2], 3)
self.assertTrue(x.size()[0], 5)
self.assertTrue(x.size()[1], 4)
# Should be simplifiable to an integer.
# Ref: https://github.com/pytorch/pytorch/pull/107492
self.assertTrue(isinstance(x.size()[1], SymInt))
self.assertTrue(
isinstance(x.size()[1].node.maybe_as_int(), int)
) # due to guard above
self.assertTrue(x.size()[2] == 3)
self.assertTrue(x.size(0) == 5)
self.assertTrue(x.size(1) == 4)
self.assertTrue(x.size(2) == 3)
self.assertTrue(isinstance(x.size(2), SymInt))
self.assertTrue(isinstance(x.size(2).node.maybe_as_int(), int))
y = create_symbolic_tensor("y", torch.randn(5, 4, 3)[1:], shape_env)
self.assertTrue(isinstance(y.storage_offset(), SymInt))
self.assertTrue(y.storage_offset() == 12)
def test_binary(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env)
z = x + y
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# broadcasting
y = create_symbolic_tensor("y2", torch.randn(1, 4, 1), shape_env)
z = x + y
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
def test_symint_args(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env)
LAST_DIM = 2
z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM])
self.assertTrue(z.shape[2] == y.shape[2])
# arithmetic expr with two symints
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM])
self.assertTrue(z.shape[2] == 2)
# arithmetic expr with a symint and python int
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1)
self.assertTrue(z.shape[2] == 2)
def test_symint_vargs(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
# varargs
z = y.expand(x.shape[0], y.shape[1], x.shape[2])
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# shape list
z = y.expand((x.shape[0], y.shape[1], x.shape[2]))
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python symints and ints
z = y.expand(x.shape[0], y.shape[1], 3)
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python symints and ints in a list
z = y.expand((x.shape[0], y.shape[1], 3))
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python symints and ints
z = y.expand(5, y.shape[1], x.shape[2])
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python ints and symints in a list
z = y.expand((5, y.shape[1], x.shape[2]))
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
z = y.expand((y.shape[1],))
z = y.expand(y.shape[1])
def test_stride(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env)
self.assertIsInstance(x.stride()[0], SymInt)
def test_size_expressions(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
expand_x = x.expand(x.shape[0], x.shape[0])
if expand_x.shape[0] > 3:
result = expand_x + expand_x
else:
result = expand_x + expand_x
gt_op, _bt = shape_env.guards[-1]
self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan))
self.assertTrue(str(x.shape[0]), str(gt_op.args[0]))
self.assertTrue(str(expand_x.shape[1]), str(x.shape[0]))
self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))
def test_floordiv_static(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 8)
# This was extracted from
# python test/inductor/test_cuda_cpp_wrapper.py -k
# DynamicShapesCudaWrapperCudaTests.test_insignificant_strides_cuda_dynamic_shapes_cuda_wrapper
bool(s0 % 2 == 0)
bool(s0 % (s0 // 2) == 0)
bool(2 * (s0 // 2) == s0)
self.assertTrue(statically_known_true(s0 // (s0 // 2) == 2))
def test_numel(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
self.assertIsInstance(x.numel(), torch.SymInt)
self.assertIsInstance(torch.numel(x), torch.SymInt)
x = torch.rand(3, 3)
self.assertIsInstance(x.numel(), int)
self.assertIsInstance(torch.numel(x), int)
def test_int_to_float(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
r = torch.sym_float(x.shape[0])
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
def test_aten_ops(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
torch.ops.aten.narrow_copy.default(x, 0, 0, x.shape[0])
shape_env = ShapeEnv()
x = create_symbolic_tensor("x2", torch.randn(5, 4, 3), shape_env)
torch.ops.aten.expand.default(x, [x.shape[0], x.shape[1], x.shape[2]])
def test_fx_trace_intlist(self):
class CustomModule(torch.nn.Module):
def forward(self, x):
bs, c, h, w = x.shape
return F.pad(x, (0, w % 2, 0, h % 2, 0, 0))
m = CustomModule()
x = torch.rand(1, 3, 4, 4)
# should not TypeError: pad(): argument 'pad' (position 2) must be
# tuple of ints, not tuple
torch.fx.symbolic_trace(m)
def test_meta_symint(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
r = torch.empty(a0, device="meta")
self.assertIsInstance(r.shape[0], SymInt)
def test_guard_int(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
self.assertEqual(guard_int(a0), 2)
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
def test_prefer_deferred_runtime_assertions_over_guards(self):
shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True)
s0 = create_symint(shape_env, 2)
self.assertEqual(guard_int(s0), 2)
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True)
s0 = create_symint(shape_env, 2)
self.assertTrue(expect_true(s0 == 2))
self.assertEqual(len(shape_env.guards), 0)
self.assertExpectedInline(
str([ra.expr for ra in shape_env.deferred_runtime_asserts[None]]),
"""[Eq(s0, 2)]""",
)
def test_sym_int(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = sym_int(a0)
self.assertEqual(r, 5)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 5)""")
a1 = create_symint(shape_env, 7)
r = sym_int(a1 / 2)
self.assertEqual(guard_int(r), 3)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s1, 2)), 3)"""
)
a3 = create_symint(shape_env, 3)
r = sym_int(2.0 * torch.sym_float(a3))
self.assertEqual(guard_int(r), 6)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s2)), 6)"""
)
def test_sym_sqrt(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 4)
r = torch._sym_sqrt(a0)
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(s0), 2.0)"""
)
def test_sym_floor(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = math.floor(a0 / 2)
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[0][0]),
"""Eq(FloorToInt(IntTrueDiv(s0, 2)), 2)""",
)
r = math.floor(3.0 * a0)
self.assertEqual(r, 15)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[1][0]),
"""Eq(FloorToInt(3.0*ToFloat(s0)), 15)""",
)
def test_sym_trunc(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = math.trunc(a0 / 2)
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s0, 2)), 2)"""
)
r = torch.sym_int(torch.sym_sqrt(a0))
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[1][0]), """Eq(TruncToInt(OpaqueUnaryFn_sqrt(s0)), 2)"""
)
def test_sym_ceil(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = math.ceil(a0 / 2)
self.assertEqual(r, 3)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[0][0]),
"""Eq(CeilToInt(IntTrueDiv(s0, 2)), 3)""",
)
r1 = 3.0 * a0
r = math.floor(r1)
self.assertEqual(r, 15)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[1][0]),
"""Eq(FloorToInt(3.0*ToFloat(s0)), 15)""",
)
def test_sym_ite(self):
shape_env = ShapeEnv()
t = create_symint(shape_env, 5)
f = create_symint(shape_env, 4)
b1 = True
r1 = torch.sym_ite(b1, t, f)
self.assertTrue(r1 is t)
b2 = False
r2 = torch.sym_ite(b2, t, f)
self.assertTrue(r2 is f)
b3 = t == 5
r3 = torch.sym_ite(b3, t, f)
self.assertEqual(len(shape_env.guards), 0)
self.assertEqual(r3, 5)
self.assertEqual(type(t), type(r3))
self.assertExpectedInline(
str(shape_env.guards[0][0]),
"""Eq(Piecewise((s0, Eq(s0, 5)), (s1, True)), 5)""",
)
b4 = f == 5
r4 = torch.sym_ite(b4, t, f)
self.assertEqual(len(shape_env.guards), 1)
self.assertEqual(r4, 4)
self.assertEqual(type(f), type(r4))
self.assertExpectedInline(
str(shape_env.guards[1][0]),
"""Eq(Piecewise((s0, Eq(s1, 5)), (s1, True)), 4)""",
)
def test_tracing_sym_ite(self):
def f(x):
b = x.shape[0] == 5
ret = torch.sym_ite(b, x.shape[0], x.shape[1])
return ret
gm = make_fx(f, tracing_mode="symbolic")(torch.ones(4, 5))
self.assertEqual(len(gm.shape_env.guards), 0)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
eq = sym_size_int == 5
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1); x_1 = None
sym_ite = torch.sym_ite(eq, sym_size_int, sym_size_int_1); eq = sym_size_int = sym_size_int_1 = None
return sym_ite""",
)
r1 = gm(torch.ones(4, 5))
self.assertIsInstance(r1, int)
self.assertEqual(r1, 5)
r2 = gm(torch.ones(5, 4))
self.assertIsInstance(r2, int)
self.assertEqual(r2, 5)
def test_int_conversion(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
int(a0)
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
def test_data_dependent_guard(self):
shape_env = ShapeEnv()
s0 = shape_env.create_unbacked_symint()
self.assertRaises(GuardOnDataDependentSymNode, lambda: bool(s0 == 0))
def test_data_dependent_guard_propagate_real_tensors(self):
shape_env = ShapeEnv()
s0 = shape_env.create_unbacked_symint()
shape_env.set_unbacked_var_to_val(s0.node.expr, 0)
self.assertEqual(bool(s0 == 0), True)
def test_expect_true_basic(self):
shape_env = ShapeEnv()
i0 = shape_env.create_unbacked_symint()
i0_sym = i0.node.expr
# This doesn't error
self.assertTrue(expect_true(i0 == 0))
# This generates a deferred runtime assert via replacement
self.assertEqual(shape_env.replacements[i0_sym], 0)
# After expecting true, guards now resolve given the runtime assert
bool(i0 == 0)
def test_expect_true_with_s0(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 5)
i0 = shape_env.create_unbacked_symint()
self.assertTrue(expect_true(i0 < s0))
self.assertExpectedInline(
str([ra.expr for ra in shape_env.deferred_runtime_asserts[i0.node.expr]]),
"""[u0 < s0]""",
)
self.assertTrue(i0 < s0)
self.assertTrue(i0 != s0)
self.assertFalse(i0 > s0)
self.assertFalse(i0 >= s0)
def test_expect_true_prefer_later(self):
shape_env = ShapeEnv()
i0 = shape_env.create_unbacked_symint()
i1 = shape_env.create_unbacked_symint()
i1_sym = i1.node.expr
self.assertTrue(expect_true(i0 + i1 == 10))
# Importantly, this is put in i1, not i0!
self.assertExpectedInline(
str([ra.expr for ra in shape_env.deferred_runtime_asserts[i1_sym]]),
"""[Eq(u0 + u1, 10)]""",
)
self.assertTrue(i0 + i1 == 10)
# NB: We currently don't support deriving that we can substitute
# i0 + i1 with 10; maybe we should, but this means our rewriting
# system is no longer confluent (it's probably OK though, because
# you're unlikely to get other equalities like this on the
# unbacked SymInts.)
def test_unbacked_substitution(self):
shape_env = ShapeEnv()
i0 = shape_env.create_unbacked_symint()
i1 = shape_env.create_unbacked_symint()
_constrain_range_for_size(i0)
_constrain_range_for_size(i1)
self.assertTrue(expect_true(i0 == i1 * 4))
self.assertExpectedInline(str(i0), """u0""")
i2 = shape_env.create_unbacked_symint()
i3 = shape_env.create_unbacked_symint()
_constrain_range_for_size(i2)
_constrain_range_for_size(i3)
self.assertTrue(expect_true(i2 * 4 == i3))
self.assertExpectedInline(str(i3), """u3""")
def test_avoid_unbacked_substitution(self):
shape_env = ShapeEnv()
i0 = shape_env.create_unbacked_symint()
_constrain_range_for_size(i0)
i1 = shape_env.create_unbacked_symint()
_constrain_range_for_size(i1)
self.assertTrue(expect_true(i0 == 10 - i1))
self.assertExpectedInline(str(i0), """u0""")
def test_expect_true_double_digits(self):
shape_env = ShapeEnv()
ia = [shape_env.create_unbacked_symint() for _ in range(11)] # allocate 10
self.assertEqual(str(ia[-1]), "u10")
self.assertTrue(expect_true(sum(ia) == 20))
self.assertEqual(len(shape_env.deferred_runtime_asserts[ia[-1].node.expr]), 1)
def test_expect_true_refine_range(self):
shape_env = ShapeEnv()
for i, rel in enumerate(
[lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]
):
with self.subTest(f"i = {i}"):
i0 = shape_env.create_unbacked_symint()
self.assertTrue(expect_true(rel(i0)))
self.assertTrue(statically_known_true(i0 != 3))
self.assertTrue(statically_known_true(i0 != 4))
self.assertFalse(statically_known_true(i0 != 5))
self.assertFalse(statically_known_true(i0 != 6))
self.assertTrue(statically_known_true(i0 > 4))
self.assertTrue(statically_known_true(i0 >= 5))
for i, rel in enumerate(
[lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]
):
with self.subTest(f"i = {i}"):
i0 = shape_env.create_unbacked_symint()
self.assertTrue(expect_true(rel(i0)))
self.assertFalse(statically_known_true(i0 != 2))
self.assertFalse(statically_known_true(i0 != 3))
self.assertTrue(statically_known_true(i0 != 4))
self.assertTrue(statically_known_true(i0 != 5))
self.assertTrue(statically_known_true(i0 < 4))
self.assertTrue(statically_known_true(i0 <= 5))
def test_guard_refine_range(self):
shape_env = ShapeEnv()
for i, rel in enumerate(
[lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]
):
with self.subTest(f"i = {i}"):
i0 = create_symint(shape_env, 10, duck=False)
self.assertTrue(bool(rel(i0)))
self.assertTrue(statically_known_true(i0 != 3))
self.assertTrue(statically_known_true(i0 != 4))
self.assertFalse(statically_known_true(i0 != 5))
self.assertFalse(statically_known_true(i0 != 6))
self.assertTrue(statically_known_true(i0 > 4))
self.assertTrue(statically_known_true(i0 >= 5))
for i, rel in enumerate(
[lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]
):
with self.subTest(f"i = {i}"):
i0 = create_symint(shape_env, 2, duck=False)
self.assertFalse(bool(rel(i0)))
self.assertFalse(statically_known_true(i0 != 3))
self.assertFalse(statically_known_true(i0 != 4))
self.assertTrue(statically_known_true(i0 != 5))
self.assertTrue(statically_known_true(i0 != 6))
self.assertTrue(statically_known_true(i0 <= 4))
self.assertTrue(statically_known_true(i0 < 5))
for i, rel in enumerate(
[lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]
):
with self.subTest(f"i = {i}"):
i0 = create_symint(shape_env, 2, duck=False)
self.assertTrue(bool(rel(i0)))
self.assertFalse(statically_known_true(i0 != 2))
self.assertFalse(statically_known_true(i0 != 3))
self.assertTrue(statically_known_true(i0 != 4))
self.assertTrue(statically_known_true(i0 != 5))
self.assertTrue(statically_known_true(i0 < 4))
self.assertTrue(statically_known_true(i0 <= 3))
for i, rel in enumerate(
[lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]
):
with self.subTest(f"i = {i}"):
i0 = create_symint(shape_env, 10, duck=False)
self.assertFalse(bool(rel(i0)))
self.assertTrue(statically_known_true(i0 != 2))
self.assertTrue(statically_known_true(i0 != 3))
self.assertFalse(statically_known_true(i0 != 4))
self.assertFalse(statically_known_true(i0 != 5))
self.assertTrue(statically_known_true(i0 >= 4))
self.assertTrue(statically_known_true(i0 > 3))
def test_mul_int_oo_nan(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 5, duck=False)
s1 = create_symint(shape_env, 6, duck=False)
s2 = create_symint(shape_env, 5, duck=False)
bool(s0 * (s1 // s0) == s2)
def test_non_overlapping_and_dense(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = torch.empty_strided((a0, 7), (1, a0), device="meta")
self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(r))
def test_non_overlapping_and_dense_unbacked(self):
shape_env = ShapeEnv()
u0 = shape_env.create_unbacked_symint()
torch._check_is_size(u0)
cf = torch.ops.aten.is_non_overlapping_and_dense.default
self.assertEqual(IsNonOverlappingAndDenseIndicator(u0.node.expr, 2, 2, 1), 1)
self.assertEqual(IsNonOverlappingAndDenseIndicator(2, u0.node.expr, 1, 2), 1)
self.assertTrue(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")))
self.assertTrue(cf(torch.empty_strided((2, u0), (1, 2), device="meta")))
self.assertEqual(IsNonOverlappingAndDenseIndicator(u0.node.expr, 1), 1)
self.assertEqual(IsNonOverlappingAndDenseIndicator(1, u0.node.expr), 1)
self.assertTrue(cf(torch.empty_strided((u0,), (1,), device="meta")))
self.assertTrue(cf(torch.empty_strided((1,), (u0,), device="meta")))
Max = torch.sym_max
# NB: This only works because we're able to determine this tensor is
# contiguous. transpose(0, 1) makes it stop working
self.assertTrue(
cf(
torch.empty_strided(
(2, 3, 1, u0),
(3 * Max(1, u0), Max(1, u0), Max(1, u0), 1),
device="meta",
)
)
)
def test_numpy_sym_max(self):
self.assertEqual(torch.sym_max(np.int64(10), 12), 12)
self.assertEqual(torch.sym_max(np.int64(12), 10), 12)
self.assertEqual(torch.sym_max(np.int64(10), 12.5), 12.5)
self.assertEqual(torch.sym_max(np.int64(14), 12.5), 14.0)
self.assertEqual(torch.sym_max(np.float64(14.0), 12), 14.0)
self.assertEqual(torch.sym_max(np.float64(14.0), 16), 16.0)
def test_numpy_sym_min(self):
self.assertEqual(torch.sym_min(np.int64(10), 12), 10)
self.assertEqual(torch.sym_min(np.int64(12), 10), 10)
self.assertEqual(torch.sym_min(np.int64(10), 12.5), 10.0)
self.assertEqual(torch.sym_min(np.int64(14), 12.5), 12.5)
self.assertEqual(torch.sym_min(np.float64(14.0), 12), 12.0)
self.assertEqual(torch.sym_min(np.float64(14.0), 16), 14.0)
def test_debug_has_internal_overlap_unbacked(self):
shape_env = ShapeEnv()
u0 = shape_env.create_unbacked_symint()
torch._check_is_size(u0)
cf = torch._debug_has_internal_overlap
self.assertEqual(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")), 0)
self.assertEqual(cf(torch.empty_strided((2, u0), (1, 2), device="meta")), 0)
self.assertEqual(cf(torch.empty_strided((u0,), (1,), device="meta")), 0)
self.assertEqual(cf(torch.empty_strided((1,), (u0,), device="meta")), 0)
Max = torch.sym_max
self.assertEqual(
cf(
torch.empty_strided(
(2, 3, 1, u0),
(3 * Max(1, u0), Max(1, u0), Max(1, u0), 1),
device="meta",
)
),
0,
)
# Wobbling these to zero is OK too
self.assertEqual(cf(torch.empty_strided((u0, 2), (3, 1), device="meta")), 2)
self.assertEqual(cf(torch.empty_strided((2, u0), (1, 3), device="meta")), 2)
def test_specialize_zero_one(self):
shape_env = ShapeEnv(specialize_zero_one=True)
a0 = create_symint(shape_env, 5)
assert a0 != 1
self.assertEqual(len(shape_env.guards), 0)
shape_env = ShapeEnv(specialize_zero_one=False)
a0 = create_symint(shape_env, 5)
assert a0 != 1
self.assertEqual(len(shape_env.guards), 1)
def test_duck_shape(self):
shape_env = ShapeEnv(duck_shape=True)
a0 = create_symint(shape_env, 5)
a1 = create_symint(shape_env, 5)
assert a0 == a1
self.assertEqual(len(shape_env.guards), 0)
shape_env = ShapeEnv(duck_shape=False)
a0 = create_symint(shape_env, 5)
a1 = create_symint(shape_env, 5)
assert a0 == a1
self.assertEqual(len(shape_env.guards), 1)
def test_int_bool(self):
# See https://github.com/pytorch/pytorch/issues/95981
shape_env = ShapeEnv(duck_shape=True)
a0 = create_symint(shape_env, 5)
assert a0
self.assertEqual(len(shape_env.guards), 0)
def test_symint_as_scalar(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
sym_int_encountered = False
class TestSymInt(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
assert func == torch.ops.aten.add.Tensor
nonlocal sym_int_encountered
# WARNING: do not do identity tests on the outer
# SymInt/SymFloat, they are NOT STABLE
sym_int_encountered = kwargs["alpha"].node is a0.node
kwargs["alpha"] = 0
return func(*args)
x = torch.rand([4, 4])
with TestSymInt():
y = torch.add(x, x, alpha=a0)
self.assertTrue(sym_int_encountered)
def test_deepcopy(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
assert a0 < 4
new_shape_env = copy.deepcopy(shape_env)
self.assertEqual(len(new_shape_env.guards), 1)
def test_print_readable_with_symints(self):
def f(a, b):
dim0 = a.shape[0] + b.shape[0]
dim1 = a.shape[1] + b.shape[1]
d = a.new_empty(dim0, dim1)
d = torch.ops.aten.native_dropout(d, 0.5, train=True)
return d
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3))
out = fx_g.print_readable(print_output=False)
self.assertExpectedInline(
out.strip(),
"""\
class f(torch.nn.Module):
def forward(self, a_1: "f32[s0, s1]", b_1: "f32[s2, s1]"):
# No stacktrace found for following nodes
sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(a_1, 0)
sym_size_int_1: "Sym(s2)" = torch.ops.aten.sym_size.int(b_1, 0)
add: "Sym(s0 + s2)" = sym_size_int + sym_size_int_1; sym_size_int = sym_size_int_1 = None
sym_size_int_2: "Sym(s1)" = torch.ops.aten.sym_size.int(a_1, 1)
sym_size_int_3: "Sym(s1)" = torch.ops.aten.sym_size.int(b_1, 1); b_1 = None
add_1: "Sym(2*s1)" = sym_size_int_2 + sym_size_int_3; sym_size_int_2 = sym_size_int_3 = None
new_empty: "f32[s0 + s2, 2*s1]" = torch.ops.aten.new_empty.default(a_1, [add, add_1], pin_memory = False); a_1 = add = add_1 = None
native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None
getitem: "f32[s0 + s2, 2*s1]" = native_dropout[0]
getitem_1: "b8[s0 + s2, 2*s1]" = native_dropout[1]; native_dropout = None
return (getitem, getitem_1)""", # noqa: B950
)
def test_statically_known_true(self):
shape_env = ShapeEnv()
s2, s3, s4 = (create_symint(shape_env, i) for i in range(2, 5))
# Statically known true
self.assertTrue(statically_known_true(True))
self.assertTrue(statically_known_true(s2 == s2))
self.assertTrue(statically_known_true(s2 * s3 > s3))
self.assertTrue(statically_known_true(s3 * s4 > s4))
self.assertTrue(statically_known_true((s3 + s3) % 2 == 0))
# Statically known false
self.assertFalse(statically_known_true(False))
self.assertFalse(statically_known_true(s3 * s4 <= s4))
self.assertFalse(statically_known_true((s3 + s3) % 2 == 1))
# True for hints, but not known statically
self.assertFalse(statically_known_true(s2 + s2 == s4))
self.assertFalse(statically_known_true(s4 % s2 == 0))
self.assertFalse(statically_known_true(s2 != s3))
self.assertFalse(statically_known_true(s3 * s4 > s2))
# False for hints, but not known statically
self.assertFalse(statically_known_true(s2 == s3))
self.assertFalse(statically_known_true(s2 > s3))
self.assertFalse(statically_known_true(s3 + s3 == s4))
# No guards should be generated
self.assertEqual(len(shape_env.guards), 0)
def test_ephemeral_source_simplification(self):
from torch._dynamo.source import EphemeralSource
# For full robustness, ensure the ephemeral source symbols are simplified out regardless
# of construction order or check order.
for construct_ephemeral_first, x_first_in_check in itertools.product(
[False, True], [False, True]
):
shape_env = ShapeEnv()
shape = (5, 10)
dynamic_dims = [DimDynamic.DYNAMIC for _ in shape]
x = create_symbolic_tensor(
"x",
torch.randn(*shape),
shape_env,
source=(EphemeralSource() if construct_ephemeral_first else None),
dynamic_dims=dynamic_dims,
)
y = create_symbolic_tensor(
"y",
torch.randn(*shape),
shape_env,
source=(EphemeralSource() if not construct_ephemeral_first else None),
dynamic_dims=dynamic_dims,
)
t_with_ephemeral = x if construct_ephemeral_first else y
def _get_ephemeral_source_symbols(t):
return [
s.node.expr
for s in itertools.chain(t.shape, t.stride(), (t.storage_offset(),))
if isinstance(s, torch.SymInt)
and s.node.expr in shape_env.var_to_sources
and any(
source.is_ephemeral()
for source in shape_env.var_to_sources[s.node.expr]
)
]
# these checks should simplify out the ephemeral symbols, regardless of the
# ordering x == y or y == x
self.assertTrue(len(_get_ephemeral_source_symbols(t_with_ephemeral)) > 0)
if x_first_in_check:
torch._check(x.size() == y.size())
torch._check(x.stride() == y.stride())
torch._check(x.storage_offset() == y.storage_offset())
else:
torch._check(y.size() == x.size())
torch._check(y.stride() == x.stride())
torch._check(y.storage_offset() == x.storage_offset())
self.assertEqual(len(_get_ephemeral_source_symbols(t_with_ephemeral)), 0)
def test_ephemeral_source_unified_with_non_ephemeral_source(self):
from torch._dynamo.source import EphemeralSource
for construct_ephemeral_first in (False, True):
shape_env = ShapeEnv()
shape = (5, 10)
# use duck sizing here to ensure symbol reuse across x and y
duck_dims = [DimDynamic.DUCK for _ in shape]
x = create_symbolic_tensor(
"x",
torch.randn(*shape),
shape_env,
source=(EphemeralSource() if construct_ephemeral_first else None),
dynamic_dims=duck_dims,
)
y = create_symbolic_tensor(
"y",
torch.randn(*shape),
shape_env,
source=(EphemeralSource() if not construct_ephemeral_first else None),
dynamic_dims=duck_dims,
)
# regardless of construction order, non-ephemeral sources should be preferred
# first in the var_to_sources list for potential guarding later on
for source_list in shape_env.var_to_sources.values():
self.assertFalse(source_list[0].is_ephemeral())
self.assertEqual(x.size(), y.size())
self.assertEqual(x.stride(), y.stride())
self.assertEqual(x.storage_offset(), y.storage_offset())
@skipIfTorchDynamo(
"Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)"
)
class TestSymNumberMagicMethods(TestCase):
def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn):
with self.subTest(fn=fn, inp1=inp1, inp2=inp2, is_unary_fn=is_unary_fn):
return self._do_test2(fn, inp1, inp2, shape_env, is_unary_fn)
def _do_test2(self, fn, inp1, inp2, shape_env, is_unary_fn):
# Helper function
# NB: don't use one as that will get specialized
# TODO: We don't have to circuitously create the float, can just
# create a symfloat directly
seed_node = (create_symint(shape_env, 2) / 2.0).node
bool_seed_node = (create_symint(shape_env, 2) == 2).node
def get_sym_inp(inp):
# NB: this must come before int
if isinstance(inp, bool):
return torch.SymBool(to_node(bool_seed_node, inp))
elif isinstance(inp, int):
return torch.SymInt(to_node(seed_node, inp))
else:
return torch.SymFloat(to_node(seed_node, inp))
if fn == "float_pow":
if inp1 < 0:
return
if fn == "pow_by_natural":
if isinstance(inp1, float) or isinstance(inp2, float):
return
if inp2 < 0:
return
def maybe_xfail(inp1, inp2):
if fn == "sym_sqrt" and inp1 < 0:
# ValueError: math domain error
return self.assertRaises((ValueError,))
elif (
fn in ("float_truediv", "int_truediv", "int_floordiv", "mod")
and inp2 == 0
):
# ZeroDivisionError: division by zero
return self.assertRaises((ZeroDivisionError,))
elif fn in ["float_pow", "pow_by_natural"] and inp1 == 0 and inp2 < 0:
# ZeroDivisionError: 0.0 cannot be raised to a negative power
return self.assertRaises((ZeroDivisionError,))
elif (
# TODO: dear catastrophe waitress,
# this doesn't work
fn in ["float_pow", "pow_by_natural"]
and inp1 < 0
and (
type(inp1) is (SymInt, SymFloat) or type(inp2) is (SymInt, SymFloat)
)
and (type(inp1) is (SymFloat, float) or type(inp2) is (SymFloat, float))
):
# Complex result, which we do not support:
# TypeError: Cannot convert complex to float
return self.assertRaises((RuntimeError,))
elif fn in ("lshift", "rshift") and not (
isinstance(inp1, (SymInt, int)) and isinstance(inp2, (SymInt, int))
):
# TypeError: unsupported operand type(s)
return self.assertRaises((TypeError,))
elif fn in ("lshift", "rshift") and inp2 < 0:
# ValueError: math domain error
return self.assertRaises((ValueError,))
else:
return contextlib.nullcontext()
lambda_apply = method_to_operator(fn)
def guard_fn(v):
if type(v) in (SymBool, bool):
return guard_bool(v)
elif type(v) in (SymFloat, float):
return guard_float(v)
else: # SymInt, int
return guard_int(v)
# Get reference result
with maybe_xfail(inp1, inp2):
if is_unary_fn:
ref_out = lambda_apply(inp1)
else:
ref_out = lambda_apply(inp1, inp2)
# Symified first arg
sym_inp1 = get_sym_inp(inp1)
with maybe_xfail(sym_inp1, inp2):
if is_unary_fn:
out = lambda_apply(sym_inp1)
else:
out = lambda_apply(sym_inp1, inp2)
if fn not in sym_node.alternate_impl_if_hinted_methods:
self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
out = guard_fn(out)
self.assertEqual(out, ref_out)
if is_unary_fn:
return
# Symified second arg
sym_inp2 = get_sym_inp(inp2)
with maybe_xfail(inp1, sym_inp2):
out = lambda_apply(inp1, sym_inp2)
if fn not in sym_node.alternate_impl_if_hinted_methods:
self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
out = guard_fn(out)
self.assertEqual(out, ref_out)
# Symified both args
with maybe_xfail(sym_inp1, sym_inp2):
out = lambda_apply(sym_inp1, sym_inp2)
if fn not in sym_node.alternate_impl_if_hinted_methods:
self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
out = guard_fn(out)
self.assertEqual(out, ref_out)
@parametrize("fn", list(sym_node.magic_methods.keys()))
def test_bool_method(self, fn):
# sym_ite has its own tests
if fn not in sym_node.bool_magic_methods or fn == "sym_ite":
self.skipTest(f"{fn} is non-bool")
is_unary_fn = fn in sym_node.unary_methods
shape_env = ShapeEnv()
self._do_test(fn, True, False, shape_env, is_unary_fn)
@parametrize("fn", list(sym_node.magic_methods.keys()))
@parametrize("first_type", ["int", "float"])
@parametrize("second_type", ["int", "float"])
def test_method(self, fn, first_type, second_type):
if first_type == "float":
# TODO: Hmm, this looks like we skip all floats
self.skipTest(f"{fn} is not a float magic method")
if (
first_type == "int" or second_type == "int"
) and fn in sym_node.only_float_magic_methods:
self.skipTest(f"{fn} is not an int method")
if second_type == "float" and fn in ["mod"]:
self.skipTest(f"{fn} only handles int")
is_unary_fn = fn in sym_node.unary_methods or fn == "round"
# Second argument is ignored for unary function. So only run for one type
if is_unary_fn and second_type == "float":
self.skipTest(f"{fn} is unary and already tested")
if fn in sym_node.bool_magic_methods:
self.skipTest(f"{fn} is bool")
# Only floats here since these will be converted to int if necessary.
# We also ignore complex and bool.
values = (
0.0,
1.0,
0.5 if fn in ("sym_acos", "sym_asin") else 2.5, # avoid math domain error
)
neg_values = tuple(-x for x in values)
for inp1, inp2 in itertools.chain(
itertools.product(values, values),
itertools.product(values, neg_values),
itertools.product(neg_values, values),
itertools.product(neg_values, neg_values),
):
if first_type == "int":
inp1 = int(inp1)
if second_type == "int":
inp2 = int(inp2)
shape_env = ShapeEnv()
self._do_test(fn, inp1, inp2, shape_env, is_unary_fn)
def get_constant_bool(self, val):
return SymBool(torch._C._get_constant_bool_symnode(val))
@unittest.expectedFailure
def test_symint_hashing(self):
shape_env = ShapeEnv()
hash(create_symint(shape_env, 3))
def test_symnode_hashing(self):
shape_env = ShapeEnv()
# These all trigger specialization when hashed
hash(create_symbool(shape_env, True))
# We should be passing in float here, but create_symbol currently
# only supports int
hash(create_symfloat(shape_env, 3.0))
# NestedInt (SymInt), constant SymBool, SymNode are hashable
j1 = torch._C._get_nested_int(1, 1)
j1_copy = torch._C._get_nested_int(1, 1)
j2 = torch._C._get_nested_int(2, 1)
t = self.get_constant_bool(True)
t_copy = self.get_constant_bool(True)
f = self.get_constant_bool(False)
n = create_symint(shape_env, 3).node
m = self.get_constant_bool(True).node
self.assertIs(j1 == j1_copy, True)
self.assertEqual(hash(j1), hash(j1_copy))
self.assertIs(j1 == j2, False)
self.assertNotEqual(hash(j1), hash(j2))
self.assertIs(t == t_copy, True)
self.assertEqual(hash(t), hash(t_copy))
self.assertIs(t == f, False)
self.assertNotEqual(hash(t), hash(f))
hash(n)
hash(m)
def test_symint_deepcopy(self):
shape_env = ShapeEnv()
symnodes = (torch._C._get_nested_int(1, 1),)
deepcopied_symnodes = copy.deepcopy(symnodes)
self.assertEqual(symnodes, deepcopied_symnodes)
def test_non_symbolic_symnode(self):
j1 = torch._C._get_nested_int(1, 1)
j2 = torch._C._get_nested_int(1, 1)
j3 = torch._C._get_nested_int(3, 1)
self.assertIsInstance(j1, torch.SymInt)
self.assertNotIsInstance(j1, int)
with self.assertRaisesRegex(
RuntimeError, "add not supported by NestedIntSymNode"
):
j1 + 3
self.assertFalse(j1 == 3)
with self.assertRaisesRegex(RuntimeError, "indeterminate"):
self.assertFalse(3 >= j2)
self.assertIs(j1 == j1, True)
self.assertIs(j1 == j2, True)
self.assertIs(j1 == j3, False)
self.assertIs(j1 != j3, True)
self.assertIs(j1 != j2, False)
x = self.get_constant_bool(True)
#
# Unary
#
# op(constant SymBool)
self.assertIs(x.__sym_not__(), False)
#
# Binary
#
# op(constant SymBool, bool)
# op(constant SymBool, constant SymBool)
# op(bool, constant SymBool)
self.assertIs(operator.and_(x, True), True)
self.assertIs(operator.and_(x, x), True)
self.assertIs(operator.and_(True, x), True)
# op(symbolic SymBool, constant Symbool)
# op(constant SymBool, symbolic Symbool)
shape_env = ShapeEnv()
a = create_symint(shape_env, 2)
b = create_symint(shape_env, 2)
c = a == b # symbolic SymBool
d = self.get_constant_bool(True)
e = operator.and_(c, d)
f = operator.and_(d, c)
self.assertTrue(is_symbolic(e))
self.assertTrue(is_symbolic(f))
self.assertIs(e.node.guard_bool("", 0), True)
self.assertIs(f.node.guard_bool("", 0), True)
# Comparing sizes
sz1 = torch.Size([j1, j1, j1])
sz2 = torch.Size([j1, j1, j1])
self.assertIs(sz1 == sz2, True)
sz1 = torch.Size([3, j1, 4])
sz2 = torch.Size([3, j2, 4])
self.assertIs(sz1 == sz2, True)
self.assertIs(sz1 != sz2, False)
def test_stride_symnode(self):
from torch._subclasses.fake_tensor import FakeTensorMode
shape_env = ShapeEnv()
def _create_symbolic_tensor(x, dynamic_sizes, dynamic_strides):
with FakeTensorMode(shape_env=shape_env) as fake_mode:
return fake_mode.from_tensor(
x,
symbolic_context=StatelessSymbolicContext(
dynamic_sizes=dynamic_sizes,
dynamic_strides=dynamic_strides,
),
)
# check everything static
t = _create_symbolic_tensor(
x=torch.ones(3, 6),
dynamic_sizes=[
DimDynamic.STATIC,
DimDynamic.STATIC,
],
dynamic_strides=[
DimDynamic.INFER_STRIDE,
DimDynamic.INFER_STRIDE,
],
)
self.assertTrue(all(isinstance(size, int) for size in t.size()))
self.assertTrue(all(isinstance(stride, int) for stride in t.stride()))
# check dynamic size but static dims
t = _create_symbolic_tensor(
x=torch.ones(3, 6),
dynamic_sizes=[
DimDynamic.DYNAMIC,
DimDynamic.DYNAMIC,
],
dynamic_strides=[
DimDynamic.INFER_STRIDE,
DimDynamic.INFER_STRIDE,
],
)
# Expect stride to be inferred
s0, s1 = t.size()
s2, s3 = t.stride()
self.assertTrue(isinstance(s0, torch.SymInt))
self.assertTrue(isinstance(s1, torch.SymInt))
self.assertTrue(isinstance(s2, torch.SymInt))
self.assertTrue(s1 == s2)
self.assertEqual(s3, 1)
# Check dynamic stride but static dims
t = _create_symbolic_tensor(
x=torch.ones(3, 6),
dynamic_sizes=[
DimDynamic.STATIC,
DimDynamic.STATIC,
],
dynamic_strides=[
DimDynamic.DYNAMIC,
DimDynamic.INFER_STRIDE,
],
)
s0, s1 = t.size()
s2, s3 = t.stride()
self.assertTrue(isinstance(s0, int))
self.assertTrue(isinstance(s1, int))
self.assertTrue(isinstance(s2, torch.SymInt))
self.assertTrue(isinstance(s3, int))
# Check dynamic sizes and dims, and ensure different symbol
t = _create_symbolic_tensor(
x=torch.ones(3, 6),
dynamic_sizes=[
DimDynamic.DYNAMIC,
DimDynamic.DYNAMIC,
],
dynamic_strides=[
DimDynamic.DYNAMIC,
DimDynamic.INFER_STRIDE,
],
)
s0, s1 = t.size()
s2, s3 = t.stride()
self.assertTrue(isinstance(s0, torch.SymInt))
self.assertTrue(isinstance(s1, torch.SymInt))
self.assertTrue(isinstance(s2, torch.SymInt))
self.assertTrue(isinstance(s3, int))
self.assertTrue(str(s1.node.expr) != str(s2.node.expr))
instantiate_parametrized_tests(TestSymNumberMagicMethods)
class TestFloorDiv(TestCase):
@staticmethod
def python_floordiv(x, y):
return x // y
@staticmethod
def torch_floordiv(x, y):
# Note: we fully evaluate here since FloorDiv might not always do
# that.
shape_env = ShapeEnv()
return shape_env.evaluate_expr(FloorDiv(x, y))
@staticmethod
def yield_test_cases(values, negate=True):
for x, y in values:
yield (x, y)
if negate:
yield (-x, y)
yield (x, -y)
yield (-x, -y)
def test_floordiv_float_int(self):
values = ((7, 2),)
for x, y in TestFloorDiv.yield_test_cases(values):
self.assertEqual(
TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y)
)
def test_floordiv_div_by_one(self):
values = ((2, 1),)
for x, y in TestFloorDiv.yield_test_cases(values):
self.assertEqual(
TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y)
)
def test_floordiv_simplify(self):
# Tests how we simplify or evaluate FloorDiv without free variables
shape_env = ShapeEnv()
result = 21
exprs = (7 * FloorDiv(6, 2),)
for expr in exprs:
self.assertEqual(expr, result)
self.assertEqual(expr.doit(deep=False), result)
self.assertEqual(expr.doit(deep=True), result)
self.assertEqual(sympy.simplify(expr), result)
self.assertEqual(shape_env.simplify(expr), result)
self.assertEqual(shape_env.evaluate_expr(expr), result)
def test_floordiv_assumptions(self):
cases = (
sympy.Symbol("i1", integer=True),
sympy.Symbol("i2", integer=True),
)
for base, divisor in itertools.product(cases, repeat=2):
def op():
return FloorDiv(base, divisor)
def is_complex(x):
return x.is_integer is False and x.is_real is False and x.is_complex
if is_complex(base) or is_complex(divisor):
self.assertRaisesRegex(
TypeError,
(
r"unsupported operand type\(s\) for //: 'Symbol' and 'Symbol',"
r" expected integer or real"
),
op,
)
continue
op = op()
# In regular Python, x//x == 1.0 if x is a float, but FloorDiv
# always returns an integer 1 when both args are the same object.
# This even works for Symbols with no assumptions specified.
if base is divisor:
self.assertTrue(op.is_integer)
self.assertTrue(op.is_real)
elif base.is_integer and divisor.is_integer:
self.assertTrue(op.is_integer)
self.assertTrue(op.is_real)
else:
self.assertEqual(op.is_integer, None)
self.assertTrue(op.is_real)
class TestDimConstraints(TestCase):
def test_dim_constraints_reduce_congruences_simple(self):
from sympy import Symbol
s = Symbol("s", positive=True, integer=True)
dim_constraints = DimConstraints({}, {}, set(), {})
dim_constraints._congruences[s] = {
(s / 2) % 2,
(s / 2) % 8,
(s / 2) % 4,
s % 2,
((s / 16) + 2) % 4,
}
congruences = dim_constraints._reduce_congruences()
self.assertEqual(congruences[s], {(s + 32) % 64})
def test_dim_constraints_reduce_inequalities_simple(self):
from sympy import Eq, Interval, Ne, Symbol
from sympy.solvers.inequalities import reduce_inequalities
s = Symbol("s", positive=True, integer=True)
exprs = {
s >= 2,
Ne(8 * s, 16),
Ne(s / 2, 1),
Ne(16 * s, 32),
s < 16,
Ne(s, 2),
s / 2 < 16,
s / 2 > 1,
s / 2 >= 2,
Ne(3 * s / 2, 3),
}
solution = reduce_inequalities(exprs, s).as_set()
self.assertEqual(solution, Interval.Ropen(4, 16))
exprs.add(Eq(s / 2, 4))
solution = reduce_inequalities(exprs, s).as_set()
self.assertEqual(solution, {8})
def test_dim_constraints_reduce_inequalities_error(self):
from collections import defaultdict
from sympy import Symbol
from sympy.solvers.inequalities import reduce_inequalities
from torch._dynamo.source import (
LocalSource,
TensorProperty,
TensorPropertySource,
)
from torch.fx.experimental.symbolic_shapes import DynamicDimConstraintPrinter
s0 = Symbol("s0", positive=True, integer=True)
exprs = {
4 * s0**3 - 4 * s0**2 + s0 <= 2147483647,
s0 >= 2,
s0**3 <= 2147483647,
s0 <= 2147483647,
}
answer = reduce_inequalities(exprs, s0)
symbol_to_source = defaultdict(list)
symbol_to_source[s0].append(
TensorPropertySource(
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=0
)
)
dcp = DynamicDimConstraintPrinter(symbol_to_source, {})
with self.assertRaisesRegex(
AssertionError,
"Unknown symbol.*created by constraints solver",
):
dcp.doprint(answer)
def test_dim_constraints_solve_full(self):
from sympy import Eq, Integer, Ne, Symbol
from torch._dynamo.source import (
LocalSource,
TensorProperty,
TensorPropertySource,
)
src0 = TensorPropertySource(
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=0
)
src2 = TensorPropertySource(
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=0
)
src3 = TensorPropertySource(
base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=0
)
src4 = TensorPropertySource(
base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=0
)
src1 = TensorPropertySource(
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=2
)
src7 = TensorPropertySource(
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=3
)
src5 = TensorPropertySource(
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=1
)
src8 = TensorPropertySource(
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=1
)
src6 = TensorPropertySource(
base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=1
)
src9 = TensorPropertySource(
base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=1
)
src10 = TensorPropertySource(
base=LocalSource(local_name="e"), prop=TensorProperty.SIZE, idx=1
)
src11 = TensorPropertySource(
base=LocalSource(local_name="f"), prop=TensorProperty.SIZE, idx=1
)
src12 = TensorPropertySource(
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=2
)
s0 = Symbol("s0", positive=True, integer=True)
s1 = Symbol("s1", positive=True, integer=True)
s5 = Symbol("s5", positive=True, integer=True)
s6 = Symbol("s6", positive=True, integer=True)
symbol_to_source = {
s0: [src0, src2, src3, src4],
s1: [src1, src7],
s5: [src5, src8],
s6: [src6, src9, src10],
}
var_to_val = {s0: 8, s1: 96, s5: 22, s6: 21}
marked_dynamic = {s0, s1, s5, s6}
dim_constraints = DimConstraints(
symbol_to_source, var_to_val, marked_dynamic, {}
)
dim_constraints.add_equality(src2, s0)
dim_constraints.add_equality(src3, s0)
dim_constraints.add_equality(src4, s0)
dim_constraints.add_equality(src7, s1)
dim_constraints.add_equality(src8, s5)
dim_constraints.add_equality(src9, s6)
dim_constraints.add_equality(src10, s6)
dim_constraints.add_equality(src11, Integer(1))
dim_constraints.add_equality(src12, Integer(3))
dim_constraints.add(s1**2 <= 2147483647)
dim_constraints.add(32 * s1**2 <= 2147483647)
dim_constraints.add(s0 < 16)
dim_constraints.add(Eq(Mod(s1, 2), 0))
dim_constraints.add(Ne(FloorDiv(s1, 2), 1))
dim_constraints.add(Ne((FloorDiv(s1, 2)) ** 2, 1))
dim_constraints.add(32 * (FloorDiv(s1, 2)) ** 2 <= 2147483647)
dim_constraints.add((FloorDiv(s1, 2)) ** 2 > 1)
dim_constraints.add(Ne(FloorDiv(s1, 2), 1))
dim_constraints.add(
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
+ 128 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
+ 64
<= 2147483647
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 2) + 1, 1))
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
+ 1,
1,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 2) + 1, 1))
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
+ 1
> 1
)
dim_constraints.add(
128 * (FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
+ 256 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
+ 128
<= 2147483647
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 4) + 1, 1))
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
+ 1,
1,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 4) + 1, 1))
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
+ 1
> 1
)
dim_constraints.add(
256 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
+ 512 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
+ 256
<= 2147483647
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 1))
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
+ 1,
1,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 1))
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
+ 1
> 1
)
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1 >= 3)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
<= 2147483647
)
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 0)
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 1)
dim_constraints.add(
Ne(
60 * s0 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * s0 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * s0,
0,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 1))
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
1,
)
)
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
0,
)
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1
>= 0
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 0))
dim_constraints.add(
1
< 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, -1))
dim_constraints.add(
Ne(
60 * s0 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * s0 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * s0,
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120,
)
)
dim_constraints.add(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
> 0
)
dim_constraints.add(
Eq(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 * (Mod(s0, 2))
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) * Mod(s0, 2)
+ 60 * (Mod(s0, 2)),
0,
)
)
dim_constraints.add(
Ne(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120,
0,
)
)
dim_constraints.add(
Ne(
60
* (FloorDiv(s0, 2))
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120
* FloorDiv(s0, 2)
* FloorDiv(s0, (FloorDiv(s0, 2)))
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
0,
)
)
dim_constraints.add(Ne(FloorDiv(s0, 2), 1))
dim_constraints.add(
Ne(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60,
0,
)
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
>= 0
)
dim_constraints.add(
1
< 60
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
)
dim_constraints.add(Ne(16 * s0, 32))
dim_constraints.add(Eq(16 * (Mod(s0, 2)), 0))
dim_constraints.add(Ne(16 * s0, 32))
dim_constraints.add(Eq(16 * (Mod(s0, 2)), 0))
dim_constraints.add(FloorDiv(s0, 2) >= 2)
dim_constraints.add(Ne(FloorDiv(s0, 2), 1))
dim_constraints.add(1 < FloorDiv(s0, 2))
dim_constraints.add(Ne(s0, 2))
dim_constraints.add(
60
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
>= 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
)
dim_constraints.add(
60
* (FloorDiv(s0, 2))
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120
* FloorDiv(s0, 2)
* FloorDiv(s0, (FloorDiv(s0, 2)))
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2))))
> 0
)
dim_constraints.add(
Ne(
60
* (FloorDiv(s0, 2))
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120
* FloorDiv(s0, 2)
* FloorDiv(s0, (FloorDiv(s0, 2)))
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
3 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
)
)
dim_constraints.add(
Ne(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20,
0,
)
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
>= 0
)
dim_constraints.add(
Ne(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20,
20,
)
)
dim_constraints.add(
Ne(
20
* (
Mod(
1,
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
)
),
0,
)
)
dim_constraints.add(
Ne(
20
* (FloorDiv((FloorDiv(s1, 2) - 1), 8))
* (
Mod(
1,
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
- 2
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
)
)
- 20
* Mod(
1,
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
- 2
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
),
0,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 1))
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1
>= 1
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
>= 0
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
>= 1
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
>= 2
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
> 1
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
< 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
)
dim_constraints.add(
Ne(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60,
60,
)
)
dim_constraints.add(
Ne(
FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1,
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
)
)
dim_constraints.add(
Eq(
(FloorDiv((FloorDiv(s1, 2) - 1), 8))
* (
Mod(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
- 2
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
1,
)
)
- Mod(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
- 2
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
1,
),
0,
)
)
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1,
)
)
dim_constraints.add(Ne(8 * s0, 16))
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
>= (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1
)
dim_constraints.add(
60
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
<= 2147483647
)
dim_constraints.add(
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90
<= 2147483647
)
dim_constraints.add(FloorDiv(s0, 2) < 16)
dim_constraints.add(FloorDiv(s0, 2) > 1)
dim_constraints.add(
Ne(
90 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
1
< 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1
> 1
)
dim_constraints.add(
60
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
> 1
)
dim_constraints.add(
Ne(
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90
> 1
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
> 1
)
dim_constraints.add(
Ne(
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2)),
3 * (FloorDiv(s0, 2)),
)
)
dim_constraints.add(
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2))
> 0
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
> 0
)
dim_constraints.add(
Ne(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120,
0,
)
)
dim_constraints.add(
1
< 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
)
dim_constraints.add(
Ne(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120,
6,
)
)
dim_constraints.add(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
> 0
)
dim_constraints.add(
Ne(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120,
0,
)
)
dim_constraints.add(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
<= 2147483647
)
dim_constraints.add(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
<= 20480
)
dim_constraints.add(
Ne(
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90,
0,
)
)
dim_constraints.add(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
> 1
)
dim_constraints.add(
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90
<= 20480
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
<= 20480
)
dim_constraints.add(
Ne(
240 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 480 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 240,
0,
)
)
dim_constraints.add(Eq(6 * s5, 132))
dim_constraints.add(Eq(4, FloorDiv(s0, 2)))
dim_constraints.add(Eq(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 4))
dim_constraints.add(
Ne(
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 128 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 64 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
1
< 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 64
)
dim_constraints.add(
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 64
<= 2147483647
)
dim_constraints.add(
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 64
> 1
)
dim_constraints.add(
62 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 124 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 62
<= 2147483647
)
dim_constraints.add(
Ne(
62 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 124 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 62 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
1
< 62 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 124 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 62
)
dim_constraints.add(Ne(3 * (FloorDiv(s0, 2)), 3))
dim_constraints.add(Ne(3 * (FloorDiv(s0, 2)), 3))
dim_constraints.add(Eq(FloorDiv(s0, 2), 4))
dim_constraints.add(Eq(4, FloorDiv(s0, 2)))
dim_constraints.add(Eq(FloorDiv(s0, 2), 4))
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 3)
dim_constraints.add(
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576
<= 2147483647
)
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3 >= 0)
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3 >= 1)
dim_constraints.add(
Ne(
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 1))
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9,
1,
)
)
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9,
0,
)
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9
>= 0
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 0))
dim_constraints.add(
1
< 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 1))
dim_constraints.add(
Ne(
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576 * (FloorDiv(s0, 2)),
256,
)
)
dim_constraints.add(
Eq(
64
* (
Mod(
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9 * (FloorDiv(s0, 2)),
4,
)
),
0,
)
)
dim_constraints.add(
Eq(
FloorDiv(s0, 2),
FloorDiv(
(
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9 * (FloorDiv(s0, 2))
),
4,
),
)
)
dim_constraints.add(
Eq(
FloorDiv(
(
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9 * (FloorDiv(s0, 2))
),
4,
),
FloorDiv(s0, 2),
)
)
dim_constraints.add(
Ne(64 * (Mod(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 4)), 0)
)
dim_constraints.add(
Eq(
64
* (
Mod(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
4,
)
),
0,
)
)
dim_constraints.add(
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576 * (FloorDiv(s0, 2))
> 0
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9
>= 1
)
dim_constraints.add(
Eq(
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576,
256,
)
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 540
<= 2147483647
)
dim_constraints.add(
Ne(
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 360 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 540 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
1
< 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 540
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9
<= 2147483647
)
dim_constraints.add(
Ne(
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
1
< (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9
> 1
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 540
> 1
)
dim_constraints.add(s0 >= 2)
dim_constraints.add(s1 >= 2)
dim_constraints.add(s6 >= 2)
dim_constraints.add(s5 >= 2)
dim_constraints.solve()
self.assertEqual(
dim_constraints._static_results,
{
"L['c'].size()[0] == 8",
"L['d'].size()[0] == 8",
"L['a'].size()[2] == 96",
"L['f'].size()[1] == 1",
"L['a'].size()[3] == 96",
"L['b'].size()[2] == 3",
"L['b'].size()[1] == 22",
"L['b'].size()[0] == 8",
"L['a'].size()[1] == 22",
"L['a'].size()[0] == 8",
},
)
self.assertEqual(
dim_constraints._dynamic_results,
{
"2 <= L['c'].size()[1]",
"L['d'].size()[1] == L['c'].size()[1]",
"L['e'].size()[1] == L['c'].size()[1]",
},
)
class TestGuardsExpressions(TestCase):
"""
Tests the guards-related methods used by the inductor FX graph cache.
"""
def test_guards_gt_lt(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 6)
s1 = create_symint(shape_env, 7)
s2 = create_symint(shape_env, 5)
guard_int(sym_int(s0 > 5))
guard_int(sym_int(s0 < 7))
guards = shape_env.produce_guards_expression([s0])
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)]))
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s2)]))
def test_guards_float_print(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 3)
guard_bool(2 / s0 == 2 / 3)
guards = shape_env.produce_guards_expression([s0])
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
def test_guards_float_div(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 8)
s1 = create_symint(shape_env, 7)
guard_int(sym_int(s0 / 2.0))
guards = shape_env.produce_guards_expression([s0])
self.assertIn("ToFloat", guards)
self.assertIn("FloatTrueDiv", guards)
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)]))
if __name__ == "__main__":
run_tests()