Symintify repeat_interleave.self_int (#89111)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89111
Approved by: https://github.com/ezyang
diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp
index e31b36d..05ee8d0 100644
--- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp
+++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp
@@ -184,7 +184,7 @@
   OP_DECOMPOSE(positive);
   OP_DECOMPOSE(qr);
   OP_DECOMPOSE(ravel);
-  OP_DECOMPOSE2(repeat_interleave, self_int);
+  m.impl("repeat_interleave.self_int", native::repeat_interleave_symint);
   OP_DECOMPOSE2(repeat_interleave, self_Tensor);
   m.impl("reshape", native::reshape_symint);
   OP_DECOMPOSE(resolve_conj);
diff --git a/aten/src/ATen/native/Repeat.cpp b/aten/src/ATen/native/Repeat.cpp
index b671a22..c8c4e13 100644
--- a/aten/src/ATen/native/Repeat.cpp
+++ b/aten/src/ATen/native/Repeat.cpp
@@ -75,11 +75,11 @@
   }
 
   Tensor repeats_ = repeats;
-  if (repeats.dim() == 0 || (repeats.dim() == 1 && repeats.size(0) == 1)) {
-    repeats_ = repeats.reshape({1}).expand({input.size(dim.value())});
+  if (repeats.dim() == 0 || (repeats.dim() == 1 && repeats.sym_size(0) == 1)) {
+    repeats_ = repeats.reshape({1}).expand_symint({input.sym_size(dim.value())});
   } else if (repeats.dim() == 1) {
     TORCH_CHECK(
-        repeats.size(0) == input.size(dim.value()),
+        repeats.sym_size(0) == input.sym_size(dim.value()),
         "repeats must have the same size as input along dim")
   } else {
     AT_ERROR("repeats must be 0-dim or 1-dim tensor");
@@ -102,10 +102,17 @@
     int64_t repeats,
     c10::optional<int64_t> dim,
     c10::optional<int64_t> output_size) {
-  at::Tensor repeats_ =
-      at::empty(1, self.options().dtype(at::kLong)).fill_(repeats);
+  at::Tensor repeats_ = at::empty(1, self.options().dtype(at::kLong)).fill_(repeats);
   return at::native::repeat_interleave(self, repeats_, dim, output_size);
 }
 
+Tensor repeat_interleave_symint(
+    const Tensor& self,
+    c10::SymInt repeats,
+    c10::optional<int64_t> dim,
+    c10::optional<int64_t> output_size) {
+    return at::native::repeat_interleave(self, repeats.guard_int(__FILE__, __LINE__), dim, output_size);
+  }
+
 } // namespace native
 } // namespace at
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index b1d1094..5cf0e75 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -4320,8 +4320,10 @@
 - func: repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, int? output_size=None) -> Tensor
   variants: function, method
 
-- func: repeat_interleave.self_int(Tensor self, int repeats, int? dim=None, *, int? output_size=None) -> Tensor
+- func: repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, int? output_size=None) -> Tensor
   variants: function, method
+  dispatch:
+    CompositeImplicitAutograd: repeat_interleave_symint
 
 - func: reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)
   variants: function, method
diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp
index 411ebdb..c2bf4e0 100644
--- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp
+++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp
@@ -152,8 +152,8 @@
     if (t_dim == 3 && nt_input->opt_size(2) && (*nt_input->opt_size(2) > 0) &&
         !(output_size.has_value())) {
       Tensor nt_sizes = nt_input->get_nested_size_tensor();
-      Tensor sizes_dim1 = at::native::narrow(nt_sizes, 1, 0, 1);
-      Tensor sizes_dim2 = at::native::narrow(nt_sizes, 1, 1, 1);
+      Tensor sizes_dim1 = at::native::narrow_symint(nt_sizes, 1, 0, 1);
+      Tensor sizes_dim2 = at::native::narrow_symint(nt_sizes, 1, 1, 1);
       Tensor result = at::detail::make_tensor<NestedTensorImpl>(
           nt_input->get_buffer(), sizes_dim1 * sizes_dim2[0]);
       TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.dim() == 2);
diff --git a/aten/src/ATen/native/quantized/FakeQuantPerTensorAffine.cpp b/aten/src/ATen/native/quantized/FakeQuantPerTensorAffine.cpp
index 700b3b1..aac039f 100644
--- a/aten/src/ATen/native/quantized/FakeQuantPerTensorAffine.cpp
+++ b/aten/src/ATen/native/quantized/FakeQuantPerTensorAffine.cpp
@@ -122,10 +122,10 @@
     const Tensor& dY,
     const Tensor& mask) {
   TORCH_CHECK(mask.scalar_type() == ScalarType::Bool);
-  TORCH_CHECK(mask.numel() == dY.numel(),
+  TORCH_CHECK(mask.sym_numel() == dY.sym_numel(),
       "`mask` and `dY` are not the same size: ",
-      "`mask` is size ", mask.numel(), " and `dY` is size ", dY.numel());
-  if (dY.numel() <= 0) {
+      "`mask` is size ", mask.sym_numel(), " and `dY` is size ", dY.sym_numel());
+  if (dY.sym_numel() <= 0) {
     return dY;
   }
   // Note: no additional kernels needed, since mask is pre-computed
diff --git a/c10/core/SymFloat.cpp b/c10/core/SymFloat.cpp
index 81e8f25..511c50e 100644
--- a/c10/core/SymFloat.cpp
+++ b/c10/core/SymFloat.cpp
@@ -70,4 +70,12 @@
   return os;
 }
 
+double SymFloat::guard_float(const char* file, int64_t line) const {
+  if (!is_symbolic()) {
+    return data_;
+  }
+  SymNode a = toSymNodeImpl();
+  return a->guard_float(file, line);
+}
+
 } // namespace c10
diff --git a/c10/core/SymFloat.h b/c10/core/SymFloat.h
index 7da364c..ff9e101 100644
--- a/c10/core/SymFloat.h
+++ b/c10/core/SymFloat.h
@@ -40,6 +40,16 @@
   SymFloat operator*(const SymFloat&) const;
   SymFloat operator/(const SymFloat&) const;
 
+  // Insert a guard for the float to be its concrete value, and then return
+  // that value.  This operation always works, even if the float is symbolic,
+  // so long as we know what the underlying value is. Don't blindly put this
+  // everywhere; you can cause overspecialization of PyTorch programs with
+  // this method.
+  //
+  // It should be called as guard_float(__FILE__, __LINE__).  The file and line
+  // number can be used to diagnose overspecialization.
+  double guard_float(const char* file, int64_t line) const;
+
   // N.B. It's important to keep this definition in the header
   // as we expect if checks to be folded for mobile builds
   // where `is_symbolic` is always false
diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py
index f3964a7..2eb1678 100644
--- a/test/dynamo/test_dynamic_shapes.py
+++ b/test/dynamo/test_dynamic_shapes.py
@@ -106,11 +106,6 @@
     DynamicShapesUnspecTests.test_unspec_float_precision_dynamic_shapes
 )
 
-unittest.expectedFailure(
-    DynamicShapesReproTests.test_reformer_sorting_dynamic_shapes
-    # Unable to cast Python instance to C++ type
-)
-
 if __name__ == "__main__":
     from torch._dynamo.test_case import run_tests
 
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 8dc42be..21682ac 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1122,6 +1122,8 @@
     xfail('multinomial'),
     xfail('cholesky'),
     xfail('cholesky_inverse'),
+    # cannot do these as they rely on tensor data
+    xfail('repeat_interleave'),
     # ASAN failures due to divide by 0
     skip('nn.functional.nll_loss'),
 }
@@ -1283,7 +1285,6 @@
     xfail('nn.functional.pixel_unshuffle', ''),  # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco...
     xfail('nn.functional.rrelu', ''),  # aten.empty_like.default - couldn't find symbolic meta function/decomposition
     xfail('nn.functional.smooth_l1_loss', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition
-    xfail('nn.functional.unfold', ''),  # aten.im2col.default - couldn't find symbolic meta function/decomposition
     xfail('nn.functional.upsample_nearest', ''),  # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/deco...
     xfail('nonzero', ''),  # aten.nonzero.default - couldn't find symbolic meta function/decomposition
     xfail('norm', 'nuc'),  # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition
diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py
index da8d9af..22917ec 100644
--- a/torch/_prims/__init__.py
+++ b/torch/_prims/__init__.py
@@ -2323,10 +2323,12 @@
         step != 0,
         lambda: "step must be nonzero",
     )
-    utils.check(
-        math.isfinite(start) and math.isfinite(end),
-        lambda: f"unsupported range: {start} -> {end}",
-    )
+    # SymInts can't represent inf
+    if not isinstance(start, torch.SymInt) and not isinstance(end, torch.SymInt):
+        utils.check(
+            math.isfinite(start) and math.isfinite(end),
+            lambda: f"unsupported range: {start} -> {end}",
+        )
     utils.check(
         (step > 0 and end >= start) or (step < 0 and end <= start),
         lambda: "upper bound and lower bound inconsistent with step sign",
diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp
index 707ebeb..8350634 100644
--- a/torch/csrc/utils/tensor_new.cpp
+++ b/torch/csrc/utils/tensor_new.cpp
@@ -79,10 +79,10 @@
     c10::TensorOptions options,
     at::ScalarType scalar_type,
     const optional<Device>& device,
-    IntArrayRef sizes) {
+    c10::SymIntArrayRef sizes) {
   maybe_initialize_cuda(options.device());
   pybind11::gil_scoped_release no_gil;
-  return torch::empty(sizes, build_options(options, scalar_type, device));
+  return at::empty_symint(sizes, build_options(options, scalar_type, device));
 }
 
 Tensor new_with_storage(
@@ -124,6 +124,12 @@
 }
 
 ScalarType infer_scalar_type(PyObject* obj) {
+  if (torch::is_symint(obj)) {
+    return ScalarType::Long;
+  }
+  if (torch::is_symfloat(obj)) {
+    return ScalarType::Double;
+  }
 #ifdef USE_NUMPY
   if (is_numpy_available()) {
     if (PyArray_Check(obj)) {
@@ -204,7 +210,21 @@
   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(data != nullptr);
 
   int64_t ndim = sizes.size();
+  bool is_symfloat = torch::is_symfloat(obj);
+  bool is_symint = torch::is_symint(obj);
   if (dim == ndim) {
+    if (is_symfloat) {
+      auto new_obj = py::reinterpret_borrow<py::object>(obj);
+      auto val = new_obj.cast<c10::SymFloat>();
+      *(double*)data = val.guard_float(__FILE__, __LINE__);
+      return;
+    }
+    if (is_symint) {
+      auto new_obj = py::reinterpret_borrow<py::object>(obj);
+      auto val = new_obj.cast<c10::SymInt>();
+      *(int64_t*)data = val.guard_int(__FILE__, __LINE__);
+      return;
+    }
     torch::utils::store_scalar(data, scalarType, obj);
     return;
   }
@@ -531,7 +551,7 @@
       "new(*, int64_t cdata)|hidden",
       "new(Tensor indices, Tensor values, *, Device? device=None)",
       "new(Tensor indices, Tensor values, IntArrayRef size, *, Device? device=None)",
-      "new(IntArrayRef size, *, Device? device=None)",
+      "new(SymIntArrayRef size, *, Device? device=None)",
   });
   if (ctor_or_new == CtorOrNew::NEW)
     check_base_legacy_new(dispatch_key, c10::kSparse);
@@ -577,7 +597,7 @@
       }
     }
     return new_with_sizes(
-        options, scalar_type, r.deviceOptional(1), r.intlist(0));
+        options, scalar_type, r.deviceOptional(1), r.symintlist(0));
   }
   throw std::runtime_error("new(): invalid arguments");
 }
@@ -615,7 +635,7 @@
                                                           // matching with
                                                           // IntArrayRef,
                                                           // PyObject*
-      "new(IntArrayRef size, *, Device? device=None)",
+      "new(SymIntArrayRef size, *, Device? device=None)",
       "new(PyObject* data, *, Device? device=None)",
   });
 
@@ -690,7 +710,7 @@
           options, scalar_type, deviceOptional, r.pyobject(0));
     }
     return new_with_sizes(
-        options, scalar_type, r.deviceOptional(1), r.intlist(0));
+        options, scalar_type, r.deviceOptional(1), r.symintlist(0));
   } else if (r.idx == 6) {
     auto deviceOptional = r.deviceOptional(1);
     check_legacy_ctor_device(dispatch_key, deviceOptional);