blob: 07dc45fac2fd6336a5a4940ccdced7e988ca37f4 [file] [log] [blame]
# Owner(s): ["module: sparse"]
#
# Test to ensure sparsity information propagates properly into traced graph.
#
import sys
import unittest
import torch
from torch._dynamo.config import is_fbcode
from torch._subclasses.fake_tensor import FakeTensor
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
subtest,
TestCase,
)
# Various data types (preserved over operations).
DTYPES = [
torch.int64,
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
]
# Various index types.
ITYPES = [torch.int32, torch.int64]
# Constructs a subtest for every sparse layout currently supported in torch.sparse.
def all_sparse_layouts(test_name="layout"):
return parametrize(
test_name,
[
subtest(torch.sparse_coo, name="SparseCOO"),
subtest(torch.sparse_csr, name="SparseCSR"),
subtest(torch.sparse_csc, name="SparseCSC"),
subtest(torch.sparse_bsr, name="SparseBSR"),
subtest(torch.sparse_bsc, name="SparseBSC"),
],
)
#
# Various network examples.
#
class IdNet(torch.nn.Module):
def forward(self, x):
return x
class SumNet(torch.nn.Module):
def forward(self, x):
return x.sum()
class EltwiseNet(torch.nn.Module):
def forward(self, x):
return torch.nn.functional.relu(2 * torch.abs(-x))
class ToDenseNet(torch.nn.Module):
def forward(self, x):
return x.to_dense()
class AddNet(torch.nn.Module):
def forward(self, x, y):
return torch.add(x, y)
class SparseActivationCOO(torch.nn.Module):
def forward(self, x):
return [xi.to_sparse() for xi in x]
class SparseActivationCSR(torch.nn.Module):
def forward(self, x):
return [xi.to_sparse_csr() for xi in x]
#
# The test driver.
#
@unittest.skipIf(is_fbcode(), "See torch._dynamo.config")
@unittest.skipIf(
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
)
class TestSparseProp(TestCase):
def setUp(self):
TestCase.setUp(self)
def assertEqualMeta(self, x, y):
self.assertIsInstance(x, FakeTensor)
self.assertIsInstance(y, torch.Tensor)
# Convert expected value to meta for comparison.
y = y.to("meta")
self.assertEqual(x, y, exact_layout=True, exact_is_coalesced=True)
# When x or y is a meta tensor (say, `x.device == "meta"`), then
# assertEqual(x, y) compares only x and y attributes but skips
# comparing their values. In the case of sparse tensors, this means
# that comparing indices and values attributes are skipped as well,
# which is why we are doing that explicitly below.
if x.layout is torch.strided:
pass
elif x.layout is torch.sparse_coo:
self.assertEqual(x._indices(), y._indices(), exact_layout=True)
self.assertEqual(x._values(), y._values(), exact_layout=True)
else:
if x.layout in {torch.sparse_csr, torch.sparse_bsr}:
x_meta1, y_meta1 = (x.crow_indices(), y.crow_indices())
x_meta2, y_meta2 = (x.col_indices(), y.col_indices())
elif x.layout in {torch.sparse_csc, torch.sparse_bsc}:
x_meta1, y_meta1 = (x.ccol_indices(), y.ccol_indices())
x_meta2, y_meta2 = (x.row_indices(), y.row_indices())
else:
assert 0 # unreachable
self.assertEqual(x_meta1, y_meta1, exact_layout=True)
self.assertEqual(x_meta2, y_meta2, exact_layout=True)
self.assertEqual(x.values(), y.values(), exact_layout=True)
@parametrize("dtype", DTYPES)
@parametrize("itype", ITYPES)
@all_sparse_layouts("layout")
def test_idnet(self, dtype, itype, layout):
net = IdNet()
for sparse_input in self.generate_simple_inputs(
layout,
device="cpu",
dtype=dtype,
index_dtype=itype,
):
# Build the traced graph.
prog = torch.export.export(net, (sparse_input,))
# Test arg/output.
for i, node in enumerate(prog.graph.nodes):
meta = node.meta.get("val", None)
if i == 0:
self.assertEqualMeta(meta, sparse_input)
else:
self.assertEqual(meta, None)
@parametrize("dtype", DTYPES)
@parametrize("itype", ITYPES)
@all_sparse_layouts("layout")
def test_sumnet(self, dtype, itype, layout):
net = SumNet()
for sparse_input in self.generate_simple_inputs(
layout,
device="cpu",
dtype=dtype,
index_dtype=itype,
):
result = net(sparse_input)
# Build the traced graph.
prog = torch.export.export(net, (sparse_input,))
# Test arg/sum/output.
for i, node in enumerate(prog.graph.nodes):
meta = node.meta.get("val", None)
if i == 0:
self.assertEqualMeta(meta, sparse_input)
elif i == 1:
self.assertEqualMeta(meta, result)
else:
self.assertEqual(meta, None)
@parametrize("dtype", DTYPES)
@parametrize("itype", ITYPES)
@all_sparse_layouts("layout")
def test_eltwisenet(self, dtype, itype, layout):
net = EltwiseNet()
for sparse_input in self.generate_simple_inputs(
layout,
device="cpu",
dtype=dtype,
index_dtype=itype,
):
result = net(sparse_input)
# Build the traced graph.
prog = torch.export.export(net, (sparse_input,))
# Test arg/neg/abs/mul/relu/output.
for i, node in enumerate(prog.graph.nodes):
meta = node.meta.get("val", None)
if i <= 4:
self.assertEqualMeta(meta, result)
else:
self.assertEqual(meta, None)
@parametrize("dtype", DTYPES)
@parametrize("itype", ITYPES)
@all_sparse_layouts("layout")
def test_todensenet(self, dtype, itype, layout):
net = ToDenseNet()
for sparse_input in self.generate_simple_inputs(
layout,
device="cpu",
dtype=dtype,
index_dtype=itype,
):
result = net(sparse_input)
# Build the traced graph.
prog = torch.export.export(net, (sparse_input,))
# Test arg/todense/output.
for i, node in enumerate(prog.graph.nodes):
meta = node.meta.get("val", None)
if i == 0:
self.assertEqualMeta(meta, sparse_input)
elif i == 1:
self.assertEqualMeta(meta, result)
else:
self.assertEqual(meta, None)
def test_add(self):
net = AddNet()
Y = torch.arange(16, 32, dtype=torch.float32).view(4, 4)
A = torch.tensor(
[
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 2.0],
[0.0, 0.0, 1.0, 1.0],
[3.0, 0.0, 3.0, 0.0],
],
dtype=torch.float32,
)
S = A.to_sparse_csr()
result = net(S, Y)
# Build the traced graph.
prog = torch.export.export(net, (S, Y))
# Test args/add/output.
for i, node in enumerate(prog.graph.nodes):
meta = node.meta.get("val", None)
if i == 0:
self.assertEqualMeta(meta, S)
elif i == 1:
self.assertEqualMeta(meta, Y)
elif i == 2:
self.assertEqualMeta(meta, result)
else:
self.assertEqual(meta, None)
def test_activation_coo(self):
net = SparseActivationCOO()
x = [torch.randn(3, 3) for _ in range(3)]
result = net(x)
# Build the traced graph.
prog = torch.export.export(net, args=(x,))
# Test args/to_sparse/output.
for i, node in enumerate(prog.graph.nodes):
meta = node.meta.get("val", None)
if i <= 2:
self.assertEqualMeta(meta, x[i])
elif i <= 5:
self.assertEqualMeta(meta, result[i - 3])
else:
self.assertEqual(meta, None)
def test_activation_csr(self):
net = SparseActivationCSR()
x = [torch.randn(3, 3) for _ in range(3)]
result = net(x)
# Build the traced graph.
prog = torch.export.export(net, args=(x,))
# Test args/to_sparse/output.
for i, node in enumerate(prog.graph.nodes):
meta = node.meta.get("val", None)
if i <= 2:
self.assertEqualMeta(meta, x[i])
elif i <= 5:
self.assertEqualMeta(meta, result[i - 3])
else:
self.assertEqual(meta, None)
instantiate_parametrized_tests(TestSparseProp)
if __name__ == "__main__":
run_tests()