Guard static shapes alongside tensors, instead of from shape_env, in dynamic_shapes=True (#99566)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99566
Approved by: https://github.com/ezyang
diff --git a/test/dynamo/test_aot_autograd.py b/test/dynamo/test_aot_autograd.py
index 7a87666..7f2ca72 100644
--- a/test/dynamo/test_aot_autograd.py
+++ b/test/dynamo/test_aot_autograd.py
@@ -426,7 +426,7 @@
         self.assertEqual(cc.frame_count, 2)
         self.assertExpectedInline(
             failure_reason,
-            """tensor 'L['a']' strides mismatch at index 0. expected 3, actual 1""",
+            """tensor 'L['a']' stride mismatch at index 0. expected 3, actual 1""",
         )
 
         torch._dynamo.reset()
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index bbbadc6..1c782fe 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -436,7 +436,10 @@
         )(compare_shapes)
         opt_fn(torch.randn([3, 4]))
         opt_fn(torch.randn([4, 3]))
-        self.assertExpectedInline(guard_failure.reason, """L['a'].size()[0] == 3""")
+        self.assertExpectedInline(
+            guard_failure.reason,
+            """tensor 'L['a']' size mismatch at index 0. expected 3, actual 4""",
+        )
 
     def test_builtin_isinstance(self):
         def fn(x):
@@ -3937,7 +3940,10 @@
 
         self.assertTrue(guard_failure is not None)
         if torch._dynamo.config.assume_static_by_default:
-            self.assertExpectedInline(guard_failure[0], """L['x'].size()[0] == 2""")
+            self.assertExpectedInline(
+                guard_failure[0],
+                """tensor 'L['x']' size mismatch at index 0. expected 2, actual 5""",
+            )
         else:
             self.assertExpectedInline(guard_failure[0], """L['x'].size()[0] < 3""")
 
@@ -3970,7 +3976,7 @@
             if torch._dynamo.config.assume_static_by_default:
                 self.assertExpectedInline(
                     guard_failure[0],
-                    """L['x'].size()[0] == 2""",
+                    """tensor 'L['x']' size mismatch at index 0. expected 2, actual 3""",
                 )
             else:
                 self.assertTrue(guard_failure is None)
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 822e914..5003765 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1261,9 +1261,13 @@
         fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(16), torch.randn(8))
         from torch._dynamo.source import LocalSource
         self.assertExpectedInline(
-            str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")])),
+            str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=False)),  # noqa: B950
             """["L['a'].size()[0] == 2*L['b'].size()[0]", "L['a'].stride()[0] == 1", "L['a'].storage_offset() == 0", "L['b'].stride()[0] == 1", "L['b'].storage_offset() == 0", "2 <= L['b'].size()[0]"]"""  # noqa: B950
         )
+        self.assertExpectedInline(
+            str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=True)),  # noqa: B950
+            """["L['a'].size()[0] == 2*L['b'].size()[0]", "2 <= L['b'].size()[0]"]"""  # noqa: B950
+        )
 
     def test_sym_storage_offset(self):
         def f(x, y):
diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py
index efadf1e..d269f15 100644
--- a/torch/_dynamo/guards.py
+++ b/torch/_dynamo/guards.py
@@ -24,7 +24,7 @@
     GuardSource,
     Source,
 )
-from torch.fx.experimental.symbolic_shapes import SYMPY_INTERP
+from torch.fx.experimental.symbolic_shapes import is_concrete_int, SYMPY_INTERP
 
 from . import config, convert_frame, mutation_guard
 from .eval_frame import set_guard_error_hook, set_guard_fail_hook
@@ -477,6 +477,8 @@
             [a.source for a in fs],
             constraint_inputs=constraint_inputs,
             source_ref=self.source_ref,
+            # Export keeps static.
+            ignore_static=(not self.check_fn_manager.output_graph.export),
         )
         output_graph.shape_env.freeze()
         for shape_guard in guards:
@@ -574,7 +576,6 @@
                     code.append(
                         f"hasattr({tensor_name}, '_dynamo_dynamic_indices') == False"
                     )
-
             if len(code) > 0:
                 self._produce_guard_code(guard, code)
 
@@ -738,8 +739,40 @@
                 local_builder.tensor_check_examples
                 + global_builder.tensor_check_examples
             )
