Add context manager for conditional rewrites of torch.* to torch._refs.* calls (#81764)
Adds a new context manager `TorchRefsNvfuserCapabilityMode` for conditional rewrite of `torch.*` calls to `torch._refs.*` based on whether the decomposition consisting of prims supports nvFuser execution or not.
A new optional argument for `TorchRefsMode` is added - `should_fallback_fn`, a callable that returns whether the original `torch.foo` or the replacement `torch._refs.foo` should be used.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81764
Approved by: https://github.com/ezyang
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index d6042ac..2d7caa8 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -13,7 +13,7 @@
from torch._decomp import decomposition_table
from torch.testing._internal.common_device_type import ops
-from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter
+from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule
from torch.utils._pytree import tree_map
from torch import nn
import re
@@ -135,6 +135,109 @@
return a + b
self._test(f, [torch.randn(3, device=device), torch.tensor(5)])
+ def test_isolated_graphmodule(self):
+ def is_any_sum(gm):
+ return any(node.target == torch.ops.aten.sum.default for node in gm.graph.nodes)
+
+ def is_any_digamma(gm):
+ return any(node.target == torch.ops.aten.digamma.default for node in gm.graph.nodes)
+
+ def is_any_sigmoid(gm):
+ return any(node.target == torch.ops.aten.sigmoid.default for node in gm.graph.nodes)
+
+ def inner(x):
+ return torch.sum(x)
+
+ def f(x):
+ gm = get_isolated_graphmodule(inner, (x,), {})
+ self.assertTrue(is_any_sum(gm))
+ return x + torch.randn(x.shape)
+
+ # get_isolated_graphmodule uses make_fx internally that shouldn't be traced
+ # by the outer make_fx call
+ traced = make_fx(f)(torch.randn(3))
+ self.assertFalse(is_any_sum(traced))
+
+ # When factory functions are used, they should not be traced
+ # by the outer make_fx call
+ def inner_with_factory():
+ val = torch.tensor(float(1))
+ val.add_(2)
+ return torch.full((10, 10), val).sum()
+
+ def f1(x):
+ gm = get_isolated_graphmodule(inner_with_factory, (), {})
+ self.assertTrue(is_any_sum(gm))
+ return torch.sigmoid(x)
+
+ def f2(x):
+ gm = get_isolated_graphmodule(f1, (x,), {})
+ self.assertFalse(is_any_sum(gm))
+ self.assertTrue(is_any_sigmoid(gm))
+ return torch.digamma(x)
+
+ traced = make_fx(f2)(torch.randn(3))
+ self.assertFalse(is_any_sum(traced))
+ self.assertFalse(is_any_sigmoid(traced))
+ self.assertTrue(is_any_digamma(traced))
+
+ # Verify nested make_fx calls don't make factory functions to be leaked
+ # into the outer graph
+ def f2(x):
+ gm = make_fx(f1)(x)
+ self.assertFalse(is_any_sum(gm))
+ self.assertTrue(is_any_sigmoid(gm))
+ return torch.digamma(x)
+
+ traced = make_fx(f2)(torch.randn(3))
+ self.assertFalse(is_any_sum(traced))
+ self.assertTrue(is_any_sigmoid(traced))
+ self.assertTrue(is_any_digamma(traced))
+
+ # Verify interaction with non-ProxyTensor modes
+ from torch.testing._internal.logging_tensor import LoggingTensorMode
+
+ def f1_logging(x):
+ with LoggingTensorMode():
+ gm = get_isolated_graphmodule(inner_with_factory, (), {})
+ self.assertTrue(is_any_sum(gm))
+ return torch.sigmoid(x)
+
+ def f2_logging(x):
+ with LoggingTensorMode(), LoggingTensorMode():
+ gm = get_isolated_graphmodule(f1_logging, (x,), {})
+ self.assertFalse(is_any_sum(gm))
+ self.assertTrue(is_any_sigmoid(gm))
+ return torch.digamma(x)
+
+ traced = make_fx(f2_logging)(torch.randn(3))
+ self.assertFalse(is_any_sum(traced))
+ self.assertFalse(is_any_sigmoid(traced))
+ self.assertTrue(is_any_digamma(traced))
+
+ # Verify interaction with another tensor subclass
+ # This case currently doesn't work and should raise an error
+ # See: https://github.com/pytorch/pytorch/pull/81764#issuecomment-1200472068
+ from torch.testing._internal.logging_tensor import LoggingTensor
+
+ def f1_logging_tensor(x):
+ gm = get_isolated_graphmodule(inner_with_factory, (), {})
+ self.assertTrue(is_any_sum(gm))
+ return torch.sigmoid(x)
+
+ def f2_logging_tensor(x):
+ x = LoggingTensor(x)
+ gm = get_isolated_graphmodule(f1_logging_tensor, (x,), {})
+ self.assertFalse(is_any_sum(gm))
+ self.assertTrue(is_any_sigmoid(gm))
+ return torch.digamma(x)
+
+ with self.assertRaisesRegex(AssertionError, "ProxyTensor is wrapped with another Tensor subclass"):
+ traced = make_fx(f2_logging_tensor)(torch.randn(3))
+ self.assertFalse(is_any_sum(traced))
+ self.assertFalse(is_any_sigmoid(traced)) # this fails, sigmoid is traced with LoggingTensor
+ self.assertTrue(is_any_digamma(traced))
+
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
def test_resnet18_backward_trace(self, device):
mod = torchvision.models.resnet18()