Don't invoke mode as overloaded argument in torch dispatch (#80992)

I noticed that in some situations torch dispatch modes were being
invoked with a mode active, which isn't supposed to happen (we
disable modes before calling into the user mode.)  I also noticed that
I was getting a warning that I had a deprecated non-static definition of
torch dispatch on an argument even though there wasn't any.

It turns out this is because modes were part of the overloaded arguments
list in the Python fallback kernel for torch dispatch.  This is wrong;
instead we should rely on the actual dispatching function to consult
modes.  This makes the code simpler.

Signed-off-by: Edward Z. Yang <[email protected]>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80992
Approved by: https://github.com/zou3519
diff --git a/aten/src/ATen/core/PythonFallbackKernel.cpp b/aten/src/ATen/core/PythonFallbackKernel.cpp
index f9f3d6f..37b46ae 100644
--- a/aten/src/ATen/core/PythonFallbackKernel.cpp
+++ b/aten/src/ATen/core/PythonFallbackKernel.cpp
@@ -53,7 +53,7 @@
   // If Torch Dispatch Mode is active, use its PyInterpreter for dispatch
   const auto& maybe_torch_dispatch_mode_state = at::impl::TorchDispatchModeTLS::get_state();
   if (maybe_torch_dispatch_mode_state) {
-    maybe_torch_dispatch_mode_state->pyinterpreter()->dispatch(op, stack, maybe_torch_dispatch_mode_state);
+    maybe_torch_dispatch_mode_state->pyinterpreter()->dispatch(op, stack);
     return;
   }
 
