Use maxint to bound integers. (#96121)
We don't actually support arbitrary precision integers.
Signed-off-by: Edward Z. Yang <[email protected]>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96121
Approved by: https://github.com/tugsbayasgalan, https://github.com/lezcano
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index d1f5de6..d8f2d6f 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -970,6 +970,16 @@
index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], view_2); crop_camera_1 = mask_1 = view_2 = None
return None""")
+ def test_unbacked_slice(self):
+ def f(x, m):
+ x = x[m]
+ return x[slice(None, None, None), slice(None, None, None), slice(None, 2, None)]
+
+ make_fx(f, tracing_mode="symbolic")(
+ torch.randn((12, 3, 3)),
+ torch.randint(0, 2, (12,), dtype=torch.bool)
+ )
+
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
def test_unbacked_batch_resnet(self):
mod = torchvision.models.resnet18()
diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py
index 03b0bc9..1200a6d 100644
--- a/test/test_sympy_utils.py
+++ b/test/test_sympy_utils.py
@@ -2,6 +2,7 @@
# Owner(s): ["oncall: pt2"]
import itertools
+import sys
import sympy
from torch.testing._internal.common_utils import (
@@ -50,6 +51,8 @@
2**24,
2**32,
2**37 - 1,
+ sys.maxsize - 1,
+ sys.maxsize,
]
# less constants for N^2 situations
LESS_CONSTANTS = [-1, 0, 1, 2, 100]
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index 359885a..42d002c 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -426,6 +426,8 @@
raise DynamicOutputShapeException(func)
if arg.nonzero_memo is None:
+ import sys
+
from torch.fx.experimental.symbolic_shapes import constrain_range
nnz = fake_mode.shape_env.create_unbacked_symint()
@@ -438,9 +440,7 @@
# disjoint with what can actually occur. But this is fine:
# remember, the hypothesis is that if your later code works
# with N >= 2, it will work with N = 1 and N = 0.
- lower = 2
- upper = None
- constrain_range(nnz, min=lower, max=upper)
+ constrain_range(nnz, min=2, max=sys.maxsize - 1)
arg._nonzero_memo = nnz
arg._nonzero_memo_vc = arg._version
diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py
index c3cb7a1..6ddc3b7 100644
--- a/torch/fx/experimental/symbolic_shapes.py
+++ b/torch/fx/experimental/symbolic_shapes.py
@@ -1334,7 +1334,7 @@
def create_unbacked_symint(self):
symbol = sympy.Symbol(f"i{next(self.unbacked_symint_counter)}", integer=True)
self.var_to_stack[symbol] = ''.join(traceback.format_list(traceback.extract_stack()[:-1]))
- self.var_to_range[symbol] = ValueRanges.unknown()
+ self.var_to_range[symbol] = ValueRanges(-sys.maxsize - 1, sys.maxsize)
return SymInt(SymNode(symbol, self, int, None))
# This is guaranteed to return a symbol or its negation is a sympy.Symbol,
@@ -1361,7 +1361,10 @@
# We also infer that it must be not 0/1
lower = 2 if self.specialize_zero_one else 0
- self.var_to_range[sympy_expr] = ValueRanges(lower, sympy.oo)
+ # NB: sys.maxsize is NOT allowed for sizes, because we use MAX_INT
+ # as a sentinel sometimes. Your sizevar isn't going to be
+ # anywhere near the max 64-bit integer anyway.
+ self.var_to_range[sympy_expr] = ValueRanges(lower, sys.maxsize - 1)
if not dyn and self.duck_shape:
# This implements duck-shaping: input sizes that match are assigned
@@ -1577,12 +1580,20 @@
if not _simplified:
for symbol, sources in symbol_to_source.items():
assert sources
+ assert symbol.is_integer
r = self.var_to_range[symbol]
bounds = []
if r.lower != -sympy.oo:
bounds.append(str(r.lower))
bounds.append(source_ref(sources[0]))
- if r.upper != sympy.oo:
+ # NB: This looks like an off-by-one error but it's not: the
+ # upper bound may be sys.maxsize - 1 because we intentionally
+ # exclude sys.maxsize from our bounds to deal with direct
+ # == INT_MAX guards, but it's still dumb to actually test it.
+ # Note that you can be off by a pretty large constant and it
+ # won't matter because sizes in practice will be no where near
+ # the 64-bit limit.
+ if r.upper != sympy.oo and r.upper < sys.maxsize - 1:
bounds.append(str(r.upper))
if len(bounds) > 1:
exprs.append(" <= ".join(bounds))