| import warnings |
| from collections import namedtuple |
| from typing import Any, Optional |
| |
| import torch |
| |
| __all__ = [ |
| "SparseSemiStructuredTensor", |
| "to_sparse_semi_structured", |
| ] |
| |
| _SEMI_STRUCTURED_SPARSE_CONFIG = namedtuple( |
| "_SEMI_STRUCTURED_SPARSE_CONFIG", "min_rows min_cols" |
| ) |
| _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG = { |
| torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 128), |
| torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64), |
| torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64), |
| # TODO enable float32 support when adding cuSPARSELt as a backend |
| # torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32) |
| } |
| |
| |
| class SparseSemiStructuredTensor(torch.Tensor): |
| """This class implementes semi-structured sparsity as a Tensor subclass. |
| |
| Semi-structured sparsity describes a sparsity pattern where n in every 2n elements are sparse, |
| depending on the datatype. It is also referred to as 2:4 sparsity or fine-grained |
| structured sparsity. |
| |
| Currently, this class supports 2:4 sparsity for int8, float16 and bfloat16 dtypes. |
| We also support 1:2 sparsity for float32 dtype. |
| |
| This subclass stores the dense tensor in a compressed form by only storing the specified elements and corresponding metadata. |
| |
| The subclass supports two backend, either CUTLASS or cuSPASRELt. |
| |
| The cuSPARSELt backend expects the specified elements and the metadata to be stored in a single tensor: |
| |
| compressed tensor = [ specified elements of original tensor | metadata ] |
| |
| For an original tensor of size (m, k) we expect the first m * k // 2 elements to be the kept elements |
| The rest of the tensor is metadata. |
| |
| For CUTLASS backend, elements of original tensor and metadata are kept in separate tensors. |
| |
| When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_linear |
| and sparse_semi_structured_from_dense for conversion to the compressed format. |
| |
| When PyTorch is compiled with cuSPARSELt support, this subclass will call into _cslt_sparse_mm for sparse mm and |
| _cslt_compress to convert into the compressed format. |
| """ |
| |
| _FUSE_TRANSPOSE = False |
| _FORCE_CUTLASS = True |
| _PROTOTYPE_WARNING_SHOWN = False |
| |
| @staticmethod |
| def __new__( |
| cls, |
| original_tensor: Optional[torch.Tensor], |
| original_shape: Optional[torch.Size] = None, |
| compressed_tensor_cusparselt: Optional[torch.Tensor] = None, |
| sparse_tensor_cutlass: Optional[torch.Tensor] = None, |
| meta_tensor_cutlass: Optional[torch.Tensor] = None, |
| transposed: bool = False, |
| ): |
| """ |
| Create a new instance of the class. |
| |
| When original_tensor is passed in, we compress it and store the compresed representation. |
| We can also create new instance of the class from the compressed representation without the original tensor. |
| |
| Args: |
| original_tensor: The original dense tensor, or None, if we have already compressed the tensor. |
| original_shape: The shape of the original dense tensor |
| compressed_tensor_cusparselt: For cuSPARSELt backend, a flattened tensor to store the specified elements and metadata. |
| sparse_tensor_cutlass: For CUTLASS backend, tensor to store the speficied elements. |
| meta_tensor_cutlass: For CUTLASS backend, tensor to store metadata. |
| transposed: Whether the tensor is transposed or not. |
| |
| Returns: |
| torch.Tensor: A torch.Tensor wrapper subclass. |
| |
| Raises: |
| ValueError: If all of the tensor arguments are None. |
| |
| """ |
| assert compressed_tensor_cusparselt is None or (sparse_tensor_cutlass is None and meta_tensor_cutlass is None) |
| |
| if not cls._PROTOTYPE_WARNING_SHOWN: |
| warnings.warn( |
| ( |
| "The PyTorch API of SparseSemiStructuredTensor is in prototype stage " |
| "and will change in the near future. Please open a Github issue " |
| "for features requests and see our documentation on the torch.sparse " |
| "module for further information about the project." |
| ), |
| UserWarning, |
| ) |
| cls._PROTOTYPE_WARNING_SHOWN = True |
| |
| if original_tensor is not None: |
| previous_tensor = original_tensor |
| original_shape = original_tensor.shape |
| elif compressed_tensor_cusparselt is not None: |
| previous_tensor = compressed_tensor_cusparselt |
| elif sparse_tensor_cutlass is not None: |
| previous_tensor = sparse_tensor_cutlass |
| else: |
| raise ValueError("All of the tensor arguments are None!") |
| |
| kwargs = {} |
| kwargs["device"] = previous_tensor.device # type: ignore[assignment] |
| kwargs["dtype"] = previous_tensor.dtype # type: ignore[assignment] |
| kwargs["layout"] = previous_tensor.layout # type: ignore[assignment] |
| kwargs["requires_grad"] = False # type: ignore[assignment] |
| |
| return torch.Tensor._make_wrapper_subclass(cls, original_shape, **kwargs) # type: ignore[attr-defined] |
| |
| @staticmethod |
| def __get_indices_dtype(values_dtype): |
| if values_dtype == torch.int8: |
| return torch.int32 |
| elif values_dtype in (torch.float16, torch.bfloat16): |
| return torch.int16 |
| else: |
| raise RuntimeError(f"Datatype {values_dtype} is not supported!") |
| return None |
| |
| def __init__( |
| self, |
| original_tensor: Optional[torch.Tensor], |
| original_shape: Optional[torch.Size] = None, |
| compressed_tensor_cusparselt: Optional[torch.Tensor] = None, |
| sparse_tensor_cutlass: Optional[torch.Tensor] = None, |
| meta_tensor_cutlass: Optional[torch.Tensor] = None, |
| transposed: bool = False, |
| ) -> None: |
| """SparseSemiStructuredTensor constructor. |
| |
| Args: |
| original_tensor: The original dense tensor, or None, if we have already compressed the tensor. |
| original_shape: The shape of the original dense tensor |
| compressed_tensor_cusparselt: For cuSPARSELt backend, a flattened tensor to store the specified elements and metadata. |
| sparse_tensor_cutlass: For CUTLASS backend, tensor to store the speficied elements. |
| meta_tensor_cutlass: For CUTLASS backend, tensor to store metadata. |
| transposed: Whether the tensor is transposed or not. |
| |
| Returns: |
| None |
| |
| Raises: |
| RuntimeError: If original_tensor is not a supported dtype, dim, shape, or device. |
| """ |
| # if original tensor is passed in, we need to compress it and store the compressed representation. |
| if original_tensor is not None: |
| # TODO right now we have unified checks and constraints for cuSPARSELt and CUTLASS, these are not actually the same. |
| # We should consolidate similar checks here and leave backend specific checks like shape in the op implementation. |
| |
| # check device |
| if not original_tensor.is_cuda: |
| raise RuntimeError( |
| f"Error original_tensor.device= {original_tensor.device} is not supported! " |
| "Only CUDA tensors are currently supported." |
| ) |
| |
| # check dim |
| if original_tensor.dim() != 2: |
| raise RuntimeError( |
| f"Error original_tensor.dim = {original_tensor.dim()} is not supported! " |
| "Only 2d tensors are currently supported." |
| ) |
| |
| # check dtype |
| if original_tensor.dtype not in _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG: |
| raise RuntimeError( |
| f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype! " |
| "dtype must be one of: {_DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG}" |
| ) |
| |
| # check shape |
| m, n = original_tensor.shape |
| min_rows = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[ |
| original_tensor.dtype |
| ].min_rows |
| min_cols = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[ |
| original_tensor.dtype |
| ].min_cols |
| if m < min_rows or m % min_rows or n < min_cols or n % min_cols: |
| # TODO in the future we can add in padding to support dimensions that aren't perfect multiples |
| raise RuntimeError( |
| f"Error original_tensor.shape {original_tensor.shape} is not supported! " |
| f"Both dimensions must be larger or equal than and a multiple of ({min_rows}, {min_cols})" |
| ) |
| |
| compressed_tensor_cusparselt = None |
| sparse_tensor_cutlass = None |
| meta_tensor_cutlass = None |
| if self._FORCE_CUTLASS: |
| from torch.sparse._semi_structured_conversions import ( |
| sparse_semi_structured_from_dense_cutlass, |
| ) |
| |
| sparse_tensor_cutlass, meta_tensor_cutlass = sparse_semi_structured_from_dense_cutlass(original_tensor) |
| else: |
| # use cuSPARSELt |
| compressed_tensor_cusparselt = torch._cslt_compress(original_tensor) |
| |
| # set values |
| self.original_tensor = None |
| self.compressed_tensor_cusparselt = compressed_tensor_cusparselt |
| self.sparse_tensor_cutlass = sparse_tensor_cutlass |
| self.meta_tensor_cutlass = meta_tensor_cutlass |
| self.transposed = transposed |
| self.original_shape = original_shape |
| |
| def __tensor_flatten__(self): |
| if self.compressed_tensor_cusparselt is not None: |
| return ['compressed_tensor_cusparselt'], (self.original_shape, self.transposed) |
| else: |
| return ['sparse_tensor_cutlass', 'meta_tensor_cutlass'], (self.original_shape, self.transposed) |
| |
| @staticmethod |
| def __tensor_unflatten__(inner_tensors, meta): |
| original_shape, transposed = meta |
| |
| if len(inner_tensors) == 2: |
| sparse_tensor_cutlass = inner_tensors['sparse_tensor_cutlass'] |
| meta_tensor_cutlass = inner_tensors['meta_tensor_cutlass'] |
| compressed_tensor_cusparselt = None |
| elif len(inner_tensors) == 1: |
| sparse_tensor_cutlass = None |
| meta_tensor_cutlass = None |
| compressed_tensor_cusparselt = inner_tensors['compressed_tensor_cusparselt'] |
| else: |
| raise RuntimeError(f"Expected 1 or 2 inner tensors but got {len(inner_tensors)}") |
| |
| return SparseSemiStructuredTensor( |
| None, |
| original_shape=original_shape, |
| compressed_tensor_cusparselt=compressed_tensor_cusparselt, |
| sparse_tensor_cutlass=sparse_tensor_cutlass, |
| meta_tensor_cutlass=meta_tensor_cutlass, |
| transposed=transposed, |
| ) |
| |
| def __repr__(self) -> str: # type: ignore[override] |
| """Return string representation of SparseSemiStructuredTensor |
| |
| Returns: |
| str: String representation |
| |
| Raises: |
| None |
| """ |
| return ( |
| f"SparseSemiStructuredTensor(shape={self.shape}, " |
| f"transposed={self.transposed}" |
| f"values={self.values()}" |
| f"metadata={self.indices()})" |
| ) |
| |
| __torch_function__ = torch._C._disabled_torch_function_impl |
| |
| def _pad_tensor_for_matmul(self, original_tensor : torch.Tensor) -> torch.Tensor: |
| """ |
| Calculates padding for dense tensor and pads tensor if necessary. |
| If padding is not required, this function returns the original tensor. |
| """ |
| # only 2d matmul |
| assert original_tensor.dim() == 2 |
| |
| # check shape |
| m, n = original_tensor.shape |
| min_rows = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[original_tensor.dtype].min_rows |
| min_cols = _DTYPE_TO_SEMI_STRUCTURED_SPARSE_CONFIG[original_tensor.dtype].min_cols |
| to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0 |
| to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0 |
| if to_pad_m or to_pad_n: |
| return torch.nn.functional.pad(original_tensor, (0, to_pad_n, 0, to_pad_m)) |
| else: |
| return original_tensor |
| |
| |
| @classmethod |
| def __torch_dispatch__(cls, func, types, args, kwargs) -> Any: |
| """Overload __torch_dispatch__ to use torch._sparse_semi_structured_linear. |
| |
| `torch.structured_sparse_linear` uses accelerated sparse CUTLASS kernels. |
| In the future we plan to also add in support for cuSPARSELt kernels. |
| |
| Args: |
| func: The function being dispatched. |
| types: The types of the arguments. |
| args: The arguments passed to the function. |
| kwargs: The keyword arguments passed to the function. |
| |
| Returns: |
| Any: The result of the dispatched operation. |
| |
| Raises: |
| NotImplementedError: If the dispatched operation is not implemented. |
| """ |
| # Since this code runs below autograd, a detach corresponds to only returning a new object |
| if func is torch.ops.aten.detach.default: |
| return SparseSemiStructuredTensor( |
| args[0].original_tensor, |
| original_shape=args[0].shape, |
| compressed_tensor_cusparselt=args[0].compressed_tensor_cusparselt, |
| sparse_tensor_cutlass=args[0].sparse_tensor_cutlass, |
| meta_tensor_cutlass=args[0].meta_tensor_cutlass, |
| transposed=args[0].transposed, |
| ) |
| |
| # Because we cannot go from the compressed representation back to the dense representation currently, |
| # we just keep track of how many times we have been transposed. Depending on whether the sparse matrix |
| # is the first or second argument, we expect an even / odd number of calls to transpose respectively. |
| if func is torch.ops.aten.t.default: |
| return SparseSemiStructuredTensor( |
| args[0].original_tensor, |
| # transpose shape |
| original_shape=torch.Size([args[0].shape[1], args[0].shape[0]]), |
| compressed_tensor_cusparselt=args[0].compressed_tensor_cusparselt, |
| sparse_tensor_cutlass=args[0].sparse_tensor_cutlass, |
| meta_tensor_cutlass=args[0].meta_tensor_cutlass, |
| transposed=not args[0].transposed, |
| ) |
| |
| # handle addmm |
| if func is torch.ops.aten.addmm.default: |
| bias, input_A, input_B = args |
| |
| # Currently, we only support the first matrix being sparse for addmm/mm in cuSPARSELT and CUTLASS. |
| # CUTLASS only supports the first input to be sparse for a given matmul. |
| # cuSPARSELt does not have this limitation, although our implementation is only for sparse first. |
| |
| # We support second matrix sparse matmul by taking advantage of some transpose properties: |
| # This is also why we want an odd number of transposed for second matrix sparse vs an even number |
| # of transpose calss for first matrix sparse. |
| # F.linear(x) = addmm(bias, input, weight.t()) = b + xW' = (b + xW')'' |
| # = (W''x' + b')' = (Wx' + b')' = addmm(bias.T, weight, input).T |
| if isinstance(input_B, cls) and input_B.transposed: |
| row, col = input_A.shape |
| input_A_padded = input_B._pad_tensor_for_matmul(input_A) |
| if input_B.compressed_tensor_cusparselt is None: |
| assert input_B.sparse_tensor_cutlass is not None and input_B.meta_tensor_cutlass is not None |
| res = torch._sparse_semi_structured_linear( |
| input_A_padded, input_B.sparse_tensor_cutlass, input_B.meta_tensor_cutlass, bias=bias |
| ) |
| else: |
| res = torch._cslt_sparse_mm( |
| input_B.compressed_tensor_cusparselt, input_A_padded.t(), bias # type: ignore[arg-type] |
| ).t() |
| return res[:row, :] |
| |
| # handle mm |
| if func is torch.ops.aten.mm.default: |
| input_A, input_B = args |
| |
| # first element sparse |
| if isinstance(input_A, cls) and not input_A.transposed: |
| row, col = input_B.shape |
| input_B_padded = input_A._pad_tensor_for_matmul(input_B) |
| if input_A.compressed_tensor_cusparselt is None: |
| assert input_A.sparse_tensor_cutlass is not None and input_A.meta_tensor_cutlass is not None |
| res = torch._sparse_semi_structured_linear( |
| input_B_padded.t(), input_A.sparse_tensor_cutlass, input_A.meta_tensor_cutlass |
| ).t() |
| else: |
| res = torch._cslt_sparse_mm( |
| input_A.compressed_tensor_cusparselt, input_B_padded, None # type: ignore[arg-type] |
| ) |
| return res[:, :col] |
| |
| # second element sparse |
| elif isinstance(input_B, cls) and input_B.transposed: |
| row, col = input_A.shape |
| input_A_padded = input_B._pad_tensor_for_matmul(input_A) |
| |
| if input_B.compressed_tensor_cusparselt is None: |
| assert input_B.sparse_tensor_cutlass is not None and input_B.meta_tensor_cutlass is not None |
| res = torch._sparse_semi_structured_linear( |
| input_A_padded, input_B.sparse_tensor_cutlass, input_B.meta_tensor_cutlass |
| ) |
| else: |
| res = torch._cslt_sparse_mm(input_B.compressed_tensor_cusparselt, input_A_padded.t(), None).t() # type: ignore[arg-type] |
| |
| return res[:row, :] |
| |
| # When torch is run with inference mode, pytorch does not decompose torch.ops.aten.linear into a .t() and addmm(), |
| # so we must match the aten.linear op. In this case, we need to explicitly handle collapsing to 2d matmul |
| # TODO see if there's a way to force pytorch to decompose the op so we don't have to handle this here. |
| if func is torch.ops.aten.linear.default: |
| input_tensor, weight, bias = args |
| shape = input_tensor.shape |
| |
| input_tensor_2d = input_tensor.view(-1, shape[-1]) |
| row, col = input_tensor_2d.shape |
| # this is a noop if already padded |
| input_tensor_2d_padded = weight._pad_tensor_for_matmul(input_tensor_2d) |
| |
| if isinstance(weight, cls): |
| if weight.compressed_tensor_cusparselt is None: |
| assert weight.sparse_tensor_cutlass is not None and weight.meta_tensor_cutlass is not None |
| res = torch._sparse_semi_structured_linear( |
| input_tensor_2d_padded, |
| weight.sparse_tensor_cutlass, |
| weight.meta_tensor_cutlass, |
| bias=bias |
| ) |
| else: |
| res = torch._cslt_sparse_mm( |
| weight.compressed_tensor_cusparselt, # type: ignore[arg-type] |
| input_tensor_2d_padded.t(), |
| bias |
| ).t() |
| return res[:row, :].view(*shape[:-1], -1) |
| |
| |
| # handle values |
| if func is torch.ops.aten.values.default: |
| if args[0].compressed_tensor_cusparselt is None: |
| return args[0].sparse_tensor_cutlass.detach() |
| else: |
| m, k = args[0].shape |
| num_kept_elements = m * k // 2 |
| return args[0].compressed_tensor_cusparselt[:num_kept_elements].view(m, k // 2) |
| |
| # handle indices |
| if func is torch.ops.aten.indices.default: |
| if args[0].compressed_tensor_cusparselt is None: |
| return args[0].meta_tensor_cutlass |
| else: |
| m, k = args[0].shape |
| num_kept_elements = m * k // 2 |
| metadata = args[0].compressed_tensor_cusparselt[num_kept_elements:].view(m, -1) |
| indices_dtype = SparseSemiStructuredTensor.__get_indices_dtype( |
| args[0].dtype |
| ) |
| return metadata.view(indices_dtype) |
| |
| error_string = "\n".join( |
| [f"func {func} with args: "] |
| + [f"arg{i}: {arg}" for i, arg in enumerate(args)] |
| ) |
| raise NotImplementedError(error_string) |
| |
| |
| def to_dense(self): |
| if self.compressed_tensor_cusparselt is not None: |
| raise RuntimeError("Converting to dense is not yet supported by cuSPARSELt backend!") |
| |
| from torch.sparse._semi_structured_conversions import ( |
| sparse_semi_structured_to_dense_cutlass, |
| ) |
| |
| return sparse_semi_structured_to_dense_cutlass( |
| self.sparse_tensor_cutlass, |
| self.meta_tensor_cutlass, |
| ) |
| |
| |
| def to_sparse_semi_structured( |
| original_tensor: torch.Tensor, |
| transposed: bool = False, |
| ) -> SparseSemiStructuredTensor: |
| """ |
| This function converts a dense tensor into a sparse semi-structured tensor. |
| It will return a SparseSemiStructuredTensor, a subclass of torch.Tensor. |
| |
| This function will check to ensure the dense tensor has the right dtype, size, dims, and device. |
| We currently only support semi-structured sparse tensors for 2d CUDA tensors. |
| Additionally, your tensor must be a positive multiple of a block size given the dtype |
| |
| - torch.float16 (r, c) must be >= and a multiple of 64 |
| - torch.int8 (r, c) must be >= and a multiple of 128 |
| |
| Args: |
| original_tensor (Tensor): the dense tensor to convert |
| transposed (bool, optional): whether the dense tensor is transposed |
| |
| Returns: |
| SparseSemiStructuredTensor: A sparse semi-structured tensor created from the given original_tensor |
| |
| Raises: |
| None |
| Example: |
| >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) |
| >>> A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda() |
| tensor([[0., 0., 1., ..., 0., 1., 1.], |
| [0., 0., 1., ..., 0., 1., 1.], |
| [0., 0., 1., ..., 0., 1., 1.], |
| ..., |
| [0., 0., 1., ..., 0., 1., 1.], |
| [0., 0., 1., ..., 0., 1., 1.], |
| [0., 0., 1., ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16) |
| >>> A_sparse = to_sparse_semi_structured(A) |
| SparseSemiStructuredTensor(shape=torch.Size([128, 128]), transposed=False, values=tensor([[1., 1., 1., ..., 1., 1., 1.], |
| [1., 1., 1., ..., 1., 1., 1.], |
| [1., 1., 1., ..., 1., 1., 1.], |
| ..., |
| [1., 1., 1., ..., 1., 1., 1.], |
| [1., 1., 1., ..., 1., 1., 1.], |
| [1., 1., 1., ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16), |
| metadata=tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370], |
| [-4370, -4370, -4370, ..., -4370, -4370, -4370], |
| [-4370, -4370, -4370, ..., -4370, -4370, -4370], |
| ..., |
| [-4370, -4370, -4370, ..., -4370, -4370, -4370], |
| [-4370, -4370, -4370, ..., -4370, -4370, -4370], |
| [-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0', |
| dtype=torch.int16)) |
| """ |
| return SparseSemiStructuredTensor( |
| original_tensor, original_shape=original_tensor.shape, transposed=transposed |
| ) |