switched over to using faketensor in proxytensor (#79634)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79634
Approved by: https://github.com/albanD
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 40bec21..76bd4ae 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -7,6 +7,7 @@
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_methods_invocations import DecorateInfo
from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed
+from torch._subclasses.fake_tensor import DynamicOutputShapeException
from torch.testing._internal.common_device_type import ops
from torch.fx.experimental.proxy_tensor import make_fx
@@ -59,7 +60,7 @@
class TestProxyTensor(TestCase):
- def test_make_fx(self, device):
+ def test_make_fx_simple(self, device):
def f(x):
return torch.sin(x)
inp = torch.randn(3)
@@ -110,6 +111,17 @@
assert inp.grad is None
torch.testing.assert_close(traced_graph_out, f(inp))
+ def test_inplace_metadata(self):
+ def f(x):
+ x = x.clone()
+ x.unsqueeze_(-1)
+ assert x.shape[-1] == 1
+ return x
+
+ inps = [torch.randn(5)]
+ fx_f = make_fx(f)(*inps)
+ self.assertEqual(fx_f(*inps), f(*inps))
+
def test_mode_tracing_factory_function(self):
def f(x):
return x + torch.randn(x.shape)
@@ -136,6 +148,7 @@
)
make_fx_failures = {
+ # unknown
xfail('allclose'),
xfail('equal'),
xfail('linalg.eigvals'),
@@ -150,6 +163,7 @@
skip('nn.functional.max_unpool2d', '', device_type='cpu'),
skip('nn.functional.max_unpool3d', '', device_type='cpu'),
skip('linalg.lstsq'), # flaky, probably just a precision issue
+
# data-dependent control flow
xfail('cov'),
xfail('istft'),
@@ -182,35 +196,67 @@
# Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
xfail('sparse.sampled_addmm'),
- # Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
+ # ???
xfail('nn.functional.ctc_loss'),
+ # Sparse tensors are not supported with faketensors for now
+ xfail('to_sparse'),
+ # segfaults
+ skip('block_diag'),
+}
+
+fake_tensor_failures = {
+ # Needs complex-value support
+ xfail('polar'),
+ xfail('complex'),
+ xfail('linalg.eig'),
+ # FakeTensor fallback doesn't work
+ xfail('linalg.matrix_power'),
+ xfail('segment_reduce', 'lengths'),
+ xfail('multinomial'),
+ xfail('mvlgamma', 'mvlgamma_p_1'),
+ xfail('mvlgamma', 'mvlgamma_p_3'),
+ xfail('mvlgamma', 'mvlgamma_p_5'),
+ xfail('cholesky'),
+ xfail('cholesky_inverse'),
+ # ASAN failures due to divide by 0
+ skip('nn.functional.nll_loss'),
}
+def _test_make_fx_helper(self, device, dtype, op, use_fake):
+ def f(args, kwargs):
+ return op.op(*args, **kwargs)
+ sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
+ new_f = None
+ for sample_input in sample_inputs_itr:
+ args = [sample_input.input] + list(sample_input.args)
+ kwargs = sample_input.kwargs
+
+ try:
+ new_f = make_fx(f, use_fake=use_fake)(args, kwargs)
+ except DynamicOutputShapeException as e:
+ self.skipTest("Dynamic output shape operation in trace")
+
+ for arg in args:
+ if isinstance(arg, torch.Tensor) and arg.dtype == torch.float:
+ arg.uniform_(0, 1)
+ try:
+ old_out = f(args, kwargs)
+ except Exception:
+ continue
+ new_out = wrapper_set_seed(new_f, args, kwargs)
+ self.assertEqual(new_out, old_out)
+
class TestProxyTensorOpInfo(TestCase):
@ops(op_db, allowed_dtypes=(torch.float,))
- @skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures
- )
+ @skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures)
def test_make_fx_exhaustive(self, device, dtype, op):
+ _test_make_fx_helper(self, device, dtype, op, False)
- def f(args, kwargs):
- return op.op(*args, **kwargs)
- sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
- new_f = None
- for sample_input in sample_inputs_itr:
- args = [sample_input.input] + list(sample_input.args)
- kwargs = sample_input.kwargs
-
- new_f = make_fx(f, trace_factory_functions=True)(args, kwargs)
- for arg in args:
- if isinstance(arg, torch.Tensor) and arg.dtype == torch.float:
- arg.uniform_(0, 1)
- try:
- old_out = f(args, kwargs)
- except Exception:
- continue
- new_out = wrapper_set_seed(new_f, args, kwargs)
- self.assertEqual(new_out, old_out)
+ @ops(op_db, allowed_dtypes=(torch.float,))
+ @skipOps('TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive', make_fx_failures.union(fake_tensor_failures))
+ def test_make_fx_fake_exhaustive(self, device, dtype, op):
+ _test_make_fx_helper(self, device, dtype, op, True)
diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py
index 441fe4a..82559ec 100644
--- a/torch/fx/experimental/proxy_tensor.py
+++ b/torch/fx/experimental/proxy_tensor.py
@@ -9,12 +9,13 @@
from torch._C import _disabled_torch_function_impl
import torch.utils._pytree as pytree
from torch.fx import Tracer, GraphModule
+from torch._subclasses.fake_tensor import FakeTensorMode
import torch.fx as fx
from torch.utils._mode_utils import no_dispatch
from torch.fx.passes.shape_prop import _extract_tensor_metadata
-from contextlib import contextmanager
+from contextlib import contextmanager, nullcontext
-from torch.utils._python_dispatch import push_torch_dispatch_mode, TorchDispatchMode
+from torch.utils._python_dispatch import TorchDispatchMode
__all__ = ["ProxyTensor", "PythonKeyTracer", "dispatch_trace", "make_fx", "enable_strict"]
aten = torch.ops.aten
@@ -22,7 +23,6 @@
CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OpOverload, Callable] = {}
-
@contextmanager
def decompose(decomposition_table):
global CURRENT_DECOMPOSITION_TABLE
@@ -39,9 +39,9 @@
global IS_STRICT
IS_STRICT = val
-def wrap_output(real_out, proxy_out):
+def wrap_output(inner_res, proxy_res):
def wrap_with_proxy(e, proxy):
- if type(e) == torch.Tensor:
+ if isinstance(e, torch.Tensor):
with no_dispatch():
return ProxyTensor(e, proxy)
else:
@@ -50,17 +50,19 @@
# Unfortunately, tree_map cannot directly be used here. As the resulting
# object may be a proxy that represents a tuple, we may need to
# explicitly unwrap the proxy by simulating the flattening operations.
- if isinstance(real_out, tuple):
- return tuple(wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out))
- elif isinstance(real_out, list):
- return list([wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)])
- elif isinstance(real_out, torch.Tensor):
- return wrap_with_proxy(real_out, proxy_out)
+ if isinstance(inner_res, tuple):
+ return tuple(wrap_with_proxy(e, proxy_res[idx]) for idx, e in enumerate(inner_res))
+ elif isinstance(inner_res, list):
+ return list([wrap_with_proxy(e, proxy_res[idx]) for idx, e in enumerate(inner_res)])
+ elif isinstance(inner_res, torch.Tensor):
+ return wrap_with_proxy(inner_res, proxy_res)
else:
- return real_out
+ return inner_res
def proxy_call(func_overload, args, kwargs=None):
+ if kwargs is None:
+ kwargs = {}
func = func_overload.overloadpacket
if func_overload in CURRENT_DECOMPOSITION_TABLE:
return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs)
@@ -72,48 +74,61 @@
def unwrap_proxy(e):
return e.proxy if isinstance(e, ProxyTensor) else e
+ def unwrap_elem(e):
+ if isinstance(e, ProxyTensor):
+ return e.elem
+ return e
+
proxy_args = pytree.tree_map(unwrap_proxy, args)
proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs)
- proxy_out = func_overload(*proxy_args, **proxy_kwargs)
+ proxy_res = func_overload(*proxy_args, **proxy_kwargs)
# Kind of a hacky way to test if an op is in-place or not
if func.__name__[-1] == "_" and func.__name__[0] != "_":
- args[0].proxy = proxy_out
- proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0])
+ args[0].proxy = proxy_res
+ proxy_res.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0])
- with no_dispatch():
- real_out = func_overload(*args, **kwargs)
+ inner_res = func_overload(*pytree.tree_map(unwrap_elem, args), **pytree.tree_map(unwrap_elem, kwargs))
+ # Needed to sync up metadata for in-place operators that modify metadata
+ if torch.Tag.inplace_view in func_overload.tags: # type: ignore[attr-defined]
+ with no_dispatch():
+ func_overload(*args, **kwargs)
- return wrap_output(real_out, proxy_out)
+ # TODO(chilli): Enable this after it's been refactored to work with wrapper tensor subclasses in general
+ # pytree.tree_map(lambda x: check_metadata_consistency(x, ProxyTensor), (inner_res, args, kwargs))
+ return wrap_output(inner_res, proxy_res)
+
class ProxyTensor(torch.Tensor):
proxy: fx.Proxy
+ elem: torch.Tensor
+
@staticmethod
def __new__(cls, elem, proxy, *, requires_grad=None):
- # Hack to deal with super().__new__ not working for sparse tensors
- if elem.is_sparse or requires_grad is not None:
- if requires_grad is None:
- requires_grad = False
- r = torch.Tensor._make_subclass(cls, elem, requires_grad)
- else:
- r = super().__new__(cls, elem) # type: ignore[call-arg]
+ r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
+ cls,
+ elem.shape, dtype=elem.dtype, layout=elem.layout, device=elem.device,
+ requires_grad=requires_grad if requires_grad is not None else False, strides=elem.stride(),
+ storage_offset=elem.storage_offset()
+ )
+ return r
+ def __init__(self, elem, proxy, *, requires_grad=None):
if elem.is_sparse:
proxy.node.meta['tensor_meta'] = {}
else:
- proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(r)
- r.proxy = proxy # type: ignore[attr-defined]
-
- return r
+ proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(self)
+ self.elem = elem
+ self.proxy = proxy
def __deepcopy__(self, memo):
return self.clone()
def __repr__(self):
with no_dispatch():
- return f"ProxyTensor({self.as_subclass(torch.Tensor)}, proxy={self.proxy})" # type: ignore[arg-type]
+ return f"ProxyTensor({self.elem}, proxy={self.proxy})"
__torch_function__ = _disabled_torch_function_impl
@@ -156,15 +171,10 @@
def dispatch_trace(
root: Union[torch.nn.Module, Callable],
+ tracer: Tracer,
concrete_args: Optional[Tuple[Any, ...]] = None,
- trace_factory_functions: bool = False,
) -> GraphModule:
- tracer = PythonKeyTracer()
- if trace_factory_functions:
- with push_torch_dispatch_mode(functools.partial(ProxyTorchDispatchMode, tracer)):
- graph = tracer.trace(root, concrete_args)
- else:
- graph = tracer.trace(root, concrete_args)
+ graph = tracer.trace(root, concrete_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return GraphModule(tracer.root, graph, name)
@@ -179,9 +189,11 @@
for idx, arg in enumerate(flat_args):
if isinstance(flat_inps[idx], torch.Tensor):
with no_dispatch():
- flat_args[idx] = ProxyTensor(flat_inps[idx], arg, requires_grad=(
- flat_inps[idx].is_leaf and flat_inps[idx].requires_grad
- ))
+ flat_args[idx] = ProxyTensor(
+ flat_inps[idx],
+ arg,
+ requires_grad=(flat_inps[idx].is_leaf and flat_inps[idx].requires_grad)
+ )
else:
flat_args[idx] = flat_inps[idx]
@@ -205,25 +217,36 @@
if any(tuple(isinstance(arg, ProxyTensor) for arg in pytree.tree_flatten(args)[0])):
return proxy_call(func_overload, args, kwargs)
else:
- proxy_out = self.tracer.create_proxy('call_function', func, args, kwargs,
+ proxy_res = self.tracer.create_proxy('call_function', func, args, kwargs,
name=self.tracer.graph._target_to_str(func.__name__))
- with no_dispatch():
- real_out = func_overload(*args, **kwargs)
+ inner_res = func_overload(*args, **kwargs)
- return wrap_output(real_out, proxy_out)
+ return wrap_output(inner_res, proxy_res)
-def make_fx(f, decomposition_table=None, trace_factory_functions=True):
+def make_fx(f, decomposition_table=None, trace_factory_functions=True, use_fake=False):
if decomposition_table is None:
decomposition_table = {}
@functools.wraps(f)
def wrapped(*args):
- phs = pytree.tree_map(lambda x: fx.PH, args) # type: ignore[attr-defined]
- with decompose(decomposition_table):
- t = dispatch_trace(wrap_key(f, args), concrete_args=tuple(phs),
- trace_factory_functions=trace_factory_functions)
+ phs = pytree.tree_map(lambda _: fx.PH, args) # type: ignore[attr-defined]
+ fx_tracer = PythonKeyTracer()
+ fake_tensor_mode = FakeTensorMode() if use_fake else nullcontext()
+ proxy_mode = ProxyTorchDispatchMode(fx_tracer) if trace_factory_functions else nullcontext()
+
+ def wrap_fake(x):
+ if isinstance(x, torch.Tensor):
+ return fake_tensor_mode.from_tensor(x) # type: ignore[attr-defined]
+
+ return x
+
+ if use_fake: # type: ignore[attr-defined]
+ args = pytree.tree_map(wrap_fake, args)
+
+ with decompose(decomposition_table), fake_tensor_mode, proxy_mode: # type: ignore[attr-defined]
+ t = dispatch_trace(wrap_key(f, args), tracer=fx_tracer, concrete_args=tuple(phs))
return t
return wrapped