blob: 03b0bc99dba873611f0f69780018038d206ce177 [file] [log] [blame]
# -*- coding: utf-8 -*-
# Owner(s): ["oncall: pt2"]
import itertools
import sympy
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TestCase,
)
from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
from torch.utils._sympy.reference import ReferenceAnalysis
from torch.utils._sympy.interp import sympy_interp
UNARY_OPS = [
"reciprocal",
"square",
"abs",
"neg",
"exp",
"log",
"sqrt",
"floor",
"ceil",
]
BINARY_OPS = ["truediv", "div", "add", "mul", "sub", "pow", "minimum", "maximum", "mod"]
UNARY_BOOL_OPS = ["not_"]
BINARY_BOOL_OPS = ["or_", "and_"]
COMPARE_OPS = ["eq", "ne", "lt", "gt", "le", "ge"]
# a mix of constants, powers of two, primes
CONSTANTS = [
-1,
0,
1,
2,
3,
4,
5,
8,
16,
32,
64,
100,
101,
2**24,
2**32,
2**37 - 1,
]
# less constants for N^2 situations
LESS_CONSTANTS = [-1, 0, 1, 2, 100]
def valid_unary(fn, v):
if fn == "log" and v <= 0:
return False
elif fn == "reciprocal" and v == 0:
return False
elif fn == "sqrt" and v < 0:
return False
return True
def valid_binary(fn, a, b):
if fn == "pow" and (
b > 4
or ( # sympy will expand to x*x*... for integral b; don't do it if it's big
a <= 0 and b == -1
)
or (a == b == 0) # no imaginary numbers # 0**0 is undefined
):
return False
elif fn == "mod" and b == 0:
return False
elif (fn == "div" or fn == "truediv") and b == 0:
return False
return True
def generate_range(vals):
for a1, a2 in itertools.product(vals, repeat=2):
if a1 in [sympy.true, sympy.false]:
if a1 == sympy.true and a2 == sympy.false:
continue
else:
if a1 > a2:
continue
# ranges that only admit infinite values are not interesting
if a1 == sympy.oo or a2 == -sympy.oo:
continue
yield ValueRanges(a1, a2)
class TestValueRanges(TestCase):
@parametrize("fn", UNARY_OPS)
def test_unary_ref(self, fn):
for v in CONSTANTS:
if not valid_unary(fn, v):
continue
with self.subTest(v=v):
ref_r = getattr(ReferenceAnalysis, fn)(sympy.Integer(v))
r = getattr(ValueRangeAnalysis, fn)(ValueRanges.wrap(v))
self.assertEqual(r.lower, r.upper)
self.assertEqual(ref_r, r.lower)
def test_pow_half(self):
ValueRangeAnalysis.pow(ValueRanges.unknown(), ValueRanges.wrap(0.5))
@parametrize("fn", BINARY_OPS)
def test_binary_ref(self, fn):
for a, b in itertools.product(CONSTANTS, repeat=2):
if not valid_binary(fn, a, b):
continue
with self.subTest(a=a, b=b):
ref_r = getattr(ReferenceAnalysis, fn)(
sympy.Integer(a), sympy.Integer(b)
)
r = getattr(ValueRangeAnalysis, fn)(
ValueRanges.wrap(a),
ValueRanges.wrap(b),
)
self.assertEqual(r.lower, r.upper)
self.assertEqual(ref_r, r.lower)
def test_mul_zero_unknown(self):
self.assertEqual(
ValueRangeAnalysis.mul(ValueRanges.wrap(0), ValueRanges.unknown()),
ValueRanges.wrap(0),
)
@parametrize("fn", UNARY_BOOL_OPS)
def test_unary_bool_ref_range(self, fn):
vals = [sympy.false, sympy.true]
for a in generate_range(vals):
with self.subTest(a=a):
ref_r = getattr(ValueRangeAnalysis, fn)(a)
unique = set()
for a0 in vals:
if a0 not in a:
continue
with self.subTest(a0=a0):
r = getattr(ReferenceAnalysis, fn)(a0)
self.assertIn(r, ref_r)
unique.add(r)
if ref_r.lower == ref_r.upper:
self.assertEqual(len(unique), 1)
else:
self.assertEqual(len(unique), 2)
@parametrize("fn", BINARY_BOOL_OPS)
def test_binary_bool_ref_range(self, fn):
vals = [sympy.false, sympy.true]
for a, b in itertools.product(generate_range(vals), repeat=2):
with self.subTest(a=a, b=b):
ref_r = getattr(ValueRangeAnalysis, fn)(a, b)
unique = set()
for a0, b0 in itertools.product(vals, repeat=2):
if a0 not in a or b0 not in b:
continue
with self.subTest(a0=a0, b0=b0):
r = getattr(ReferenceAnalysis, fn)(a0, b0)
self.assertIn(r, ref_r)
unique.add(r)
if ref_r.lower == ref_r.upper:
self.assertEqual(len(unique), 1)
else:
self.assertEqual(len(unique), 2)
@parametrize("fn", UNARY_OPS)
def test_unary_ref_range(self, fn):
vals = [-sympy.oo, *CONSTANTS, sympy.oo]
for a in generate_range(vals):
with self.subTest(a=a):
ref_r = getattr(ValueRangeAnalysis, fn)(a)
for a0 in CONSTANTS:
if a0 not in a:
continue
if not valid_unary(fn, a0):
continue
with self.subTest(a0=a0):
r = getattr(ReferenceAnalysis, fn)(sympy.Integer(a0))
self.assertIn(r, ref_r)
# This takes about 4s for all the variants
@parametrize("fn", BINARY_OPS + COMPARE_OPS)
def test_binary_ref_range(self, fn):
vals = [-sympy.oo, *LESS_CONSTANTS, sympy.oo]
for a, b in itertools.product(generate_range(vals), repeat=2):
# don't attempt pow on exponents that are too large (but oo is OK)
if fn == "pow" and b.upper > 4 and b.upper != sympy.oo:
continue
with self.subTest(a=a, b=b):
ref_r = getattr(ValueRangeAnalysis, fn)(a, b)
for a0, b0 in itertools.product(LESS_CONSTANTS, repeat=2):
if a0 not in a or b0 not in b:
continue
if not valid_binary(fn, a0, b0):
continue
with self.subTest(a0=a0, b0=b0):
r = getattr(ReferenceAnalysis, fn)(
sympy.Integer(a0), sympy.Integer(b0)
)
self.assertIn(r, ref_r)
class TestSympyInterp(TestCase):
@parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS)
def test_interp(self, fn):
from sympy.abc import x, y
vals = CONSTANTS
if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}:
vals = [True, False]
arity = 1
if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}:
arity = 2
symbols = [x]
if arity == 2:
symbols = [x, y]
for args in itertools.product(vals, repeat=arity):
if arity == 1 and not valid_unary(fn, *args):
continue
elif arity == 2 and not valid_binary(fn, *args):
continue
with self.subTest(args=args):
sargs = [sympy.sympify(a) for a in args]
sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols)
ref_r = getattr(ReferenceAnalysis, fn)(*sargs)
# Yes, I know this is a longwinded way of saying xreplace; the
# point is to test sympy_interp
r = sympy_interp(ReferenceAnalysis, dict(zip(symbols, sargs)), sympy_expr)
self.assertEqual(ref_r, r)
instantiate_parametrized_tests(TestValueRanges)
instantiate_parametrized_tests(TestSympyInterp)
if __name__ == "__main__":
run_tests()