| # Owner(s): ["module: sparse"] |
| import itertools |
| import random |
| import unittest |
| |
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
| |
| from torch.sparse import ( |
| SparseSemiStructuredTensor, |
| SparseSemiStructuredTensorCUSPARSELT, |
| SparseSemiStructuredTensorCUTLASS, |
| to_sparse_semi_structured, |
| ) |
| |
| from torch.sparse._semi_structured_conversions import ( |
| sparse_semi_structured_from_dense_cutlass, |
| _sparse_semi_structured_tile, |
| _compute_compressed_swizzled_bitmask, |
| ) |
| |
| from torch.testing import make_tensor |
| from torch.testing._internal.common_cuda import _get_torch_cuda_version |
| from torch.testing._internal.common_device_type import ( |
| dtypes, |
| instantiate_device_type_tests, |
| ) |
| |
| from torch.testing._internal.common_dtype import all_types_and_complex |
| import torch._dynamo.test_case |
| from torch.testing._internal.common_utils import ( |
| parametrize, |
| run_tests, |
| subtest, |
| TestCase, |
| TEST_WITH_ROCM, |
| IS_WINDOWS, |
| ) |
| |
| import pytest |
| |
| from torch.utils._triton import has_triton |
| |
| SEMI_STRUCTURED_SUPPORTED_BACKENDS = dict() |
| |
| _IS_SM8X = False |
| _IS_SM9X = False |
| |
| if torch.cuda.is_available(): |
| _IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8 |
| _IS_SM9X = torch.cuda.get_device_capability(0)[0] == 9 |
| |
| # CUTLASS kernels only work for Ampere |
| if _IS_SM8X: |
| SEMI_STRUCTURED_SUPPORTED_BACKENDS["cutlass"] = SparseSemiStructuredTensorCUTLASS |
| |
| # add cuSPASRELt tests if available |
| if torch.backends.cusparselt.is_available() and (_IS_SM8X or _IS_SM9X): |
| SEMI_STRUCTURED_SUPPORTED_BACKENDS["cusparselt"] = SparseSemiStructuredTensorCUSPARSELT |
| |
| inference_dtypes = dtypes(torch.float16, torch.bfloat16, torch.int8) |
| training_dtypes = dtypes(torch.float16, torch.bfloat16) |
| parametrize_backends = parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS) |
| |
| atol_rtol_kw = { |
| torch.float16: { |
| "rtol": 1e-3, |
| "atol": 1e-3, |
| }, |
| torch.bfloat16: { |
| "rtol": 1e-1, |
| "atol": 1e-1, |
| }, |
| } |
| |
| def sparse24_largest_mask_2d(original): |
| sparse = SparseSemiStructuredTensorCUTLASS.prune_dense_static_sort(original) |
| return sparse.to_dense().bool() |
| |
| def sparsify24_dense(original): |
| return sparse24_largest_mask_2d(original) * original |
| |
| def rand_sparse_semi_structured_mask( |
| r, c, dtype=torch.float16, device="cuda", choice=None |
| ): |
| """ |
| This function returns a 1:2 sparse matrix of size (r, c). |
| Note that this means this matrix will also be 2:4 and 4:8 sparse as well. |
| """ |
| |
| choices = [[0, 1], [1, 0]] |
| mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)] |
| |
| return ( |
| torch.tensor(mask_entries, dtype=dtype, device=device) |
| .reshape(r, c) |
| .contiguous() |
| ) |
| |
| def rand_sparse_semi_structured(r, c, dtype, device, choice=None): |
| pattern = '2by4' if dtype != torch.float32 else '1by2' |
| if pattern == '1by2': |
| ksparse = 2 |
| choices = [ |
| [0, 1], |
| [1, 0] |
| ] |
| elif pattern == '2by4': |
| ksparse = 4 |
| choices = [ |
| [1, 1, 0, 0], |
| [1, 0, 1, 0], |
| [1, 0, 0, 1], |
| [0, 1, 1, 0], |
| [0, 1, 0, 1], |
| [0, 0, 1, 1] |
| ] |
| mask_entries = [choice or random.choice(choices) for i in range(r * c // ksparse)] |
| mask = torch.tensor(mask_entries, dtype=torch.bool).view(r, c).to(device) |
| dense = make_tensor(r, c, dtype=dtype, device=device) |
| dense[dense == 0] = 1 # To prevent zeros except where mask applied. |
| dense = dense.masked_fill(~mask, 0) |
| return dense |
| |
| |
| def rand_sparse_semi_structured_all_patterns(r, c, dtype, device): |
| pattern = '2by4' if dtype != torch.float32 else '1by2' |
| if pattern == '1by2': |
| ksparse = 2 |
| choices = [ |
| [[0, 0], [0, 1]], |
| [[0, 1], [0, 1]], |
| [[1, 0], [1, 0]], |
| [[1, 1], [1, 0]] |
| ] |
| elif pattern == '2by4': |
| ksparse = 4 |
| choices = [ |
| [[0, 0, 0, 0], [0, 0, 1, 1]], |
| [[0, 0, 0, 1], [0, 0, 1, 1]], |
| [[0, 0, 1, 0], [0, 0, 1, 1]], |
| [[0, 0, 1, 1], [0, 0, 1, 1]], |
| [[0, 1, 0, 0], [0, 1, 1, 0]], |
| [[0, 1, 0, 1], [0, 1, 0, 1]], |
| [[0, 1, 1, 0], [0, 1, 1, 0]], |
| [[0, 1, 1, 1], [0, 1, 0, 1]], |
| [[1, 0, 0, 0], [1, 0, 1, 0]], |
| [[1, 0, 0, 1], [1, 0, 0, 1]], |
| [[1, 0, 1, 0], [1, 0, 1, 0]], |
| [[1, 0, 1, 1], [1, 0, 0, 1]], |
| [[1, 1, 0, 0], [1, 1, 0, 0]], |
| [[1, 1, 0, 1], [1, 1, 0, 0]], |
| [[1, 1, 1, 0], [1, 1, 0, 0]], |
| [[1, 1, 1, 1], [1, 1, 0, 0]], |
| ] |
| mask_rows = [random.randint(0, len(choices) - 1) for i in range(r * c // ksparse)] |
| |
| COL_INV, COL_VAL = 0, 1 |
| mask_entries_inv = [choices[i][COL_INV] for i in mask_rows] |
| mask_entries_val = [choices[i][COL_VAL] for i in mask_rows] |
| mask_inv = torch.tensor(mask_entries_inv, dtype=torch.bool).view(r, c).to(device) |
| mask_val = torch.tensor(mask_entries_val, dtype=torch.bool).view(r, c).to(device) |
| dense = make_tensor(r, c, dtype=dtype, device=device) |
| dense[dense == 0] = 1 # To prevent zeros except where mask below applied. |
| dense_inv = dense.masked_fill(~mask_inv, 0) |
| dense_val = dense_inv.masked_fill(~mask_val, 0) |
| |
| return dense_inv, dense_val |
| |
| |
| class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase): |
| |
| def setUp(self): |
| if len(SEMI_STRUCTURED_SUPPORTED_BACKENDS) == 0: |
| self.skipTest('semi-structured sparsity has no available backend!') |
| super().setUp() |
| |
| def tearDown(self): |
| super().tearDown() |
| |
| @staticmethod |
| def _test_mlp_contiguous_relu_compile(backend, dense_input_shape): |
| """ |
| Test nn.Linear + .contiguous() + nn.ReLU with SparseSemiStructuredTensor + torch.compile |
| We expect: |
| (1) The sparse tensor subclass should turn nn.Linear into `aten._structured_sparse_addmm` + `aten.contiguous()` |
| (2) Inductor should fuse the .contiguous() call into the relu |
| """ |
| |
| class Model(nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.linear = nn.Linear(128, 128) |
| |
| def forward(self, x): |
| x = self.linear(x) |
| x = x.contiguous() |
| return torch.nn.functional.relu(x) |
| |
| input = torch.rand(dense_input_shape, device="cuda").half() |
| model = Model().eval().cuda().half() |
| mod_linear = model.linear |
| m, n = mod_linear.weight.shape |
| mask = torch.Tensor([1, 0, 0, 1]).tile((m, n // 4)).bool().cuda() |
| # set masked weight |
| mod_linear.weight = nn.Parameter(mod_linear.weight * mask) |
| |
| dense_result = model(input) |
| mod_linear.weight = nn.Parameter(SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].from_dense(mod_linear.weight)) |
| sparse_result = model(input) |
| |
| model = torch.compile(model, backend="inductor", fullgraph=True) |
| sparse_compile_result = model(input) |
| |
| # test that sparse_compile_result and dense_result are numerically close |
| torch.testing.assert_close(dense_result, sparse_compile_result, rtol=1e-3, atol=1e-3) |
| # assert sparse and sparse_compile have the same strides, |
| # as meta registrations may return contiguous tensors when the output is transposed |
| # https://github.com/pytorch/pytorch/pull/114477 |
| assert sparse_result.stride() == sparse_compile_result.stride() |
| |
| @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows") |
| @unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine") |
| def test_mlp_contiguous_relu_compile_cusparselt(self): |
| """ |
| test for cuSPASRELt meta registrations (_cslt_sparse_mm) + torch.compile |
| """ |
| for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]: |
| SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cusparselt", dense_input_shape) |
| |
| |
| @unittest.skipIf("cutlass" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cutlass not supported on this machine") |
| @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows") |
| def test_mlp_contiguous_relu_compile_cutlass(self): |
| """ |
| test for CUTLASS meta registrations (_sparse_semi_structured_addmm) + torch.compile |
| """ |
| for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]: |
| SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cutlass", dense_input_shape) |
| |
| |
| @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows") |
| @unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine") |
| def test_sp24_compile(self) -> None: |
| x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True) |
| e = torch.eye(x.shape[0], x.shape[0], device="cuda", dtype=torch.float16) |
| |
| def fn(x, e): |
| y = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(x) |
| y = y.t() |
| return x @ y |
| |
| # Eager |
| output = fn(x, e) |
| output.backward(output) |
| # Torch compile |
| output = torch.compile(fn)(x, e) |
| output.backward(output) |
| |
| class TestSparseSemiStructured(TestCase): |
| |
| def setUp(self): |
| if len(SEMI_STRUCTURED_SUPPORTED_BACKENDS) == 0: |
| self.skipTest('semi-structured sparsity has no available backend!') |
| if IS_WINDOWS: |
| self.skipTest("torch.compile not supported on windows") |
| |
| @inference_dtypes |
| @parametrize_backends |
| def test_to_sparse_semi_structured(self, dtype, backend): |
| SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") |
| A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype) |
| A_sparse = to_sparse_semi_structured(A) |
| |
| assert A.shape == A_sparse.shape |
| assert A.device == A_sparse.device |
| assert A.dtype == A_sparse.dtype |
| |
| assert isinstance(A, torch.Tensor) |
| assert isinstance(A_sparse, SparseSemiStructuredTensor) |
| |
| @inference_dtypes |
| @parametrize_backends |
| @parametrize("dense_input_shape", [(128, 1), (128, 64), (128, 128)]) |
| def test_mm_sparse_first_NN(self, dense_input_shape, dtype, device, backend): |
| """ |
| Ensure torch.mm(A_sparse, B) is correct for float16 and will throw error for int8 |
| """ |
| SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") |
| A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype) |
| A_sparse = to_sparse_semi_structured(A) |
| |
| B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype) |
| |
| # Currently we don't support int matmul on GPU, so evaluate on CPU and copy over |
| if dtype is torch.int8: |
| if backend == "cutlass": |
| with self.assertRaisesRegex(RuntimeError, "spgemm_cutlass_dispatch_layouts"): |
| sparse_result = torch.mm(A_sparse, B) |
| else: |
| with self.assertRaisesRegex(RuntimeError, |
| "CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"): |
| sparse_result = torch.mm(A_sparse, B) |
| else: |
| dense_result = torch.mm(A, B) |
| sparse_result = torch.mm(A_sparse, B) |
| torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) |
| |
| @inference_dtypes |
| @parametrize_backends |
| @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)]) |
| def test_mm_sparse_first_NT(self, dense_input_shape, dtype, device, backend): |
| """ |
| Ensure torch.mm(A_sparse, B.t()) is correct for float16/bfloat16 |
| and will throw an error for int8 + padding |
| """ |
| SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") |
| A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype) |
| A_sparse = to_sparse_semi_structured(A) |
| |
| B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype) |
| |
| # Currently we don't support int matmul on GPU, so evaluate on CPU and copy over |
| if dtype is torch.int8 and dense_input_shape in {(1, 128)}: |
| # padding with int8 throws an error because transposing B yields a contiguous output |
| # and row-row 2:4 sparse @ dense with NN is not supported by cuSPARSELt or CUTLASS. |
| if backend == "cutlass": |
| with self.assertRaisesRegex(RuntimeError, "spgemm_cutlass_dispatch_layouts"): |
| sparse_result = torch.mm(A_sparse, B.t()) |
| else: |
| with self.assertRaisesRegex(RuntimeError, |
| "CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"): |
| sparse_result = torch.mm(A_sparse, B.t()) |
| elif dtype is torch.int8: |
| # test transpose |
| dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int8) |
| sparse_result = torch.mm(A_sparse, B.t()) |
| torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) |
| else: |
| # test transpose |
| dense_result = torch.mm(A, B.t()) |
| sparse_result = torch.mm(A_sparse, B.t()) |
| torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) |
| |
| @inference_dtypes |
| @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)]) |
| @parametrize_backends |
| def test_mm_sparse_first_TN(self, dtype, dense_input_shape, device, backend): |
| """ |
| Ensure torch.mm(A_sparse.t(), B) throws error |
| """ |
| SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") |
| if backend == "cutlass" and IS_WINDOWS: |
| self.skipTest("CUTLASS not supported on Windows") |
| A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype) |
| A_sparse = to_sparse_semi_structured(A) |
| |
| B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype) |
| |
| with self.assertRaisesRegex( |
| NotImplementedError, |
| r"`SparseSemiStructuredTensor.*` matmul: operation is not supported", |
| ): |
| torch.mm(A_sparse.t(), B) |
| |
| @inference_dtypes |
| @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)]) |
| @parametrize_backends |
| def test_mm_sparse_second_NT(self, dense_input_shape, dtype, device, backend): |
| """ |
| Ensure torch.mm(A, B_sparse.t()) is correct |
| """ |
| SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") |
| if backend == "cutlass" and IS_WINDOWS: |
| self.skipTest("CUTLASS not supported on Windows") |
| B = rand_sparse_semi_structured_mask(256, 128, dtype=dtype) |
| B_sparse = to_sparse_semi_structured(B) |
| |
| A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype) |
| |
| # Currently we don't support int matmul on GPU, so evaluate on CPU and copy over |
| if dtype is torch.int8: |
| dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int8) |
| sparse_result = torch.mm(A, B_sparse.t()) |
| else: |
| dense_result = torch.mm(A, B.t()) |
| sparse_result = torch.mm(A, B_sparse.t()) |
| |
| torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) |
| |
| @inference_dtypes |
| @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)]) |
| @parametrize_backends |
| def test_mm_sparse_second_NN(self, dense_input_shape, dtype, device, backend): |
| """ |
| Ensure torch.mm(A, B_sparse) throws error |
| """ |
| SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") |
| if backend == "cutlass" and IS_WINDOWS: |
| self.skipTest("CUTLASS not supported on Windows") |
| B = rand_sparse_semi_structured_mask(256, 128, dtype=dtype) |
| B_sparse = to_sparse_semi_structured(B) |
| |
| A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype) |
| |
| with self.assertRaisesRegex( |
| NotImplementedError, |
| r"`SparseSemiStructuredTensor.*` matmul: operation is not supported", |
| ): |
| sparse_result = torch.mm(A, B_sparse) |
| |
| @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)]) |
| @parametrize("inference_mode", [subtest(True), subtest(False)]) |
| @parametrize_backends |
| def test_linear(self, dense_input_shape, inference_mode, device, backend): |
| """ |
| Test nn.Linear has the same numerics |
| """ |
| SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") |
| if backend == "cutlass" and IS_WINDOWS: |
| self.skipTest("CUTLASS not supported on Windows") |
| input = torch.rand((dense_input_shape), device=device).half() |
| model = nn.Linear(128, 256).to(device).half() |
| m, n = model.weight.shape |
| mask = rand_sparse_semi_structured_mask(m, n, device=device, dtype=torch.bool) |
| # set masked weight |
| model.weight = nn.Parameter(model.weight * mask) |
| |
| dense_result = model(input) |
| |
| model.weight = nn.Parameter(to_sparse_semi_structured(model.weight)) |
| |
| if inference_mode: |
| with torch.inference_mode(): |
| sparse_result = model(input) |
| else: |
| sparse_result = model(input) |
| |
| torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) |
| |
| @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)]) |
| @parametrize_backends |
| def test_mlp(self, device, dense_input_shape, backend): |
| SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") |
| input = torch.rand(dense_input_shape, device=device).half() |
| model = ( |
| nn.Sequential( |
| nn.Linear(128, 256), |
| nn.Linear(256, 128), |
| ) |
| .half() |
| .to(device) |
| ) |
| |
| for i in range(2): |
| m, n = model[i].weight.shape |
| mask = rand_sparse_semi_structured_mask( |
| m, n, device=device, dtype=torch.bool |
| ) |
| # set masked weight |
| model[i].weight = nn.Parameter(model[i].weight * mask) |
| |
| dense_result = model(input) |
| |
| for i in range(2): |
| model[i].weight = nn.Parameter(to_sparse_semi_structured(model[i].weight)) |
| |
| sparse_result = model(input) |
| |
| torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) |
| |
| @parametrize_backends |
| def test_values(self, backend): |
| SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") |
| if backend == "cutlass" and IS_WINDOWS: |
| self.skipTest("CUTLASS not supported on Windows") |
| A = rand_sparse_semi_structured_mask(128, 128) |
| A_sparse = to_sparse_semi_structured(A) |
| assert A_sparse.values().shape == (128, 64) |
| assert (A_sparse.values() == 1).all() |
| |
| @parametrize_backends |
| def test_indices(self, backend): |
| SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") |
| if backend == "cutlass" and IS_WINDOWS: |
| self.skipTest("CUTLASS not supported on Windows") |
| A = rand_sparse_semi_structured_mask(128, 128) |
| A_sparse = to_sparse_semi_structured(A) |
| assert A_sparse.indices().shape == (128, 8) |
| |
| @inference_dtypes |
| @parametrize_backends |
| def test_min_sparse_shape(self, dtype, device, backend): |
| SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") |
| config = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend]._DTYPE_SHAPE_CONSTRAINTS[dtype] |
| A = rand_sparse_semi_structured_mask(config.sparse_min_rows, config.sparse_min_cols, dtype=dtype, device=device) |
| A_sparse = to_sparse_semi_structured(A) |
| B = torch.rand((config.sparse_min_cols, config.dense_min_cols), device=device).to(dtype) |
| if dtype == torch.int8: |
| dense_res = torch.mm(A.cpu(), B.cpu()).to(device, dtype=torch.int8) |
| # int8 sparse matmul not supported for R/R -> R layout, so we transpose one of the arguments to get R/C -> R |
| B_t = B.t().contiguous() |
| sparse_res = torch.mm(A_sparse, B_t.t()) |
| else: |
| dense_res = torch.mm(A, B) |
| sparse_res = torch.mm(A_sparse, B) |
| torch.testing.assert_close(sparse_res, dense_res, rtol=1e-3, atol=1e-3) |
| |
| @inference_dtypes |
| @parametrize_backends |
| def test_unsupported_shape(self, dtype, device, backend): |
| SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") |
| if backend == "cutlass" and IS_WINDOWS: |
| self.skipTest("CUTLASS not supported on Windows") |
| A = rand_sparse_semi_structured_mask(2, 2, dtype=dtype, device=device) |
| with self.assertRaisesRegex(RuntimeError, "Error original_tensor.shape"): |
| A_sparse = to_sparse_semi_structured(A) |
| |
| @dtypes(*all_types_and_complex()) |
| @parametrize_backends |
| def test_unsupported_dtype(self, dtype, device, backend): |
| SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") |
| if backend == "cutlass" and IS_WINDOWS: |
| self.skipTest("CUTLASS not supported on Windows") |
| A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype, device=device) |
| |
| if dtype not in SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend]._DTYPE_SHAPE_CONSTRAINTS: |
| with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dtype"): |
| A_sparse = to_sparse_semi_structured(A) |
| else: |
| A_sparse = to_sparse_semi_structured(A) |
| |
| @parametrize_backends |
| def test_unsupported_dim(self, device, backend): |
| SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") |
| if backend == "cutlass" and IS_WINDOWS: |
| self.skipTest("CUTLASS not supported on Windows") |
| A = torch.rand(128, 128, 128, device=device, dtype=torch.float16) |
| |
| with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dim"): |
| A_sparse = to_sparse_semi_structured(A) |
| |
| |
| def create_random_mask(shape) -> torch.Tensor: |
| r = random.Random(0) |
| mask = torch.zeros(shape, dtype=torch.bool) |
| for line in range(mask.shape[0]): |
| for col in range(0, mask.shape[1], 4): |
| sparsity = r.choice( |
| [ |
| [False, False, True, True], |
| [False, True, False, True], |
| [True, False, False, True], |
| [False, True, True, False], |
| [True, False, True, False], |
| [True, True, False, False], |
| ] |
| ) |
| mask[line, col : col + 4] = torch.tensor(sparsity, dtype=torch.bool) |
| return mask |
| |
| class TestSparseSemiStructuredTraining(TestCase): |
| |
| def setUp(self): |
| if not _IS_SM8X: |
| self.skipTest("SparseSemiStructuredTensor training only supported on SM8x (Ampere)") |
| |
| if IS_WINDOWS: |
| self.skipTest('CUTLASS not supported on windows') |
| |
| |
| @training_dtypes |
| def test_prune_dense_static_sort(self, dtype) -> None: |
| # Ideally we would like to clone and compare, but that won't work because the sorting order will be different |
| # instead we pass the pruned matrix to the CUDA implementation and preserve the sparsity pattern. |
| dense = torch.randn(128, 128, device="cuda", dtype=dtype) |
| pruned = _sparse_semi_structured_tile(dense) |
| |
| # CUTLASS |
| reference_cutlass = SparseSemiStructuredTensorCUTLASS.prune_dense_static_sort(pruned, algorithm="largest_abs_values_greedy") |
| torch.testing.assert_close(pruned, reference_cutlass.to_dense()) |
| |
| packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned) |
| packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous()) |
| meta_cutlass = meta_cutlass.as_strided(reference_cutlass.meta.shape, reference_cutlass.meta.stride()) |
| meta_t_cutlass = meta_t_cutlass.as_strided(reference_cutlass.meta_t.shape, reference_cutlass.meta_t.stride()) |
| compressed_swizzled_bitmask = _compute_compressed_swizzled_bitmask(pruned) |
| compressed_swizzled_bitmask = compressed_swizzled_bitmask.as_strided(reference_cutlass.compressed_swizzled_bitmask.shape, |
| reference_cutlass.compressed_swizzled_bitmask.stride()) |
| cutlass = SparseSemiStructuredTensorCUTLASS(dense.shape, |
| packed_cutlass, |
| meta_cutlass, |
| packed_t_cutlass, |
| meta_t_cutlass, |
| compressed_swizzled_bitmask) |
| torch.testing.assert_close(reference_cutlass.to_dense(), cutlass.to_dense()) |
| |
| # CUSPARSELT |
| reference_cusparselt = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(pruned, |
| algorithm="largest_abs_values_greedy") |
| torch.testing.assert_close(pruned, reference_cusparselt.to_dense()) |
| |
| packed_cusparselt = torch._cslt_compress(pruned) |
| packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous()) |
| cusparselt = SparseSemiStructuredTensorCUSPARSELT(dense.shape, |
| packed_cusparselt, |
| None, |
| packed_t_cusparselt, |
| None, |
| compressed_swizzled_bitmask) |
| torch.testing.assert_close(reference_cusparselt.to_dense(), cusparselt.to_dense()) |
| |
| |
| |
| @training_dtypes |
| @parametrize_backends |
| def test_pruning_algo_largest_abs_values_greedy(self, dtype, backend) -> None: |
| inp = torch.tensor( |
| [[4, 3, 2, 1], [-1, -3, 0.6, 0.5], [1, 2, 3, 4], [10, 2, -1, 5]], |
| device="cuda", |
| dtype=dtype, |
| ) |
| inp = F.pad(inp, (0, 128 - 4, 0, 128 - 4), "constant", 1) |
| sInp = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].prune_dense_static_sort(inp, algorithm="largest_abs_values_greedy") |
| |
| mask = sInp.to_dense() / inp |
| assert mask[:4, :4].int().tolist() == [ |
| [1, 1, 0, 0], |
| [0, 1, 1, 0], |
| [0, 0, 1, 1], |
| [1, 0, 0, 1], |
| ] |
| |
| @training_dtypes |
| def test_gemm(self, dtype) -> None: |
| M, N, K = 32, 32, 64 |
| a = torch.randn([M, K], device="cuda", dtype=dtype) |
| b = torch.randn([K, N], device="cuda", dtype=dtype) |
| mask = rand_sparse_semi_structured_mask(M, K, dtype=torch.bool) |
| |
| a.masked_fill_(~mask, 0) |
| |
| a_sparse = to_sparse_semi_structured(a) |
| |
| masked_a = a * mask |
| ref_out = masked_a @ b |
| sp24_out = a_sparse @ b |
| torch.testing.assert_close(ref_out, sp24_out, **atol_rtol_kw[dtype]) |
| |
| |
| @training_dtypes |
| @parametrize_backends |
| def test_pack_both_ways_meta_correctness(self, dtype, backend) -> None: |
| M, N = 128, 256 |
| # Construct x to make sure we always have exactly 8 elements per 4x4 tile |
| a = (4 * torch.arange(8))[:, None] + torch.arange(8)[None, :] |
| a = a.repeat(M // 8, N // 8) |
| assert a.shape == (M, N) |
| a = a.cuda().to(dtype) |
| b = torch.randn([a.shape[1], 128], device="cuda", dtype=dtype) |
| |
| a_sparse = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].prune_dense_static_sort(a) |
| |
| mask_dense = sparse24_largest_mask_2d(a).to(dtype) |
| |
| if backend == "cutlass": |
| assert isinstance(a_sparse, SparseSemiStructuredTensorCUTLASS) |
| (packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile( |
| mask_dense, use_cutlass=True) |
| |
| sparse_mask = SparseSemiStructuredTensorCUTLASS( |
| mask_dense.shape, |
| packed=packed, |
| meta=meta, |
| packed_t=packed_t, |
| meta_t=meta_t, |
| compressed_swizzled_bitmask=bitmask, |
| ) |
| torch.testing.assert_close(a_sparse.meta.view(torch.short), sparse_mask.meta) |
| |
| ref_gemm = (mask_dense * a) @ b |
| pack_gemm = a_sparse @ b |
| torch.testing.assert_close(ref_gemm, pack_gemm, **atol_rtol_kw[dtype]) |
| |
| @training_dtypes |
| def test_pack_both_ways_id(self, dtype) -> None: |
| N = 512 |
| torch.manual_seed(0) |
| a = torch.randn([N, N], dtype=dtype, device="cuda") |
| b = torch.eye(N, dtype=dtype, device="cuda") |
| |
| packed, meta, packed_t, meta_t = torch._sparse_semi_structured_tile(a)[ |
| :4 |
| ] |
| # Heuristic to ensure we pack the same values |
| torch.testing.assert_close( |
| packed.to(torch.float64).sum(), packed_t.to(torch.float64).sum() |
| ) |
| |
| mask_dense = sparse24_largest_mask_2d(a.to(dtype)) |
| |
| ref_gemm = mask_dense * a |
| # Test A@B |
| pack_gemm = torch._sparse_semi_structured_linear(b.t(), packed, meta).t() |
| max_diff = (ref_gemm - pack_gemm).abs().argmax() |
| torch.testing.assert_close( |
| ref_gemm, pack_gemm, |
| **atol_rtol_kw[dtype] |
| ), f"packed is wrong at pos: ({max_diff // N}, {max_diff % N})" |
| # Test A.t@B |
| pack_gemm = torch._sparse_semi_structured_linear(b.t(), packed_t, meta_t) |
| max_diff = (ref_gemm - pack_gemm).abs().argmax() |
| |
| torch.testing.assert_close( |
| ref_gemm, pack_gemm, |
| **atol_rtol_kw[dtype] |
| ), f"packed_t is wrong at pos: ({max_diff // N}, {max_diff % N})" |
| |
| @training_dtypes |
| def test_pack_both_ways_edge_case1(self, dtype) -> None: |
| # In this case, the heuristic will keep 7 values out of 16 |
| # instead of 8. let's see how the kernel handles this |
| quad = torch.tensor( |
| [ |
| [2, -1, -2, -3], # Should be packed as `2 <null>` |
| [-1, 8, -1, 6], |
| [-1, -1, 4, 5], |
| [-1, 3, 7, -1], |
| ], |
| dtype=dtype, |
| device="cuda", |
| ) |
| a = torch.randn([32, 64], dtype=dtype, device="cuda") |
| a[:4, :4] = quad |
| packed, meta, packed_t, meta_t = torch._sparse_semi_structured_tile(a)[:4] |
| # Check first line in A |
| assert packed[0, 0].item() == 2 |
| assert packed[0, 1].item() == 0 |
| # And first column in A.t |
| assert packed_t[0, 0].item() == 2 |
| assert packed_t[0, 1].item() == 0 |
| |
| @training_dtypes |
| def test_sp24_apply(self, dtype) -> None: |
| M, N = 256, 1024 |
| x = torch.randn([M, N], dtype=dtype, device="cuda") |
| ( |
| packed, |
| meta, |
| packed_t, |
| meta_t, |
| bitmask, |
| ) = torch._sparse_semi_structured_tile(x) |
| packed2, packed_t2 = torch._sparse_semi_structured_apply(x, bitmask) |
| torch.testing.assert_close(packed, packed2) |
| torch.testing.assert_close(packed_t, packed_t2) |
| |
| @training_dtypes |
| def test_sp24_apply_dense(self, dtype) -> None: |
| M, N = 256, 1024 |
| x = torch.randn([M, N], dtype=dtype, device="cuda") |
| ( |
| packed, |
| meta, |
| packed_t, |
| meta_t, |
| bitmask, |
| ) = torch._sparse_semi_structured_tile(x) |
| |
| expected = SparseSemiStructuredTensorCUTLASS( |
| x.shape, |
| packed=packed, |
| meta=meta, |
| packed_t=packed_t, |
| meta_t=meta_t, |
| compressed_swizzled_bitmask=bitmask, |
| ).to_dense() |
| |
| packed2, packed_t2 = torch._sparse_semi_structured_apply(x, bitmask) |
| sparse = SparseSemiStructuredTensorCUTLASS( |
| x.shape, |
| packed=packed2, |
| meta=meta, |
| packed_t=packed_t2, |
| meta_t=meta_t, |
| compressed_swizzled_bitmask=bitmask, |
| ) |
| |
| dense = torch._sparse_semi_structured_apply_dense(x, bitmask) |
| |
| torch.testing.assert_close(dense, expected) |
| torch.testing.assert_close(sparse.to_dense(), expected) |
| |
| |
| @training_dtypes |
| def test_sp24_matmuls(self, dtype) -> None: |
| M, N, K = 64, 256, 1024 |
| a = torch.randn([M, K], device="cuda", dtype=dtype) |
| b = torch.randn([K, N], device="cuda", dtype=dtype) |
| a_m = sparse24_largest_mask_2d(a) |
| b_m = sparse24_largest_mask_2d(b) |
| (packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile(a) |
| a_s = SparseSemiStructuredTensorCUTLASS( |
| a.shape, |
| packed=packed, |
| meta=meta, |
| packed_t=packed_t, |
| meta_t=meta_t, |
| compressed_swizzled_bitmask=bitmask, |
| ) |
| (packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile(b) |
| b_s = SparseSemiStructuredTensorCUTLASS( |
| b.shape, |
| packed=packed, |
| meta=meta, |
| packed_t=packed_t, |
| meta_t=meta_t, |
| compressed_swizzled_bitmask=bitmask, |
| ) |
| |
| torch.testing.assert_close(a_s @ b, (a * a_m) @ b, rtol=1e-1, atol=1.5e-1) |
| torch.testing.assert_close(a @ b_s, a @ (b * b_m), rtol=1e-1, atol=1.5e-1) |
| torch.testing.assert_close( |
| a @ a_s.t(), a @ (a * a_m).t(), rtol=1e-1, atol=1.5e-1 |
| ) |
| torch.testing.assert_close( |
| a_s.t() @ a, (a * a_m).t() @ a, rtol=1e-1, atol=1e-1 |
| ) |
| |
| def test_sp24_matmuls_mat_vec(self) -> None: |
| a = torch.randn([64, 128], device="cuda", dtype=torch.float16) |
| b = torch.randn([128], device="cuda", dtype=torch.float16) |
| a_m = sparse24_largest_mask_2d(a) |
| a_s = to_sparse_semi_structured(a) |
| |
| with pytest.raises(NotImplementedError): |
| torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype]) |
| |
| |
| def test_sp24_matmuls_bmm(self) -> None: |
| a = torch.randn([64, 128], device="cuda", dtype=torch.float16) |
| b = torch.randn([5, 6, 128], device="cuda", dtype=torch.float16) |
| a_m = sparse24_largest_mask_2d(a) |
| a_s = to_sparse_semi_structured(a) |
| |
| with pytest.raises(NotImplementedError): |
| torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype]) |
| |
| class TestSparseSemiStructuredCUTLASS(TestCase): |
| """ |
| This contains CUTLASS specific tests for |
| - torch._sparse_semi_structured_linear |
| """ |
| def setUp(self): |
| if "cutlass" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS: |
| self.skipTest('CUTLASS not enabled') |
| |
| @unittest.skipIf(TEST_WITH_ROCM or IS_WINDOWS, "ROCm and Windows doesn't support CUTLASS") |
| @inference_dtypes |
| def test_linear_cutlass(self, device, dtype): |
| |
| def run_test(batch_shape, m, n, k, device, dtype, dtype_out, add_bias, activation, rtol, atol): |
| weight = rand_sparse_semi_structured(m, k, dtype, device) |
| input = make_tensor((*batch_shape, n, k), dtype=dtype, device=device) |
| bias = make_tensor((m,), dtype=dtype_out, device=device) if add_bias else None |
| |
| dtype_dense = torch.float32 |
| input_dense = input.to(dtype_dense) |
| weight_dense = weight.to(dtype_dense) |
| bias_dense = bias.to(dtype_dense) if add_bias else None |
| output0 = torch.nn.functional.linear(input_dense, weight_dense, bias=bias_dense) |
| if activation == "relu": |
| relu = torch.nn.ReLU() |
| output0 = relu(output0) |
| elif activation == "silu": |
| silu = torch.nn.SiLU() |
| output0 = silu(output0) |
| |
| compressed = to_sparse_semi_structured(weight) |
| |
| weight_sparse = compressed.values() |
| meta = compressed.indices() |
| |
| output1 = torch._sparse_semi_structured_linear(input, weight_sparse, meta, bias=bias, activation=activation, |
| out_dtype=dtype_out if dtype == torch.int8 else None) |
| torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol) |
| |
| if dtype == torch.float32: |
| # Inputs are converted to TF32 internally for sparse GEMM, |
| # so make dense GEMM to do the same for matching results. |
| orig = torch.backends.cuda.matmul.allow_tf32 |
| torch.backends.cuda.matmul.allow_tf32 = True |
| |
| batch_shapes = [[], [3], [3, 1]] |
| dtype_out = {torch.int8: torch.int32, torch.half: torch.half, torch.bfloat16: torch.bfloat16, torch.float32: torch.float32} |
| activations = [None, "relu", "silu"] |
| rtol, atol = 1e-3, 1e-3 |
| if dtype == torch.bfloat16: |
| rtol, atol = 5e-3, 5e-3 |
| elif dtype == torch.float32: |
| rtol, atol = 1e-3, 75e-2 |
| for batch_shape, m, n, k, add_bias, activation in \ |
| itertools.product(batch_shapes, range(3), range(3), range(3), (False, True), activations): |
| if activation == "silu" and dtype == torch.int8: |
| continue # SiLU not supported for integer inputs |
| |
| m = 2 ** m * 32 |
| n = 2 ** n * 32 |
| k = 2 ** k * 128 |
| run_test(batch_shape, m, n, k, device, dtype, dtype_out[dtype], add_bias, activation, rtol, atol) |
| |
| if dtype == torch.float32: |
| torch.backends.cuda.matmul.allow_tf32 = orig |
| |
| |
| @unittest.skipIf(TEST_WITH_ROCM or IS_WINDOWS, "ROCm and Windows doesn't support CUTLASS") |
| @parametrize("backend", ["cutlass"]) |
| @inference_dtypes |
| def test_sparse_semi_structured_ops_cutlass(self, device, dtype, backend): |
| SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") |
| if backend == "cutlass" and IS_WINDOWS: |
| self.skipTest("CUTLASS not supported on Windows") |
| |
| def run_test(m, n, k, device, dtype, dtype_out, use_input, rtol, atol): |
| mat1 = rand_sparse_semi_structured(m, k, dtype, device) |
| # mat2 transposed as int8 case supports only row-major/column-major combination |
| mat2 = make_tensor((n, k), dtype=dtype, device=device).t() |
| input = make_tensor((m,), dtype=dtype_out, device=device) if use_input else None |
| |
| if use_input: |
| if dtype.is_floating_point: |
| alpha = 1.3 |
| beta = -0.7 |
| else: |
| alpha = 2 |
| beta = -3 |
| |
| dtype_dense = torch.float32 |
| mat1_dense = mat1.to(dtype_dense) |
| mat2_dense = mat2.to(dtype_dense) |
| if not use_input: |
| output0 = torch.mm(mat1_dense, mat2_dense) |
| else: |
| input_dense = input.to(dtype_dense)[:, None] |
| output0 = torch.addmm(input_dense, mat1_dense, mat2_dense, alpha=alpha, beta=beta) |
| |
| compressed = to_sparse_semi_structured(mat1) |
| |
| mat1_sparse = compressed.values() |
| mat1_meta = compressed.indices() |
| |
| if not use_input: |
| output1 = torch._sparse_semi_structured_mm(mat1_sparse, mat1_meta, mat2, out_dtype=dtype_out) |
| else: |
| output1 = torch._sparse_semi_structured_addmm( |
| input, mat1_sparse, mat1_meta, mat2, alpha=alpha, beta=beta, out_dtype=dtype_out |
| ) |
| torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol) |
| |
| if dtype == torch.float32: |
| # Inputs are converted to TF32 internally for sparse GEMM, |
| # so make dense GEMM to do the same for matching results. |
| orig = torch.backends.cuda.matmul.allow_tf32 |
| torch.backends.cuda.matmul.allow_tf32 = True |
| |
| dtype_out = {torch.int8: torch.int32, torch.half: torch.half, torch.bfloat16: torch.bfloat16, torch.float32: torch.float32} |
| rtol, atol = 1e-3, 1e-3 |
| if dtype == torch.bfloat16: |
| rtol, atol = 5e-3, 5e-3 |
| elif dtype == torch.float32: |
| rtol, atol = 1e-3, 75e-2 |
| for m, n, k, use_input in \ |
| itertools.product(range(3), range(3), range(3), (False, True)): |
| m = 2 ** m * 32 |
| n = 2 ** n * 32 |
| k = 2 ** k * 128 |
| run_test(m, n, k, device, dtype, dtype_out[dtype], use_input, rtol, atol) |
| |
| if dtype == torch.float32: |
| torch.backends.cuda.matmul.allow_tf32 = orig |
| |
| |
| @unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch") |
| @inference_dtypes |
| def test_conversions(self, device, dtype): |
| |
| def run_test(r, c, device, dtype): |
| dense_ref = rand_sparse_semi_structured(r, c, dtype, device) |
| |
| compressed = to_sparse_semi_structured(dense_ref) |
| |
| # The torch.ops.aten._to_sparse_semi_structured operator |
| # uses CUTLASS to perform conversion from given dense |
| # matrix to the pair of corresponding sparse and metadata |
| # matrices, with the later used here as a reference to |
| # compare the metadata matrix produced by conversion |
| # performed by SparseSemiStructuredTensor class |
| # constructor against. |
| _, meta_ref = torch.ops.aten._to_sparse_semi_structured(dense_ref) |
| |
| meta = compressed.indices() |
| torch.testing.assert_close(meta, meta_ref, rtol=0, atol=0) |
| |
| dense = compressed.to_dense() |
| torch.testing.assert_close(dense, dense_ref, rtol=0, atol=0) |
| |
| shapes = [[32, 128], [32, 256], [64, 128], [64, 256]] |
| for r, c in shapes: |
| run_test(r, c, device, dtype) |
| |
| @unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch") |
| @inference_dtypes |
| def test_conversions_all_patterns(self, device, dtype): |
| r, c = 32, 128 |
| |
| dense_inv, dense_val = rand_sparse_semi_structured_all_patterns(r, c, dtype, device) |
| |
| compressed = to_sparse_semi_structured(dense_inv) |
| dense = compressed.to_dense() |
| |
| torch.testing.assert_close(dense, dense_val, rtol=0, atol=0) |
| |
| |
| |
| CUSPARSELT_NUM_ALG_IDS = 4 |
| CUSPARSELT_MIXED_DTYPE_SUPPORT = [torch.float16, torch.bfloat16, torch.int32] |
| |
| |
| class TestSparseSemiStructuredCUSPARSELT(TestCase): |
| """ |
| This contains cuSPARSELt specific tests for |
| torch._cslt_compress |
| torch._cslt_sparse_mm |
| """ |
| def setUp(self): |
| if "cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS: |
| self.skipTest('cuSPARSELt not enabled') |
| |
| @parametrize("out_dtype", CUSPARSELT_MIXED_DTYPE_SUPPORT) |
| @parametrize("dense_input_shape", [(128, 128)]) |
| def test_cslt_sparse_mm_mixed_dtype(self, dense_input_shape, out_dtype, device): |
| A = rand_sparse_semi_structured_mask(128, 128, dtype=torch.int8) |
| A_compressed = torch._cslt_compress(A) |
| |
| B = torch.rand(dense_input_shape, device=device).to(torch.int8) |
| |
| dense_result = torch.mm(A.cpu().to(torch.int64), B.t().cpu().to(torch.int64)).to(device, dtype=out_dtype) |
| sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), out_dtype=out_dtype) |
| torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) |
| |
| @unittest.skip("cuSPARSELt v0.6.x does not support bfloat/float16 alpha scaling") |
| @training_dtypes |
| def test_cslt_sparse_mm_alpha(self, dtype, device): |
| A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(dtype).cuda() |
| B = torch.ones((256, 128), device=device).to(dtype) |
| alpha = torch.Tensor([2**(-i) for i in range(128)]).cuda() |
| bias = torch.ones(128, device=device).to(dtype) |
| |
| A_compressed = torch._cslt_compress(A) |
| sparse_result = torch._cslt_sparse_mm(A_compressed, B, alpha=alpha, bias=bias) |
| |
| alpha_scaled = torch.stack([alpha] * 128).t() |
| dense_result = alpha_scaled * torch.mm(A.to(torch.float32), B.to(torch.float32)) |
| dense_result = dense_result.to(dtype) |
| |
| torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3) |
| |
| @parametrize("out_dtype", CUSPARSELT_MIXED_DTYPE_SUPPORT) |
| def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device): |
| A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda() |
| B = torch.ones((128, 256), device=device).to(torch.int8).t() |
| alpha = torch.Tensor([2**(-i) if out_dtype is not torch.int32 else 1 |
| for i in range(128)]).cuda() |
| |
| A_compressed = torch._cslt_compress(A) |
| sparse_result = torch._cslt_sparse_mm(A_compressed, B, alpha=alpha, out_dtype=out_dtype).cpu() |
| |
| alpha_scaled = torch.stack([alpha] * 128).t() |
| dense_result = alpha_scaled.cpu() * torch.mm(A.to(torch.int64).cpu(), B.to(torch.int64).cpu()) |
| dense_result = dense_result.to(out_dtype) |
| |
| torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3) |
| |
| @parametrize("alg_id", range(CUSPARSELT_NUM_ALG_IDS)) |
| @inference_dtypes |
| def test_cslt_sparse_mm_alg_id(self, device, dtype, alg_id): |
| # alg_id=3 not supported for float32 dtype |
| if dtype == torch.float32 and alg_id == 3: |
| return |
| A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) |
| A_compressed = torch._cslt_compress(A) |
| B = torch.ones((128, 128), device=device).to(dtype) |
| |
| A_compressed = torch._cslt_compress(A) |
| sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), alg_id=alg_id) |
| |
| dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32)) |
| dense_result = dense_result.to(dtype) |
| |
| torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3) |
| |
| @inference_dtypes |
| def test_cslt_sparse_mm_search(self, device, dtype): |
| A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype) |
| A_compressed = torch._cslt_compress(A) |
| B = torch.ones((128, 128), device=device).to(dtype) |
| |
| A_compressed = torch._cslt_compress(A) |
| alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t()) |
| # for cuSPARSELt v0.4.0 there is a bug where although there are 5 alg_ids, we run into an error |
| # when setting using the last one (4) |
| # in cuSPARSELt v0.5.0 there are only 4 alg_ids total, so we should remove the +1 here when we update. |
| # TODO Move this into the cuSPARSELt backendk |
| assert alg_id in range(CUSPARSELT_NUM_ALG_IDS + 1) |
| |
| def test_cusparselt_backend(self): |
| version = _get_torch_cuda_version() |
| assert torch.backends.cusparselt.is_available() |
| |
| # CUDA 11.8 has cuSPARSELt v0.4.0 support |
| if version == (11, 8): |
| assert torch.backends.cusparselt.version() == 400 |
| # CUDA 12.1 has cuSPARSELt v0.5.2 support |
| elif version == (12, 1): |
| assert torch.backends.cusparselt.version() == 502 |
| # CUDA 12.4+ has cuSPARSELt v0.6.2 support |
| elif version >= (12, 4): |
| assert torch.backends.cusparselt.version() == 602 |
| else: |
| assert torch.backends.cusparselt.version() is None |
| |
| if len(SEMI_STRUCTURED_SUPPORTED_BACKENDS) > 0: |
| instantiate_device_type_tests(TestSparseSemiStructured, globals(), only_for="cuda") |
| if "cutlass" in SEMI_STRUCTURED_SUPPORTED_BACKENDS: |
| instantiate_device_type_tests(TestSparseSemiStructuredCUTLASS, globals(), only_for="cuda") |
| instantiate_device_type_tests(TestSparseSemiStructuredTraining, globals(), only_for="cuda") |
| if "cusparselt" in SEMI_STRUCTURED_SUPPORTED_BACKENDS: |
| instantiate_device_type_tests(TestSparseSemiStructuredCUSPARSELT, globals(), only_for="cuda") |
| |
| if __name__ == "__main__": |
| run_tests() |