blob: d5dd3e8ee8a90830384e8bde954fce28671660be [file] [log] [blame]
# Owner(s): ["oncall: jit"]
import contextlib
import copy
import itertools
import inspect
import math
import operator
import re
import sympy
import torch
import torch.fx
import torch.nn.functional as F
from torch import sym_int, SymBool, SymFloat, SymInt
from torch._C import _disabled_torch_function_impl
from torch.fx.experimental import symbolic_shapes
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import (
DimConstraints,
DimDynamic,
guard_bool,
guard_float,
guard_int,
GuardOnDataDependentSymNode,
ShapeEnv,
sym_float,
sym_sqrt,
SymNode,
to_node,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
skipIfTorchDynamo,
TestCase,
)
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map
from torch.utils._sympy.functions import FloorDiv, Mod
aten = torch.ops.aten
meta_funcs = {}
def register_meta(op):
def decorator(f):
def add_func(op):
meta_funcs[op] = f
tree_map(add_func, op)
return f
return decorator
@register_meta([aten.add.Tensor, aten.sub.Tensor])
def binary_meta(a, b):
return a.new_empty(a.shape)
@register_meta(aten.cat.default)
def cat_meta(tensors, dim=0):
concat_length = 0
shape = tensors[0].shape
for tensor in tensors:
for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)):
if idx == dim:
concat_length = concat_length + length
else:
assert length == common_length
new_shape = list(shape)
new_shape[dim] = concat_length
return tensors[0].new_empty(new_shape)
@register_meta([aten.narrow_copy.default])
def narrow_copy_symint_meta(a, dim, start, length, **kwargs):
shape = []
for i, x in enumerate(a.shape):
if i == dim:
shape.append(length)
else:
shape.append(x)
return a.new_empty(tuple(shape))
@register_meta([aten.expand.default])
def expand_symint_meta(a, size, implicit=False):
return a.new_empty(size)
def create_contiguous(shape):
strides = [1]
for dim in reversed(shape[:-1]):
strides.append(dim * strides[-1])
return list(reversed(strides))
class FakeSymbolicTensor(torch.Tensor):
@staticmethod
def __new__(cls, sym_shape, sym_strides, dtype, layout, requires_grad, device, storage_offset=0):
# TODO: this is wrong in general
sym_stride = create_contiguous(sym_shape)
r = torch.Tensor._make_wrapper_subclass(
cls, sym_shape,
sym_stride, storage_offset,
dtype=dtype, layout=layout, requires_grad=requires_grad,
device=device,
)
return r
__torch_function__ = _disabled_torch_function_impl
def new_empty(self, shape):
return FakeSymbolicTensor(shape, None, self.dtype, self.layout, self.requires_grad, self.device)
@classmethod
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
if func_overload in meta_funcs:
return meta_funcs[func_overload](*args, **kwargs)
if func_overload == torch.ops.aten.new_empty.default:
self = args[0]
shape = args[1]
return FakeSymbolicTensor(shape, self.stride(), self.dtype, self.layout, self.requires_grad, self.device)
raise RuntimeError(f"operator {func_overload} not supported")
def create_symbolic_tensor(name, arg, shape_env):
from torch._dynamo.source import ConstantSource
constraint_dims = [None] * arg.dim()
dynamic_dims = [DimDynamic.DUCK] * arg.dim()
sym_shapes, sym_strides, sym_storage_offset = \
shape_env.create_symbolic_sizes_strides_storage_offset(
arg,
source=ConstantSource(name),
dynamic_dims=dynamic_dims,
constraint_dims=constraint_dims
)
return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, sym_storage_offset)
def create_symint(shape_env, i: int):
from torch._dynamo.source import ConstantSource
return shape_env.create_symintnode(
shape_env.create_symbol(
i,
source=ConstantSource(f"__testing_only{len(shape_env.var_to_val)}"),
dynamic_dim=DimDynamic.DUCK,
constraint_dim=None,
),
hint=i
)
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
class TestPySymInt(TestCase):
def test_arith_ops(self):
shape_env = ShapeEnv()
symints = []
for i in range(2, 5):
symints.append((i, create_symint(shape_env, i)))
ops = [operator.add, operator.sub, operator.floordiv, operator.mul, operator.mod]
for op in ops:
for args in itertools.permutations(symints, 2):
if not isinstance(args[0][1], int) and ((op != operator.mod or op != operator.floordiv) and args[1][0] != 0):
self.assertTrue(op(args[0][1], args[1][1]) == op(args[0][0], args[1][0]))
def test_reverse_arith_ops(self):
shape_env = ShapeEnv()
a = create_symint(shape_env, 2)
self.assertTrue(5 // a == 5 // 2)
a = create_symint(shape_env, 2)
self.assertTrue(5 * a == 5 * 2)
def test_roundtrip(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
self.assertTrue(not isinstance(x.shape[0], SymNode))
self.assertTrue(isinstance(x.shape[0], SymInt))
self.assertTrue(x.shape[0] == 5)
self.assertTrue(x.shape[1] == 4)
self.assertTrue(x.shape[2], 3)
self.assertTrue(x.size()[0], 5)
self.assertTrue(x.size()[1], 4)
self.assertTrue(isinstance(x.size()[1], int)) # due to guard above
self.assertTrue(x.size()[2] == 3)
self.assertTrue(x.size(0) == 5)
self.assertTrue(x.size(1) == 4)
self.assertTrue(x.size(2) == 3)
self.assertTrue(isinstance(x.size(2), int))
y = create_symbolic_tensor("y", torch.randn(5, 4, 3)[1:], shape_env)
self.assertTrue(isinstance(y.storage_offset(), SymInt))
self.assertTrue(y.storage_offset() == 12)
def test_binary(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env)
z = x + y
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# broadcasting
y = create_symbolic_tensor("y2", torch.randn(1, 4, 1), shape_env)
z = x + y
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
def test_symint_args(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env)
LAST_DIM = 2
z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM])
self.assertTrue(z.shape[2] == y.shape[2])
# arithmetic expr with two symints
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM])
self.assertTrue(z.shape[2] == 2)
# arithmetic expr with a symint and python int
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1)
self.assertTrue(z.shape[2] == 2)
def test_symint_vargs(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
# varargs
z = y.expand(x.shape[0], y.shape[1], x.shape[2])
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# shape list
z = y.expand((x.shape[0], y.shape[1], x.shape[2]))
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python symints and ints
z = y.expand(x.shape[0], y.shape[1], 3)
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python symints and ints in a list
z = y.expand((x.shape[0], y.shape[1], 3))
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python symints and ints
z = y.expand(5, y.shape[1], x.shape[2])
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python ints and symints in a list
z = y.expand((5, y.shape[1], x.shape[2]))
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
z = y.expand((y.shape[1],))
z = y.expand(y.shape[1])
def test_stride(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env)
self.assertIsInstance(x.stride()[0], SymInt)
def test_size_expressions(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
expand_x = x.expand(x.shape[0], x.shape[0])
if expand_x.shape[0] > 3:
result = expand_x + expand_x
else:
result = expand_x + expand_x
gt_op, _bt = shape_env.guards[-1]
self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan))
self.assertTrue(str(x.shape[0]), str(gt_op.args[0]))
self.assertTrue(str(expand_x.shape[1]), str(x.shape[0]))
self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))
def test_numel(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
self.assertIsInstance(x.numel(), torch.SymInt)
self.assertIsInstance(torch.numel(x), torch.SymInt)
x = torch.rand(3, 3)
self.assertIsInstance(x.numel(), int)
self.assertIsInstance(torch.numel(x), int)
def test_int_to_float(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
r = sym_float(x.shape[0])
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
def test_aten_ops(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
torch.ops.aten.narrow_copy.default(x, 0, 0, x.shape[0])
shape_env = ShapeEnv()
x = create_symbolic_tensor("x2", torch.randn(5, 4, 3), shape_env)
torch.ops.aten.expand.default(x, [x.shape[0], x.shape[1], x.shape[2]])
def test_fx_trace_intlist(self):
class CustomModule(torch.nn.Module):
def forward(self, x):
bs, c, h, w = x.shape
return F.pad(x, (0, w % 2, 0, h % 2, 0, 0))
m = CustomModule()
x = torch.rand(1, 3, 4, 4)
# should not TypeError: pad(): argument 'pad' (position 2) must be
# tuple of ints, not tuple
torch.fx.symbolic_trace(m)
def test_meta_symint(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
r = torch.empty(a0, device='meta')
self.assertIsInstance(r.shape[0], SymInt)
def test_guard_int(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
self.assertEqual(guard_int(a0), 2)
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
def test_sym_int(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = sym_int(a0)
self.assertEqual(r, 5)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 5)""")
a1 = create_symint(shape_env, 7)
r = sym_int(a1 / 2)
self.assertEqual(guard_int(r), 3)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(floor(s1/2), 3)""")
a3 = create_symint(shape_env, 3)
r = sym_int(2.0 * sym_float(a3))
self.assertEqual(guard_int(r), 6)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[2][0]), """Eq(2*s2, 6)""")
def test_sym_sqrt(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 4)
r = sym_sqrt(a0)
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(sqrt(s0), 2)""")
def test_sym_floor(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = math.floor(a0 / 2)
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(floor(s0/2), 2)""")
r = math.floor(3.0 * a0)
self.assertEqual(r, 15)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""")
def test_sym_ceil(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = math.ceil(a0 / 2)
self.assertEqual(r, 3)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(ceiling(s0/2), 3)""")
r = math.floor(3.0 * a0)
self.assertEqual(r, 15)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""")
def test_int_conversion(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
int(a0)
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
def test_data_dependent_guard(self):
shape_env = ShapeEnv()
s0 = shape_env.create_unbacked_symint()
self.assertRaises(GuardOnDataDependentSymNode, lambda: bool(s0 == 0))
def test_non_overlapping_and_dense(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = torch.empty_strided((a0, 7), (1, a0), device='meta')
self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(r))
def test_specialize_zero_one(self):
shape_env = ShapeEnv(specialize_zero_one=True)
a0 = create_symint(shape_env, 5)
assert a0 != 1
self.assertEqual(len(shape_env.guards), 0)
shape_env = ShapeEnv(specialize_zero_one=False)
a0 = create_symint(shape_env, 5)
assert a0 != 1
self.assertEqual(len(shape_env.guards), 1)
def test_duck_shape(self):
shape_env = ShapeEnv(duck_shape=True)
a0 = create_symint(shape_env, 5)
a1 = create_symint(shape_env, 5)
assert a0 == a1
self.assertEqual(len(shape_env.guards), 0)
shape_env = ShapeEnv(duck_shape=False)
a0 = create_symint(shape_env, 5)
a1 = create_symint(shape_env, 5)
assert a0 == a1
self.assertEqual(len(shape_env.guards), 1)
def test_int_bool(self):
# See https://github.com/pytorch/pytorch/issues/95981
shape_env = ShapeEnv(duck_shape=True)
a0 = create_symint(shape_env, 5)
assert a0
self.assertEqual(len(shape_env.guards), 0)
def test_symint_as_scalar(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
sym_int_encountered = False
class TestSymInt(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
assert func == torch.ops.aten.add.Tensor
nonlocal sym_int_encountered
# WARNING: do not do identity tests on the outer
# SymInt/SymFloat, they are NOT STABLE
sym_int_encountered = kwargs["alpha"].node is a0.node
kwargs["alpha"] = 0
return func(*args)
x = torch.rand([4, 4])
with TestSymInt():
y = torch.add(x, x, alpha=a0)
self.assertTrue(sym_int_encountered)
def test_deepcopy(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
assert a0 < 4
new_shape_env = copy.deepcopy(shape_env)
self.assertEqual(len(new_shape_env.guards), 1)
def test_print_readable_with_symints(self):
def f(a, b):
dim0 = a.shape[0] + b.shape[0]
dim1 = a.shape[1] + b.shape[1]
d = a.new_empty(dim0, dim1)
d = torch.ops.aten.native_dropout(d, 0.5, train=True)
return d
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3))
out = fx_g.print_readable(print_output=False)
self.assertExpectedInline(out.strip(), """\
class f(torch.nn.Module):
def forward(self, a_1: f32[s0, s1], b_1: f32[s2, s1]):
# No stacktrace found for following nodes
sym_size: Sym(s0) = torch.ops.aten.sym_size(a_1, 0)
sym_size_1: Sym(s2) = torch.ops.aten.sym_size(b_1, 0)
add: Sym(s0 + s2) = sym_size + sym_size_1; sym_size = sym_size_1 = None
sym_size_2: Sym(s1) = torch.ops.aten.sym_size(a_1, 1)
sym_size_3: Sym(s1) = torch.ops.aten.sym_size(b_1, 1); b_1 = None
add_1: Sym(2*s1) = sym_size_2 + sym_size_3; sym_size_2 = sym_size_3 = None
new_empty: f32[s0 + s2, 2*s1] = torch.ops.aten.new_empty.default(a_1, [add, add_1], pin_memory = False); a_1 = add = add_1 = None
native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None
getitem: f32[s0 + s2, 2*s1] = native_dropout[0]
getitem_1: b8[s0 + s2, 2*s1] = native_dropout[1]; native_dropout = None
return (getitem, getitem_1)""") # noqa: B950
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
class TestSymNumberMagicMethods(TestCase):
def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn):
# Helper function
# NB: don't use one as that will get specialized
seed_node = (create_symint(shape_env, 2) / 2.).node
bool_seed_node = (create_symint(shape_env, 2) == 2).node
def get_sym_inp(inp):
# NB: this must come before int
if isinstance(inp, bool):
return torch.SymBool(to_node(bool_seed_node, inp))
elif isinstance(inp, int):
return torch.SymInt(to_node(seed_node, inp))
else:
return torch.SymFloat(to_node(seed_node, inp))
def maybe_xfail(inp1, inp2):
if fn == "sym_sqrt" and inp1 < 0:
# ValueError: math domain error
return self.assertRaises((ValueError,))
elif fn in ("truediv", "floordiv", "mod") and inp2 == 0:
# ZeroDivisionError: division by zero
return self.assertRaises((ZeroDivisionError,))
elif fn == "pow" and inp1 == 0 and inp2 < 0:
# ZeroDivisionError: 0.0 cannot be raised to a negative power
return self.assertRaises((ZeroDivisionError,))
elif fn == "pow" and inp1 < 0 and inp2 in (2.5, -2.5) and (
type(inp1) in (SymFloat, SymInt) or
type(inp2) in (SymFloat, SymInt)
):
# Complex result, which we do not support:
# TypeError: Cannot convert complex to float
return self.assertRaises((TypeError,))
elif fn in ("lshift", "rshift") and not (
isinstance(inp1, (SymInt, int)) and
isinstance(inp2, (SymInt, int))
):
# TypeError: unsupported operand type(s)
return self.assertRaises((TypeError,))
elif fn in ("lshift", "rshift") and inp2 < 0:
# ValueError: math domain error
return self.assertRaises((ValueError,))
else:
return contextlib.nullcontext()
if fn in symbolic_shapes.magic_methods_on_math:
lambda_apply = getattr(math, fn)
elif fn in symbolic_shapes.magic_methods_on_submodule:
lambda_apply = getattr(symbolic_shapes, fn)
elif fn in symbolic_shapes.magic_methods_on_operator_with_trailing_underscore:
lambda_apply = getattr(operator, f"{fn}_")
else:
lambda_apply = getattr(operator, fn)
def guard_fn(v):
if type(v) in (SymBool, bool):
return guard_bool(v)
elif type(v) in (SymFloat, float):
return guard_float(v)
else: # SymInt, int
return guard_int(v)
# Get reference result
with maybe_xfail(inp1, inp2):
if is_unary_fn:
ref_out = lambda_apply(inp1)
else:
ref_out = lambda_apply(inp1, inp2)
# Symified first arg
sym_inp1 = get_sym_inp(inp1)
with maybe_xfail(sym_inp1, inp2):
if is_unary_fn:
out = lambda_apply(sym_inp1)
else:
out = lambda_apply(sym_inp1, inp2)
out = guard_fn(out)
self.assertEqual(out, ref_out)
if is_unary_fn:
return
# Symified second arg
sym_inp2 = get_sym_inp(inp2)
with maybe_xfail(inp1, sym_inp2):
out = lambda_apply(inp1, sym_inp2)
out = guard_fn(out)
self.assertEqual(out, ref_out)
# Symified both args
with maybe_xfail(sym_inp1, sym_inp2):
out = lambda_apply(sym_inp1, sym_inp2)
out = guard_fn(out)
self.assertEqual(out, ref_out)
@parametrize("fn", list(symbolic_shapes.magic_methods.keys()))
def test_bool_method(self, fn):
if fn not in symbolic_shapes.bool_magic_methods:
self.skipTest(f"{fn} is non-bool")
is_unary_fn = fn in symbolic_shapes.unary_magic_methods
shape_env = ShapeEnv()
self._do_test(fn, True, False, shape_env, is_unary_fn)
@parametrize("fn", list(symbolic_shapes.magic_methods.keys()))
@parametrize("first_type", ["int", "float"])
@parametrize("second_type", ["int", "float"])
def test_method(self, fn, first_type, second_type):
if first_type == "float":
# TODO: Hmm, this looks like we skip all floats
self.skipTest(f"{fn} is not a float magic method")
is_unary_fn = fn in symbolic_shapes.unary_magic_methods
# Second argument is ignored for unary function. So only run for one type
if is_unary_fn and second_type == "float":
self.skipTest(f"{fn} is unary and already tested")
if fn in symbolic_shapes.bool_magic_methods:
self.skipTest(f"{fn} is bool")
# Only floats here since these will be converted to int if necessary.
# We also ignore complex and bool.
values = (
0.0,
1.0,
2.5,
)
neg_values = tuple(-x for x in values)
for inp1, inp2 in itertools.chain(
itertools.product(values, values),
itertools.product(values, neg_values),
itertools.product(neg_values, values),
itertools.product(neg_values, neg_values),
):
if first_type == "int":
inp1 = int(inp1)
if second_type == "int":
inp2 = int(inp2)
shape_env = ShapeEnv()
self._do_test(fn, inp1, inp2, shape_env, is_unary_fn)
instantiate_parametrized_tests(TestSymNumberMagicMethods)
class TestFloorDiv(TestCase):
@staticmethod
def python_floordiv(x, y):
return x // y
@staticmethod
def torch_floordiv(x, y):
# Note: we fully evaluate here since FloorDiv might not always do
# that.
shape_env = ShapeEnv()
return shape_env.evaluate_expr(FloorDiv(x, y))
@staticmethod
def yield_test_cases(values, negate=True):
for x, y in values:
yield (x, y)
if negate:
yield (-x, y)
yield (x, -y)
yield (-x, -y)
def test_floordiv_float_int(self):
values = (
(2.5, 2.1),
(2.1, 2.5),
(2.0, 2.1),
(7, 2.5),
(2.1, 7),
(7, 2),
)
for x, y in TestFloorDiv.yield_test_cases(values):
self.assertEqual(TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y))
def test_floordiv_bool(self):
values = (
(False, True),
(True, 2.5),
(2.5, True),
(False, 7),
(7, True),
)
for x, y in TestFloorDiv.yield_test_cases(values, negate=False):
# Compares to int since our FloorDiv has no bool support
self.assertEqual(TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(int(x), int(y)))
# Tests that our impl throws
self.assertRaisesRegex(
TypeError,
(rf"unsupported operand type\(s\) for //: "
rf"'{type(sympy.sympify(x)).__name__}' and '{type(sympy.sympify(y)).__name__}'"
rf", expected integer or real"),
lambda: TestFloorDiv.torch_floordiv(x, y))
def test_floordiv_complex(self):
values = (
(1.5 + 2.5j, 1.3 + 3.5j),
(1.5 + 2.5j, 2.5),
(2.5, 1.5 + 2.5j),
(1.5 + 2.5j, 7),
(7, 1.5 + 2.5j),
)
for x, y in TestFloorDiv.yield_test_cases(values):
# We don't test error messages to avoid depending on Python
# interpreter version
self.assertRaises(TypeError, lambda: TestFloorDiv.python_floordiv(x, y))
self.assertRaisesRegex(
TypeError,
(rf"unsupported operand type\(s\) for //: "
rf"'{type(sympy.sympify(x)).__name__}' and '{type(sympy.sympify(y)).__name__}'"
rf", expected integer or real"),
lambda: TestFloorDiv.torch_floordiv(x, y))
def test_floordiv_div_by_zero(self):
values = (
(2.5, 0),
(2.1, 0.0),
(2.3, sympy.Symbol("s", zero=True)),
)
for x, y in TestFloorDiv.yield_test_cases(values, negate=False):
# We don't test error messages to avoid depending on Python
# interpreter version
if type(y) is not sympy.Symbol:
self.assertRaises(ZeroDivisionError, lambda: TestFloorDiv.python_floordiv(x, y))
self.assertRaisesRegex(
ZeroDivisionError,
"division by zero",
lambda: TestFloorDiv.torch_floordiv(x, y))
def test_floordiv_zero_base(self):
values = (
(0, 2.5),
(0.0, 2.1),
(sympy.Symbol("s", zero=True), 2.3),
)
for x, y in TestFloorDiv.yield_test_cases(values, negate=False):
if type(x) is not sympy.Symbol:
self.assertEqual(TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y))
else:
self.assertEqual(0, TestFloorDiv.torch_floordiv(x, y))
def test_floordiv_div_by_one(self):
values = (
(2.5, 1),
(2.1, 1.0),
(2, 1.0),
(2, 1),
)
for x, y in TestFloorDiv.yield_test_cases(values):
self.assertEqual(TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y))
def test_floordiv_simplify(self):
# Tests how we simplify or evaluate FloorDiv without free variables
shape_env = ShapeEnv()
result = 21
exprs = (
7 * FloorDiv(6, 2),
7 * FloorDiv(6.28, 2),
7 * FloorDiv(6.28, 2.0),
7 * FloorDiv(6.28, (FloorDiv(6.28, 3.14))),
)
for expr in exprs:
self.assertEqual(expr, result)
self.assertEqual(expr.doit(deep=False), result)
self.assertEqual(expr.doit(deep=True), result)
self.assertEqual(sympy.simplify(expr), result)
self.assertEqual(shape_env.simplify(expr), result)
self.assertEqual(shape_env.evaluate_expr(expr), result)
def test_floordiv_assumptions(self):
# We define two Symbols (with different names) for each type to make
# sure the behavior is consistent regardless of whether both arguments
# are the same object or not.
cases = (
sympy.Symbol("i1", integer=True),
sympy.Symbol("i2", integer=True),
sympy.Symbol("r1", real=True),
sympy.Symbol("r2", real=True),
sympy.Symbol("c1", complex=True, real=False, integer=False),
sympy.Symbol("c2", complex=True, real=False, integer=False),
sympy.Symbol("s1"),
sympy.Symbol("s2"),
)
for base, divisor in itertools.product(cases, repeat=2):
def op():
return FloorDiv(base, divisor)
def is_complex(x):
return x.is_integer is False and x.is_real is False and x.is_complex
if is_complex(base) or is_complex(divisor):
self.assertRaisesRegex(
TypeError,
(r"unsupported operand type\(s\) for //: 'Symbol' and 'Symbol',"
r" expected integer or real"),
op)
continue
op = op()
# In regular Python, x//x == 1.0 if x is a float, but FloorDiv
# always returns an integer 1 when both args are the same object.
# This even works for Symbols with no assumptions specified.
if base is divisor:
self.assertTrue(op.is_integer)
self.assertTrue(op.is_real)
elif base.is_integer and divisor.is_integer:
self.assertTrue(op.is_integer)
self.assertTrue(op.is_real)
else:
self.assertEqual(op.is_integer, None)
self.assertTrue(op.is_real)
class TestDimConstraints(TestCase):
def test_dim_constraints_reduce_congruences_simple(self):
from sympy import Symbol
from torch.fx.experimental.symbolic_shapes import DimConstraints
s = Symbol("s", positive=True, integer=True)
dim_constraints = DimConstraints({}, {}, set())
dim_constraints._congruences[s] = {
(s / 2) % 2,
(s / 2) % 8,
(s / 2) % 4,
s % 2,
((s / 16) + 2) % 4,
}
congruences = dim_constraints.reduce_congruences()
self.assertEqual(congruences[s], {(s + 32) % 64})
def test_dim_constraints_reduce_inequalities_simple(self):
from sympy import Eq, Interval, Ne, Symbol
from sympy.solvers.inequalities import reduce_inequalities
s = Symbol("s", positive=True, integer=True)
exprs = {
s >= 2,
Ne(8 * s, 16),
Ne(s / 2, 1),
Ne(16 * s, 32),
s < 16,
Ne(s, 2),
s / 2 < 16,
s / 2 > 1,
s / 2 >= 2,
Ne(3 * s / 2, 3),
}
solution = reduce_inequalities(exprs, s).as_set()
self.assertEqual(solution, Interval.Ropen(4, 16))
exprs.add(Eq(s / 2, 4))
solution = reduce_inequalities(exprs, s).as_set()
self.assertEqual(solution, {8})
def test_dim_constraints_solve_full(self):
from sympy import Eq, Integer, Ne, Symbol
from torch._dynamo.source import LocalSource, TensorProperty, TensorPropertySource
src0 = TensorPropertySource(
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=0
)
src2 = TensorPropertySource(
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=0
)
src3 = TensorPropertySource(
base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=0
)
src4 = TensorPropertySource(
base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=0
)
src1 = TensorPropertySource(
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=2
)
src7 = TensorPropertySource(
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=3
)
src5 = TensorPropertySource(
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=1
)
src8 = TensorPropertySource(
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=1
)
src6 = TensorPropertySource(
base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=1
)
src9 = TensorPropertySource(
base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=1
)
src10 = TensorPropertySource(
base=LocalSource(local_name="e"), prop=TensorProperty.SIZE, idx=1
)
src11 = TensorPropertySource(
base=LocalSource(local_name="f"), prop=TensorProperty.SIZE, idx=1
)
src12 = TensorPropertySource(
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=2
)
s0 = Symbol("s0", positive=True, integer=True)
s1 = Symbol("s1", positive=True, integer=True)
s5 = Symbol("s5", positive=True, integer=True)
s6 = Symbol("s6", positive=True, integer=True)
symbol_to_source = {
s0: [src0, src2, src3, src4],
s1: [src1, src7],
s5: [src5, src8],
s6: [src6, src9, src10],
}
var_to_val = {s0: 8, s1: 96, s5: 22, s6: 21}
marked_dynamic = {s0, s1, s5, s6}
dim_constraints = DimConstraints(symbol_to_source, var_to_val, marked_dynamic)
dim_constraints.add_equality(src2, s0)
dim_constraints.add_equality(src3, s0)
dim_constraints.add_equality(src4, s0)
dim_constraints.add_equality(src7, s1)
dim_constraints.add_equality(src8, s5)
dim_constraints.add_equality(src9, s6)
dim_constraints.add_equality(src10, s6)
dim_constraints.add_equality(src11, Integer(1))
dim_constraints.add_equality(src12, Integer(3))
dim_constraints.add(s1**2 <= 2147483647)
dim_constraints.add(32 * s1**2 <= 2147483647)
dim_constraints.add(s0 < 16)
dim_constraints.add(Eq(Mod(s1, 2), 0))
dim_constraints.add(Ne(FloorDiv(s1, 2), 1))
dim_constraints.add(Ne((FloorDiv(s1, 2)) ** 2, 1))
dim_constraints.add(32 * (FloorDiv(s1, 2)) ** 2 <= 2147483647)
dim_constraints.add((FloorDiv(s1, 2)) ** 2 > 1)
dim_constraints.add(Ne(FloorDiv(s1, 2), 1))
dim_constraints.add(
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
+ 128 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
+ 64
<= 2147483647
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 2) + 1, 1))
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
+ 1,
1,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 2) + 1, 1))
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
+ 1
> 1
)
dim_constraints.add(
128 * (FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
+ 256 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
+ 128
<= 2147483647
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 4) + 1, 1))
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
+ 1,
1,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 4) + 1, 1))
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
+ 1
> 1
)
dim_constraints.add(
256 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
+ 512 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
+ 256
<= 2147483647
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 1))
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
+ 1,
1,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 1))
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
+ 1
> 1
)
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1 >= 3)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
<= 2147483647
)
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 0)
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 1)
dim_constraints.add(
Ne(
60 * s0 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * s0 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * s0,
0,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 1))
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
1,
)
)
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
0,
)
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1
>= 0
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 0))
dim_constraints.add(
1
< 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, -1))
dim_constraints.add(
Ne(
60 * s0 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * s0 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * s0,
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120,
)
)
dim_constraints.add(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
> 0
)
dim_constraints.add(
Eq(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 * (Mod(s0, 2))
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) * Mod(s0, 2)
+ 60 * (Mod(s0, 2)),
0,
)
)
dim_constraints.add(
Ne(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120,
0,
)
)
dim_constraints.add(
Ne(
60
* (FloorDiv(s0, 2))
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120
* FloorDiv(s0, 2)
* FloorDiv(s0, (FloorDiv(s0, 2)))
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
0,
)
)
dim_constraints.add(Ne(FloorDiv(s0, 2), 1))
dim_constraints.add(
Ne(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60,
0,
)
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
>= 0
)
dim_constraints.add(
1
< 60
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
)
dim_constraints.add(Ne(16 * s0, 32))
dim_constraints.add(Eq(16 * (Mod(s0, 2)), 0))
dim_constraints.add(Ne(16 * s0, 32))
dim_constraints.add(Eq(16 * (Mod(s0, 2)), 0))
dim_constraints.add(FloorDiv(s0, 2) >= 2)
dim_constraints.add(Ne(FloorDiv(s0, 2), 1))
dim_constraints.add(1 < FloorDiv(s0, 2))
dim_constraints.add(Ne(s0, 2))
dim_constraints.add(
60
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
>= 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
)
dim_constraints.add(
60
* (FloorDiv(s0, 2))
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120
* FloorDiv(s0, 2)
* FloorDiv(s0, (FloorDiv(s0, 2)))
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2))))
> 0
)
dim_constraints.add(
Ne(
60
* (FloorDiv(s0, 2))
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120
* FloorDiv(s0, 2)
* FloorDiv(s0, (FloorDiv(s0, 2)))
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
3 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
)
)
dim_constraints.add(
Ne(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20,
0,
)
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
>= 0
)
dim_constraints.add(
Ne(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20,
20,
)
)
dim_constraints.add(
Ne(
20
* (
Mod(
1,
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
)
),
0,
)
)
dim_constraints.add(
Ne(
20
* (FloorDiv((FloorDiv(s1, 2) - 1), 8))
* (
Mod(
1,
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
- 2
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
)
)
- 20
* Mod(
1,
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
- 2
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
),
0,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 1))
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1
>= 1
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
>= 0
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
>= 1
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
>= 2
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
> 1
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
< 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
)
dim_constraints.add(
Ne(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60,
60,
)
)
dim_constraints.add(
Ne(
FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1,
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
)
)
dim_constraints.add(
Eq(
(FloorDiv((FloorDiv(s1, 2) - 1), 8))
* (
Mod(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
- 2
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
1,
)
)
- Mod(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
- 2
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
1,
),
0,
)
)
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1,
)
)
dim_constraints.add(Ne(8 * s0, 16))
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
>= (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1
)
dim_constraints.add(
60
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
<= 2147483647
)
dim_constraints.add(
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90
<= 2147483647
)
dim_constraints.add(FloorDiv(s0, 2) < 16)
dim_constraints.add(FloorDiv(s0, 2) > 1)
dim_constraints.add(
Ne(
90 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
1
< 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1
> 1
)
dim_constraints.add(
60
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
> 1
)
dim_constraints.add(
Ne(
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90
> 1
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
> 1
)
dim_constraints.add(
Ne(
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2)),
3 * (FloorDiv(s0, 2)),
)
)
dim_constraints.add(
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2))
> 0
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
> 0
)
dim_constraints.add(
Ne(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120,
0,
)
)
dim_constraints.add(
1
< 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
)
dim_constraints.add(
Ne(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120,
6,
)
)
dim_constraints.add(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
> 0
)
dim_constraints.add(
Ne(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120,
0,
)
)
dim_constraints.add(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
<= 2147483647
)
dim_constraints.add(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
<= 20480
)
dim_constraints.add(
Ne(
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90,
0,
)
)
dim_constraints.add(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
> 1
)
dim_constraints.add(
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90
<= 20480
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
<= 20480
)
dim_constraints.add(
Ne(
240 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 480 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 240,
0,
)
)
dim_constraints.add(Eq(6 * s5, 132))
dim_constraints.add(Eq(4, FloorDiv(s0, 2)))
dim_constraints.add(Eq(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 4))
dim_constraints.add(
Ne(
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 128 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 64 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
1
< 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 64
)
dim_constraints.add(
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 64
<= 2147483647
)
dim_constraints.add(
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 64
> 1
)
dim_constraints.add(
62 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 124 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 62
<= 2147483647
)
dim_constraints.add(
Ne(
62 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 124 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 62 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
1
< 62 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 124 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 62
)
dim_constraints.add(Ne(3 * (FloorDiv(s0, 2)), 3))
dim_constraints.add(Ne(3 * (FloorDiv(s0, 2)), 3))
dim_constraints.add(Eq(FloorDiv(s0, 2), 4))
dim_constraints.add(Eq(4, FloorDiv(s0, 2)))
dim_constraints.add(Eq(FloorDiv(s0, 2), 4))
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 3)
dim_constraints.add(
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576
<= 2147483647
)
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3 >= 0)
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3 >= 1)
dim_constraints.add(
Ne(
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 1))
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9,
1,
)
)
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9,
0,
)
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9
>= 0
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 0))
dim_constraints.add(
1
< 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 1))
dim_constraints.add(
Ne(
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576 * (FloorDiv(s0, 2)),
256,
)
)
dim_constraints.add(
Eq(
64
* (
Mod(
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9 * (FloorDiv(s0, 2)),
4,
)
),
0,
)
)
dim_constraints.add(
Eq(
FloorDiv(s0, 2),
FloorDiv(
(
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9 * (FloorDiv(s0, 2))
),
4,
),
)
)
dim_constraints.add(
Eq(
FloorDiv(
(
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9 * (FloorDiv(s0, 2))
),
4,
),
FloorDiv(s0, 2),
)
)
dim_constraints.add(Ne(64 * (Mod(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 4)), 0))
dim_constraints.add(
Eq(
64
* (
Mod(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
4,
)
),
0,
)
)
dim_constraints.add(
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576 * (FloorDiv(s0, 2))
> 0
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9
>= 1
)
dim_constraints.add(
Eq(
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576,
256,
)
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 540
<= 2147483647
)
dim_constraints.add(
Ne(
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 360 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 540 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
1
< 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 540
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9
<= 2147483647
)
dim_constraints.add(
Ne(
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
1
< (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9
> 1
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 540
> 1
)
dim_constraints.add(s0 >= 2)
dim_constraints.add(s1 >= 2)
dim_constraints.add(s6 >= 2)
dim_constraints.add(s5 >= 2)
dim_constraints.solve()
self.assertEqual(dim_constraints._static_results, {
"L['c'].size()[0] == 8",
"L['d'].size()[0] == 8",
"L['a'].size()[2] == 96",
"L['f'].size()[1] == 1",
"L['a'].size()[3] == 96",
"L['b'].size()[2] == 3",
"L['b'].size()[1] == 22",
"L['b'].size()[0] == 8",
"L['a'].size()[1] == 22",
"L['a'].size()[0] == 8",
})
self.assertEqual(dim_constraints._dynamic_results, {
"dynamic_dim(L['e'], 1) == dynamic_dim(L['c'], 1)",
"2 <= dynamic_dim(L['c'], 1)",
"dynamic_dim(L['d'], 1) == dynamic_dim(L['c'], 1)",
})
def dummy_fn(a, b, c, d, e, f):
pass
action_code = dim_constraints.prettify_results(inspect.signature(dummy_fn))
static_code, dynamic_code = re.findall(r"```(.*?)```", action_code, re.DOTALL)
expected_static = '''
def specializations(a, b, c, d, e, f):
# a:
assert a.size()[0] == 8
assert a.size()[1] == 22
assert a.size()[2] == 96
assert a.size()[3] == 96
# b:
assert b.size()[0] == 8
assert b.size()[1] == 22
assert b.size()[2] == 3
# c:
assert c.size()[0] == 8
# d:
assert d.size()[0] == 8
# f:
assert f.size()[1] == 1
'''
expected_dynamic = '''
def specify_constraints(a, b, c, d, e, f):
return [
# c:
dynamic_dim(c, 1),
# d:
dynamic_dim(d, 1) == dynamic_dim(c, 1),
# e:
dynamic_dim(e, 1) == dynamic_dim(c, 1),
]
'''
self.assertEqual(static_code, expected_static)
self.assertEqual(dynamic_code, expected_dynamic)
if __name__ == '__main__':
run_tests()