@@ -69,7 +69,7 @@
     if (ivalue.isTensor()) {
       auto* interpreter = ivalue.unsafeToTensorImpl()->pyobj_interpreter();
       if (interpreter) {
-        interpreter->dispatch(op, stack, nullptr);
+        interpreter->dispatch(op, stack);
         return;
       }
     } else if (ivalue.isTensorList()) {
@@ -78,7 +78,7 @@
       for (const auto& nv : ivalue.toListRef()) {
         auto* interpreter = nv.unsafeToTensorImpl()->pyobj_interpreter();
         if (interpreter) {
-          interpreter->dispatch(op, stack, nullptr);
+          interpreter->dispatch(op, stack);
           return;
         }
       }
diff --git a/c10/core/impl/PyInterpreter.cpp b/c10/core/impl/PyInterpreter.cpp
index 1f70460..eec1d23 100644
--- a/c10/core/impl/PyInterpreter.cpp
+++ b/c10/core/impl/PyInterpreter.cpp
@@ -24,8 +24,7 @@
 static void noop_dispatch_fn(
     const PyInterpreter*,
     const c10::OperatorHandle& op,
-    torch::jit::Stack* stack,
-    const std::shared_ptr<SafePyObject>& type) {
+    torch::jit::Stack* stack) {
   TORCH_INTERNAL_ASSERT(
       0,
       "attempted to dispatch (__torch_dispatch__) an operator on Tensor with nontrivial PyObject after corresponding interpreter died");
diff --git a/c10/core/impl/PyInterpreter.h b/c10/core/impl/PyInterpreter.h
index 510eb7b..db3d975 100644
--- a/c10/core/impl/PyInterpreter.h
+++ b/c10/core/impl/PyInterpreter.h
@@ -127,9 +127,7 @@
   using dispatch_sig = void(
       const PyInterpreter*,
       const c10::OperatorHandle&,
-      torch::jit::Stack* stack,
-      // This is a Tensor subclass type object
-      const std::shared_ptr<SafePyObject>& type);
+      torch::jit::Stack* stack);
   using is_contiguous_sig = bool(const PyInterpreter*, const TensorImpl*);
   using device_sig = c10::Device(const PyInterpreter*, const TensorImpl*);
   using dim_sig = int64_t(const PyInterpreter*, const TensorImpl*);
@@ -203,9 +201,8 @@
   // Invoke the Python boxed fallback dispatch to go back into Python
   __ubsan_ignore_function__ void dispatch(
       const c10::OperatorHandle& op,
-      torch::jit::Stack* stack,
-      const std::shared_ptr<SafePyObject>& type) const {
-    return (*dispatch_fn_)(this, op, stack, type);
+      torch::jit::Stack* stack) const {
+    return (*dispatch_fn_)(this, op, stack);
   }
 
   __ubsan_ignore_function__ bool is_contiguous(const TensorImpl* self) const {
diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py
index dea1770..ea43f8e 100644
--- a/test/test_python_dispatch.py
+++ b/test/test_python_dispatch.py
@@ -1100,6 +1100,56 @@
 
         assert self.assertRaisesRegex(RuntimeError, "subclass Mode but.* associated to a python object of type Mode")
 
+    def test_notimplemented_mode(self):
+        sub_count = 0
+
+        class PoliteMode(TorchDispatchMode):
+            def __init__(self):
+                self.pre_count = 0
+                self.post_count = 0
+
+            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
+                self.pre_count += 1
+                if any(t is not torch.Tensor for t in types):
+                    return NotImplemented
+                self.post_count += 1
+                return func(*args, **kwargs)
+
+        class SubTensor(torch.Tensor):
+            def __new__(cls, elem):
+                r = torch.Tensor._make_wrapper_subclass(cls, elem.shape)
+                r.elem = elem
+                return r
+
+            @classmethod
+            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+                nonlocal sub_count
+                sub_count += 1
+
+                def unwrap(t):
+                    if isinstance(t, SubTensor):
+                        return t.elem
+                    else:
+                        return t
+
+                return func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
+
+            __torch_function__ = torch._C._disabled_torch_function_impl
+
+        a = SubTensor(torch.randn(2))
+        mode = PoliteMode()
+        with mode:
+            a.abs()
+
+        self.assertEqual(mode.pre_count, 2)
+        self.assertEqual(mode.post_count, 1)
+        self.assertEqual(sub_count, 1)
+
+        # make sure this doesn't error
+        with PoliteMode():
+            with PoliteMode():
+                a.abs()
+
     def test_make_wrapper_subclass_with_modes(self):
         class ModeTensor(torch.Tensor):
             def __new__(cls, elem, mode):
diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp
index 1870e3c..2f3bc0b 100644
--- a/torch/csrc/autograd/python_variable.cpp
+++ b/torch/csrc/autograd/python_variable.cpp
@@ -240,8 +240,7 @@
 void concrete_dispatch_fn(
     const c10::impl::PyInterpreter*,
     const c10::OperatorHandle& op,
-    torch::jit::Stack* stack,
-    const std::shared_ptr<SafePyObject>& type);
+    torch::jit::Stack* stack);
 bool concrete_is_contiguous_fn(
     const c10::impl::PyInterpreter*,
     const c10::TensorImpl* self);
@@ -2124,19 +2123,10 @@
           TorchFunctionName::TorchDispatch));
 }
 
-// NOTE [dispatch_fn's type argument]
-// `type` is nullable and represents the TorchDispatchMode going on.
-// Right now we only support a single TorchDispatchMode, but in the future we
-// could change this to a stack of TorchDispatchModes.
-//
-// If `type` isn't null, then we consider the type for dispatch by prepending
-// it to the overloaded_args list. `handle_torch_funciton_no_python_arg_parser`
-// is responsible for doing overload resolution.
 void concrete_dispatch_fn(
     const c10::impl::PyInterpreter*,
     const c10::OperatorHandle& op,
-    torch::jit::Stack* stack,
-    const std::shared_ptr<SafePyObject>& type) {
+    torch::jit::Stack* stack) {
   const auto& schema = op.schema();
   const auto num_arguments = schema.arguments().size();
   auto arguments = torch::jit::pop(*stack, num_arguments);
@@ -2172,10 +2162,6 @@
   }
   std::string module_name_str = "torch.ops." + ns_str;
 
-  if (type) {
-    append_overloaded_type(&overloaded_args, type->ptr(getPyInterpreter()));
-  }
-
   // Find overloaded tensors
   for (const auto idx : c10::irange(arguments.size())) {
     const auto& ivalue = arguments[idx];