+            dynamic_dims_sizes = None
+            dynamic_dims_strides = None
+            if config.dynamic_shapes:
+
+                def convert(size_or_stride):
+                    converted: List[Optional[int]] = []
+                    for dim in size_or_stride:
+                        if is_concrete_int(dim):
+                            converted.append(int(dim))
+                        else:
+                            converted.append(None)
+                    return converted
+
+                dynamic_dims_sizes = [
+                    convert(
+                        self.output_graph.tracing_context.fake_mode.from_tensor(
+                            t
+                        ).size()
+                    )
+                    for t in tensor_check_examples
+                ]
+                dynamic_dims_strides = [
+                    convert(
+                        self.output_graph.tracing_context.fake_mode.from_tensor(
+                            t
+                        ).stride()
+                    )
+                    for t in tensor_check_examples
+                ]
+
             tensor_guards = TensorGuards(
-                *tensor_check_examples, dynamic_shapes=config.dynamic_shapes
+                *tensor_check_examples,
+                dynamic_dims_sizes=dynamic_dims_sizes,
+                dynamic_dims_strides=dynamic_dims_strides,
             )
             check_tensors_fn = tensor_guards.check
             check_tensors_verbose_fn = tensor_guards.check_verbose
diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py
index ffac547..0fe5a30 100644
--- a/torch/_dynamo/variables/builder.py
+++ b/torch/_dynamo/variables/builder.py
@@ -1160,8 +1160,6 @@
                         if e.size()[i] != dim:
                             curr_sizes[i] = None
 
-        tx.output.frame_state[name] = curr_sizes
-
         # TODO: index export_constraints ahead of time so we don't have to
         # do a linear scan every time here
         t_id = id(e)
@@ -1196,6 +1194,12 @@
                 automatic_dynamic = config.automatic_dynamic_shapes and (
                     curr_sizes is None or curr_sizes[i] is None
                 )
+
+                # Reflect the user directive in the frame_state
+                # For dynamic, apply None always
+                if marked_dynamic:
+                    curr_sizes[i] = None
+
                 # We will process constraints first, as they will imply that we
                 # have a dynamic dimension
                 # Precedence: export constraints > eager constraints
@@ -1219,6 +1223,8 @@
                     dynamic = DimDynamic.DUCK
                 dynamic_dims.append(dynamic)
 
+        tx.output.frame_state[name] = curr_sizes
+
         fake_e = wrap_fake_exception(
             lambda: tx.fake_mode.from_tensor(
                 e,
diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp
index 1436e5f..8076637 100644
--- a/torch/csrc/dynamo/guards.cpp
+++ b/torch/csrc/dynamo/guards.cpp
@@ -26,24 +26,18 @@
       const LocalState& state,
       PyTypeObject* pt,
       const at::Tensor& v,
-      bool dynamic_shapes)
+      std::vector<std::optional<int64_t>> dynamic_dims_sizes,
+      std::vector<std::optional<int64_t>> dynamic_dims_strides)
       : pytype(pt),
         dispatch_key_(state.apply(v.key_set()).raw_repr()),
         dtype_(v.dtype().toScalarType()),
         device_index_(v.device().index()),
         requires_grad_(state.grad_mode_enabled && v.requires_grad()),
-        dynamic_shapes_(dynamic_shapes) {
-    dim_ = v.ndimension();
-    if (!dynamic_shapes_) {
-      const auto& sizes = v.sizes();
-      const auto& strides = v.strides();
-      sizes_.reserve(dim_);
-      strides_.reserve(dim_);
-      for (auto i : c10::irange(dim_)) {
-        sizes_.emplace_back(sizes[i]);
-        strides_.emplace_back(strides[i]);
-      }
-    }
+        sizes_(std::move(dynamic_dims_sizes)),
+        strides_(std::move(dynamic_dims_strides)) {
+    // TODO(voz): In cases where sizes_ and strides_ are fully dynamic, should
+    // we just treat this as optional?
+    dim_ = sizes_.size();
   }
 
   // See note in guards.py [Note - On Export Tensor Guards]
@@ -59,11 +53,18 @@
     if (ndim != dim_) {
       return false;
     }
-    if (!dynamic_shapes_) {
-      const auto& sizes = v.sizes();
-      const auto& strides = v.strides();
-      for (auto i : c10::irange(ndim)) {
-        if (sizes_[i] != sizes[i] || strides_[i] != strides[i]) {
+    const auto& sizes = v.sizes();
+    const auto& strides = v.strides();
+    for (auto i : c10::irange(ndim)) {
+      auto known_size = sizes_[i];
+      auto known_stride = strides_[i];
+      if (known_size.has_value()) {
+        if (known_size.value() != sizes[i]) {
+          return false;
+        }
+      }
+      if (known_stride.has_value()) {
+        if (known_stride.value() != strides[i]) {
           return false;
         }
       }
@@ -112,23 +113,20 @@
                   << ndim;
       return fail_reason.str();
     }
-    if (!dynamic_shapes_) {
-      const auto& sizes = v.sizes();
-      const auto& strides = v.strides();
-      for (auto i : c10::irange(ndim)) {
-        if (sizes_[i] != sizes[i]) {
-          // return fmt::format("tensor size mismatch at index {}. expected {},
-          // actual {}", i, sizes_[i], sizes[i]);
-          fail_reason << "size mismatch at index " << i << ". expected "
-                      << sizes_[i] << ", actual " << sizes[i];
-          return fail_reason.str();
-        } else if (strides_[i] != strides[i]) {
-          // return fmt::format("tensor strides mismatch at index {}. expected
-          // {}, actual {}", i, strides_[i]);
-          fail_reason << "strides mismatch at index " << i << ". expected "
-                      << strides_[i] << ", actual " << strides[i];
-          return fail_reason.str();
-        }
+    const auto& sizes = v.sizes();
+    const auto& strides = v.strides();
+    for (auto i : c10::irange(ndim)) {
+      auto known_size = sizes_[i];
+      auto known_stride = strides_[i];
+      if (known_size.has_value() && (known_size.value() != sizes[i])) {
+        fail_reason << "size mismatch at index " << i << ". expected "
+                    << known_size.value() << ", actual " << sizes[i];
+        return fail_reason.str();
+      }
+      if (known_stride.has_value() && known_stride.value() != strides[i]) {
+        fail_reason << "stride mismatch at index " << i << ". expected "
+                    << known_stride.value() << ", actual " << strides[i];
+        return fail_reason.str();
       }
     }
     return "";
@@ -144,10 +142,9 @@
   // necessarily capture device indices correctly.
   at::DeviceIndex device_index_;
   bool requires_grad_;
-  bool dynamic_shapes_;
   // NB: These are unset if dynamic shapes is enabled.
-  std::vector<int64_t> sizes_;
-  std::vector<int64_t> strides_;
+  std::vector<std::optional<int64_t>> sizes_;
+  std::vector<std::optional<int64_t>> strides_;
   // Not strictly required for dense tensors, but nested tensors need it.
   int64_t dim_;
 };
@@ -178,6 +175,51 @@
   return (PyObject*)self;
 }
 
+static std::vector<std::optional<int64_t>> wrapIntegersInOptional(
+    const c10::IntArrayRef& intArray) {
+  std::vector<std::optional<int64_t>> optVec(intArray.size());
+  std::transform(
+      intArray.begin(), intArray.end(), optVec.begin(), [](int64_t value) {
+        return std::make_optional(value);
+      });
+  return optVec;
+}
+
+static std::vector<std::optional<int64_t>> pyListToVecOptInt(PyObject* pyList) {
+  std::vector<std::optional<int64_t>> vec;
+  Py_ssize_t size = PyList_Size(pyList);
+  for (Py_ssize_t i = 0; i < size; i++) {
+    PyObject* item = PyList_GetItem(pyList, i);
+    if (item == Py_None) {
+      vec.push_back(std::nullopt);
+    } else {
+      int64_t value = PyLong_AsLongLong(item);
+      if (value == -1 && PyErr_Occurred()) {
+        PyErr_SetString(
+            PyExc_TypeError,
+            "Size or stride list item is not a valid integer.");
+        TORCH_CHECK(false, "Size or stride list item is not a valid integer.");
+      }
+      vec.push_back(value);
+    }
+  }
+  return vec;
+}
+
+static std::vector<std::vector<std::optional<int64_t>>> get_dynamic_dims(
+    PyObject* dynamic_dims_py) {
+  std::vector<std::vector<std::optional<int64_t>>> per_tensor_dynamic_dims;
+  if (dynamic_dims_py != Py_None) {
+    Py_ssize_t size = PyList_Size(dynamic_dims_py);
+    for (Py_ssize_t i = 0; i < size; i++) {
+      PyObject* py_list = PyList_GetItem(dynamic_dims_py, i);
+      std::vector<std::optional<int64_t>> vec = pyListToVecOptInt(py_list);
+      per_tensor_dynamic_dims.push_back(std::move(vec));
+    }
+  }
+  return per_tensor_dynamic_dims;
+}
+
 static int TensorGuards_init(
     TensorGuards* self,
     PyObject* args,
@@ -186,12 +228,27 @@
     PyErr_SetString(PyExc_TypeError, "expected tuple()");
     return -1;
   }
-  PyObject* dynamic_shapes_py = PyDict_GetItemString(kwds, "dynamic_shapes");
-  if (dynamic_shapes_py == NULL) {
-    PyErr_SetString(PyExc_TypeError, "missing dynamic_shapes=...");
+  // Top level structure is List[List[Union[int, None]]]
+  PyObject* dynamic_dims_sizes_py =
+      PyDict_GetItemString(kwds, "dynamic_dims_sizes");
+  if (dynamic_dims_sizes_py == NULL) {
+    PyErr_SetString(PyExc_TypeError, "missing dynamic_dims_sizes=...");
     return -1;
   }
-  bool dynamic_shapes = PyObject_IsTrue(dynamic_shapes_py);
+  PyObject* dynamic_dims_strides_py =
+      PyDict_GetItemString(kwds, "dynamic_dims_strides");
+  if (dynamic_dims_strides_py == NULL) {
+    PyErr_SetString(PyExc_TypeError, "missing dynamic_dims_strides=...");
+    return -1;
+  }
+
+  // dynamic_dims_strides/sizes_py is None when dynamic_shapes=False - this is
+  // an optimization to avoid invoking .size()/.stride() in python needlessly
+  std::vector<std::vector<std::optional<int64_t>>>
+      per_tensor_dynamic_dims_sizes = get_dynamic_dims(dynamic_dims_sizes_py);
+  std::vector<std::vector<std::optional<int64_t>>>
+      per_tensor_dynamic_dims_strides =
+          get_dynamic_dims(dynamic_dims_strides_py);
 
   auto& checks = *self->checks;
   auto len = PyTuple_GET_SIZE(args);
@@ -203,8 +260,21 @@
       PyErr_SetString(PyExc_TypeError, "expected Tensor()");
       return -1;
     }
+    auto tensor = THPVariable_Unpack(item);
+    std::vector<std::optional<int64_t>> tensor_dims_size =
+        per_tensor_dynamic_dims_sizes.size() == 0
+        ? wrapIntegersInOptional(tensor.sizes())
+        : per_tensor_dynamic_dims_sizes[i];
+    std::vector<std::optional<int64_t>> tensor_dims_stride =
+        per_tensor_dynamic_dims_strides.size() == 0
+        ? wrapIntegersInOptional(tensor.strides())
+        : per_tensor_dynamic_dims_strides[i];
     checks.emplace_back(
-        state, Py_TYPE(item), THPVariable_Unpack(item), dynamic_shapes);
+        state,
+        Py_TYPE(item),
+        std::move(tensor),
+        std::move(tensor_dims_size),
+        std::move(tensor_dims_stride));
   }
   return 0;
 }
diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py
index caf0975..7db1236 100644
--- a/torch/fx/experimental/symbolic_shapes.py
+++ b/torch/fx/experimental/symbolic_shapes.py
@@ -415,8 +415,8 @@
 # Given a GraphModule and arguments to run it with, evaluate that the guards
 # for its associated ShapeEnv are satisfied by the passed arguments.  This
 # WILL check for duck sizing.
-def eval_guards(gm, *args):
-    return gm.shape_env.evaluate_guards_for_args(fx_placeholder_vals(gm), args)
+def eval_guards(gm, *args, ignore_static=True):
+    return gm.shape_env.evaluate_guards_for_args(fx_placeholder_vals(gm), args, ignore_static=ignore_static)
 
 def bind_symbols(gm, *args):
     return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args)
