blob: 1c2f929791376f07a5271dcc5e828b32d371b29b [file] [log] [blame]
# Owner(s): ["module: sparse"]
import itertools
import random
import unittest
import torch
from torch import nn
from torch.sparse.semi_structured import (
_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG,
SparseSemiStructuredTensor,
to_sparse_semi_structured,
)
from torch.testing import make_tensor
from torch.testing._internal.common_device_type import (
dtypes,
instantiate_device_type_tests,
)
from torch.testing._internal.common_dtype import all_types_and_complex
import torch._dynamo.test_case
from torch.testing._internal.common_utils import (
parametrize,
run_tests,
subtest,
TestCase,
TEST_WITH_ROCM,
IS_WINDOWS,
)
from torch.utils._triton import has_triton
CUSPARSELT_NUM_ALG_IDS = 4
SEMI_STRUCTURED_SUPPORTED_DTYPES = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG.keys()
SEMI_STRUCTURED_SUPPORTED_BACKENDS = []
_IS_SM8X = False
if torch.cuda.is_available():
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
SEMI_STRUCTURED_SUPPORTED_BACKENDS.append("cutlass")
# check if cslt is available for now using this:
# TODO when we add cusparselt as a backend, we can update this to be use torch.cusparselt.is_available()
try:
torch._cslt_compress(torch.ones(128, 256).cuda())
SEMI_STRUCTURED_SUPPORTED_BACKENDS.append("cusparselt")
except Exception:
pass
def rand_sparse_semi_structured_mask(
r, c, dtype=torch.float16, device="cuda", choice=None
):
"""
This function returns a 1:2 sparse matrix of size (r, c).
Note that this means this matrix will also be 2:4 and 4:8 sparse as well.
"""
choices = [[0, 1], [1, 0]]
mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)]
return (
torch.tensor(mask_entries, dtype=dtype, device=device)
.reshape(r, c)
.contiguous()
)
def rand_sparse_semi_structured(r, c, dtype, device, pattern='2by4', choice=None):
if pattern == '2by4':
choices = [
[1, 1, 0, 0],
[1, 0, 1, 0],
[1, 0, 0, 1],
[0, 1, 1, 0],
[0, 1, 0, 1],
[0, 0, 1, 1]
]
mask_entries = [choice or random.choice(choices) for i in range(r * c // 4)]
elif pattern == '1by2':
choices = [
[0, 1],
[1, 0]
]
mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)]
else:
assert(false)
mask = torch.tensor(mask_entries, dtype=torch.bool).view(r, c).to(device)
dense = make_tensor(r, c, dtype=dtype, device=device)
dense[dense == 0] = 1 # To prevent zeros except where mask applied.
dense = dense.masked_fill(~mask, 0)
return dense
def rand_sparse_semi_structured_all_patterns(r, c, dtype, device, pattern='2by4'):
if pattern == '2by4':
choices = [
[[0, 0, 0, 0], [0, 0, 1, 1]],
[[0, 0, 0, 1], [0, 0, 1, 1]],
[[0, 0, 1, 0], [0, 0, 1, 1]],
[[0, 0, 1, 1], [0, 0, 1, 1]],
[[0, 1, 0, 0], [0, 1, 0, 1]],
[[0, 1, 0, 1], [0, 1, 0, 1]],
[[0, 1, 1, 0], [0, 1, 1, 0]],
[[0, 1, 1, 1], [0, 1, 1, 0]],
[[1, 0, 0, 0], [1, 0, 0, 1]],
[[1, 0, 0, 1], [1, 0, 0, 1]],
[[1, 0, 1, 0], [1, 0, 1, 0]],
[[1, 0, 1, 1], [1, 0, 1, 0]],
[[1, 1, 0, 0], [1, 1, 0, 0]],
[[1, 1, 0, 1], [1, 1, 0, 0]],
[[1, 1, 1, 0], [1, 0, 1, 0]],
[[1, 1, 1, 1], [1, 0, 1, 0]],
]
mask_rows = [random.randint(0, len(choices) - 1) for i in range(r * c // 4)]
else:
assert(false)
COL_INV, COL_VAL = 0, 1
mask_entries_inv = [choices[i][COL_INV] for i in mask_rows]
mask_entries_val = [choices[i][COL_VAL] for i in mask_rows]
mask_inv = torch.tensor(mask_entries_inv, dtype=torch.bool).view(r, c).to(device)
mask_val = torch.tensor(mask_entries_val, dtype=torch.bool).view(r, c).to(device)
dense = make_tensor(r, c, dtype=dtype, device=device)
dense[dense == 0] = 1 # To prevent zeros except where mask below applied.
dense_inv = dense.masked_fill(~mask_inv, 0)
dense_val = dense_inv.masked_fill(~mask_val, 0)
return dense_inv, dense_val
class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
def setUp(self):
if not _IS_SM8X:
self.skipTest('Only runs on SM80')
super().setUp()
def tearDown(self):
super().tearDown()
@staticmethod
def _test_mlp_contiguous_relu_compile(backend, dense_input_shape):
"""
Test nn.Linear + .contiguous() + nn.ReLU with SparseSemiStructuredTensor + torch.compile
We expect:
(1) The sparse tensor subclass should turn nn.Linear into `aten._structured_sparse_linear` + `aten.contiguous()`
(2) Inductor should fuse the .contiguous() call into the relu
"""
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(128, 128)
def forward(self, x):
x = self.linear(x)
x = x.contiguous()
return torch.nn.functional.relu(x)
SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"
input = torch.rand(dense_input_shape, device="cuda").half()
model = Model().eval().cuda().half()
mod_linear = model.linear
m, n = mod_linear.weight.shape
mask = torch.Tensor([1, 0, 0, 1]).tile((m, n // 4)).bool().cuda()
# set masked weight
mod_linear.weight = nn.Parameter(mod_linear.weight * mask)
dense_result = model(input)
mod_linear.weight = nn.Parameter(to_sparse_semi_structured(mod_linear.weight))
sparse_result = model(input)
model = torch.compile(model, backend="inductor", fullgraph=True)
sparse_compile_result = model(input)
# test that sparse_compile_result and dense_result are numerically close
assert torch.allclose(dense_result, sparse_compile_result, rtol=1e-3, atol=1e-3)
# assert sparse and sparse_compile have the same strides,
# as meta registrations may return contiguous tensors when the output is transposed
# https://github.com/pytorch/pytorch/pull/114477
assert sparse_result.stride() == sparse_compile_result.stride()
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
def test_mlp_contiguous_relu_compile_cusparselt(self):
"""
test for cuSPASRELt meta registrations (_cslt_sparse_mm) + torch.compile
"""
for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]:
SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cusparselt", dense_input_shape)
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
def test_mlp_contiguous_relu_compile_cutlass(self):
"""
test for CUTLASS meta registrations (_sparse_semi_structured_linear) + torch.compile
"""
for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]:
SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cutlass", dense_input_shape)
class TestSparseSemiStructured(TestCase):
def setUp(self):
if not _IS_SM8X:
self.skipTest('Only runs on SM80')
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_to_sparse_semi_structured(self, dtype, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype)
A_sparse = to_sparse_semi_structured(A)
assert A.shape == A_sparse.shape
assert A.device == A_sparse.device
assert A.dtype == A_sparse.dtype
assert isinstance(A, torch.Tensor)
assert isinstance(A_sparse, SparseSemiStructuredTensor)
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("dense_input_shape", [(128, 1), (128, 64), (128, 128)])
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_mm_sparse_first_NN(self, dense_input_shape, dtype, device, backend):
"""
Ensure torch.mm(A_sparse, B) is correct for float16 and will throw error for int8
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
A_sparse = to_sparse_semi_structured(A)
B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
if dtype is torch.int8:
# This should fail
if backend == "cutlass":
with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_cutlass_dispatch_layouts"):
sparse_result = torch.mm(A_sparse, B)
else:
with self.assertRaisesRegex(RuntimeError,
"CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"):
sparse_result = torch.mm(A_sparse, B)
else:
dense_result = torch.mm(A, B)
sparse_result = torch.mm(A_sparse, B)
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_mm_sparse_first_NT(self, dense_input_shape, dtype, device, backend):
"""
Ensure torch.mm(A_sparse, B.t()) is correct for float16/bfloat16
and will throw an error for int8 + padding
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
A_sparse = to_sparse_semi_structured(A)
B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
if dtype is torch.int8 and dense_input_shape in {(1, 128)}:
# padding with int8 throws an error because transposing B yields a contiguous output
# and row-row 2:4 sparse @ dense with NN is not supported by cuSPARSELt or CUTLASS.
if backend == "cutlass":
with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_cutlass_dispatch_layouts"):
sparse_result = torch.mm(A_sparse, B.t())
else:
with self.assertRaisesRegex(RuntimeError,
"CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"):
sparse_result = torch.mm(A_sparse, B.t())
elif dtype is torch.int8:
# test transpose
# NOTE: CUTLASS and cuSPARSELt have slightly different int8 behavior.
# CUTLASS will output to an int32 tensor while cuSPARSELt will output to a int8 tensor
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int32 if backend == "cutlass" else torch.int8)
sparse_result = torch.mm(A_sparse, B.t())
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
else:
# test transpose
dense_result = torch.mm(A, B.t())
sparse_result = torch.mm(A_sparse, B.t())
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_mm_sparse_first_TN(self, dtype, dense_input_shape, device, backend):
"""
Ensure torch.mm(A_sparse.t(), B) throws error
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype)
A_sparse = to_sparse_semi_structured(A)
B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)
with self.assertRaisesRegex(
NotImplementedError,
r"arg0: SparseSemiStructuredTensor\(.*transposed=True",
):
torch.mm(A_sparse.t(), B)
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_mm_sparse_second_NT(self, dense_input_shape, dtype, device, backend):
"""
Ensure torch.mm(A, B_sparse.t()) is correct
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
B = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
B_sparse = to_sparse_semi_structured(B)
A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype)
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
if dtype is torch.int8:
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int32 if backend == "cutlass" else torch.int8)
sparse_result = torch.mm(A, B_sparse.t())
else:
dense_result = torch.mm(A, B.t())
sparse_result = torch.mm(A, B_sparse.t())
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_mm_sparse_second_NN(self, dense_input_shape, dtype, device, backend):
"""
Ensure torch.mm(A, B_sparse) throws error
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
B = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
B_sparse = to_sparse_semi_structured(B)
A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype)
with self.assertRaisesRegex(
NotImplementedError,
r"arg1: SparseSemiStructuredTensor\(.*transposed=False",
):
sparse_result = torch.mm(A, B_sparse)
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)])
@parametrize("inference_mode", [subtest(True), subtest(False)])
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_linear(self, dense_input_shape, inference_mode, device, backend):
"""
Test nn.Linear has the same numerics
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
input = torch.rand((dense_input_shape), device=device).half()
model = nn.Linear(128, 256).to(device).half()
m, n = model.weight.shape
mask = rand_sparse_semi_structured_mask(m, n, device=device, dtype=torch.bool)
# set masked weight
model.weight = nn.Parameter(model.weight * mask)
dense_result = model(input)
model.weight = nn.Parameter(to_sparse_semi_structured(model.weight))
if inference_mode:
with torch.inference_mode():
sparse_result = model(input)
else:
sparse_result = model(input)
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)])
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_mlp(self, device, dense_input_shape, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass"
input = torch.rand(dense_input_shape, device=device).half()
model = (
nn.Sequential(
nn.Linear(128, 256),
nn.Linear(256, 128),
)
.half()
.to(device)
)
for i in range(2):
m, n = model[i].weight.shape
mask = rand_sparse_semi_structured_mask(
m, n, device=device, dtype=torch.bool
)
# set masked weight
model[i].weight = nn.Parameter(model[i].weight * mask)
dense_result = model(input)
for i in range(2):
model[i].weight = nn.Parameter(to_sparse_semi_structured(model[i].weight))
sparse_result = model(input)
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_values(self, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
A = rand_sparse_semi_structured_mask(128, 128)
A_sparse = to_sparse_semi_structured(A)
assert A_sparse.values().shape == (128, 64)
assert (A_sparse.values() == 1).all()
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_indices(self, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
A = rand_sparse_semi_structured_mask(128, 128)
A_sparse = to_sparse_semi_structured(A)
assert A_sparse.indices().shape == (128, 8)
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_min_sparse_shape(self, dtype, device, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
config = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[dtype]
A = rand_sparse_semi_structured_mask(config.sparse_min_rows, config.sparse_min_cols, dtype=dtype, device=device)
A_sparse = to_sparse_semi_structured(A)
B = torch.rand((config.sparse_min_cols, config.dense_min_cols), device=device).to(dtype)
if dtype == torch.int8:
dense_res = torch.mm(A.cpu(), B.cpu()).to(device, dtype=torch.int32 if backend == "cutlass" else torch.int8)
# int8 sparse matmul not supported for R/R -> R layout, so we transpose one of the arguments to get R/C -> R
B_t = B.t().contiguous()
sparse_res = torch.mm(A_sparse, B_t.t())
else:
dense_res = torch.mm(A, B)
sparse_res = torch.mm(A_sparse, B)
assert torch.allclose(sparse_res, dense_res, rtol=1e-3, atol=1e-3)
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_unsupported_shape(self, dtype, device, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
A = rand_sparse_semi_structured_mask(2, 2, dtype=dtype, device=device)
with self.assertRaisesRegex(RuntimeError, "Error original_tensor.shape"):
A_sparse = to_sparse_semi_structured(A)
@dtypes(*all_types_and_complex())
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_unsupported_dtype(self, dtype, device, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype, device=device)
if dtype not in SEMI_STRUCTURED_SUPPORTED_DTYPES:
with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dtype"):
A_sparse = to_sparse_semi_structured(A)
else:
A_sparse = to_sparse_semi_structured(A)
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
def test_unsupported_dim(self, device, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
A = torch.rand(128, 128, 128, device=device, dtype=torch.float16)
with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dim"):
A_sparse = to_sparse_semi_structured(A)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@parametrize("backend", ["cutlass"])
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
def test_linear_cutlass(self, device, dtype, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
def run_test(batch_shape, m, n, k, device, dtype, dtype_out, add_bias, activation, rtol, atol):
pattern = '2by4' if dtype != torch.float32 else '1by2'
weight = rand_sparse_semi_structured(m, k, dtype, device, pattern=pattern)
input = make_tensor((*batch_shape, n, k), dtype=dtype, device=device)
bias = make_tensor((m,), dtype=dtype_out, device=device) if add_bias else None
dtype_dense = torch.float32
input_dense = input.to(dtype_dense)
weight_dense = weight.to(dtype_dense)
bias_dense = bias.to(dtype_dense) if add_bias else None
output0 = torch.nn.functional.linear(input_dense, weight_dense, bias=bias_dense)
if activation == "relu":
relu = torch.nn.ReLU()
output0 = relu(output0)
elif activation == "silu":
silu = torch.nn.SiLU()
output0 = silu(output0)
compressed = to_sparse_semi_structured(weight)
weight_sparse = compressed.values()
meta = compressed.indices()
output1 = torch._sparse_semi_structured_linear(input, weight_sparse, meta, bias=bias, activation=activation)
torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol)
if dtype == torch.float32:
# Inputs are converted to TF32 internally for sparse GEMM,
# so make dense GEMM to do the same for matching results.
orig = torch.backends.cuda.matmul.allow_tf32
torch.backends.cuda.matmul.allow_tf32 = True
batch_shapes = [[], [3], [3, 1]]
dtype_out = {torch.int8: torch.int32, torch.half: torch.half, torch.bfloat16: torch.bfloat16, torch.float32: torch.float32}
activations = [None, "relu", "silu"]
rtol, atol = 1e-3, 1e-3
if dtype == torch.bfloat16:
rtol, atol = 5e-3, 5e-3
elif dtype == torch.float32:
rtol, atol = 1e-3, 5e-1
for batch_shape, m, n, k, add_bias, activation in \
itertools.product(batch_shapes, range(3), range(3), range(3), (False, True), activations):
if activation == "silu" and dtype == torch.int8:
continue # SiLU not supported for integer inputs
m = 2 ** m * 32
n = 2 ** n * 32
k = 2 ** k * 128
run_test(batch_shape, m, n, k, device, dtype, dtype_out[dtype], add_bias, activation, rtol, atol)
if dtype == torch.float32:
torch.backends.cuda.matmul.allow_tf32 = orig
@unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch")
@parametrize("backend", ["cutlass"])
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
def test_conversions(self, device, dtype, backend):
if dtype == torch.float32:
return
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
def run_test(r, c, device, dtype):
pattern = '2by4' if dtype != torch.float32 else '1by2'
dense_ref = rand_sparse_semi_structured(r, c, dtype, device, pattern=pattern)
compressed = to_sparse_semi_structured(dense_ref)
# The torch.ops.aten._to_sparse_semi_structured operator
# uses CUTLASS to perform conversion from given dense
# matrix to the pair of corresponding sparse and metadata
# matrices, with the later used here as a reference to
# compare the metadata matrix produced by conversion
# performed by SparseSemiStructuredTensor class
# constructor against.
_, meta_ref = torch.ops.aten._to_sparse_semi_structured(dense_ref)
meta = compressed.indices()
torch.testing.assert_close(meta, meta_ref, rtol=0, atol=0)
dense = compressed.to_dense()
torch.testing.assert_close(dense, dense_ref, rtol=0, atol=0)
shapes = [[32, 128], [32, 256], [64, 128], [64, 256]]
for r, c in shapes:
run_test(r, c, device, dtype)
@unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch")
@parametrize("backend", ["cutlass"])
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
def test_conversions_all_patterns(self, device, dtype, backend):
if dtype == torch.float32:
return
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
r, c = 32, 128
dense_inv, dense_val = rand_sparse_semi_structured_all_patterns(r, c, dtype, device)
compressed = to_sparse_semi_structured(dense_inv)
dense = compressed.to_dense()
torch.testing.assert_close(dense, dense_val, rtol=0, atol=0)
class TestCUSPARSELT(TestCase):
"""
This contains cuSPARSELt specific tests.
"""
def setUp(self):
if not _IS_SM8X:
self.skipTest('Only runs on SM80')
if "cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
self.skipTest('cuSPARSELt not enabled')
else:
SparseSemiStructuredTensor._FORCE_CUTLASS = False
@parametrize("dense_input_shape", [(128, 128)])
def test_cslt_sparse_mm_int8_in_fp16_out(self, dense_input_shape, device):
A = rand_sparse_semi_structured_mask(128, 128, dtype=torch.int8)
A_compressed = torch._cslt_compress(A)
B = torch.rand(dense_input_shape, device=device).to(torch.int8)
dense_result = torch.mm(A.cpu().to(torch.int64), B.t().cpu().to(torch.int64)).to(device, dtype=torch.float16)
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), out_dtype=torch.float16)
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
@dtypes(torch.float16, torch.bfloat16)
def test_cslt_sparse_mm_alpha(self, dtype, device):
A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(dtype).cuda()
B = torch.ones((256, 128), device=device).to(dtype)
alpha = torch.Tensor([2**(-i) for i in range(128)]).cuda()
A_compressed = torch._cslt_compress(A)
sparse_result = torch._cslt_sparse_mm(A_compressed, B, alpha=alpha)
alpha_scaled = torch.stack([alpha] * 128).t()
dense_result = alpha_scaled * torch.mm(A.to(torch.float32), B.to(torch.float32))
dense_result = dense_result.to(dtype)
assert torch.allclose(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
def test_cslt_sparse_mm_alpha_int8_in_f16_out(self, device):
A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda()
B = torch.ones((128, 256), device=device).to(torch.int8).t()
alpha = torch.Tensor([2**(-i) for i in range(128)]).cuda()
A_compressed = torch._cslt_compress(A)
sparse_result = torch._cslt_sparse_mm(A_compressed, B, alpha=alpha, out_dtype=torch.float16).cpu()
alpha_scaled = torch.stack([alpha] * 128).t()
dense_result = alpha_scaled.cpu() * torch.mm(A.to(torch.int32).cpu(), B.to(torch.int32).cpu())
dense_result = dense_result.to(torch.float16)
assert torch.allclose(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
@parametrize("alg_id", range(CUSPARSELT_NUM_ALG_IDS))
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
def test_cslt_sparse_mm_alg_id(self, device, dtype, alg_id):
# alg_id=3 not supported for float32 dtype
if dtype == torch.float32 and alg_id == 3:
return
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
A_compressed = torch._cslt_compress(A)
B = torch.ones((128, 128), device=device).to(dtype)
A_compressed = torch._cslt_compress(A)
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), alg_id=alg_id)
dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
dense_result = dense_result.to(dtype)
assert torch.allclose(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
def test_cslt_sparse_mm_search(self, device, dtype):
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
A_compressed = torch._cslt_compress(A)
B = torch.ones((128, 128), device=device).to(dtype)
A_compressed = torch._cslt_compress(A)
alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t())
# for cuSPARSELt v0.4.0 there is a bug where although there are 5 alg_ids, we run into an error
# when setting using the last one (4)
# in cuSPARSELt v0.5.0 there are only 4 alg_ids total, so we should remove the +1 here when we update.
assert alg_id in range(CUSPARSELT_NUM_ALG_IDS + 1)
instantiate_device_type_tests(TestSparseSemiStructured, globals(), only_for="cuda")
instantiate_device_type_tests(TestCUSPARSELT, globals(), only_for="cuda")
if __name__ == "__main__":
run_tests()