Sparse fake tensor support (#82172)

Add support for sparse fake tensors.

- The testing strategy is to run a fake tensor cross ref test on `test_sparse.py`. This is necessary because OpInfo sparse coverage is completely nonexistent. We could have tried to turn on cross ref testing globally for all files, but that would be very time consuming and the tests I'm interested in are mostly in this file. There are some exclusions in testing for things that don't work.
- I make fake tensor converter raise a UnsupportedFakeTensorException if the meta converter fails to do a conversion (which can happen in a relatively large number of situations).
- I relax fake tensor invariants so that you can make a fake tensor from a meta tensor. This is useful because in the cross ref test sometimes we operate on meta tensors.
- Fake tensor wrapping is improved to handle the case when a function doesn't return any tensors
- Meta converter is taught how to convert sparse tensors to meta

There's still a little more cleanup that needs to be done, but this is good for review.

Signed-off-by: Edward Z. Yang <[email protected]>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82172
Approved by: https://github.com/eellison
diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py
index 6ea0936..e989a66 100644
--- a/test/test_fake_tensor.py
+++ b/test/test_fake_tensor.py
@@ -291,7 +291,7 @@
                 for ten in out:
                     if i == 1:
                         self.assertTrue(isinstance(ten, FakeTensor))
-                    self.assertTrue(ten.device.type == 'cuda')
+                    self.assertEqual(ten.device.type, 'cuda')
 
     @skipIfRocm
     @unittest.skipIf(not RUN_CUDA, "requires cuda")
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index a4503b1..ab33104 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -556,8 +556,8 @@
 
     # ???
     xfail('nn.functional.ctc_loss'),
-    # Sparse tensors are not supported with faketensors for now
-    xfail('to_sparse'),
+    # proxy tensor doesn't support sparse correctly right now
+    skip('to_sparse'),
     # segfaults
     skip('block_diag'),
 }
diff --git a/test/test_sparse.py b/test/test_sparse.py
index e0b50e1..30bb6f3 100644
--- a/test/test_sparse.py
+++ b/test/test_sparse.py
@@ -9,7 +9,7 @@
 from torch.testing import make_tensor
 from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \
     do_test_empty_full, load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \
-    DeterministicGuard, first_sample
+    DeterministicGuard, first_sample, TEST_WITH_CROSSREF
 from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version
 from numbers import Number
 from typing import Dict, Any
@@ -25,6 +25,7 @@
     all_types, all_types_and_complex, all_types_and_complex_and, floating_and_complex_types,
     floating_and_complex_types_and, integral_types, floating_types_and,
 )
+from torch.utils._python_dispatch import TorchDispatchMode
 
 if TEST_SCIPY:
     import scipy.sparse
@@ -40,7 +41,53 @@
     IS_WINDOWS and torch.version.cuda and LooseVersion(torch.version.cuda) > "11.2"
 ) or (not IS_WINDOWS and CUDA11OrLater)
 
-class TestSparse(TestCase):
+class CrossRefSparseFakeMode(TorchDispatchMode):
+    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
+        kwargs = kwargs or {}
+
+        def on_tensor(f):
+            def go(t):
+                if isinstance(t, torch.Tensor):
+                    return f(t)
+                else:
+                    return t
+            return go
+
+        # empty_like excluded for now due to sparse complex
+        # aten._to_dense.default this one is getting called with csc
+        if (
+            func not in [
+                torch.ops.aten.lift_fresh.default,
+                torch.ops.aten.empty_like.default,
+                torch.ops.aten.set_.source_Storage_storage_offset,
+                torch.ops.aten.sspaddmm.out,
+                torch.ops.aten._spdiags.default,
+                torch.ops.aten._to_dense.default
+            ]
+            and torch.Tag.dynamic_output_shape not in func.tags
+            and torch.Tag.inplace_view not in func.tags
+        ):
+            from torch._subclasses.fake_tensor import FakeTensorMode, UnsupportedFakeTensorException
+            from torch.utils._pytree import tree_map
+            try:
+                with FakeTensorMode(allow_meta=True) as fake_mode:
+                    fake_args, fake_kwargs = tree_map(on_tensor(fake_mode.from_tensor), (args, kwargs))
+                    fake_r = func(*fake_args, **fake_kwargs)
+            except UnsupportedFakeTensorException:
+                pass
+
+        r = func(*args, **kwargs)
+        return r
+
+class TestSparseBase(TestCase):
+    def run(self, result=None):
+        if TEST_WITH_CROSSREF:
+            with CrossRefSparseFakeMode():
+                return super().run(result)
+        else:
+            return super().run(result)
+
+class TestSparse(TestSparseBase):
 
     def setUp(self):
         TestCase.setUp(self)
@@ -1641,6 +1688,7 @@
 
     @coalescedonoff
     @dtypes(torch.double)
+    @unittest.skipIf(TEST_WITH_CROSSREF, "fallback triggers cuda device error")
     def test_sparse_sum(self, device, dtype, coalesced):
 
         def run_tests(S, td=None):
@@ -3413,6 +3461,7 @@
                                       *[torch.bfloat16] if CUDA11OrLater and SM80OrLater else [],
                                       *[torch.complex64] if CUDA11OrLater else [],
                                       *[torch.complex128] if CUSPARSE_SPMM_COMPLEX128_SUPPORTED else []))
+    @unittest.skipIf(TEST_WITH_CROSSREF, "not working with fake tensor")
     @precisionOverride({torch.bfloat16: 1e-2, torch.float16: 1e-2, torch.complex64: 1e-2, torch.float32: 1e-2})
     def test_sparse_matmul(self, device, dtype, coalesced):
         """
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index e4dbc6c..f126fe1 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -129,12 +129,13 @@
             return maybe_memo
         existing_device = t.device
         # not yet supported in metatensors
-        if t.is_sparse:
-            raise UnsupportedFakeTensorException("sparse nyi in meta tensors")
         if t.is_quantized:
             raise UnsupportedFakeTensorException("quantized nyi in meta tensors")
         with no_dispatch():
-            out = FakeTensor(fake_mode, self.meta_converter(t), existing_device)
+            meta_t = self.meta_converter(t)
+            if meta_t.device.type != "meta":
+                raise UnsupportedFakeTensorException("meta converter nyi")
+            out = FakeTensor(fake_mode, meta_t, existing_device)
         if type(t) is torch.nn.Parameter:
             out = torch.nn.Parameter(out, requires_grad=out.requires_grad)  # type: ignore[assignment]
         if t.grad is not None:
@@ -150,11 +151,21 @@
         self.set_tensor_memo(t, out)
         return out
 
+    # There are two ways to call this.  First, you can have manually constructed
+    # a meta tensor and you need to turn it into a fake tensor.  In that case,
+    # pass a meta tensor and a device argument.  Alternately, you can have a
+    # real tensor that you need to convert into a fake tensor; in that case,
+    # omit the device.
+    #
+    # The disallowed case: if you specify the device, it MUST be a meta tensor.
+    # However, you're allowed to pass a meta tensor to be turned into a fake
+    # tensor; although an odd thing to do, this can occur if you're doing
+    # cross ref testing and the inner test is already operating on meta tensors
     def __call__(self, fake_mode, t, device=None):
-        assert t.device.type != "meta" or device is not None
-        if t.device.type != "meta":
+        if device is None:
             return self.from_real_tensor(fake_mode, t)
         else:
+            assert t.device.type == "meta"
             return self.from_meta_and_device(fake_mode, t, device)
 
 
@@ -216,6 +227,12 @@
     return func(*args, **kwargs)
 
 
+@register_op_impl(aten._sparse_coo_tensor_with_dims_and_tensors.default)
+def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs):
+    # TODO: remove me
+    return constructors(fake_mode, func, *args, **kwargs)
+
+
 # _to_copy fails when run with FakeTensors to cuda device
 # TODO: debug
 @register_op_impl(aten._to_copy.default)
@@ -345,14 +362,20 @@
         )
 
     def __init__(self, fake_mode, elem, device: Union[torch.device, str]):
-        # elem does not need to be recorded, because FakeTensor *is a* elem
-        assert elem.device.type == "meta", elem
+        assert elem.device.type == "meta", elem.device.type
         device = device if isinstance(device, torch.device) else torch.device(device)
-        # normalize cuda device
+        # NB: it is fine, if a little confusing, for device to be meta
+        # (we are faking a meta tensor in that case).  However, it often
+        # indicates some sort of confusion (e.g., you accidentally passed
+        # in a meta tensor when you should have passed in the real tensor).
+        # So by default we disallow meta, and if you are working in a situation
+        # where it is helpful (e.g., crossref testing) you can turn it back
+        # on
+        if not fake_mode.allow_meta:
+            assert device.type != "meta"
+        # normalize cuda device.
         if device.type == "cuda" and device.index is None:
             device = torch.device(f"cuda:{torch.cuda.current_device()}")
-        assert device.type != "meta"
-
         self.fake_device = device
         self.fake_mode = fake_mode
         self.has_sym_ints = symbolic_shapes.has_symbolic_sizes_strides(elem)
@@ -360,11 +383,14 @@
     @staticmethod
     def from_tensor(t, fake_mode):
         existing_device = t.device
+        # TODO: this should use meta converter
         return FakeTensor(fake_mode, t.to(device="meta"), existing_device)
 
     # TODO: resolve error in default __repr__
     def __repr__(self):
-        return f"FakeTensor({self.fake_device}, {self.size()}, {self.dtype})"
+        with in_kernel_invocation_manager(self.fake_mode):
+            self_repr = super().__repr__()
+        return f"FakeTensor({self.fake_mode}, {self_repr}, {self.fake_device})"
 
     def stride(self):
         if self.has_sym_ints:
@@ -404,6 +430,14 @@
                 return torch.device("meta")
             else:
                 return args[0].fake_device
+        # Need this to handle infinite recursion with sparse tensors.
+        # Sparse tensors have custom stride policy which means that
+        # they will dispatch here on dispatch, and we need to trigger
+        # the default behavior.
+        # TODO: when we get other tensor types online they will also
+        # need to get entries here.
+        elif func == torch.ops.aten.stride.default:
+            return None
 
         # Because fake mode can return NotImplemented (if it sees a subclass
         # it doesn't know how to deal with), this test here is important
@@ -489,9 +523,10 @@
 
 
 class FakeTensorMode(TorchDispatchMode):
-    def __init__(self, allow_fallback_kernels=True):
+    def __init__(self, *, allow_fallback_kernels=True, allow_meta=False):
         self.allow_fallback_kernels = allow_fallback_kernels
         self.fake_tensor_converter = FakeTensorConverter()
+        self.allow_meta = allow_meta
 
         # [in_kernel_invocation]
         # when FakeTensor is invoked in user code, .device should return
@@ -637,7 +672,9 @@
             except NotImplementedError as not_implemented_error:
                 if not self.allow_fallback_kernels:
                     raise not_implemented_error
-                r = run_fallback_kernel(func, args, kwargs, not_implemented_error)
+                return run_fallback_kernel(
+                    self, func, args, kwargs, not_implemented_error
+                )
 
             # TODO: handle non-kwarg devices
             assert func not in _device_not_kwarg_ops, f"NYI: {func}"
@@ -666,7 +703,8 @@
         return self.fake_tensor_converter(self, tensor)
 
 
-def run_fallback_kernel(func, args, kwargs, orig_not_implemented_exception):
+# NB: returns fake tensors
+def run_fallback_kernel(fake_mode, func, args, kwargs, orig_not_implemented_exception):
     # these should all be supported, just to be safe
     # avoid fallback for operators which inplace modify metadata
     # because the input fake tensors would be umodified
@@ -679,6 +717,8 @@
         def to_real_tensor(e):
             if isinstance(e, FakeTensor):
                 out = torch.zeros_like(e, device=e.fake_device)
+                if e.is_sparse:
+                    out._coalesced_(e.is_coalesced())
                 inp_impls[id(out)] = e
                 return out
             return e
@@ -693,7 +733,8 @@
 
         for e in tree_flatten((args, kwargs))[0]:
             if isinstance(e, torch.Tensor):
-                storages.add(e.storage()._cdata)
+                if not e.is_sparse:
+                    storages.add(e.storage()._cdata)
 
         # TODO: also check metadata change on inputs
         # proper aliasing/metadata relationship between outputs and inputs will
@@ -701,16 +742,20 @@
         # input impl
         for e in tree_flatten(r)[0]:
             if id(e) not in inp_impls and (
-                isinstance(e, torch.Tensor) and e.storage()._cdata in storages
+                isinstance(e, torch.Tensor)
+                and not e.is_sparse
+                and e.storage()._cdata in storages
             ):
                 raise orig_not_implemented_exception
 
-    # the outputs which are are not reused from impls will be converted
-    # to fake tensors later
-    meta_converter = MetaConverter()
-
     def map_out(e):
-        return inp_impls.get(id(e), meta_converter(e))
+        if isinstance(e, torch.Tensor):
+            if id(e) in inp_impls:
+                return inp_impls[id(e)]
+            else:
+                return fake_mode.fake_tensor_converter(fake_mode, e)
+        else:
+            return e
 
     return tree_map(map_out, r)
 
diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py
index 5e554fb..d14685e 100644
--- a/torch/_subclasses/meta_utils.py
+++ b/torch/_subclasses/meta_utils.py
@@ -94,7 +94,10 @@
         # hold a weak ref to self, otherwise it will be kept alive
         # by the del_ten closure
         self_weak_ref = weakref.ref(self)
-        weak_st = StorageWeakRef(t.storage())
+        if t.is_sparse:
+            weak_st = None
+        else:
+            weak_st = StorageWeakRef(t.storage())
         tensor_ref_key = WeakTensorRefKey(t)
 
         def del_ten():
@@ -106,7 +109,7 @@
             self_ref.tensor_memo.pop(tensor_ref_key, None)
             if weak_st and weak_st.expired():
                 self_ref.storage_memo.pop(weak_st, None)
-            else:
+            elif weak_st is not None:
                 # [expired-storages]
                 # NB: even though the tensor has died,
                 # the deallocation of its storage can take longer,
@@ -143,7 +146,25 @@
 
         if self.get_tensor_memo(t) is None:
             with torch.inference_mode(t.is_inference()):
-                if t._is_view():
+                if t.is_sparse:
+                    is_leaf = safe_is_leaf(t)
+                    r = torch.ops.aten._sparse_coo_tensor_with_dims(
+                        t.sparse_dim(),
+                        t.dense_dim(),
+                        t.shape,
+                        dtype=t.dtype,
+                        layout=torch.sparse_coo,
+                        device="meta",
+                    )
+                    r._coalesced_(t.is_coalesced())
+                    if t.requires_grad:
+                        r.requires_grad = True
+                    if t.requires_grad and not is_leaf:
+                        with torch.enable_grad():
+                            r = r.clone()
+                            r._coalesced_(t.is_coalesced())
+
+                elif t._is_view():
                     # Construct views in two steps: recursively meta-fy their
                     # base, and then create the view off that.  NB: doing it
                     # directly from storage is WRONG because this won't cause
@@ -211,10 +232,11 @@
             if any(
                 [
                     t.is_sparse_csr,
-                    t.is_sparse,
+                    t.layout in [torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc],
                     t.is_mkldnn,
                     t.is_quantized,
                     t.is_nested,
+                    t._is_view() and t._base is not None and t._base.is_sparse,
                     torch._is_functional_tensor(t),
                     # these are supported in meta conversion but the fallbacks
                     # don't work
diff --git a/torch/_tensor_str.py b/torch/_tensor_str.py
index 0308d02..493f176 100644
--- a/torch/_tensor_str.py
+++ b/torch/_tensor_str.py
@@ -409,7 +409,10 @@
     )
     if self.is_sparse:
         suffixes.append("size=" + str(tuple(self.shape)))
-        suffixes.append("nnz=" + str(self._nnz()))
+        from torch._subclasses.fake_tensor import FakeTensor
+
+        if not self.is_meta and not isinstance(self, FakeTensor):
+            suffixes.append("nnz=" + str(self._nnz()))
         if not has_default_dtype:
             suffixes.append("dtype=" + str(self.dtype))
         if not custom_contents_provided: