Add context manager for conditional rewrites of torch.* to torch._refs.* calls (#81764)

Adds a new context manager `TorchRefsNvfuserCapabilityMode` for conditional rewrite of `torch.*` calls to `torch._refs.*` based on whether the decomposition consisting of prims supports nvFuser execution or not.

A new optional argument for `TorchRefsMode` is added - `should_fallback_fn`, a callable that returns whether the original `torch.foo` or the replacement `torch._refs.foo` should be used.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81764
Approved by: https://github.com/ezyang
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index d6042ac..2d7caa8 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -13,7 +13,7 @@
 
 from torch._decomp import decomposition_table
 from torch.testing._internal.common_device_type import ops
-from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter
+from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule
 from torch.utils._pytree import tree_map
 from torch import nn
 import re
@@ -135,6 +135,109 @@
             return a + b
         self._test(f, [torch.randn(3, device=device), torch.tensor(5)])
 
+    def test_isolated_graphmodule(self):
+        def is_any_sum(gm):
+            return any(node.target == torch.ops.aten.sum.default for node in gm.graph.nodes)
+
+        def is_any_digamma(gm):
+            return any(node.target == torch.ops.aten.digamma.default for node in gm.graph.nodes)
+
+        def is_any_sigmoid(gm):
+            return any(node.target == torch.ops.aten.sigmoid.default for node in gm.graph.nodes)
+
+        def inner(x):
+            return torch.sum(x)
+
+        def f(x):
+            gm = get_isolated_graphmodule(inner, (x,), {})
+            self.assertTrue(is_any_sum(gm))
+            return x + torch.randn(x.shape)
+
+        # get_isolated_graphmodule uses make_fx internally that shouldn't be traced
+        # by the outer make_fx call
+        traced = make_fx(f)(torch.randn(3))
+        self.assertFalse(is_any_sum(traced))
+
+        # When factory functions are used, they should not be traced
+        # by the outer make_fx call
+        def inner_with_factory():
+            val = torch.tensor(float(1))
+            val.add_(2)
+            return torch.full((10, 10), val).sum()
+
+        def f1(x):
+            gm = get_isolated_graphmodule(inner_with_factory, (), {})
+            self.assertTrue(is_any_sum(gm))
+            return torch.sigmoid(x)
+
+        def f2(x):
+            gm = get_isolated_graphmodule(f1, (x,), {})
+            self.assertFalse(is_any_sum(gm))
+            self.assertTrue(is_any_sigmoid(gm))
+            return torch.digamma(x)
+
+        traced = make_fx(f2)(torch.randn(3))
+        self.assertFalse(is_any_sum(traced))
+        self.assertFalse(is_any_sigmoid(traced))
+        self.assertTrue(is_any_digamma(traced))
+
+        # Verify nested make_fx calls don't make factory functions to be leaked
+        # into the outer graph
+        def f2(x):
+            gm = make_fx(f1)(x)
+            self.assertFalse(is_any_sum(gm))
+            self.assertTrue(is_any_sigmoid(gm))
+            return torch.digamma(x)
+
+        traced = make_fx(f2)(torch.randn(3))
+        self.assertFalse(is_any_sum(traced))
+        self.assertTrue(is_any_sigmoid(traced))
+        self.assertTrue(is_any_digamma(traced))
+
+        # Verify interaction with non-ProxyTensor modes
+        from torch.testing._internal.logging_tensor import LoggingTensorMode
+
+        def f1_logging(x):
+            with LoggingTensorMode():
+                gm = get_isolated_graphmodule(inner_with_factory, (), {})
+            self.assertTrue(is_any_sum(gm))
+            return torch.sigmoid(x)
+
+        def f2_logging(x):
+            with LoggingTensorMode(), LoggingTensorMode():
+                gm = get_isolated_graphmodule(f1_logging, (x,), {})
+            self.assertFalse(is_any_sum(gm))
+            self.assertTrue(is_any_sigmoid(gm))
+            return torch.digamma(x)
+
+        traced = make_fx(f2_logging)(torch.randn(3))
+        self.assertFalse(is_any_sum(traced))
+        self.assertFalse(is_any_sigmoid(traced))
+        self.assertTrue(is_any_digamma(traced))
+
+        # Verify interaction with another tensor subclass
+        # This case currently doesn't work and should raise an error
+        # See: https://github.com/pytorch/pytorch/pull/81764#issuecomment-1200472068
+        from torch.testing._internal.logging_tensor import LoggingTensor
+
+        def f1_logging_tensor(x):
+            gm = get_isolated_graphmodule(inner_with_factory, (), {})
+            self.assertTrue(is_any_sum(gm))
+            return torch.sigmoid(x)
+
+        def f2_logging_tensor(x):
+            x = LoggingTensor(x)
+            gm = get_isolated_graphmodule(f1_logging_tensor, (x,), {})
+            self.assertFalse(is_any_sum(gm))
+            self.assertTrue(is_any_sigmoid(gm))
+            return torch.digamma(x)
+
+        with self.assertRaisesRegex(AssertionError, "ProxyTensor is wrapped with another Tensor subclass"):
+            traced = make_fx(f2_logging_tensor)(torch.randn(3))
+            self.assertFalse(is_any_sum(traced))
+            self.assertFalse(is_any_sigmoid(traced))  # this fails, sigmoid is traced with LoggingTensor
+            self.assertTrue(is_any_digamma(traced))
+
     @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
     def test_resnet18_backward_trace(self, device):
         mod = torchvision.models.resnet18()