Guard static shapes alongside tensors, instead of from shape_env, in dynamic_shapes=True (#99566)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99566
Approved by: https://github.com/ezyang
diff --git a/test/dynamo/test_aot_autograd.py b/test/dynamo/test_aot_autograd.py
index 7a87666..7f2ca72 100644
--- a/test/dynamo/test_aot_autograd.py
+++ b/test/dynamo/test_aot_autograd.py
@@ -426,7 +426,7 @@
self.assertEqual(cc.frame_count, 2)
self.assertExpectedInline(
failure_reason,
- """tensor 'L['a']' strides mismatch at index 0. expected 3, actual 1""",
+ """tensor 'L['a']' stride mismatch at index 0. expected 3, actual 1""",
)
torch._dynamo.reset()
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index bbbadc6..1c782fe 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -436,7 +436,10 @@
)(compare_shapes)
opt_fn(torch.randn([3, 4]))
opt_fn(torch.randn([4, 3]))
- self.assertExpectedInline(guard_failure.reason, """L['a'].size()[0] == 3""")
+ self.assertExpectedInline(
+ guard_failure.reason,
+ """tensor 'L['a']' size mismatch at index 0. expected 3, actual 4""",
+ )
def test_builtin_isinstance(self):
def fn(x):
@@ -3937,7 +3940,10 @@
self.assertTrue(guard_failure is not None)
if torch._dynamo.config.assume_static_by_default:
- self.assertExpectedInline(guard_failure[0], """L['x'].size()[0] == 2""")
+ self.assertExpectedInline(
+ guard_failure[0],
+ """tensor 'L['x']' size mismatch at index 0. expected 2, actual 5""",
+ )
else:
self.assertExpectedInline(guard_failure[0], """L['x'].size()[0] < 3""")
@@ -3970,7 +3976,7 @@
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(
guard_failure[0],
- """L['x'].size()[0] == 2""",
+ """tensor 'L['x']' size mismatch at index 0. expected 2, actual 3""",
)
else:
self.assertTrue(guard_failure is None)
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 822e914..5003765 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1261,9 +1261,13 @@
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(16), torch.randn(8))
from torch._dynamo.source import LocalSource
self.assertExpectedInline(
- str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")])),
+ str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=False)), # noqa: B950
"""["L['a'].size()[0] == 2*L['b'].size()[0]", "L['a'].stride()[0] == 1", "L['a'].storage_offset() == 0", "L['b'].stride()[0] == 1", "L['b'].storage_offset() == 0", "2 <= L['b'].size()[0]"]""" # noqa: B950
)
+ self.assertExpectedInline(
+ str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=True)), # noqa: B950
+ """["L['a'].size()[0] == 2*L['b'].size()[0]", "2 <= L['b'].size()[0]"]""" # noqa: B950
+ )
def test_sym_storage_offset(self):
def f(x, y):
diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py
index efadf1e..d269f15 100644
--- a/torch/_dynamo/guards.py
+++ b/torch/_dynamo/guards.py
@@ -24,7 +24,7 @@
GuardSource,
Source,
)
-from torch.fx.experimental.symbolic_shapes import SYMPY_INTERP
+from torch.fx.experimental.symbolic_shapes import is_concrete_int, SYMPY_INTERP
from . import config, convert_frame, mutation_guard
from .eval_frame import set_guard_error_hook, set_guard_fail_hook
@@ -477,6 +477,8 @@
[a.source for a in fs],
constraint_inputs=constraint_inputs,
source_ref=self.source_ref,
+ # Export keeps static.
+ ignore_static=(not self.check_fn_manager.output_graph.export),
)
output_graph.shape_env.freeze()
for shape_guard in guards:
@@ -574,7 +576,6 @@
code.append(
f"hasattr({tensor_name}, '_dynamo_dynamic_indices') == False"
)
-
if len(code) > 0:
self._produce_guard_code(guard, code)
@@ -738,8 +739,40 @@
local_builder.tensor_check_examples
+ global_builder.tensor_check_examples
)
+ dynamic_dims_sizes = None
+ dynamic_dims_strides = None
+ if config.dynamic_shapes:
+
+ def convert(size_or_stride):
+ converted: List[Optional[int]] = []
+ for dim in size_or_stride:
+ if is_concrete_int(dim):
+ converted.append(int(dim))
+ else:
+ converted.append(None)
+ return converted
+
+ dynamic_dims_sizes = [
+ convert(
+ self.output_graph.tracing_context.fake_mode.from_tensor(
+ t
+ ).size()
+ )
+ for t in tensor_check_examples
+ ]
+ dynamic_dims_strides = [
+ convert(
+ self.output_graph.tracing_context.fake_mode.from_tensor(
+ t
+ ).stride()
+ )
+ for t in tensor_check_examples
+ ]
+
tensor_guards = TensorGuards(
- *tensor_check_examples, dynamic_shapes=config.dynamic_shapes
+ *tensor_check_examples,
+ dynamic_dims_sizes=dynamic_dims_sizes,
+ dynamic_dims_strides=dynamic_dims_strides,
)
check_tensors_fn = tensor_guards.check
check_tensors_verbose_fn = tensor_guards.check_verbose
diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py
index ffac547..0fe5a30 100644
--- a/torch/_dynamo/variables/builder.py
+++ b/torch/_dynamo/variables/builder.py
@@ -1160,8 +1160,6 @@
if e.size()[i] != dim:
curr_sizes[i] = None
- tx.output.frame_state[name] = curr_sizes
-
# TODO: index export_constraints ahead of time so we don't have to
# do a linear scan every time here
t_id = id(e)
@@ -1196,6 +1194,12 @@
automatic_dynamic = config.automatic_dynamic_shapes and (
curr_sizes is None or curr_sizes[i] is None
)
+
+ # Reflect the user directive in the frame_state
+ # For dynamic, apply None always
+ if marked_dynamic:
+ curr_sizes[i] = None
+
# We will process constraints first, as they will imply that we
# have a dynamic dimension
# Precedence: export constraints > eager constraints
@@ -1219,6 +1223,8 @@
dynamic = DimDynamic.DUCK
dynamic_dims.append(dynamic)
+ tx.output.frame_state[name] = curr_sizes
+
fake_e = wrap_fake_exception(
lambda: tx.fake_mode.from_tensor(
e,
diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp
index 1436e5f..8076637 100644
--- a/torch/csrc/dynamo/guards.cpp
+++ b/torch/csrc/dynamo/guards.cpp
@@ -26,24 +26,18 @@
const LocalState& state,
PyTypeObject* pt,
const at::Tensor& v,
- bool dynamic_shapes)
+ std::vector<std::optional<int64_t>> dynamic_dims_sizes,
+ std::vector<std::optional<int64_t>> dynamic_dims_strides)
: pytype(pt),
dispatch_key_(state.apply(v.key_set()).raw_repr()),
dtype_(v.dtype().toScalarType()),
device_index_(v.device().index()),
requires_grad_(state.grad_mode_enabled && v.requires_grad()),
- dynamic_shapes_(dynamic_shapes) {
- dim_ = v.ndimension();
- if (!dynamic_shapes_) {
- const auto& sizes = v.sizes();
- const auto& strides = v.strides();
- sizes_.reserve(dim_);
- strides_.reserve(dim_);
- for (auto i : c10::irange(dim_)) {
- sizes_.emplace_back(sizes[i]);
- strides_.emplace_back(strides[i]);
- }
- }
+ sizes_(std::move(dynamic_dims_sizes)),
+ strides_(std::move(dynamic_dims_strides)) {
+ // TODO(voz): In cases where sizes_ and strides_ are fully dynamic, should
+ // we just treat this as optional?
+ dim_ = sizes_.size();
}
// See note in guards.py [Note - On Export Tensor Guards]
@@ -59,11 +53,18 @@
if (ndim != dim_) {
return false;
}
- if (!dynamic_shapes_) {
- const auto& sizes = v.sizes();
- const auto& strides = v.strides();
- for (auto i : c10::irange(ndim)) {
- if (sizes_[i] != sizes[i] || strides_[i] != strides[i]) {
+ const auto& sizes = v.sizes();
+ const auto& strides = v.strides();
+ for (auto i : c10::irange(ndim)) {
+ auto known_size = sizes_[i];
+ auto known_stride = strides_[i];
+ if (known_size.has_value()) {
+ if (known_size.value() != sizes[i]) {
+ return false;
+ }
+ }
+ if (known_stride.has_value()) {
+ if (known_stride.value() != strides[i]) {
return false;
}
}
@@ -112,23 +113,20 @@
<< ndim;
return fail_reason.str();
}
- if (!dynamic_shapes_) {
- const auto& sizes = v.sizes();
- const auto& strides = v.strides();
- for (auto i : c10::irange(ndim)) {
- if (sizes_[i] != sizes[i]) {
- // return fmt::format("tensor size mismatch at index {}. expected {},
- // actual {}", i, sizes_[i], sizes[i]);
- fail_reason << "size mismatch at index " << i << ". expected "
- << sizes_[i] << ", actual " << sizes[i];
- return fail_reason.str();
- } else if (strides_[i] != strides[i]) {
- // return fmt::format("tensor strides mismatch at index {}. expected
- // {}, actual {}", i, strides_[i]);
- fail_reason << "strides mismatch at index " << i << ". expected "
- << strides_[i] << ", actual " << strides[i];
- return fail_reason.str();
- }
+ const auto& sizes = v.sizes();
+ const auto& strides = v.strides();
+ for (auto i : c10::irange(ndim)) {
+ auto known_size = sizes_[i];
+ auto known_stride = strides_[i];
+ if (known_size.has_value() && (known_size.value() != sizes[i])) {
+ fail_reason << "size mismatch at index " << i << ". expected "
+ << known_size.value() << ", actual " << sizes[i];
+ return fail_reason.str();
+ }
+ if (known_stride.has_value() && known_stride.value() != strides[i]) {
+ fail_reason << "stride mismatch at index " << i << ". expected "
+ << known_stride.value() << ", actual " << strides[i];
+ return fail_reason.str();
}
}
return "";
@@ -144,10 +142,9 @@
// necessarily capture device indices correctly.
at::DeviceIndex device_index_;
bool requires_grad_;
- bool dynamic_shapes_;
// NB: These are unset if dynamic shapes is enabled.
- std::vector<int64_t> sizes_;
- std::vector<int64_t> strides_;
+ std::vector<std::optional<int64_t>> sizes_;
+ std::vector<std::optional<int64_t>> strides_;
// Not strictly required for dense tensors, but nested tensors need it.
int64_t dim_;
};
@@ -178,6 +175,51 @@
return (PyObject*)self;
}
+static std::vector<std::optional<int64_t>> wrapIntegersInOptional(
+ const c10::IntArrayRef& intArray) {
+ std::vector<std::optional<int64_t>> optVec(intArray.size());
+ std::transform(
+ intArray.begin(), intArray.end(), optVec.begin(), [](int64_t value) {
+ return std::make_optional(value);
+ });
+ return optVec;
+}
+
+static std::vector<std::optional<int64_t>> pyListToVecOptInt(PyObject* pyList) {
+ std::vector<std::optional<int64_t>> vec;
+ Py_ssize_t size = PyList_Size(pyList);
+ for (Py_ssize_t i = 0; i < size; i++) {
+ PyObject* item = PyList_GetItem(pyList, i);
+ if (item == Py_None) {
+ vec.push_back(std::nullopt);
+ } else {
+ int64_t value = PyLong_AsLongLong(item);
+ if (value == -1 && PyErr_Occurred()) {
+ PyErr_SetString(
+ PyExc_TypeError,
+ "Size or stride list item is not a valid integer.");
+ TORCH_CHECK(false, "Size or stride list item is not a valid integer.");
+ }
+ vec.push_back(value);
+ }
+ }
+ return vec;
+}
+
+static std::vector<std::vector<std::optional<int64_t>>> get_dynamic_dims(
+ PyObject* dynamic_dims_py) {
+ std::vector<std::vector<std::optional<int64_t>>> per_tensor_dynamic_dims;
+ if (dynamic_dims_py != Py_None) {
+ Py_ssize_t size = PyList_Size(dynamic_dims_py);
+ for (Py_ssize_t i = 0; i < size; i++) {
+ PyObject* py_list = PyList_GetItem(dynamic_dims_py, i);
+ std::vector<std::optional<int64_t>> vec = pyListToVecOptInt(py_list);
+ per_tensor_dynamic_dims.push_back(std::move(vec));
+ }
+ }
+ return per_tensor_dynamic_dims;
+}
+
static int TensorGuards_init(
TensorGuards* self,
PyObject* args,
@@ -186,12 +228,27 @@
PyErr_SetString(PyExc_TypeError, "expected tuple()");
return -1;
}
- PyObject* dynamic_shapes_py = PyDict_GetItemString(kwds, "dynamic_shapes");
- if (dynamic_shapes_py == NULL) {
- PyErr_SetString(PyExc_TypeError, "missing dynamic_shapes=...");
+ // Top level structure is List[List[Union[int, None]]]
+ PyObject* dynamic_dims_sizes_py =
+ PyDict_GetItemString(kwds, "dynamic_dims_sizes");
+ if (dynamic_dims_sizes_py == NULL) {
+ PyErr_SetString(PyExc_TypeError, "missing dynamic_dims_sizes=...");
return -1;
}
- bool dynamic_shapes = PyObject_IsTrue(dynamic_shapes_py);
+ PyObject* dynamic_dims_strides_py =
+ PyDict_GetItemString(kwds, "dynamic_dims_strides");
+ if (dynamic_dims_strides_py == NULL) {
+ PyErr_SetString(PyExc_TypeError, "missing dynamic_dims_strides=...");
+ return -1;
+ }
+
+ // dynamic_dims_strides/sizes_py is None when dynamic_shapes=False - this is
+ // an optimization to avoid invoking .size()/.stride() in python needlessly
+ std::vector<std::vector<std::optional<int64_t>>>
+ per_tensor_dynamic_dims_sizes = get_dynamic_dims(dynamic_dims_sizes_py);
+ std::vector<std::vector<std::optional<int64_t>>>
+ per_tensor_dynamic_dims_strides =
+ get_dynamic_dims(dynamic_dims_strides_py);
auto& checks = *self->checks;
auto len = PyTuple_GET_SIZE(args);
@@ -203,8 +260,21 @@
PyErr_SetString(PyExc_TypeError, "expected Tensor()");
return -1;
}
+ auto tensor = THPVariable_Unpack(item);
+ std::vector<std::optional<int64_t>> tensor_dims_size =
+ per_tensor_dynamic_dims_sizes.size() == 0
+ ? wrapIntegersInOptional(tensor.sizes())
+ : per_tensor_dynamic_dims_sizes[i];
+ std::vector<std::optional<int64_t>> tensor_dims_stride =
+ per_tensor_dynamic_dims_strides.size() == 0
+ ? wrapIntegersInOptional(tensor.strides())
+ : per_tensor_dynamic_dims_strides[i];
checks.emplace_back(
- state, Py_TYPE(item), THPVariable_Unpack(item), dynamic_shapes);
+ state,
+ Py_TYPE(item),
+ std::move(tensor),
+ std::move(tensor_dims_size),
+ std::move(tensor_dims_stride));
}
return 0;
}
diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py
index caf0975..7db1236 100644
--- a/torch/fx/experimental/symbolic_shapes.py
+++ b/torch/fx/experimental/symbolic_shapes.py
@@ -415,8 +415,8 @@
# Given a GraphModule and arguments to run it with, evaluate that the guards
# for its associated ShapeEnv are satisfied by the passed arguments. This
# WILL check for duck sizing.
-def eval_guards(gm, *args):
- return gm.shape_env.evaluate_guards_for_args(fx_placeholder_vals(gm), args)
+def eval_guards(gm, *args, ignore_static=True):
+ return gm.shape_env.evaluate_guards_for_args(fx_placeholder_vals(gm), args, ignore_static=ignore_static)
def bind_symbols(gm, *args):
return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args)
@@ -2020,7 +2020,9 @@
# DimList[DimConstraint]). Whenever Optional is accepted, that
# just means there are no constraints
constraint_inputs: Optional[InputList[Union[DimConstraint, Optional[DimList[DimConstraint]]]]] = None,
- _simplified=False
+ _simplified=False,
+ # Indicates if we should produce guards for known static values.
+ ignore_static=True,
) -> List[str]:
self.log.info("produce_guards")
@@ -2209,8 +2211,18 @@
source == symbol_to_source[expr][0]
):
continue
+
+ # This logic excludes static values found on tensors from guarding, because
+ # dynamo's check_tensor_fn does that (see guards.cpp).
+ # However, for non tensor sources, we still need to guard here.
+ if ignore_static and isinstance(source, TensorPropertySource):
+ if len(expr.free_symbols) == 0:
+ self.log.debug("Skipping guard %s", f"{source_ref(source)} == {expr}")
+ continue
+
if is_dim(source):
dim_constraints.add_equality(source, expr)
+
sexpr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr)
exprs.append(f"{source_ref(source)} == {sexpr}")
# NB: Not necessary to report constraint violations here:
@@ -2328,10 +2340,10 @@
return exprs
- def evaluate_guards_for_args(self, placeholders, args):
+ def evaluate_guards_for_args(self, placeholders, args, *, ignore_static=True):
from torch._dynamo.source import LocalSource
arg_names = [f"t{i}" for i in range(len(args))]
- guards = self.produce_guards(placeholders, [LocalSource(a) for a in arg_names])
+ guards = self.produce_guards(placeholders, [LocalSource(a) for a in arg_names], ignore_static=ignore_static)
if guards:
code = " and ".join(guards)
return eval(code, SYMPY_INTERP, {"L": dict(zip(arg_names, args))})