| # -*- coding: utf-8 -*- |
| # Owner(s): ["oncall: pt2"] |
| |
| import itertools |
| |
| import sympy |
| from torch.testing._internal.common_utils import ( |
| instantiate_parametrized_tests, |
| parametrize, |
| run_tests, |
| TestCase, |
| ) |
| from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges |
| from torch.utils._sympy.reference import ReferenceAnalysis |
| from torch.utils._sympy.interp import sympy_interp |
| |
| |
| UNARY_OPS = [ |
| "reciprocal", |
| "square", |
| "abs", |
| "neg", |
| "exp", |
| "log", |
| "sqrt", |
| "floor", |
| "ceil", |
| ] |
| BINARY_OPS = ["truediv", "div", "add", "mul", "sub", "pow", "minimum", "maximum", "mod"] |
| |
| UNARY_BOOL_OPS = ["not_"] |
| BINARY_BOOL_OPS = ["or_", "and_"] |
| COMPARE_OPS = ["eq", "ne", "lt", "gt", "le", "ge"] |
| |
| # a mix of constants, powers of two, primes |
| CONSTANTS = [ |
| -1, |
| 0, |
| 1, |
| 2, |
| 3, |
| 4, |
| 5, |
| 8, |
| 16, |
| 32, |
| 64, |
| 100, |
| 101, |
| 2**24, |
| 2**32, |
| 2**37 - 1, |
| ] |
| # less constants for N^2 situations |
| LESS_CONSTANTS = [-1, 0, 1, 2, 100] |
| |
| |
| def valid_unary(fn, v): |
| if fn == "log" and v <= 0: |
| return False |
| elif fn == "reciprocal" and v == 0: |
| return False |
| elif fn == "sqrt" and v < 0: |
| return False |
| return True |
| |
| |
| def valid_binary(fn, a, b): |
| if fn == "pow" and ( |
| b > 4 |
| or ( # sympy will expand to x*x*... for integral b; don't do it if it's big |
| a <= 0 and b == -1 |
| ) |
| or (a == b == 0) # no imaginary numbers # 0**0 is undefined |
| ): |
| return False |
| elif fn == "mod" and b == 0: |
| return False |
| elif (fn == "div" or fn == "truediv") and b == 0: |
| return False |
| return True |
| |
| |
| def generate_range(vals): |
| for a1, a2 in itertools.product(vals, repeat=2): |
| if a1 in [sympy.true, sympy.false]: |
| if a1 == sympy.true and a2 == sympy.false: |
| continue |
| else: |
| if a1 > a2: |
| continue |
| # ranges that only admit infinite values are not interesting |
| if a1 == sympy.oo or a2 == -sympy.oo: |
| continue |
| yield ValueRanges(a1, a2) |
| |
| |
| class TestValueRanges(TestCase): |
| @parametrize("fn", UNARY_OPS) |
| def test_unary_ref(self, fn): |
| for v in CONSTANTS: |
| if not valid_unary(fn, v): |
| continue |
| with self.subTest(v=v): |
| ref_r = getattr(ReferenceAnalysis, fn)(sympy.Integer(v)) |
| r = getattr(ValueRangeAnalysis, fn)(ValueRanges.wrap(v)) |
| self.assertEqual(r.lower, r.upper) |
| self.assertEqual(ref_r, r.lower) |
| |
| def test_pow_half(self): |
| ValueRangeAnalysis.pow(ValueRanges.unknown(), ValueRanges.wrap(0.5)) |
| |
| @parametrize("fn", BINARY_OPS) |
| def test_binary_ref(self, fn): |
| for a, b in itertools.product(CONSTANTS, repeat=2): |
| if not valid_binary(fn, a, b): |
| continue |
| with self.subTest(a=a, b=b): |
| ref_r = getattr(ReferenceAnalysis, fn)( |
| sympy.Integer(a), sympy.Integer(b) |
| ) |
| r = getattr(ValueRangeAnalysis, fn)( |
| ValueRanges.wrap(a), |
| ValueRanges.wrap(b), |
| ) |
| self.assertEqual(r.lower, r.upper) |
| self.assertEqual(ref_r, r.lower) |
| |
| def test_mul_zero_unknown(self): |
| self.assertEqual( |
| ValueRangeAnalysis.mul(ValueRanges.wrap(0), ValueRanges.unknown()), |
| ValueRanges.wrap(0), |
| ) |
| |
| @parametrize("fn", UNARY_BOOL_OPS) |
| def test_unary_bool_ref_range(self, fn): |
| vals = [sympy.false, sympy.true] |
| for a in generate_range(vals): |
| with self.subTest(a=a): |
| ref_r = getattr(ValueRangeAnalysis, fn)(a) |
| unique = set() |
| for a0 in vals: |
| if a0 not in a: |
| continue |
| with self.subTest(a0=a0): |
| r = getattr(ReferenceAnalysis, fn)(a0) |
| self.assertIn(r, ref_r) |
| unique.add(r) |
| if ref_r.lower == ref_r.upper: |
| self.assertEqual(len(unique), 1) |
| else: |
| self.assertEqual(len(unique), 2) |
| |
| @parametrize("fn", BINARY_BOOL_OPS) |
| def test_binary_bool_ref_range(self, fn): |
| vals = [sympy.false, sympy.true] |
| for a, b in itertools.product(generate_range(vals), repeat=2): |
| with self.subTest(a=a, b=b): |
| ref_r = getattr(ValueRangeAnalysis, fn)(a, b) |
| unique = set() |
| for a0, b0 in itertools.product(vals, repeat=2): |
| if a0 not in a or b0 not in b: |
| continue |
| with self.subTest(a0=a0, b0=b0): |
| r = getattr(ReferenceAnalysis, fn)(a0, b0) |
| self.assertIn(r, ref_r) |
| unique.add(r) |
| if ref_r.lower == ref_r.upper: |
| self.assertEqual(len(unique), 1) |
| else: |
| self.assertEqual(len(unique), 2) |
| |
| @parametrize("fn", UNARY_OPS) |
| def test_unary_ref_range(self, fn): |
| vals = [-sympy.oo, *CONSTANTS, sympy.oo] |
| for a in generate_range(vals): |
| with self.subTest(a=a): |
| ref_r = getattr(ValueRangeAnalysis, fn)(a) |
| for a0 in CONSTANTS: |
| if a0 not in a: |
| continue |
| if not valid_unary(fn, a0): |
| continue |
| with self.subTest(a0=a0): |
| r = getattr(ReferenceAnalysis, fn)(sympy.Integer(a0)) |
| self.assertIn(r, ref_r) |
| |
| # This takes about 4s for all the variants |
| @parametrize("fn", BINARY_OPS + COMPARE_OPS) |
| def test_binary_ref_range(self, fn): |
| vals = [-sympy.oo, *LESS_CONSTANTS, sympy.oo] |
| for a, b in itertools.product(generate_range(vals), repeat=2): |
| # don't attempt pow on exponents that are too large (but oo is OK) |
| if fn == "pow" and b.upper > 4 and b.upper != sympy.oo: |
| continue |
| with self.subTest(a=a, b=b): |
| ref_r = getattr(ValueRangeAnalysis, fn)(a, b) |
| for a0, b0 in itertools.product(LESS_CONSTANTS, repeat=2): |
| if a0 not in a or b0 not in b: |
| continue |
| if not valid_binary(fn, a0, b0): |
| continue |
| with self.subTest(a0=a0, b0=b0): |
| r = getattr(ReferenceAnalysis, fn)( |
| sympy.Integer(a0), sympy.Integer(b0) |
| ) |
| self.assertIn(r, ref_r) |
| |
| |
| class TestSympyInterp(TestCase): |
| @parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS) |
| def test_interp(self, fn): |
| from sympy.abc import x, y |
| vals = CONSTANTS |
| if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}: |
| vals = [True, False] |
| arity = 1 |
| if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}: |
| arity = 2 |
| symbols = [x] |
| if arity == 2: |
| symbols = [x, y] |
| for args in itertools.product(vals, repeat=arity): |
| if arity == 1 and not valid_unary(fn, *args): |
| continue |
| elif arity == 2 and not valid_binary(fn, *args): |
| continue |
| with self.subTest(args=args): |
| sargs = [sympy.sympify(a) for a in args] |
| sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols) |
| ref_r = getattr(ReferenceAnalysis, fn)(*sargs) |
| # Yes, I know this is a longwinded way of saying xreplace; the |
| # point is to test sympy_interp |
| r = sympy_interp(ReferenceAnalysis, dict(zip(symbols, sargs)), sympy_expr) |
| self.assertEqual(ref_r, r) |
| |
| |
| instantiate_parametrized_tests(TestValueRanges) |
| instantiate_parametrized_tests(TestSympyInterp) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |