| # Owner(s): ["oncall: pt2"] |
| |
| import itertools |
| import math |
| import sys |
| |
| import sympy |
| from typing import Callable, List, Tuple, Type |
| from torch.testing._internal.common_device_type import skipIf |
| from torch.testing._internal.common_utils import ( |
| TEST_Z3, |
| instantiate_parametrized_tests, |
| parametrize, |
| run_tests, |
| TestCase, |
| ) |
| from torch.utils._sympy.functions import FloorDiv, simple_floordiv_gcd |
| from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve |
| from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges |
| from torch.utils._sympy.reference import ReferenceAnalysis, PythonReferenceAnalysis |
| from torch.utils._sympy.interp import sympy_interp |
| from torch.utils._sympy.singleton_int import SingletonInt |
| from torch.utils._sympy.numbers import int_oo, IntInfinity, NegativeIntInfinity |
| from sympy.core.relational import is_ge, is_le, is_gt, is_lt |
| import functools |
| import torch.fx as fx |
| |
| |
| |
| UNARY_OPS = [ |
| "reciprocal", |
| "square", |
| "abs", |
| "neg", |
| "exp", |
| "log", |
| "sqrt", |
| "floor", |
| "ceil", |
| ] |
| BINARY_OPS = [ |
| "truediv", "floordiv", |
| # "truncdiv", # TODO |
| # NB: pow is float_pow |
| "add", "mul", "sub", "pow", "pow_by_natural", "minimum", "maximum", "mod" |
| ] |
| |
| UNARY_BOOL_OPS = ["not_"] |
| BINARY_BOOL_OPS = ["or_", "and_"] |
| COMPARE_OPS = ["eq", "ne", "lt", "gt", "le", "ge"] |
| |
| # a mix of constants, powers of two, primes |
| CONSTANTS = [ |
| -1, |
| 0, |
| 1, |
| 2, |
| 3, |
| 4, |
| 5, |
| 8, |
| 16, |
| 32, |
| 64, |
| 100, |
| 101, |
| 2**24, |
| 2**32, |
| 2**37 - 1, |
| sys.maxsize - 1, |
| sys.maxsize, |
| ] |
| # less constants for N^2 situations |
| LESS_CONSTANTS = [-1, 0, 1, 2, 100] |
| # SymPy relational types. |
| RELATIONAL_TYPES = [sympy.Eq, sympy.Ne, sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le] |
| |
| |
| def valid_unary(fn, v): |
| if fn == "log" and v <= 0: |
| return False |
| elif fn == "reciprocal" and v == 0: |
| return False |
| elif fn == "sqrt" and v < 0: |
| return False |
| return True |
| |
| |
| def valid_binary(fn, a, b): |
| if fn == "pow" and ( |
| # sympy will expand to x*x*... for integral b; don't do it if it's big |
| b > 4 |
| # no imaginary numbers |
| or a <= 0 |
| # 0**0 is undefined |
| or (a == b == 0) |
| ): |
| return False |
| elif fn == "pow_by_natural" and ( |
| # sympy will expand to x*x*... for integral b; don't do it if it's big |
| b > 4 |
| or b < 0 |
| or (a == b == 0) |
| ): |
| return False |
| elif fn == "mod" and (a < 0 or b <= 0): |
| return False |
| elif (fn in ["div", "truediv", "floordiv"]) and b == 0: |
| return False |
| return True |
| |
| |
| def generate_range(vals): |
| for a1, a2 in itertools.product(vals, repeat=2): |
| if a1 in [sympy.true, sympy.false]: |
| if a1 == sympy.true and a2 == sympy.false: |
| continue |
| else: |
| if a1 > a2: |
| continue |
| # ranges that only admit infinite values are not interesting |
| if a1 == sympy.oo or a2 == -sympy.oo: |
| continue |
| yield ValueRanges(a1, a2) |
| |
| |
| class TestNumbers(TestCase): |
| def test_int_infinity(self): |
| self.assertIsInstance(int_oo, IntInfinity) |
| self.assertIsInstance(-int_oo, NegativeIntInfinity) |
| self.assertTrue(int_oo.is_integer) |
| # is tests here are for singleton-ness, don't use it for comparisons |
| # against numbers |
| self.assertIs(int_oo + int_oo, int_oo) |
| self.assertIs(int_oo + 1, int_oo) |
| self.assertIs(int_oo - 1, int_oo) |
| self.assertIs(-int_oo - 1, -int_oo) |
| self.assertIs(-int_oo + 1, -int_oo) |
| self.assertIs(-int_oo + (-int_oo), -int_oo) |
| self.assertIs(-int_oo - int_oo, -int_oo) |
| self.assertIs(1 + int_oo, int_oo) |
| self.assertIs(1 - int_oo, -int_oo) |
| self.assertIs(int_oo * int_oo, int_oo) |
| self.assertIs(2 * int_oo, int_oo) |
| self.assertIs(int_oo * 2, int_oo) |
| self.assertIs(-1 * int_oo, -int_oo) |
| self.assertIs(-int_oo * int_oo, -int_oo) |
| self.assertIs(2 * -int_oo, -int_oo) |
| self.assertIs(-int_oo * 2, -int_oo) |
| self.assertIs(-1 * -int_oo, int_oo) |
| self.assertIs(int_oo / 2, sympy.oo) |
| self.assertIs(-(-int_oo), int_oo) # noqa: B002 |
| self.assertIs(abs(int_oo), int_oo) |
| self.assertIs(abs(-int_oo), int_oo) |
| self.assertIs(int_oo ** 2, int_oo) |
| self.assertIs((-int_oo) ** 2, int_oo) |
| self.assertIs((-int_oo) ** 3, -int_oo) |
| self.assertEqual(int_oo ** -1, 0) |
| self.assertEqual((-int_oo) ** -1, 0) |
| self.assertIs(int_oo ** int_oo, int_oo) |
| self.assertTrue(int_oo == int_oo) |
| self.assertFalse(int_oo != int_oo) |
| self.assertTrue(-int_oo == -int_oo) |
| self.assertFalse(int_oo == 2) |
| self.assertTrue(int_oo != 2) |
| self.assertFalse(int_oo == sys.maxsize) |
| self.assertTrue(int_oo >= sys.maxsize) |
| self.assertTrue(int_oo >= 2) |
| self.assertTrue(int_oo >= -int_oo) |
| |
| def test_relation(self): |
| self.assertIs(sympy.Add(2, int_oo), int_oo) |
| self.assertFalse(-int_oo > 2) |
| |
| def test_lt_self(self): |
| self.assertFalse(int_oo < int_oo) |
| self.assertIs(min(-int_oo, -4), -int_oo) |
| self.assertIs(min(-int_oo, -int_oo), -int_oo) |
| |
| def test_float_cast(self): |
| self.assertEqual(float(int_oo), math.inf) |
| self.assertEqual(float(-int_oo), -math.inf) |
| |
| def test_mixed_oo_int_oo(self): |
| # Arbitrary choice |
| self.assertTrue(int_oo < sympy.oo) |
| self.assertFalse(int_oo > sympy.oo) |
| self.assertTrue(sympy.oo > int_oo) |
| self.assertFalse(sympy.oo < int_oo) |
| self.assertIs(max(int_oo, sympy.oo), sympy.oo) |
| self.assertTrue(-int_oo > -sympy.oo) |
| self.assertIs(min(-int_oo, -sympy.oo), -sympy.oo) |
| |
| |
| class TestValueRanges(TestCase): |
| @parametrize("fn", UNARY_OPS) |
| @parametrize("dtype", ("int", "float")) |
| def test_unary_ref(self, fn, dtype): |
| dtype = {"int": sympy.Integer, "float": sympy.Float}[dtype] |
| for v in CONSTANTS: |
| if not valid_unary(fn, v): |
| continue |
| with self.subTest(v=v): |
| v = dtype(v) |
| ref_r = getattr(ReferenceAnalysis, fn)(v) |
| r = getattr(ValueRangeAnalysis, fn)(v) |
| self.assertEqual(r.lower.is_integer, r.upper.is_integer) |
| self.assertEqual(r.lower, r.upper) |
| self.assertEqual(ref_r.is_integer, r.upper.is_integer) |
| self.assertEqual(ref_r, r.lower) |
| |
| def test_pow_half(self): |
| ValueRangeAnalysis.pow(ValueRanges.unknown(), ValueRanges.wrap(0.5)) |
| |
| @parametrize("fn", BINARY_OPS) |
| @parametrize("dtype", ("int", "float")) |
| def test_binary_ref(self, fn, dtype): |
| to_dtype = {"int": sympy.Integer, "float": sympy.Float} |
| # Don't test float on int only methods |
| if dtype == "float" and fn in ["pow_by_natural", "mod"]: |
| return |
| dtype = to_dtype[dtype] |
| for a, b in itertools.product(CONSTANTS, repeat=2): |
| if not valid_binary(fn, a, b): |
| continue |
| a = dtype(a) |
| b = dtype(b) |
| with self.subTest(a=a, b=b): |
| r = getattr(ValueRangeAnalysis, fn)(a, b) |
| if r == ValueRanges.unknown(): |
| continue |
| ref_r = getattr(ReferenceAnalysis, fn)(a, b) |
| |
| self.assertEqual(r.lower.is_integer, r.upper.is_integer) |
| self.assertEqual(ref_r.is_integer, r.upper.is_integer) |
| self.assertEqual(r.lower, r.upper) |
| self.assertEqual(ref_r, r.lower) |
| |
| def test_mul_zero_unknown(self): |
| self.assertEqual( |
| ValueRangeAnalysis.mul(ValueRanges.wrap(0), ValueRanges.unknown()), |
| ValueRanges.wrap(0), |
| ) |
| self.assertEqual( |
| ValueRangeAnalysis.mul(ValueRanges.wrap(0.0), ValueRanges.unknown()), |
| ValueRanges.wrap(0.0), |
| ) |
| |
| @parametrize("fn", UNARY_BOOL_OPS) |
| def test_unary_bool_ref_range(self, fn): |
| vals = [sympy.false, sympy.true] |
| for a in generate_range(vals): |
| with self.subTest(a=a): |
| ref_r = getattr(ValueRangeAnalysis, fn)(a) |
| unique = set() |
| for a0 in vals: |
| if a0 not in a: |
| continue |
| with self.subTest(a0=a0): |
| r = getattr(ReferenceAnalysis, fn)(a0) |
| self.assertIn(r, ref_r) |
| unique.add(r) |
| if ref_r.lower == ref_r.upper: |
| self.assertEqual(len(unique), 1) |
| else: |
| self.assertEqual(len(unique), 2) |
| |
| @parametrize("fn", BINARY_BOOL_OPS) |
| def test_binary_bool_ref_range(self, fn): |
| vals = [sympy.false, sympy.true] |
| for a, b in itertools.product(generate_range(vals), repeat=2): |
| with self.subTest(a=a, b=b): |
| ref_r = getattr(ValueRangeAnalysis, fn)(a, b) |
| unique = set() |
| for a0, b0 in itertools.product(vals, repeat=2): |
| if a0 not in a or b0 not in b: |
| continue |
| with self.subTest(a0=a0, b0=b0): |
| r = getattr(ReferenceAnalysis, fn)(a0, b0) |
| self.assertIn(r, ref_r) |
| unique.add(r) |
| if ref_r.lower == ref_r.upper: |
| self.assertEqual(len(unique), 1) |
| else: |
| self.assertEqual(len(unique), 2) |
| |
| @parametrize("fn", UNARY_OPS) |
| def test_unary_ref_range(self, fn): |
| # TODO: bring back sympy.oo testing for float unary fns |
| vals = CONSTANTS |
| for a in generate_range(vals): |
| with self.subTest(a=a): |
| ref_r = getattr(ValueRangeAnalysis, fn)(a) |
| for a0 in CONSTANTS: |
| if a0 not in a: |
| continue |
| if not valid_unary(fn, a0): |
| continue |
| with self.subTest(a0=a0): |
| r = getattr(ReferenceAnalysis, fn)(sympy.Integer(a0)) |
| self.assertIn(r, ref_r) |
| |
| # This takes about 4s for all the variants |
| @parametrize("fn", BINARY_OPS + COMPARE_OPS) |
| def test_binary_ref_range(self, fn): |
| # TODO: bring back sympy.oo testing for float unary fns |
| vals = LESS_CONSTANTS |
| for a, b in itertools.product(generate_range(vals), repeat=2): |
| # don't attempt pow on exponents that are too large (but oo is OK) |
| if fn == "pow" and b.upper > 4 and b.upper != sympy.oo: |
| continue |
| with self.subTest(a=a, b=b): |
| for a0, b0 in itertools.product(LESS_CONSTANTS, repeat=2): |
| if a0 not in a or b0 not in b: |
| continue |
| if not valid_binary(fn, a0, b0): |
| continue |
| with self.subTest(a0=a0, b0=b0): |
| ref_r = getattr(ValueRangeAnalysis, fn)(a, b) |
| r = getattr(ReferenceAnalysis, fn)( |
| sympy.Integer(a0), sympy.Integer(b0) |
| ) |
| if r.is_finite: |
| self.assertIn(r, ref_r) |
| |
| |
| class TestSympyInterp(TestCase): |
| @parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS) |
| def test_interp(self, fn): |
| # SymPy does not implement truncation for Expressions |
| if fn in ("div", "truncdiv", "minimum", "maximum", "mod"): |
| return |
| |
| is_integer = None |
| if fn == "pow_by_natural": |
| is_integer = True |
| |
| x = sympy.Dummy('x', integer=is_integer) |
| y = sympy.Dummy('y', integer=is_integer) |
| |
| vals = CONSTANTS |
| if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}: |
| vals = [True, False] |
| arity = 1 |
| if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}: |
| arity = 2 |
| symbols = [x] |
| if arity == 2: |
| symbols = [x, y] |
| for args in itertools.product(vals, repeat=arity): |
| if arity == 1 and not valid_unary(fn, *args): |
| continue |
| elif arity == 2 and not valid_binary(fn, *args): |
| continue |
| with self.subTest(args=args): |
| sargs = [sympy.sympify(a) for a in args] |
| sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols) |
| ref_r = getattr(ReferenceAnalysis, fn)(*sargs) |
| # Yes, I know this is a longwinded way of saying xreplace; the |
| # point is to test sympy_interp |
| r = sympy_interp(ReferenceAnalysis, dict(zip(symbols, sargs)), sympy_expr) |
| self.assertEqual(ref_r, r) |
| |
| @parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS) |
| def test_python_interp_fx(self, fn): |
| # These never show up from symbolic_shapes |
| if fn in ("log", "exp"): |
| return |
| |
| # Sympy does not support truncation on symbolic shapes |
| if fn in ("truncdiv", "mod"): |
| return |
| |
| vals = CONSTANTS |
| if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}: |
| vals = [True, False] |
| |
| arity = 1 |
| if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}: |
| arity = 2 |
| |
| is_integer = None |
| if fn == "pow_by_natural": |
| is_integer = True |
| |
| x = sympy.Dummy('x', integer=is_integer) |
| y = sympy.Dummy('y', integer=is_integer) |
| |
| symbols = [x] |
| if arity == 2: |
| symbols = [x, y] |
| |
| for args in itertools.product(vals, repeat=arity): |
| if arity == 1 and not valid_unary(fn, *args): |
| continue |
| elif arity == 2 and not valid_binary(fn, *args): |
| continue |
| if fn == "truncdiv" and args[1] == 0: |
| continue |
| elif fn in ("pow", "pow_by_natural") and (args[0] == 0 and args[1] <= 0): |
| continue |
| elif fn == "floordiv" and args[1] == 0: |
| continue |
| with self.subTest(args=args): |
| # Workaround mpf from symbol error |
| if fn == "minimum": |
| sympy_expr = sympy.Min(x, y) |
| elif fn == "maximum": |
| sympy_expr = sympy.Max(x, y) |
| else: |
| sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols) |
| |
| if arity == 1: |
| def trace_f(px): |
| return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr) |
| else: |
| def trace_f(px, py): |
| return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr) |
| |
| gm = fx.symbolic_trace(trace_f) |
| |
| self.assertEqual( |
| sympy_interp(PythonReferenceAnalysis, dict(zip(symbols, args)), sympy_expr), |
| gm(*args) |
| ) |
| |
| |
| def type_name_fn(type: Type) -> str: |
| return type.__name__ |
| |
| def parametrize_relational_types(*types): |
| def wrapper(f: Callable): |
| return parametrize("op", types or RELATIONAL_TYPES, name_fn=type_name_fn)(f) |
| return wrapper |
| |
| |
| class TestSympySolve(TestCase): |
| def _create_integer_symbols(self) -> List[sympy.Symbol]: |
| return sympy.symbols("a b c", integer=True) |
| |
| def test_give_up(self): |
| from sympy import Eq, Ne |
| |
| a, b, c = self._create_integer_symbols() |
| |
| cases = [ |
| # Not a relational operation. |
| a + b, |
| # 'a' appears on both sides. |
| Eq(a, a + 1), |
| # 'a' doesn't appear on neither side. |
| Eq(b, c + 1), |
| # Result is a 'sympy.And'. |
| Eq(FloorDiv(a, b), c), |
| # Result is a 'sympy.Or'. |
| Ne(FloorDiv(a, b), c), |
| ] |
| |
| for case in cases: |
| e = try_solve(case, a) |
| self.assertEqual(e, None) |
| |
| @parametrize_relational_types() |
| def test_noop(self, op): |
| a, b, _ = self._create_integer_symbols() |
| |
| lhs, rhs = a, 42 * b |
| expr = op(lhs, rhs) |
| |
| r = try_solve(expr, a) |
| self.assertNotEqual(r, None) |
| |
| r_expr, r_rhs = r |
| self.assertEqual(r_expr, expr) |
| self.assertEqual(r_rhs, rhs) |
| |
| @parametrize_relational_types() |
| def test_noop_rhs(self, op): |
| a, b, _ = self._create_integer_symbols() |
| |
| lhs, rhs = 42 * b, a |
| |
| mirror = mirror_rel_op(op) |
| self.assertNotEqual(mirror, None) |
| |
| expr = op(lhs, rhs) |
| |
| r = try_solve(expr, a) |
| self.assertNotEqual(r, None) |
| |
| r_expr, r_rhs = r |
| self.assertEqual(r_expr, mirror(rhs, lhs)) |
| self.assertEqual(r_rhs, lhs) |
| |
| def _test_cases(self, cases: List[Tuple[sympy.Basic, sympy.Basic]], thing: sympy.Basic, op: Type[sympy.Rel], **kwargs): |
| for source, expected in cases: |
| r = try_solve(source, thing, **kwargs) |
| |
| self.assertTrue( |
| (r is None and expected is None) |
| or (r is not None and expected is not None) |
| ) |
| |
| if r is not None: |
| r_expr, r_rhs = r |
| self.assertEqual(r_rhs, expected) |
| self.assertEqual(r_expr, op(thing, expected)) |
| |
| def test_addition(self): |
| from sympy import Eq |
| |
| a, b, c = self._create_integer_symbols() |
| |
| cases = [ |
| (Eq(a + b, 0), -b), |
| (Eq(a + 5, b - 5), b - 10), |
| (Eq(a + c * b, 1), 1 - c * b), |
| ] |
| |
| self._test_cases(cases, a, Eq) |
| |
| @parametrize_relational_types(sympy.Eq, sympy.Ne) |
| def test_multiplication_division(self, op): |
| a, b, c = self._create_integer_symbols() |
| |
| cases = [ |
| (op(a * b, 1), 1 / b), |
| (op(a * 5, b - 5), (b - 5) / 5), |
| (op(a * b, c), c / b), |
| ] |
| |
| self._test_cases(cases, a, op) |
| |
| @parametrize_relational_types(*INEQUALITY_TYPES) |
| def test_multiplication_division_inequality(self, op): |
| a, b, _ = self._create_integer_symbols() |
| intneg = sympy.Symbol("neg", integer=True, negative=True) |
| intpos = sympy.Symbol("pos", integer=True, positive=True) |
| |
| cases = [ |
| # Divide/multiply both sides by positive number. |
| (op(a * intpos, 1), 1 / intpos), |
| (op(a / (5 * intpos), 1), 5 * intpos), |
| (op(a * 5, b - 5), (b - 5) / 5), |
| # 'b' is not strictly positive nor negative, so we can't |
| # divide/multiply both sides by 'b'. |
| (op(a * b, 1), None), |
| (op(a / b, 1), None), |
| (op(a * b * intpos, 1), None), |
| ] |
| |
| mirror_cases = [ |
| # Divide/multiply both sides by negative number. |
| (op(a * intneg, 1), 1 / intneg), |
| (op(a / (5 * intneg), 1), 5 * intneg), |
| (op(a * -5, b - 5), -(b - 5) / 5), |
| ] |
| mirror_op = mirror_rel_op(op) |
| assert mirror_op is not None |
| |
| self._test_cases(cases, a, op) |
| self._test_cases(mirror_cases, a, mirror_op) |
| |
| @parametrize_relational_types() |
| def test_floordiv(self, op): |
| from sympy import Eq, Ne, Gt, Ge, Lt, Le |
| |
| a, b, c = sympy.symbols("a b c") |
| pos = sympy.Symbol("pos", positive=True) |
| integer = sympy.Symbol("integer", integer=True) |
| |
| # (Eq(FloorDiv(a, pos), integer), And(Ge(a, integer * pos), Lt(a, (integer + 1) * pos))), |
| # (Eq(FloorDiv(a + 5, pos), integer), And(Ge(a, integer * pos), Lt(a, (integer + 1) * pos))), |
| # (Ne(FloorDiv(a, pos), integer), Or(Lt(a, integer * pos), Ge(a, (integer + 1) * pos))), |
| |
| special_case = { |
| # 'FloorDiv' turns into 'And', which can't be simplified any further. |
| Eq: (Eq(FloorDiv(a, pos), integer), None), |
| # 'FloorDiv' turns into 'Or', which can't be simplified any further. |
| Ne: (Ne(FloorDiv(a, pos), integer), None), |
| Gt: (Gt(FloorDiv(a, pos), integer), (integer + 1) * pos), |
| Ge: (Ge(FloorDiv(a, pos), integer), integer * pos), |
| Lt: (Lt(FloorDiv(a, pos), integer), integer * pos), |
| Le: (Le(FloorDiv(a, pos), integer), (integer + 1) * pos), |
| }[op] |
| |
| cases: List[Tuple[sympy.Basic, sympy.Basic]] = [ |
| # 'b' is not strictly positive |
| (op(FloorDiv(a, b), integer), None), |
| # 'c' is not strictly positive |
| (op(FloorDiv(a, pos), c), None), |
| ] |
| |
| # The result might change after 'FloorDiv' transformation. |
| # Specifically: |
| # - [Ge, Gt] => Ge |
| # - [Le, Lt] => Lt |
| if op in (sympy.Gt, sympy.Ge): |
| r_op = sympy.Ge |
| elif op in (sympy.Lt, sympy.Le): |
| r_op = sympy.Lt |
| else: |
| r_op = op |
| |
| self._test_cases([special_case, *cases], a, r_op) |
| self._test_cases([(special_case[0], None), *cases], a, r_op, floordiv_inequality=False) |
| |
| def test_floordiv_eq_simplify(self): |
| from sympy import Eq, Lt, Le |
| |
| a = sympy.Symbol("a", positive=True, integer=True) |
| |
| def check(expr, expected): |
| r = try_solve(expr, a) |
| self.assertNotEqual(r, None) |
| r_expr, _ = r |
| self.assertEqual(r_expr, expected) |
| |
| # (a + 10) // 3 == 3 |
| # ===================================== |
| # 3 * 3 <= a + 10 (always true) |
| # a + 10 < 4 * 3 (not sure) |
| check(Eq(FloorDiv(a + 10, 3), 3), Lt(a, (3 + 1) * 3 - 10)) |
| |
| # (a + 10) // 2 == 4 |
| # ===================================== |
| # 4 * 2 <= 10 - a (not sure) |
| # 10 - a < 5 * 2 (always true) |
| check(Eq(FloorDiv(10 - a, 2), 4), Le(a, -(4 * 2 - 10))) |
| |
| @skipIf(not TEST_Z3, "Z3 not installed") |
| def test_z3_proof_floordiv_eq_simplify(self): |
| import z3 |
| from sympy import Eq, Lt |
| |
| a = sympy.Symbol("a", positive=True, integer=True) |
| a_ = z3.Int("a") |
| |
| # (a + 10) // 3 == 3 |
| # ===================================== |
| # 3 * 3 <= a + 10 (always true) |
| # a + 10 < 4 * 3 (not sure) |
| solver = z3.SolverFor("QF_NRA") |
| |
| # Add assertions for 'a_'. |
| solver.add(a_ > 0) |
| |
| expr = Eq(FloorDiv(a + 10, 3), 3) |
| r_expr, _ = try_solve(expr, a) |
| |
| # Check 'try_solve' really returns the 'expected' below. |
| expected = Lt(a, (3 + 1) * 3 - 10) |
| self.assertEqual(r_expr, expected) |
| |
| # Check whether there is an integer 'a_' such that the |
| # equation below is satisfied. |
| solver.add( |
| # expr |
| (z3.ToInt((a_ + 10) / 3.0) == 3) |
| != |
| # expected |
| (a_ < (3 + 1) * 3 - 10) |
| ) |
| |
| # Assert that there's no such an integer. |
| # i.e. the transformation is sound. |
| r = solver.check() |
| self.assertEqual(r, z3.unsat) |
| |
| def test_simple_floordiv_gcd(self): |
| x, y, z = sympy.symbols("x y z") |
| |
| # positive tests |
| self.assertEqual(simple_floordiv_gcd(x, x), x) |
| self.assertEqual(simple_floordiv_gcd(128 * x, 2304), 128) |
| self.assertEqual(simple_floordiv_gcd(128 * x + 128 * y, 2304), 128) |
| self.assertEqual(simple_floordiv_gcd(128 * x + 128 * y + 8192 * z, 9216), 128) |
| self.assertEqual(simple_floordiv_gcd(49152 * x, 96 * x), 96 * x) |
| self.assertEqual(simple_floordiv_gcd(96 * x, 96 * x), 96 * x) |
| self.assertEqual(simple_floordiv_gcd(x * y, x), x) |
| self.assertEqual(simple_floordiv_gcd(384 * x * y, x * y), x * y) |
| self.assertEqual(simple_floordiv_gcd(256 * x * y, 8 * x), 8 * x) |
| |
| # negative tests |
| self.assertEqual(simple_floordiv_gcd(x * y + x + y + 1, x + 1), 1) |
| |
| |
| class TestSingletonInt(TestCase): |
| def test_basic(self): |
| j1 = SingletonInt(1, coeff=1) |
| j1_copy = SingletonInt(1, coeff=1) |
| j2 = SingletonInt(2, coeff=1) |
| j1x2 = SingletonInt(1, coeff=2) |
| |
| def test_eq(a, b, expected): |
| self.assertEqual(sympy.Eq(a, b), expected) |
| self.assertEqual(sympy.Ne(b, a), not expected) |
| |
| # eq, ne |
| test_eq(j1, j1, True) |
| test_eq(j1, j1_copy, True) |
| test_eq(j1, j2, False) |
| test_eq(j1, j1x2, False) |
| test_eq(j1, sympy.Integer(1), False) |
| test_eq(j1, sympy.Integer(3), False) |
| |
| def test_ineq(a, b, expected, *, strict=True): |
| greater = (sympy.Gt, is_gt) if strict else (sympy.Ge, is_ge) |
| less = (sympy.Lt, is_lt) if strict else (sympy.Le, is_le) |
| |
| if isinstance(expected, bool): |
| # expected is always True |
| for fn in greater: |
| self.assertEqual(fn(a, b), expected) |
| self.assertEqual(fn(b, a), not expected) |
| for fn in less: |
| self.assertEqual(fn(b, a), expected) |
| self.assertEqual(fn(a, b), not expected) |
| else: |
| for fn in greater: |
| with self.assertRaisesRegex(ValueError, expected): |
| fn(a, b) |
| for fn in less: |
| with self.assertRaisesRegex(ValueError, expected): |
| fn(b, a) |
| |
| # ge, le, gt, lt |
| for strict in (True, False): |
| _test_ineq = functools.partial(test_ineq, strict=strict) |
| _test_ineq(j1, sympy.Integer(0), True) |
| _test_ineq(j1, sympy.Integer(3), "indeterminate") |
| _test_ineq(j1, j2, "indeterminate") |
| _test_ineq(j1x2, j1, True) |
| |
| # Special cases for ge, le, gt, lt: |
| for ge in (sympy.Ge, is_ge): |
| self.assertTrue(ge(j1, j1)) |
| self.assertTrue(ge(j1, sympy.Integer(2))) |
| with self.assertRaisesRegex(ValueError, "indeterminate"): |
| ge(sympy.Integer(2), j1) |
| for le in (sympy.Le, is_le): |
| self.assertTrue(le(j1, j1)) |
| self.assertTrue(le(sympy.Integer(2), j1)) |
| with self.assertRaisesRegex(ValueError, "indeterminate"): |
| le(j1, sympy.Integer(2)) |
| |
| for gt in (sympy.Gt, is_gt): |
| self.assertFalse(gt(j1, j1)) |
| self.assertFalse(gt(sympy.Integer(2), j1)) |
| # it is only known to be that j1 >= 2, j1 > 2 is indeterminate |
| with self.assertRaisesRegex(ValueError, "indeterminate"): |
| gt(j1, sympy.Integer(2)) |
| |
| for lt in (sympy.Lt, is_lt): |
| self.assertFalse(lt(j1, j1)) |
| self.assertFalse(lt(j1, sympy.Integer(2))) |
| with self.assertRaisesRegex(ValueError, "indeterminate"): |
| lt(sympy.Integer(2), j1) |
| |
| # mul |
| self.assertEqual(j1 * 2, j1x2) |
| # Unfortunately, this doesn't not automatically simplify to 2*j1 |
| # since sympy.Mul doesn't trigger __mul__ unlike the above. |
| self.assertIsInstance(sympy.Mul(j1, 2), sympy.core.mul.Mul) |
| |
| with self.assertRaisesRegex(ValueError, "cannot be multiplied"): |
| j1 * j2 |
| |
| self.assertEqual(j1.free_symbols, set()) |
| |
| |
| instantiate_parametrized_tests(TestValueRanges) |
| instantiate_parametrized_tests(TestSympyInterp) |
| instantiate_parametrized_tests(TestSympySolve) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |