Symintify repeat_interleave.self_int (#89111)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89111
Approved by: https://github.com/ezyang
diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp
index e31b36d..05ee8d0 100644
--- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp
+++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp
@@ -184,7 +184,7 @@
OP_DECOMPOSE(positive);
OP_DECOMPOSE(qr);
OP_DECOMPOSE(ravel);
- OP_DECOMPOSE2(repeat_interleave, self_int);
+ m.impl("repeat_interleave.self_int", native::repeat_interleave_symint);
OP_DECOMPOSE2(repeat_interleave, self_Tensor);
m.impl("reshape", native::reshape_symint);
OP_DECOMPOSE(resolve_conj);
diff --git a/aten/src/ATen/native/Repeat.cpp b/aten/src/ATen/native/Repeat.cpp
index b671a22..c8c4e13 100644
--- a/aten/src/ATen/native/Repeat.cpp
+++ b/aten/src/ATen/native/Repeat.cpp
@@ -75,11 +75,11 @@
}
Tensor repeats_ = repeats;
- if (repeats.dim() == 0 || (repeats.dim() == 1 && repeats.size(0) == 1)) {
- repeats_ = repeats.reshape({1}).expand({input.size(dim.value())});
+ if (repeats.dim() == 0 || (repeats.dim() == 1 && repeats.sym_size(0) == 1)) {
+ repeats_ = repeats.reshape({1}).expand_symint({input.sym_size(dim.value())});
} else if (repeats.dim() == 1) {
TORCH_CHECK(
- repeats.size(0) == input.size(dim.value()),
+ repeats.sym_size(0) == input.sym_size(dim.value()),
"repeats must have the same size as input along dim")
} else {
AT_ERROR("repeats must be 0-dim or 1-dim tensor");
@@ -102,10 +102,17 @@
int64_t repeats,
c10::optional<int64_t> dim,
c10::optional<int64_t> output_size) {
- at::Tensor repeats_ =
- at::empty(1, self.options().dtype(at::kLong)).fill_(repeats);
+ at::Tensor repeats_ = at::empty(1, self.options().dtype(at::kLong)).fill_(repeats);
return at::native::repeat_interleave(self, repeats_, dim, output_size);
}
+Tensor repeat_interleave_symint(
+ const Tensor& self,
+ c10::SymInt repeats,
+ c10::optional<int64_t> dim,
+ c10::optional<int64_t> output_size) {
+ return at::native::repeat_interleave(self, repeats.guard_int(__FILE__, __LINE__), dim, output_size);
+ }
+
} // namespace native
} // namespace at
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index b1d1094..5cf0e75 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -4320,8 +4320,10 @@
- func: repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, int? output_size=None) -> Tensor
variants: function, method
-- func: repeat_interleave.self_int(Tensor self, int repeats, int? dim=None, *, int? output_size=None) -> Tensor
+- func: repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, int? output_size=None) -> Tensor
variants: function, method
+ dispatch:
+ CompositeImplicitAutograd: repeat_interleave_symint
- func: reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)
variants: function, method
diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp
index 411ebdb..c2bf4e0 100644
--- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp
+++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp
@@ -152,8 +152,8 @@
if (t_dim == 3 && nt_input->opt_size(2) && (*nt_input->opt_size(2) > 0) &&
!(output_size.has_value())) {
Tensor nt_sizes = nt_input->get_nested_size_tensor();
- Tensor sizes_dim1 = at::native::narrow(nt_sizes, 1, 0, 1);
- Tensor sizes_dim2 = at::native::narrow(nt_sizes, 1, 1, 1);
+ Tensor sizes_dim1 = at::native::narrow_symint(nt_sizes, 1, 0, 1);
+ Tensor sizes_dim2 = at::native::narrow_symint(nt_sizes, 1, 1, 1);
Tensor result = at::detail::make_tensor<NestedTensorImpl>(
nt_input->get_buffer(), sizes_dim1 * sizes_dim2[0]);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.dim() == 2);
diff --git a/aten/src/ATen/native/quantized/FakeQuantPerTensorAffine.cpp b/aten/src/ATen/native/quantized/FakeQuantPerTensorAffine.cpp
index 700b3b1..aac039f 100644
--- a/aten/src/ATen/native/quantized/FakeQuantPerTensorAffine.cpp
+++ b/aten/src/ATen/native/quantized/FakeQuantPerTensorAffine.cpp
@@ -122,10 +122,10 @@
const Tensor& dY,
const Tensor& mask) {
TORCH_CHECK(mask.scalar_type() == ScalarType::Bool);
- TORCH_CHECK(mask.numel() == dY.numel(),
+ TORCH_CHECK(mask.sym_numel() == dY.sym_numel(),
"`mask` and `dY` are not the same size: ",
- "`mask` is size ", mask.numel(), " and `dY` is size ", dY.numel());
- if (dY.numel() <= 0) {
+ "`mask` is size ", mask.sym_numel(), " and `dY` is size ", dY.sym_numel());
+ if (dY.sym_numel() <= 0) {
return dY;
}
// Note: no additional kernels needed, since mask is pre-computed
diff --git a/c10/core/SymFloat.cpp b/c10/core/SymFloat.cpp
index 81e8f25..511c50e 100644
--- a/c10/core/SymFloat.cpp
+++ b/c10/core/SymFloat.cpp
@@ -70,4 +70,12 @@
return os;
}
+double SymFloat::guard_float(const char* file, int64_t line) const {
+ if (!is_symbolic()) {
+ return data_;
+ }
+ SymNode a = toSymNodeImpl();
+ return a->guard_float(file, line);
+}
+
} // namespace c10
diff --git a/c10/core/SymFloat.h b/c10/core/SymFloat.h
index 7da364c..ff9e101 100644
--- a/c10/core/SymFloat.h
+++ b/c10/core/SymFloat.h
@@ -40,6 +40,16 @@
SymFloat operator*(const SymFloat&) const;
SymFloat operator/(const SymFloat&) const;
+ // Insert a guard for the float to be its concrete value, and then return
+ // that value. This operation always works, even if the float is symbolic,
+ // so long as we know what the underlying value is. Don't blindly put this
+ // everywhere; you can cause overspecialization of PyTorch programs with
+ // this method.
+ //
+ // It should be called as guard_float(__FILE__, __LINE__). The file and line
+ // number can be used to diagnose overspecialization.
+ double guard_float(const char* file, int64_t line) const;
+
// N.B. It's important to keep this definition in the header
// as we expect if checks to be folded for mobile builds
// where `is_symbolic` is always false
diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py
index f3964a7..2eb1678 100644
--- a/test/dynamo/test_dynamic_shapes.py
+++ b/test/dynamo/test_dynamic_shapes.py
@@ -106,11 +106,6 @@
DynamicShapesUnspecTests.test_unspec_float_precision_dynamic_shapes
)
-unittest.expectedFailure(
- DynamicShapesReproTests.test_reformer_sorting_dynamic_shapes
- # Unable to cast Python instance to C++ type
-)
-
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 8dc42be..21682ac 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1122,6 +1122,8 @@
xfail('multinomial'),
xfail('cholesky'),
xfail('cholesky_inverse'),
+ # cannot do these as they rely on tensor data
+ xfail('repeat_interleave'),
# ASAN failures due to divide by 0
skip('nn.functional.nll_loss'),
}
@@ -1283,7 +1285,6 @@
xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco...
xfail('nn.functional.rrelu', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.smooth_l1_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
- xfail('nn.functional.unfold', ''), # aten.im2col.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.upsample_nearest', ''), # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/deco...
xfail('nonzero', ''), # aten.nonzero.default - couldn't find symbolic meta function/decomposition
xfail('norm', 'nuc'), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition
diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py
index da8d9af..22917ec 100644
--- a/torch/_prims/__init__.py
+++ b/torch/_prims/__init__.py
@@ -2323,10 +2323,12 @@
step != 0,
lambda: "step must be nonzero",
)
- utils.check(
- math.isfinite(start) and math.isfinite(end),
- lambda: f"unsupported range: {start} -> {end}",
- )
+ # SymInts can't represent inf
+ if not isinstance(start, torch.SymInt) and not isinstance(end, torch.SymInt):
+ utils.check(
+ math.isfinite(start) and math.isfinite(end),
+ lambda: f"unsupported range: {start} -> {end}",
+ )
utils.check(
(step > 0 and end >= start) or (step < 0 and end <= start),
lambda: "upper bound and lower bound inconsistent with step sign",
diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp
index 707ebeb..8350634 100644
--- a/torch/csrc/utils/tensor_new.cpp
+++ b/torch/csrc/utils/tensor_new.cpp
@@ -79,10 +79,10 @@
c10::TensorOptions options,
at::ScalarType scalar_type,
const optional<Device>& device,
- IntArrayRef sizes) {
+ c10::SymIntArrayRef sizes) {
maybe_initialize_cuda(options.device());
pybind11::gil_scoped_release no_gil;
- return torch::empty(sizes, build_options(options, scalar_type, device));
+ return at::empty_symint(sizes, build_options(options, scalar_type, device));
}
Tensor new_with_storage(
@@ -124,6 +124,12 @@
}
ScalarType infer_scalar_type(PyObject* obj) {
+ if (torch::is_symint(obj)) {
+ return ScalarType::Long;
+ }
+ if (torch::is_symfloat(obj)) {
+ return ScalarType::Double;
+ }
#ifdef USE_NUMPY
if (is_numpy_available()) {
if (PyArray_Check(obj)) {
@@ -204,7 +210,21 @@
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(data != nullptr);
int64_t ndim = sizes.size();
+ bool is_symfloat = torch::is_symfloat(obj);
+ bool is_symint = torch::is_symint(obj);
if (dim == ndim) {
+ if (is_symfloat) {
+ auto new_obj = py::reinterpret_borrow<py::object>(obj);
+ auto val = new_obj.cast<c10::SymFloat>();
+ *(double*)data = val.guard_float(__FILE__, __LINE__);
+ return;
+ }
+ if (is_symint) {
+ auto new_obj = py::reinterpret_borrow<py::object>(obj);
+ auto val = new_obj.cast<c10::SymInt>();
+ *(int64_t*)data = val.guard_int(__FILE__, __LINE__);
+ return;
+ }
torch::utils::store_scalar(data, scalarType, obj);
return;
}
@@ -531,7 +551,7 @@
"new(*, int64_t cdata)|hidden",
"new(Tensor indices, Tensor values, *, Device? device=None)",
"new(Tensor indices, Tensor values, IntArrayRef size, *, Device? device=None)",
- "new(IntArrayRef size, *, Device? device=None)",
+ "new(SymIntArrayRef size, *, Device? device=None)",
});
if (ctor_or_new == CtorOrNew::NEW)
check_base_legacy_new(dispatch_key, c10::kSparse);
@@ -577,7 +597,7 @@
}
}
return new_with_sizes(
- options, scalar_type, r.deviceOptional(1), r.intlist(0));
+ options, scalar_type, r.deviceOptional(1), r.symintlist(0));
}
throw std::runtime_error("new(): invalid arguments");
}
@@ -615,7 +635,7 @@
// matching with
// IntArrayRef,
// PyObject*
- "new(IntArrayRef size, *, Device? device=None)",
+ "new(SymIntArrayRef size, *, Device? device=None)",
"new(PyObject* data, *, Device? device=None)",
});
@@ -690,7 +710,7 @@
options, scalar_type, deviceOptional, r.pyobject(0));
}
return new_with_sizes(
- options, scalar_type, r.deviceOptional(1), r.intlist(0));
+ options, scalar_type, r.deviceOptional(1), r.symintlist(0));
} else if (r.idx == 6) {
auto deviceOptional = r.deviceOptional(1);
check_legacy_ctor_device(dispatch_key, deviceOptional);