| # Owner(s): ["oncall: jit"] |
| |
| import contextlib |
| import copy |
| import itertools |
| import inspect |
| import math |
| import operator |
| import re |
| |
| import sympy |
| import torch |
| import torch.fx |
| import torch.nn.functional as F |
| from torch import sym_int, SymBool, SymFloat, SymInt |
| from torch._C import _disabled_torch_function_impl |
| from torch.fx.experimental import symbolic_shapes |
| from torch.fx.experimental.proxy_tensor import make_fx |
| from torch.fx.experimental.symbolic_shapes import ( |
| DimConstraints, |
| DimDynamic, |
| guard_bool, |
| guard_float, |
| guard_int, |
| GuardOnDataDependentSymNode, |
| ShapeEnv, |
| sym_float, |
| sym_sqrt, |
| SymNode, |
| to_node, |
| ) |
| from torch.testing._internal.common_utils import ( |
| instantiate_parametrized_tests, |
| parametrize, |
| run_tests, |
| skipIfTorchDynamo, |
| TestCase, |
| ) |
| from torch.utils._python_dispatch import TorchDispatchMode |
| from torch.utils._pytree import tree_map |
| from torch.utils._sympy.functions import FloorDiv, Mod |
| |
| aten = torch.ops.aten |
| |
| meta_funcs = {} |
| |
| |
| def register_meta(op): |
| def decorator(f): |
| def add_func(op): |
| meta_funcs[op] = f |
| tree_map(add_func, op) |
| return f |
| return decorator |
| |
| |
| @register_meta([aten.add.Tensor, aten.sub.Tensor]) |
| def binary_meta(a, b): |
| return a.new_empty(a.shape) |
| |
| |
| @register_meta(aten.cat.default) |
| def cat_meta(tensors, dim=0): |
| concat_length = 0 |
| shape = tensors[0].shape |
| for tensor in tensors: |
| for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)): |
| if idx == dim: |
| concat_length = concat_length + length |
| else: |
| assert length == common_length |
| new_shape = list(shape) |
| new_shape[dim] = concat_length |
| return tensors[0].new_empty(new_shape) |
| |
| |
| @register_meta([aten.narrow_copy.default]) |
| def narrow_copy_symint_meta(a, dim, start, length, **kwargs): |
| shape = [] |
| for i, x in enumerate(a.shape): |
| if i == dim: |
| shape.append(length) |
| else: |
| shape.append(x) |
| return a.new_empty(tuple(shape)) |
| |
| |
| @register_meta([aten.expand.default]) |
| def expand_symint_meta(a, size, implicit=False): |
| return a.new_empty(size) |
| |
| |
| def create_contiguous(shape): |
| strides = [1] |
| for dim in reversed(shape[:-1]): |
| strides.append(dim * strides[-1]) |
| return list(reversed(strides)) |
| |
| |
| class FakeSymbolicTensor(torch.Tensor): |
| @staticmethod |
| def __new__(cls, sym_shape, sym_strides, dtype, layout, requires_grad, device, storage_offset=0): |
| # TODO: this is wrong in general |
| sym_stride = create_contiguous(sym_shape) |
| r = torch.Tensor._make_wrapper_subclass( |
| cls, sym_shape, |
| sym_stride, storage_offset, |
| dtype=dtype, layout=layout, requires_grad=requires_grad, |
| device=device, |
| ) |
| return r |
| |
| __torch_function__ = _disabled_torch_function_impl |
| |
| def new_empty(self, shape): |
| return FakeSymbolicTensor(shape, None, self.dtype, self.layout, self.requires_grad, self.device) |
| |
| @classmethod |
| def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None): |
| if func_overload in meta_funcs: |
| return meta_funcs[func_overload](*args, **kwargs) |
| |
| if func_overload == torch.ops.aten.new_empty.default: |
| self = args[0] |
| shape = args[1] |
| return FakeSymbolicTensor(shape, self.stride(), self.dtype, self.layout, self.requires_grad, self.device) |
| |
| raise RuntimeError(f"operator {func_overload} not supported") |
| |
| |
| def create_symbolic_tensor(name, arg, shape_env): |
| from torch._dynamo.source import ConstantSource |
| |
| constraint_dims = [None] * arg.dim() |
| dynamic_dims = [DimDynamic.DUCK] * arg.dim() |
| sym_shapes, sym_strides, sym_storage_offset = \ |
| shape_env.create_symbolic_sizes_strides_storage_offset( |
| arg, |
| source=ConstantSource(name), |
| dynamic_dims=dynamic_dims, |
| constraint_dims=constraint_dims |
| ) |
| return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, sym_storage_offset) |
| |
| def create_symint(shape_env, i: int): |
| from torch._dynamo.source import ConstantSource |
| return shape_env.create_symintnode( |
| shape_env.create_symbol( |
| i, |
| source=ConstantSource(f"__testing_only{len(shape_env.var_to_val)}"), |
| dynamic_dim=DimDynamic.DUCK, |
| constraint_dim=None, |
| ), |
| hint=i |
| ) |
| |
| @skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)") |
| class TestPySymInt(TestCase): |
| |
| def test_arith_ops(self): |
| shape_env = ShapeEnv() |
| symints = [] |
| for i in range(2, 5): |
| symints.append((i, create_symint(shape_env, i))) |
| |
| ops = [operator.add, operator.sub, operator.floordiv, operator.mul, operator.mod] |
| |
| for op in ops: |
| for args in itertools.permutations(symints, 2): |
| if not isinstance(args[0][1], int) and ((op != operator.mod or op != operator.floordiv) and args[1][0] != 0): |
| self.assertTrue(op(args[0][1], args[1][1]) == op(args[0][0], args[1][0])) |
| |
| |
| def test_reverse_arith_ops(self): |
| shape_env = ShapeEnv() |
| |
| a = create_symint(shape_env, 2) |
| self.assertTrue(5 // a == 5 // 2) |
| |
| a = create_symint(shape_env, 2) |
| self.assertTrue(5 * a == 5 * 2) |
| |
| |
| def test_roundtrip(self): |
| shape_env = ShapeEnv() |
| x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) |
| |
| self.assertTrue(not isinstance(x.shape[0], SymNode)) |
| self.assertTrue(isinstance(x.shape[0], SymInt)) |
| |
| self.assertTrue(x.shape[0] == 5) |
| self.assertTrue(x.shape[1] == 4) |
| self.assertTrue(x.shape[2], 3) |
| |
| self.assertTrue(x.size()[0], 5) |
| self.assertTrue(x.size()[1], 4) |
| self.assertTrue(isinstance(x.size()[1], int)) # due to guard above |
| self.assertTrue(x.size()[2] == 3) |
| |
| self.assertTrue(x.size(0) == 5) |
| self.assertTrue(x.size(1) == 4) |
| self.assertTrue(x.size(2) == 3) |
| self.assertTrue(isinstance(x.size(2), int)) |
| |
| y = create_symbolic_tensor("y", torch.randn(5, 4, 3)[1:], shape_env) |
| self.assertTrue(isinstance(y.storage_offset(), SymInt)) |
| self.assertTrue(y.storage_offset() == 12) |
| |
| def test_binary(self): |
| shape_env = ShapeEnv() |
| x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) |
| y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env) |
| |
| z = x + y |
| self.assertTrue(z.shape[0] == 5) |
| self.assertTrue(z.shape[1] == 4) |
| self.assertTrue(z.shape[2] == 3) |
| |
| # broadcasting |
| y = create_symbolic_tensor("y2", torch.randn(1, 4, 1), shape_env) |
| z = x + y |
| self.assertTrue(z.shape[0] == 5) |
| self.assertTrue(z.shape[1] == 4) |
| self.assertTrue(z.shape[2] == 3) |
| |
| def test_symint_args(self): |
| shape_env = ShapeEnv() |
| x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) |
| y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env) |
| LAST_DIM = 2 |
| z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM]) |
| self.assertTrue(z.shape[2] == y.shape[2]) |
| |
| # arithmetic expr with two symints |
| z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM]) |
| self.assertTrue(z.shape[2] == 2) |
| |
| # arithmetic expr with a symint and python int |
| z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1) |
| self.assertTrue(z.shape[2] == 2) |
| |
| def test_symint_vargs(self): |
| shape_env = ShapeEnv() |
| x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) |
| y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env) |
| |
| # varargs |
| z = y.expand(x.shape[0], y.shape[1], x.shape[2]) |
| self.assertTrue(z.shape[0] == 5) |
| self.assertTrue(z.shape[1] == 4) |
| self.assertTrue(z.shape[2] == 3) |
| |
| # shape list |
| z = y.expand((x.shape[0], y.shape[1], x.shape[2])) |
| self.assertTrue(z.shape[0] == 5) |
| self.assertTrue(z.shape[1] == 4) |
| self.assertTrue(z.shape[2] == 3) |
| |
| # mixed python symints and ints |
| z = y.expand(x.shape[0], y.shape[1], 3) |
| self.assertTrue(z.shape[0] == 5) |
| self.assertTrue(z.shape[1] == 4) |
| self.assertTrue(z.shape[2] == 3) |
| |
| # mixed python symints and ints in a list |
| z = y.expand((x.shape[0], y.shape[1], 3)) |
| self.assertTrue(z.shape[0] == 5) |
| self.assertTrue(z.shape[1] == 4) |
| self.assertTrue(z.shape[2] == 3) |
| |
| # mixed python symints and ints |
| z = y.expand(5, y.shape[1], x.shape[2]) |
| self.assertTrue(z.shape[0] == 5) |
| self.assertTrue(z.shape[1] == 4) |
| self.assertTrue(z.shape[2] == 3) |
| |
| # mixed python ints and symints in a list |
| z = y.expand((5, y.shape[1], x.shape[2])) |
| self.assertTrue(z.shape[0] == 5) |
| self.assertTrue(z.shape[1] == 4) |
| self.assertTrue(z.shape[2] == 3) |
| |
| z = y.expand((y.shape[1],)) |
| z = y.expand(y.shape[1]) |
| |
| def test_stride(self): |
| shape_env = ShapeEnv() |
| x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env) |
| self.assertIsInstance(x.stride()[0], SymInt) |
| |
| def test_size_expressions(self): |
| shape_env = ShapeEnv() |
| x = create_symbolic_tensor("x", torch.randn(5), shape_env) |
| expand_x = x.expand(x.shape[0], x.shape[0]) |
| if expand_x.shape[0] > 3: |
| result = expand_x + expand_x |
| else: |
| result = expand_x + expand_x |
| |
| gt_op, _bt = shape_env.guards[-1] |
| self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan)) |
| self.assertTrue(str(x.shape[0]), str(gt_op.args[0])) |
| self.assertTrue(str(expand_x.shape[1]), str(x.shape[0])) |
| self.assertTrue(str(expand_x.shape[1]), str(result.shape[0])) |
| |
| def test_numel(self): |
| shape_env = ShapeEnv() |
| x = create_symbolic_tensor("x", torch.randn(5), shape_env) |
| self.assertIsInstance(x.numel(), torch.SymInt) |
| self.assertIsInstance(torch.numel(x), torch.SymInt) |
| |
| x = torch.rand(3, 3) |
| self.assertIsInstance(x.numel(), int) |
| self.assertIsInstance(torch.numel(x), int) |
| |
| def test_int_to_float(self): |
| shape_env = ShapeEnv() |
| x = create_symbolic_tensor("x", torch.randn(5), shape_env) |
| r = sym_float(x.shape[0]) |
| self.assertIsInstance(r, torch.SymFloat, msg=type(r)) |
| |
| def test_aten_ops(self): |
| |
| shape_env = ShapeEnv() |
| x = create_symbolic_tensor("x", torch.randn(5), shape_env) |
| torch.ops.aten.narrow_copy.default(x, 0, 0, x.shape[0]) |
| |
| shape_env = ShapeEnv() |
| x = create_symbolic_tensor("x2", torch.randn(5, 4, 3), shape_env) |
| torch.ops.aten.expand.default(x, [x.shape[0], x.shape[1], x.shape[2]]) |
| |
| def test_fx_trace_intlist(self): |
| class CustomModule(torch.nn.Module): |
| def forward(self, x): |
| bs, c, h, w = x.shape |
| return F.pad(x, (0, w % 2, 0, h % 2, 0, 0)) |
| |
| m = CustomModule() |
| x = torch.rand(1, 3, 4, 4) |
| # should not TypeError: pad(): argument 'pad' (position 2) must be |
| # tuple of ints, not tuple |
| torch.fx.symbolic_trace(m) |
| |
| def test_meta_symint(self): |
| shape_env = ShapeEnv() |
| a0 = create_symint(shape_env, 2) |
| r = torch.empty(a0, device='meta') |
| self.assertIsInstance(r.shape[0], SymInt) |
| |
| def test_guard_int(self): |
| shape_env = ShapeEnv() |
| a0 = create_symint(shape_env, 2) |
| self.assertEqual(guard_int(a0), 2) |
| self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""") |
| |
| def test_sym_int(self): |
| shape_env = ShapeEnv() |
| a0 = create_symint(shape_env, 5) |
| r = sym_int(a0) |
| self.assertEqual(r, 5) |
| self.assertIsInstance(r, torch.SymInt, msg=type(r)) |
| self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 5)""") |
| |
| a1 = create_symint(shape_env, 7) |
| r = sym_int(a1 / 2) |
| self.assertEqual(guard_int(r), 3) |
| self.assertIsInstance(r, torch.SymInt, msg=type(r)) |
| self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(floor(s1/2), 3)""") |
| |
| a3 = create_symint(shape_env, 3) |
| r = sym_int(2.0 * sym_float(a3)) |
| self.assertEqual(guard_int(r), 6) |
| self.assertIsInstance(r, torch.SymInt, msg=type(r)) |
| self.assertExpectedInline(str(shape_env.guards[2][0]), """Eq(2*s2, 6)""") |
| |
| def test_sym_sqrt(self): |
| shape_env = ShapeEnv() |
| a0 = create_symint(shape_env, 4) |
| r = sym_sqrt(a0) |
| self.assertEqual(r, 2) |
| self.assertIsInstance(r, torch.SymFloat, msg=type(r)) |
| self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(sqrt(s0), 2)""") |
| |
| def test_sym_floor(self): |
| shape_env = ShapeEnv() |
| a0 = create_symint(shape_env, 5) |
| r = math.floor(a0 / 2) |
| self.assertEqual(r, 2) |
| self.assertIsInstance(r, torch.SymInt, msg=type(r)) |
| self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(floor(s0/2), 2)""") |
| r = math.floor(3.0 * a0) |
| self.assertEqual(r, 15) |
| self.assertIsInstance(r, torch.SymInt, msg=type(r)) |
| self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""") |
| |
| def test_sym_ceil(self): |
| shape_env = ShapeEnv() |
| a0 = create_symint(shape_env, 5) |
| r = math.ceil(a0 / 2) |
| self.assertEqual(r, 3) |
| self.assertIsInstance(r, torch.SymInt, msg=type(r)) |
| self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(ceiling(s0/2), 3)""") |
| r = math.floor(3.0 * a0) |
| self.assertEqual(r, 15) |
| self.assertIsInstance(r, torch.SymInt, msg=type(r)) |
| self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""") |
| |
| |
| def test_int_conversion(self): |
| shape_env = ShapeEnv() |
| a0 = create_symint(shape_env, 2) |
| int(a0) |
| self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""") |
| |
| def test_data_dependent_guard(self): |
| shape_env = ShapeEnv() |
| s0 = shape_env.create_unbacked_symint() |
| self.assertRaises(GuardOnDataDependentSymNode, lambda: bool(s0 == 0)) |
| |
| def test_non_overlapping_and_dense(self): |
| shape_env = ShapeEnv() |
| a0 = create_symint(shape_env, 5) |
| r = torch.empty_strided((a0, 7), (1, a0), device='meta') |
| self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(r)) |
| |
| def test_specialize_zero_one(self): |
| shape_env = ShapeEnv(specialize_zero_one=True) |
| a0 = create_symint(shape_env, 5) |
| assert a0 != 1 |
| self.assertEqual(len(shape_env.guards), 0) |
| |
| shape_env = ShapeEnv(specialize_zero_one=False) |
| a0 = create_symint(shape_env, 5) |
| assert a0 != 1 |
| self.assertEqual(len(shape_env.guards), 1) |
| |
| def test_duck_shape(self): |
| shape_env = ShapeEnv(duck_shape=True) |
| a0 = create_symint(shape_env, 5) |
| a1 = create_symint(shape_env, 5) |
| assert a0 == a1 |
| self.assertEqual(len(shape_env.guards), 0) |
| |
| shape_env = ShapeEnv(duck_shape=False) |
| a0 = create_symint(shape_env, 5) |
| a1 = create_symint(shape_env, 5) |
| assert a0 == a1 |
| self.assertEqual(len(shape_env.guards), 1) |
| |
| def test_int_bool(self): |
| # See https://github.com/pytorch/pytorch/issues/95981 |
| shape_env = ShapeEnv(duck_shape=True) |
| a0 = create_symint(shape_env, 5) |
| assert a0 |
| self.assertEqual(len(shape_env.guards), 0) |
| |
| def test_symint_as_scalar(self): |
| shape_env = ShapeEnv() |
| a0 = create_symint(shape_env, 2) |
| |
| sym_int_encountered = False |
| |
| class TestSymInt(TorchDispatchMode): |
| def __torch_dispatch__(self, func, types, args=(), kwargs=None): |
| assert func == torch.ops.aten.add.Tensor |
| |
| nonlocal sym_int_encountered |
| # WARNING: do not do identity tests on the outer |
| # SymInt/SymFloat, they are NOT STABLE |
| sym_int_encountered = kwargs["alpha"].node is a0.node |
| kwargs["alpha"] = 0 |
| return func(*args) |
| |
| x = torch.rand([4, 4]) |
| with TestSymInt(): |
| y = torch.add(x, x, alpha=a0) |
| |
| self.assertTrue(sym_int_encountered) |
| |
| def test_deepcopy(self): |
| shape_env = ShapeEnv() |
| a0 = create_symint(shape_env, 2) |
| assert a0 < 4 |
| new_shape_env = copy.deepcopy(shape_env) |
| self.assertEqual(len(new_shape_env.guards), 1) |
| |
| def test_print_readable_with_symints(self): |
| def f(a, b): |
| dim0 = a.shape[0] + b.shape[0] |
| dim1 = a.shape[1] + b.shape[1] |
| d = a.new_empty(dim0, dim1) |
| d = torch.ops.aten.native_dropout(d, 0.5, train=True) |
| return d |
| |
| fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3)) |
| out = fx_g.print_readable(print_output=False) |
| |
| self.assertExpectedInline(out.strip(), """\ |
| class f(torch.nn.Module): |
| def forward(self, a_1: f32[s0, s1], b_1: f32[s2, s1]): |
| # No stacktrace found for following nodes |
| sym_size: Sym(s0) = torch.ops.aten.sym_size(a_1, 0) |
| sym_size_1: Sym(s2) = torch.ops.aten.sym_size(b_1, 0) |
| add: Sym(s0 + s2) = sym_size + sym_size_1; sym_size = sym_size_1 = None |
| sym_size_2: Sym(s1) = torch.ops.aten.sym_size(a_1, 1) |
| sym_size_3: Sym(s1) = torch.ops.aten.sym_size(b_1, 1); b_1 = None |
| add_1: Sym(2*s1) = sym_size_2 + sym_size_3; sym_size_2 = sym_size_3 = None |
| new_empty: f32[s0 + s2, 2*s1] = torch.ops.aten.new_empty.default(a_1, [add, add_1], pin_memory = False); a_1 = add = add_1 = None |
| native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None |
| getitem: f32[s0 + s2, 2*s1] = native_dropout[0] |
| getitem_1: b8[s0 + s2, 2*s1] = native_dropout[1]; native_dropout = None |
| return (getitem, getitem_1)""") # noqa: B950 |
| |
| @skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)") |
| class TestSymNumberMagicMethods(TestCase): |
| def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn): |
| # Helper function |
| # NB: don't use one as that will get specialized |
| seed_node = (create_symint(shape_env, 2) / 2.).node |
| bool_seed_node = (create_symint(shape_env, 2) == 2).node |
| |
| def get_sym_inp(inp): |
| # NB: this must come before int |
| if isinstance(inp, bool): |
| return torch.SymBool(to_node(bool_seed_node, inp)) |
| elif isinstance(inp, int): |
| return torch.SymInt(to_node(seed_node, inp)) |
| else: |
| return torch.SymFloat(to_node(seed_node, inp)) |
| |
| def maybe_xfail(inp1, inp2): |
| if fn == "sym_sqrt" and inp1 < 0: |
| # ValueError: math domain error |
| return self.assertRaises((ValueError,)) |
| elif fn in ("truediv", "floordiv", "mod") and inp2 == 0: |
| # ZeroDivisionError: division by zero |
| return self.assertRaises((ZeroDivisionError,)) |
| elif fn == "pow" and inp1 == 0 and inp2 < 0: |
| # ZeroDivisionError: 0.0 cannot be raised to a negative power |
| return self.assertRaises((ZeroDivisionError,)) |
| elif fn == "pow" and inp1 < 0 and inp2 in (2.5, -2.5) and ( |
| type(inp1) in (SymFloat, SymInt) or |
| type(inp2) in (SymFloat, SymInt) |
| ): |
| # Complex result, which we do not support: |
| # TypeError: Cannot convert complex to float |
| return self.assertRaises((TypeError,)) |
| elif fn in ("lshift", "rshift") and not ( |
| isinstance(inp1, (SymInt, int)) and |
| isinstance(inp2, (SymInt, int)) |
| ): |
| # TypeError: unsupported operand type(s) |
| return self.assertRaises((TypeError,)) |
| elif fn in ("lshift", "rshift") and inp2 < 0: |
| # ValueError: math domain error |
| return self.assertRaises((ValueError,)) |
| else: |
| return contextlib.nullcontext() |
| |
| if fn in symbolic_shapes.magic_methods_on_math: |
| lambda_apply = getattr(math, fn) |
| elif fn in symbolic_shapes.magic_methods_on_submodule: |
| lambda_apply = getattr(symbolic_shapes, fn) |
| elif fn in symbolic_shapes.magic_methods_on_operator_with_trailing_underscore: |
| lambda_apply = getattr(operator, f"{fn}_") |
| else: |
| lambda_apply = getattr(operator, fn) |
| |
| def guard_fn(v): |
| if type(v) in (SymBool, bool): |
| return guard_bool(v) |
| elif type(v) in (SymFloat, float): |
| return guard_float(v) |
| else: # SymInt, int |
| return guard_int(v) |
| |
| # Get reference result |
| with maybe_xfail(inp1, inp2): |
| if is_unary_fn: |
| ref_out = lambda_apply(inp1) |
| else: |
| ref_out = lambda_apply(inp1, inp2) |
| |
| # Symified first arg |
| sym_inp1 = get_sym_inp(inp1) |
| with maybe_xfail(sym_inp1, inp2): |
| if is_unary_fn: |
| out = lambda_apply(sym_inp1) |
| else: |
| out = lambda_apply(sym_inp1, inp2) |
| out = guard_fn(out) |
| self.assertEqual(out, ref_out) |
| |
| if is_unary_fn: |
| return |
| |
| # Symified second arg |
| sym_inp2 = get_sym_inp(inp2) |
| with maybe_xfail(inp1, sym_inp2): |
| out = lambda_apply(inp1, sym_inp2) |
| out = guard_fn(out) |
| self.assertEqual(out, ref_out) |
| |
| # Symified both args |
| with maybe_xfail(sym_inp1, sym_inp2): |
| out = lambda_apply(sym_inp1, sym_inp2) |
| out = guard_fn(out) |
| self.assertEqual(out, ref_out) |
| |
| |
| @parametrize("fn", list(symbolic_shapes.magic_methods.keys())) |
| def test_bool_method(self, fn): |
| if fn not in symbolic_shapes.bool_magic_methods: |
| self.skipTest(f"{fn} is non-bool") |
| |
| is_unary_fn = fn in symbolic_shapes.unary_magic_methods |
| shape_env = ShapeEnv() |
| self._do_test(fn, True, False, shape_env, is_unary_fn) |
| |
| |
| @parametrize("fn", list(symbolic_shapes.magic_methods.keys())) |
| @parametrize("first_type", ["int", "float"]) |
| @parametrize("second_type", ["int", "float"]) |
| def test_method(self, fn, first_type, second_type): |
| if first_type == "float": |
| # TODO: Hmm, this looks like we skip all floats |
| self.skipTest(f"{fn} is not a float magic method") |
| |
| is_unary_fn = fn in symbolic_shapes.unary_magic_methods |
| # Second argument is ignored for unary function. So only run for one type |
| if is_unary_fn and second_type == "float": |
| self.skipTest(f"{fn} is unary and already tested") |
| |
| if fn in symbolic_shapes.bool_magic_methods: |
| self.skipTest(f"{fn} is bool") |
| |
| # Only floats here since these will be converted to int if necessary. |
| # We also ignore complex and bool. |
| values = ( |
| 0.0, |
| 1.0, |
| 2.5, |
| ) |
| |
| neg_values = tuple(-x for x in values) |
| |
| for inp1, inp2 in itertools.chain( |
| itertools.product(values, values), |
| itertools.product(values, neg_values), |
| itertools.product(neg_values, values), |
| itertools.product(neg_values, neg_values), |
| ): |
| if first_type == "int": |
| inp1 = int(inp1) |
| if second_type == "int": |
| inp2 = int(inp2) |
| |
| shape_env = ShapeEnv() |
| |
| self._do_test(fn, inp1, inp2, shape_env, is_unary_fn) |
| |
| instantiate_parametrized_tests(TestSymNumberMagicMethods) |
| |
| class TestFloorDiv(TestCase): |
| @staticmethod |
| def python_floordiv(x, y): |
| return x // y |
| |
| @staticmethod |
| def torch_floordiv(x, y): |
| # Note: we fully evaluate here since FloorDiv might not always do |
| # that. |
| shape_env = ShapeEnv() |
| return shape_env.evaluate_expr(FloorDiv(x, y)) |
| |
| @staticmethod |
| def yield_test_cases(values, negate=True): |
| for x, y in values: |
| yield (x, y) |
| if negate: |
| yield (-x, y) |
| yield (x, -y) |
| yield (-x, -y) |
| |
| def test_floordiv_float_int(self): |
| values = ( |
| (2.5, 2.1), |
| (2.1, 2.5), |
| (2.0, 2.1), |
| (7, 2.5), |
| (2.1, 7), |
| (7, 2), |
| ) |
| |
| for x, y in TestFloorDiv.yield_test_cases(values): |
| self.assertEqual(TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y)) |
| |
| def test_floordiv_bool(self): |
| values = ( |
| (False, True), |
| (True, 2.5), |
| (2.5, True), |
| (False, 7), |
| (7, True), |
| ) |
| |
| for x, y in TestFloorDiv.yield_test_cases(values, negate=False): |
| # Compares to int since our FloorDiv has no bool support |
| self.assertEqual(TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(int(x), int(y))) |
| # Tests that our impl throws |
| self.assertRaisesRegex( |
| TypeError, |
| (rf"unsupported operand type\(s\) for //: " |
| rf"'{type(sympy.sympify(x)).__name__}' and '{type(sympy.sympify(y)).__name__}'" |
| rf", expected integer or real"), |
| lambda: TestFloorDiv.torch_floordiv(x, y)) |
| |
| def test_floordiv_complex(self): |
| values = ( |
| (1.5 + 2.5j, 1.3 + 3.5j), |
| (1.5 + 2.5j, 2.5), |
| (2.5, 1.5 + 2.5j), |
| (1.5 + 2.5j, 7), |
| (7, 1.5 + 2.5j), |
| ) |
| |
| for x, y in TestFloorDiv.yield_test_cases(values): |
| # We don't test error messages to avoid depending on Python |
| # interpreter version |
| self.assertRaises(TypeError, lambda: TestFloorDiv.python_floordiv(x, y)) |
| self.assertRaisesRegex( |
| TypeError, |
| (rf"unsupported operand type\(s\) for //: " |
| rf"'{type(sympy.sympify(x)).__name__}' and '{type(sympy.sympify(y)).__name__}'" |
| rf", expected integer or real"), |
| lambda: TestFloorDiv.torch_floordiv(x, y)) |
| |
| def test_floordiv_div_by_zero(self): |
| values = ( |
| (2.5, 0), |
| (2.1, 0.0), |
| (2.3, sympy.Symbol("s", zero=True)), |
| ) |
| |
| for x, y in TestFloorDiv.yield_test_cases(values, negate=False): |
| # We don't test error messages to avoid depending on Python |
| # interpreter version |
| if type(y) is not sympy.Symbol: |
| self.assertRaises(ZeroDivisionError, lambda: TestFloorDiv.python_floordiv(x, y)) |
| self.assertRaisesRegex( |
| ZeroDivisionError, |
| "division by zero", |
| lambda: TestFloorDiv.torch_floordiv(x, y)) |
| |
| def test_floordiv_zero_base(self): |
| values = ( |
| (0, 2.5), |
| (0.0, 2.1), |
| (sympy.Symbol("s", zero=True), 2.3), |
| ) |
| |
| for x, y in TestFloorDiv.yield_test_cases(values, negate=False): |
| if type(x) is not sympy.Symbol: |
| self.assertEqual(TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y)) |
| else: |
| self.assertEqual(0, TestFloorDiv.torch_floordiv(x, y)) |
| |
| def test_floordiv_div_by_one(self): |
| values = ( |
| (2.5, 1), |
| (2.1, 1.0), |
| (2, 1.0), |
| (2, 1), |
| ) |
| |
| for x, y in TestFloorDiv.yield_test_cases(values): |
| self.assertEqual(TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y)) |
| |
| def test_floordiv_simplify(self): |
| # Tests how we simplify or evaluate FloorDiv without free variables |
| shape_env = ShapeEnv() |
| result = 21 |
| exprs = ( |
| 7 * FloorDiv(6, 2), |
| 7 * FloorDiv(6.28, 2), |
| 7 * FloorDiv(6.28, 2.0), |
| 7 * FloorDiv(6.28, (FloorDiv(6.28, 3.14))), |
| ) |
| |
| for expr in exprs: |
| self.assertEqual(expr, result) |
| self.assertEqual(expr.doit(deep=False), result) |
| self.assertEqual(expr.doit(deep=True), result) |
| self.assertEqual(sympy.simplify(expr), result) |
| self.assertEqual(shape_env.simplify(expr), result) |
| self.assertEqual(shape_env.evaluate_expr(expr), result) |
| |
| def test_floordiv_assumptions(self): |
| # We define two Symbols (with different names) for each type to make |
| # sure the behavior is consistent regardless of whether both arguments |
| # are the same object or not. |
| cases = ( |
| sympy.Symbol("i1", integer=True), |
| sympy.Symbol("i2", integer=True), |
| sympy.Symbol("r1", real=True), |
| sympy.Symbol("r2", real=True), |
| sympy.Symbol("c1", complex=True, real=False, integer=False), |
| sympy.Symbol("c2", complex=True, real=False, integer=False), |
| sympy.Symbol("s1"), |
| sympy.Symbol("s2"), |
| ) |
| |
| for base, divisor in itertools.product(cases, repeat=2): |
| def op(): |
| return FloorDiv(base, divisor) |
| |
| def is_complex(x): |
| return x.is_integer is False and x.is_real is False and x.is_complex |
| |
| if is_complex(base) or is_complex(divisor): |
| self.assertRaisesRegex( |
| TypeError, |
| (r"unsupported operand type\(s\) for //: 'Symbol' and 'Symbol'," |
| r" expected integer or real"), |
| op) |
| continue |
| |
| op = op() |
| |
| # In regular Python, x//x == 1.0 if x is a float, but FloorDiv |
| # always returns an integer 1 when both args are the same object. |
| # This even works for Symbols with no assumptions specified. |
| if base is divisor: |
| self.assertTrue(op.is_integer) |
| self.assertTrue(op.is_real) |
| elif base.is_integer and divisor.is_integer: |
| self.assertTrue(op.is_integer) |
| self.assertTrue(op.is_real) |
| else: |
| self.assertEqual(op.is_integer, None) |
| self.assertTrue(op.is_real) |
| |
| |
| class TestDimConstraints(TestCase): |
| def test_dim_constraints_reduce_congruences_simple(self): |
| from sympy import Symbol |
| from torch.fx.experimental.symbolic_shapes import DimConstraints |
| |
| s = Symbol("s", positive=True, integer=True) |
| dim_constraints = DimConstraints({}, {}, set()) |
| dim_constraints._congruences[s] = { |
| (s / 2) % 2, |
| (s / 2) % 8, |
| (s / 2) % 4, |
| s % 2, |
| ((s / 16) + 2) % 4, |
| } |
| congruences = dim_constraints.reduce_congruences() |
| self.assertEqual(congruences[s], {(s + 32) % 64}) |
| |
| def test_dim_constraints_reduce_inequalities_simple(self): |
| from sympy import Eq, Interval, Ne, Symbol |
| from sympy.solvers.inequalities import reduce_inequalities |
| |
| s = Symbol("s", positive=True, integer=True) |
| exprs = { |
| s >= 2, |
| Ne(8 * s, 16), |
| Ne(s / 2, 1), |
| Ne(16 * s, 32), |
| s < 16, |
| Ne(s, 2), |
| s / 2 < 16, |
| s / 2 > 1, |
| s / 2 >= 2, |
| Ne(3 * s / 2, 3), |
| } |
| solution = reduce_inequalities(exprs, s).as_set() |
| self.assertEqual(solution, Interval.Ropen(4, 16)) |
| |
| exprs.add(Eq(s / 2, 4)) |
| solution = reduce_inequalities(exprs, s).as_set() |
| self.assertEqual(solution, {8}) |
| |
| def test_dim_constraints_solve_full(self): |
| from sympy import Eq, Integer, Ne, Symbol |
| from torch._dynamo.source import LocalSource, TensorProperty, TensorPropertySource |
| |
| src0 = TensorPropertySource( |
| base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=0 |
| ) |
| src2 = TensorPropertySource( |
| base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=0 |
| ) |
| src3 = TensorPropertySource( |
| base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=0 |
| ) |
| src4 = TensorPropertySource( |
| base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=0 |
| ) |
| |
| src1 = TensorPropertySource( |
| base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=2 |
| ) |
| src7 = TensorPropertySource( |
| base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=3 |
| ) |
| |
| src5 = TensorPropertySource( |
| base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=1 |
| ) |
| src8 = TensorPropertySource( |
| base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=1 |
| ) |
| |
| src6 = TensorPropertySource( |
| base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=1 |
| ) |
| src9 = TensorPropertySource( |
| base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=1 |
| ) |
| src10 = TensorPropertySource( |
| base=LocalSource(local_name="e"), prop=TensorProperty.SIZE, idx=1 |
| ) |
| |
| src11 = TensorPropertySource( |
| base=LocalSource(local_name="f"), prop=TensorProperty.SIZE, idx=1 |
| ) |
| src12 = TensorPropertySource( |
| base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=2 |
| ) |
| |
| s0 = Symbol("s0", positive=True, integer=True) |
| s1 = Symbol("s1", positive=True, integer=True) |
| s5 = Symbol("s5", positive=True, integer=True) |
| s6 = Symbol("s6", positive=True, integer=True) |
| symbol_to_source = { |
| s0: [src0, src2, src3, src4], |
| s1: [src1, src7], |
| s5: [src5, src8], |
| s6: [src6, src9, src10], |
| } |
| var_to_val = {s0: 8, s1: 96, s5: 22, s6: 21} |
| marked_dynamic = {s0, s1, s5, s6} |
| dim_constraints = DimConstraints(symbol_to_source, var_to_val, marked_dynamic) |
| dim_constraints.add_equality(src2, s0) |
| dim_constraints.add_equality(src3, s0) |
| dim_constraints.add_equality(src4, s0) |
| dim_constraints.add_equality(src7, s1) |
| dim_constraints.add_equality(src8, s5) |
| dim_constraints.add_equality(src9, s6) |
| dim_constraints.add_equality(src10, s6) |
| dim_constraints.add_equality(src11, Integer(1)) |
| dim_constraints.add_equality(src12, Integer(3)) |
| |
| dim_constraints.add(s1**2 <= 2147483647) |
| dim_constraints.add(32 * s1**2 <= 2147483647) |
| dim_constraints.add(s0 < 16) |
| dim_constraints.add(Eq(Mod(s1, 2), 0)) |
| dim_constraints.add(Ne(FloorDiv(s1, 2), 1)) |
| dim_constraints.add(Ne((FloorDiv(s1, 2)) ** 2, 1)) |
| dim_constraints.add(32 * (FloorDiv(s1, 2)) ** 2 <= 2147483647) |
| dim_constraints.add((FloorDiv(s1, 2)) ** 2 > 1) |
| dim_constraints.add(Ne(FloorDiv(s1, 2), 1)) |
| dim_constraints.add( |
| 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2 |
| + 128 * (FloorDiv((FloorDiv(s1, 2) - 1), 2)) |
| + 64 |
| <= 2147483647 |
| ) |
| dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 2) + 1, 1)) |
| dim_constraints.add( |
| Ne( |
| (FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2 |
| + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 2)) |
| + 1, |
| 1, |
| ) |
| ) |
| dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 2) + 1, 1)) |
| dim_constraints.add( |
| (FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2 |
| + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 2)) |
| + 1 |
| > 1 |
| ) |
| dim_constraints.add( |
| 128 * (FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2 |
| + 256 * (FloorDiv((FloorDiv(s1, 2) - 1), 4)) |
| + 128 |
| <= 2147483647 |
| ) |
| dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 4) + 1, 1)) |
| dim_constraints.add( |
| Ne( |
| (FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2 |
| + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 4)) |
| + 1, |
| 1, |
| ) |
| ) |
| dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 4) + 1, 1)) |
| dim_constraints.add( |
| (FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2 |
| + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 4)) |
| + 1 |
| > 1 |
| ) |
| dim_constraints.add( |
| 256 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| + 512 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) |
| + 256 |
| <= 2147483647 |
| ) |
| dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 1)) |
| dim_constraints.add( |
| Ne( |
| (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) |
| + 1, |
| 1, |
| ) |
| ) |
| dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 1)) |
| dim_constraints.add( |
| (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) |
| + 1 |
| > 1 |
| ) |
| dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1 >= 3) |
| dim_constraints.add( |
| 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 60 |
| <= 2147483647 |
| ) |
| dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 0) |
| dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 1) |
| dim_constraints.add( |
| Ne( |
| 60 * s0 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 120 * s0 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 60 * s0, |
| 0, |
| ) |
| ) |
| dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 1)) |
| dim_constraints.add( |
| Ne( |
| (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 1, |
| 1, |
| ) |
| ) |
| dim_constraints.add( |
| Ne( |
| (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 1, |
| 0, |
| ) |
| ) |
| dim_constraints.add( |
| (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 1 |
| >= 0 |
| ) |
| dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 0)) |
| dim_constraints.add( |
| 1 |
| < 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 60 |
| ) |
| dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, -1)) |
| dim_constraints.add( |
| Ne( |
| 60 * s0 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 120 * s0 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 60 * s0, |
| 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 120, |
| ) |
| ) |
| dim_constraints.add( |
| 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 120 |
| > 0 |
| ) |
| dim_constraints.add( |
| Eq( |
| 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 * (Mod(s0, 2)) |
| - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) * Mod(s0, 2) |
| + 60 * (Mod(s0, 2)), |
| 0, |
| ) |
| ) |
| dim_constraints.add( |
| Ne( |
| 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 120, |
| 0, |
| ) |
| ) |
| dim_constraints.add( |
| Ne( |
| 60 |
| * (FloorDiv(s0, 2)) |
| * (FloorDiv(s0, (FloorDiv(s0, 2)))) |
| * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 120 |
| * FloorDiv(s0, 2) |
| * FloorDiv(s0, (FloorDiv(s0, 2))) |
| * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))), |
| 0, |
| ) |
| ) |
| dim_constraints.add(Ne(FloorDiv(s0, 2), 1)) |
| dim_constraints.add( |
| Ne( |
| 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 60, |
| 0, |
| ) |
| ) |
| dim_constraints.add( |
| 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 60 |
| >= 0 |
| ) |
| dim_constraints.add( |
| 1 |
| < 60 |
| * (FloorDiv(s0, (FloorDiv(s0, 2)))) |
| * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 60 * (FloorDiv(s0, (FloorDiv(s0, 2)))) |
| ) |
| dim_constraints.add(Ne(16 * s0, 32)) |
| dim_constraints.add(Eq(16 * (Mod(s0, 2)), 0)) |
| dim_constraints.add(Ne(16 * s0, 32)) |
| dim_constraints.add(Eq(16 * (Mod(s0, 2)), 0)) |
| dim_constraints.add(FloorDiv(s0, 2) >= 2) |
| dim_constraints.add(Ne(FloorDiv(s0, 2), 1)) |
| dim_constraints.add(1 < FloorDiv(s0, 2)) |
| dim_constraints.add(Ne(s0, 2)) |
| dim_constraints.add( |
| 60 |
| * (FloorDiv(s0, (FloorDiv(s0, 2)))) |
| * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 60 * (FloorDiv(s0, (FloorDiv(s0, 2)))) |
| >= 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 60 |
| ) |
| dim_constraints.add( |
| 60 |
| * (FloorDiv(s0, 2)) |
| * (FloorDiv(s0, (FloorDiv(s0, 2)))) |
| * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 120 |
| * FloorDiv(s0, 2) |
| * FloorDiv(s0, (FloorDiv(s0, 2))) |
| * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))) |
| > 0 |
| ) |
| dim_constraints.add( |
| Ne( |
| 60 |
| * (FloorDiv(s0, 2)) |
| * (FloorDiv(s0, (FloorDiv(s0, 2)))) |
| * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 120 |
| * FloorDiv(s0, 2) |
| * FloorDiv(s0, (FloorDiv(s0, 2))) |
| * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))), |
| 3 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))), |
| ) |
| ) |
| dim_constraints.add( |
| Ne( |
| 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 20, |
| 0, |
| ) |
| ) |
| dim_constraints.add( |
| 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 20 |
| >= 0 |
| ) |
| dim_constraints.add( |
| Ne( |
| 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 20, |
| 20, |
| ) |
| ) |
| dim_constraints.add( |
| Ne( |
| 20 |
| * ( |
| Mod( |
| 1, |
| (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 1, |
| ) |
| ), |
| 0, |
| ) |
| ) |
| dim_constraints.add( |
| Ne( |
| 20 |
| * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) |
| * ( |
| Mod( |
| 1, |
| (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) |
| - 2 |
| * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) |
| + 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1), |
| ) |
| ) |
| - 20 |
| * Mod( |
| 1, |
| (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) |
| - 2 |
| * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) |
| + 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1), |
| ), |
| 0, |
| ) |
| ) |
| dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 1)) |
| dim_constraints.add( |
| (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 1 |
| >= 1 |
| ) |
| dim_constraints.add( |
| 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 20 |
| >= 0 |
| ) |
| dim_constraints.add( |
| 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 20 |
| >= 1 |
| ) |
| dim_constraints.add( |
| 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 20 |
| >= 2 |
| ) |
| dim_constraints.add( |
| 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 20 |
| > 1 |
| ) |
| dim_constraints.add( |
| 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 20 |
| < 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 60 |
| ) |
| dim_constraints.add( |
| Ne( |
| 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 60, |
| 60, |
| ) |
| ) |
| dim_constraints.add( |
| Ne( |
| FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, |
| (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 1, |
| ) |
| ) |
| dim_constraints.add( |
| Eq( |
| (FloorDiv((FloorDiv(s1, 2) - 1), 8)) |
| * ( |
| Mod( |
| (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) |
| - 2 |
| * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) |
| + 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1), |
| 1, |
| ) |
| ) |
| - Mod( |
| (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) |
| - 2 |
| * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) |
| + 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1), |
| 1, |
| ), |
| 0, |
| ) |
| ) |
| dim_constraints.add( |
| Ne( |
| (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 1, |
| FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, |
| ) |
| ) |
| dim_constraints.add(Ne(8 * s0, 16)) |
| dim_constraints.add( |
| 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 60 |
| >= (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 1 |
| ) |
| dim_constraints.add( |
| 60 |
| * (FloorDiv(s0, (FloorDiv(s0, 2)))) |
| * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 60 * (FloorDiv(s0, (FloorDiv(s0, 2)))) |
| <= 2147483647 |
| ) |
| dim_constraints.add( |
| 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 90 |
| <= 2147483647 |
| ) |
| dim_constraints.add(FloorDiv(s0, 2) < 16) |
| dim_constraints.add(FloorDiv(s0, 2) > 1) |
| dim_constraints.add( |
| Ne( |
| 90 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 180 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 90 * (FloorDiv(s0, 2)), |
| 0, |
| ) |
| ) |
| dim_constraints.add( |
| 1 |
| < 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 90 |
| ) |
| dim_constraints.add( |
| (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 1 |
| > 1 |
| ) |
| dim_constraints.add( |
| 60 |
| * (FloorDiv(s0, (FloorDiv(s0, 2)))) |
| * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 60 * (FloorDiv(s0, (FloorDiv(s0, 2)))) |
| > 1 |
| ) |
| dim_constraints.add( |
| Ne( |
| 60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 60 * (FloorDiv(s0, 2)), |
| 0, |
| ) |
| ) |
| dim_constraints.add( |
| 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 90 |
| > 1 |
| ) |
| dim_constraints.add( |
| 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 60 |
| > 1 |
| ) |
| dim_constraints.add( |
| Ne( |
| 60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 60 * (FloorDiv(s0, 2)), |
| 3 * (FloorDiv(s0, 2)), |
| ) |
| ) |
| dim_constraints.add( |
| 60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 60 * (FloorDiv(s0, 2)) |
| > 0 |
| ) |
| dim_constraints.add( |
| 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 60 |
| > 0 |
| ) |
| dim_constraints.add( |
| Ne( |
| 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 120, |
| 0, |
| ) |
| ) |
| dim_constraints.add( |
| 1 |
| < 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 120 |
| ) |
| dim_constraints.add( |
| Ne( |
| 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 120, |
| 6, |
| ) |
| ) |
| dim_constraints.add( |
| 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 120 |
| > 0 |
| ) |
| dim_constraints.add( |
| Ne( |
| 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 120, |
| 0, |
| ) |
| ) |
| dim_constraints.add( |
| 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 120 |
| <= 2147483647 |
| ) |
| dim_constraints.add( |
| 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 120 |
| <= 20480 |
| ) |
| dim_constraints.add( |
| Ne( |
| 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 90, |
| 0, |
| ) |
| ) |
| dim_constraints.add( |
| 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 120 |
| > 1 |
| ) |
| dim_constraints.add( |
| 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 90 |
| <= 20480 |
| ) |
| dim_constraints.add( |
| 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 60 |
| <= 20480 |
| ) |
| dim_constraints.add( |
| Ne( |
| 240 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 480 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 240, |
| 0, |
| ) |
| ) |
| dim_constraints.add(Eq(6 * s5, 132)) |
| dim_constraints.add(Eq(4, FloorDiv(s0, 2))) |
| dim_constraints.add(Eq(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 4)) |
| dim_constraints.add( |
| Ne( |
| 64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 128 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 64 * (FloorDiv(s0, 2)), |
| 0, |
| ) |
| ) |
| dim_constraints.add( |
| 1 |
| < 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 64 |
| ) |
| dim_constraints.add( |
| 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 64 |
| <= 2147483647 |
| ) |
| dim_constraints.add( |
| 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 64 |
| > 1 |
| ) |
| dim_constraints.add( |
| 62 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 124 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 62 |
| <= 2147483647 |
| ) |
| dim_constraints.add( |
| Ne( |
| 62 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 124 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 62 * (FloorDiv(s0, 2)), |
| 0, |
| ) |
| ) |
| dim_constraints.add( |
| 1 |
| < 62 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 124 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 62 |
| ) |
| dim_constraints.add(Ne(3 * (FloorDiv(s0, 2)), 3)) |
| dim_constraints.add(Ne(3 * (FloorDiv(s0, 2)), 3)) |
| dim_constraints.add(Eq(FloorDiv(s0, 2), 4)) |
| dim_constraints.add(Eq(4, FloorDiv(s0, 2))) |
| dim_constraints.add(Eq(FloorDiv(s0, 2), 4)) |
| dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 3) |
| dim_constraints.add( |
| 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 576 |
| <= 2147483647 |
| ) |
| dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3 >= 0) |
| dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3 >= 1) |
| dim_constraints.add( |
| Ne( |
| 64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 576 * (FloorDiv(s0, 2)), |
| 0, |
| ) |
| ) |
| dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 1)) |
| dim_constraints.add( |
| Ne( |
| (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 9, |
| 1, |
| ) |
| ) |
| dim_constraints.add( |
| Ne( |
| (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 9, |
| 0, |
| ) |
| ) |
| dim_constraints.add( |
| (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 9 |
| >= 0 |
| ) |
| dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 0)) |
| dim_constraints.add( |
| 1 |
| < 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 576 |
| ) |
| dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 1)) |
| dim_constraints.add( |
| Ne( |
| 64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 576 * (FloorDiv(s0, 2)), |
| 256, |
| ) |
| ) |
| dim_constraints.add( |
| Eq( |
| 64 |
| * ( |
| Mod( |
| (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 9 * (FloorDiv(s0, 2)), |
| 4, |
| ) |
| ), |
| 0, |
| ) |
| ) |
| dim_constraints.add( |
| Eq( |
| FloorDiv(s0, 2), |
| FloorDiv( |
| ( |
| (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 9 * (FloorDiv(s0, 2)) |
| ), |
| 4, |
| ), |
| ) |
| ) |
| dim_constraints.add( |
| Eq( |
| FloorDiv( |
| ( |
| (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 9 * (FloorDiv(s0, 2)) |
| ), |
| 4, |
| ), |
| FloorDiv(s0, 2), |
| ) |
| ) |
| dim_constraints.add(Ne(64 * (Mod(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 4)), 0)) |
| dim_constraints.add( |
| Eq( |
| 64 |
| * ( |
| Mod( |
| (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 1, |
| 4, |
| ) |
| ), |
| 0, |
| ) |
| ) |
| dim_constraints.add( |
| 64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 576 * (FloorDiv(s0, 2)) |
| > 0 |
| ) |
| dim_constraints.add( |
| (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 9 |
| >= 1 |
| ) |
| dim_constraints.add( |
| Eq( |
| 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 576, |
| 256, |
| ) |
| ) |
| dim_constraints.add( |
| 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 540 |
| <= 2147483647 |
| ) |
| dim_constraints.add( |
| Ne( |
| 60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 360 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 540 * (FloorDiv(s0, 2)), |
| 0, |
| ) |
| ) |
| dim_constraints.add( |
| 1 |
| < 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 540 |
| ) |
| dim_constraints.add( |
| (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 9 |
| <= 2147483647 |
| ) |
| dim_constraints.add( |
| Ne( |
| (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 9 * (FloorDiv(s0, 2)), |
| 0, |
| ) |
| ) |
| dim_constraints.add( |
| 1 |
| < (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 9 |
| ) |
| dim_constraints.add( |
| (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 9 |
| > 1 |
| ) |
| dim_constraints.add( |
| 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 |
| - 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8) |
| + 540 |
| > 1 |
| ) |
| dim_constraints.add(s0 >= 2) |
| dim_constraints.add(s1 >= 2) |
| dim_constraints.add(s6 >= 2) |
| dim_constraints.add(s5 >= 2) |
| |
| dim_constraints.solve() |
| self.assertEqual(dim_constraints._static_results, { |
| "L['c'].size()[0] == 8", |
| "L['d'].size()[0] == 8", |
| "L['a'].size()[2] == 96", |
| "L['f'].size()[1] == 1", |
| "L['a'].size()[3] == 96", |
| "L['b'].size()[2] == 3", |
| "L['b'].size()[1] == 22", |
| "L['b'].size()[0] == 8", |
| "L['a'].size()[1] == 22", |
| "L['a'].size()[0] == 8", |
| }) |
| self.assertEqual(dim_constraints._dynamic_results, { |
| "dynamic_dim(L['e'], 1) == dynamic_dim(L['c'], 1)", |
| "2 <= dynamic_dim(L['c'], 1)", |
| "dynamic_dim(L['d'], 1) == dynamic_dim(L['c'], 1)", |
| }) |
| |
| def dummy_fn(a, b, c, d, e, f): |
| pass |
| |
| action_code = dim_constraints.prettify_results(inspect.signature(dummy_fn)) |
| static_code, dynamic_code = re.findall(r"```(.*?)```", action_code, re.DOTALL) |
| expected_static = ''' |
| def specializations(a, b, c, d, e, f): |
| # a: |
| assert a.size()[0] == 8 |
| assert a.size()[1] == 22 |
| assert a.size()[2] == 96 |
| assert a.size()[3] == 96 |
| |
| # b: |
| assert b.size()[0] == 8 |
| assert b.size()[1] == 22 |
| assert b.size()[2] == 3 |
| |
| # c: |
| assert c.size()[0] == 8 |
| |
| # d: |
| assert d.size()[0] == 8 |
| |
| # f: |
| assert f.size()[1] == 1 |
| ''' |
| expected_dynamic = ''' |
| def specify_constraints(a, b, c, d, e, f): |
| return [ |
| # c: |
| dynamic_dim(c, 1), |
| |
| # d: |
| dynamic_dim(d, 1) == dynamic_dim(c, 1), |
| |
| # e: |
| dynamic_dim(e, 1) == dynamic_dim(c, 1), |
| ] |
| ''' |
| |
| self.assertEqual(static_code, expected_static) |
| self.assertEqual(dynamic_code, expected_dynamic) |
| |
| |
| |
| if __name__ == '__main__': |
| run_tests() |