Use maxint to bound integers. (#96121)

We don't actually support arbitrary precision integers.

Signed-off-by: Edward Z. Yang <[email protected]>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96121
Approved by: https://github.com/tugsbayasgalan, https://github.com/lezcano
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index d1f5de6..d8f2d6f 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -970,6 +970,16 @@
     index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], view_2);  crop_camera_1 = mask_1 = view_2 = None
     return None""")
 
+    def test_unbacked_slice(self):
+        def f(x, m):
+            x = x[m]
+            return x[slice(None, None, None), slice(None, None, None), slice(None, 2, None)]
+
+        make_fx(f, tracing_mode="symbolic")(
+            torch.randn((12, 3, 3)),
+            torch.randint(0, 2, (12,), dtype=torch.bool)
+        )
+
     @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
     def test_unbacked_batch_resnet(self):
         mod = torchvision.models.resnet18()
diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py
index 03b0bc9..1200a6d 100644
--- a/test/test_sympy_utils.py
+++ b/test/test_sympy_utils.py
@@ -2,6 +2,7 @@
 # Owner(s): ["oncall: pt2"]
 
 import itertools
+import sys
 
 import sympy
 from torch.testing._internal.common_utils import (
@@ -50,6 +51,8 @@
     2**24,
     2**32,
     2**37 - 1,
+    sys.maxsize - 1,
+    sys.maxsize,
 ]
 # less constants for N^2 situations
 LESS_CONSTANTS = [-1, 0, 1, 2, 100]
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index 359885a..42d002c 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -426,6 +426,8 @@
         raise DynamicOutputShapeException(func)
 
     if arg.nonzero_memo is None:
+        import sys
+
         from torch.fx.experimental.symbolic_shapes import constrain_range
 
         nnz = fake_mode.shape_env.create_unbacked_symint()
@@ -438,9 +440,7 @@
         # disjoint with what can actually occur.  But this is fine:
         # remember, the hypothesis is that if your later code works
         # with N >= 2, it will work with N = 1 and N = 0.
-        lower = 2
-        upper = None
-        constrain_range(nnz, min=lower, max=upper)
+        constrain_range(nnz, min=2, max=sys.maxsize - 1)
 
         arg._nonzero_memo = nnz
         arg._nonzero_memo_vc = arg._version
diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py
index c3cb7a1..6ddc3b7 100644
--- a/torch/fx/experimental/symbolic_shapes.py
+++ b/torch/fx/experimental/symbolic_shapes.py
@@ -1334,7 +1334,7 @@
     def create_unbacked_symint(self):
         symbol = sympy.Symbol(f"i{next(self.unbacked_symint_counter)}", integer=True)
         self.var_to_stack[symbol] = ''.join(traceback.format_list(traceback.extract_stack()[:-1]))
-        self.var_to_range[symbol] = ValueRanges.unknown()
+        self.var_to_range[symbol] = ValueRanges(-sys.maxsize - 1, sys.maxsize)
         return SymInt(SymNode(symbol, self, int, None))
 
     # This is guaranteed to return a symbol or its negation is a sympy.Symbol,
@@ -1361,7 +1361,10 @@
 
             # We also infer that it must be not 0/1
             lower = 2 if self.specialize_zero_one else 0
-            self.var_to_range[sympy_expr] = ValueRanges(lower, sympy.oo)
+            # NB: sys.maxsize is NOT allowed for sizes, because we use MAX_INT
+            # as a sentinel sometimes.  Your sizevar isn't going to be
+            # anywhere near the max 64-bit integer anyway.
+            self.var_to_range[sympy_expr] = ValueRanges(lower, sys.maxsize - 1)
 
         if not dyn and self.duck_shape:
             # This implements duck-shaping: input sizes that match are assigned
@@ -1577,12 +1580,20 @@
         if not _simplified:
             for symbol, sources in symbol_to_source.items():
                 assert sources
+                assert symbol.is_integer
                 r = self.var_to_range[symbol]
                 bounds = []
                 if r.lower != -sympy.oo:
                     bounds.append(str(r.lower))
                 bounds.append(source_ref(sources[0]))
-                if r.upper != sympy.oo:
+                # NB: This looks like an off-by-one error but it's not: the
+                # upper bound may be sys.maxsize - 1 because we intentionally
+                # exclude sys.maxsize from our bounds to deal with direct
+                # == INT_MAX guards, but it's still dumb to actually test it.
+                # Note that you can be off by a pretty large constant and it
+                # won't matter because sizes in practice will be no where near
+                # the 64-bit limit.
+                if r.upper != sympy.oo and r.upper < sys.maxsize - 1:
                     bounds.append(str(r.upper))
                 if len(bounds) > 1:
                     exprs.append(" <= ".join(bounds))