@@ -2020,7 +2020,9 @@
         # DimList[DimConstraint]).  Whenever Optional is accepted, that
         # just means there are no constraints
         constraint_inputs: Optional[InputList[Union[DimConstraint, Optional[DimList[DimConstraint]]]]] = None,
-        _simplified=False
+        _simplified=False,
+        # Indicates if we should produce guards for known static values.
+        ignore_static=True,
     ) -> List[str]:
         self.log.info("produce_guards")
 
@@ -2209,8 +2211,18 @@
                     source == symbol_to_source[expr][0]
                 ):
                     continue
+
+                # This logic excludes static values found on tensors from guarding, because
+                # dynamo's check_tensor_fn does that (see guards.cpp).
+                # However, for non tensor sources, we still need to guard here.
+                if ignore_static and isinstance(source, TensorPropertySource):
+                    if len(expr.free_symbols) == 0:
+                        self.log.debug("Skipping guard %s", f"{source_ref(source)} == {expr}")
+                        continue
+
                 if is_dim(source):
                     dim_constraints.add_equality(source, expr)
+
                 sexpr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr)
                 exprs.append(f"{source_ref(source)} == {sexpr}")
                 # NB: Not necessary to report constraint violations here:
@@ -2328,10 +2340,10 @@
 
         return exprs
 
-    def evaluate_guards_for_args(self, placeholders, args):
+    def evaluate_guards_for_args(self, placeholders, args, *, ignore_static=True):
         from torch._dynamo.source import LocalSource
         arg_names = [f"t{i}" for i in range(len(args))]
-        guards = self.produce_guards(placeholders, [LocalSource(a) for a in arg_names])
+        guards = self.produce_guards(placeholders, [LocalSource(a) for a in arg_names], ignore_static=ignore_static)
         if guards:
             code = " and ".join(guards)
             return eval(code, SYMPY_INTERP, {"L": dict(zip(arg_names, args))})