| # Owner(s): ["module: linear algebra"] |
| |
| import unittest |
| from itertools import product |
| from functools import partial |
| from typing import Optional |
| import re |
| |
| import torch |
| |
| from torch.quantization._quantized_conversions import ( |
| pack_int4_to_int8, |
| quantized_weight_reorder_for_mixed_dtypes_linear_cutlass, |
| ) |
| |
| from torch.testing import make_tensor |
| from torch.testing._internal.common_cuda import ( |
| SM53OrLater, |
| SM90OrLater, |
| _get_torch_cuda_version, |
| PLATFORM_SUPPORTS_FP8 |
| ) |
| from torch.testing._internal.common_device_type import ( |
| dtypes, |
| instantiate_device_type_tests, |
| onlyCUDA, |
| tol as xtol, |
| toleranceOverride, |
| ) |
| |
| from torch.testing._internal.common_utils import ( |
| IS_ARM64, |
| IS_JETSON, |
| IS_WINDOWS, |
| parametrize, |
| run_tests, |
| skipIfRocmVersionLessThan, |
| TEST_WITH_ROCM, |
| skipIfRocm, |
| TestCase, |
| ) |
| |
| _IS_SM8X = False |
| if torch.cuda.is_available(): |
| _IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8 |
| |
| # Protects against includes accidentally setting the default dtype |
| assert torch.get_default_dtype() is torch.float32 |
| |
| |
| @unittest.skipIf(IS_ARM64, "Issue with numpy version on arm") |
| class TestMatmulCuda(TestCase): |
| def setUp(self): |
| super(self.__class__, self).setUp() |
| torch.backends.cuda.matmul.allow_tf32 = False |
| |
| def tearDown(self): |
| torch.backends.cuda.matmul.allow_tf32 = True |
| super(self.__class__, self).tearDown() |
| |
| def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool = False): |
| # |
| # Check for catastrophic cuBLAS inaccuracy by measuring the deviation between |
| # results from the CUDA invocation of torch.addmm and the CPU invocation |
| # (which does not use CUDA backend). |
| # |
| # Get dims |
| n, m, p = (size + 1, size, size + 2) |
| # Disable reduced precision reductions in BFloat16 to bypass some kernels |
| # which fail the threshold check |
| orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction |
| orig_fp16 = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction |
| torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = reduced_precision |
| torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = reduced_precision |
| # Make random tensors on CPU (seed set on common_utils.py import) |
| # (Not using numpy because it does not support bfloat16) |
| make_arg = partial(make_tensor, dtype=dtype, device="cpu") |
| m_beta = make_arg(1) |
| m_input = make_arg((n, p)) |
| m_1 = make_arg((n, m)) |
| m_2 = make_arg((m, p)) |
| # *(B)FLOAT16 Special Handling* |
| # Backend does not tensorize float16 on CPU, |
| # and bloat16 may present accuracy issues, |
| # so convert to float32 for these cases |
| # (but keep same for other types, e.g. float32 and int*) |
| if dtype == torch.float16 or dtype == torch.bfloat16: |
| m_beta = m_beta.to(dtype=torch.float32) |
| m_input = m_input.to(dtype=torch.float32) |
| m_1 = m_1.to(dtype=torch.float32) |
| m_2 = m_2.to(dtype=torch.float32) |
| # Get CPU result |
| res_cpu = torch.addmm(m_input, m_1, m_2, beta=m_beta.item()) |
| # *(B)FLOAT16 Special Handling*`` |
| # Convert back to (b)float16 |
| if dtype == torch.float16 or dtype == torch.bfloat16: |
| m_beta = m_beta.to(dtype=dtype) |
| m_input = m_input.to(dtype=dtype) |
| m_1 = m_1.to(dtype=dtype) |
| m_2 = m_2.to(dtype=dtype) |
| res_cpu = res_cpu.to(dtype=dtype) |
| # Move arg tensors to CUDA |
| m_beta = m_beta.to("cuda") |
| m_input = m_input.to("cuda") |
| m_1 = m_1.to("cuda") |
| m_2 = m_2.to("cuda") |
| # Get CUDA result |
| res_cuda = torch.addmm(m_input, m_1, m_2, beta=m_beta.item()) |
| # Move to CPU for comparison |
| res_cuda = res_cuda.to("cpu") |
| # Compare |
| self.assertEqual(res_cpu, res_cuda) |
| torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig_bf16 |
| torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_fp16 |
| |
| @onlyCUDA |
| @skipIfRocmVersionLessThan((5, 2)) |
| # imported 'tol' as 'xtol' to avoid aliasing in code above |
| @toleranceOverride({torch.float16: xtol(atol=1e-1, rtol=1e-1), |
| torch.bfloat16: xtol(atol=1e-1, rtol=1e-1), |
| torch.float32: xtol(atol=1e-1, rtol=1e-1)}) |
| @dtypes(torch.float16, torch.bfloat16, torch.float32) |
| @parametrize("size", [100, 1000, 10000]) |
| def test_cublas_addmm(self, size: int, dtype: torch.dtype): |
| self.cublas_addmm(size, dtype, False) |
| |
| @onlyCUDA |
| @skipIfRocmVersionLessThan((5, 2)) |
| # imported 'tol' as 'xtol' to avoid aliasing in code above |
| @toleranceOverride({torch.float16: xtol(atol=7e-1, rtol=2e-1), |
| torch.bfloat16: xtol(atol=1e1, rtol=2e-1)}) |
| @dtypes(torch.float16, torch.bfloat16) |
| @parametrize("size", [100, 1000, 10000]) |
| def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype): |
| self.cublas_addmm(size, dtype, True) |
| |
| @onlyCUDA |
| @toleranceOverride({torch.float16: xtol(atol=1e-3, rtol=2e-3)}) |
| @dtypes(torch.float16) |
| def test_cublas_addmm_alignment(self, dtype): |
| device = 'cuda' |
| # perturb X, A, or B alignment |
| for idx in range(0, 3): |
| for offset in range(1, 3): |
| offsets = [0, 0, 0] |
| offsets[idx] = offset |
| x_offset, a_offset, b_offset = offsets |
| A = torch.rand((5120 * 2560 + a_offset), requires_grad=True, dtype=dtype, device=device) |
| A = A[a_offset:].reshape(5120, 2560) |
| X = torch.rand((26 * 2560 + x_offset), requires_grad=True, dtype=dtype, device=device) |
| X = X[x_offset:].reshape(26, 1, 2560) |
| B = torch.rand((5120 + b_offset), requires_grad=True, dtype=dtype, device=device) |
| B = B[b_offset:].reshape(5120) |
| out = torch.nn.functional.linear(X, A, B) |
| self.assertEqual(out, torch.matmul(X, A.transpose(1, 0)) + B) |
| |
| @onlyCUDA |
| @unittest.skipIf(IS_JETSON, "Too large for Jetson") |
| @toleranceOverride({torch.float32: xtol(atol=1e-5, rtol=1.1e-5)}) |
| @dtypes(*([torch.float32, torch.float16] + |
| [torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else [])) |
| @parametrize( |
| "batch_size, N, M, P", |
| [(2, 100, 100, 100), |
| (2, 1000, 1000, 1000), |
| (1, 10000, 1000, 10000), |
| (1, 10000, 10000, 10000)], |
| name_fn=lambda batch_size, N, M, P: f"{batch_size}_{N}_{M}_{P}", |
| ) |
| @skipIfRocm |
| def test_cublas_baddbmm_large_input(self, device, batch_size, N, M, P, dtype): |
| cpu_dtype = dtype |
| if dtype == torch.float16 or dtype == torch.bfloat16: |
| cpu_dtype = torch.float32 |
| |
| M1 = torch.rand((N, M), device=device, dtype=dtype) |
| M2 = torch.rand((M, P), device=device, dtype=dtype) |
| A = torch.rand((N, P), device=device, dtype=dtype) |
| |
| def _convert_to_cpu(t): |
| return t.to(device='cpu', dtype=cpu_dtype) |
| M1_cpu, M2_cpu, A_cpu = map(_convert_to_cpu, [M1, M2, A]) |
| |
| # linear |
| out1_cpu = torch.nn.functional.linear(M1_cpu, M2_cpu.t(), A_cpu).to(dtype=dtype) |
| out1_gpu = torch.nn.functional.linear(M1, M2.t(), A).cpu() |
| self.assertEqual(out1_cpu, out1_gpu) |
| # test multiply the identity matrix |
| if N == M and M == P: |
| M2_eye = torch.eye(N, device=device, dtype=dtype) |
| out1_eye_gpu = torch.nn.functional.linear(M1, M2_eye.t(), torch.zeros_like(A)) |
| self.assertEqual(M1_cpu.to(dtype=dtype), out1_eye_gpu.cpu()) |
| |
| # baddbmm |
| def _expand_to_batch(t: torch.Tensor): |
| return t.expand((batch_size, ) + t.size()) |
| alpha, beta = 1.0, 1.0 |
| M1, M2, A, M1_cpu, M2_cpu, A_cpu = map(_expand_to_batch, [M1, M2, A, M1_cpu, M2_cpu, A_cpu]) |
| |
| out2_cpu = torch.baddbmm(A_cpu, M1_cpu, M2_cpu, beta=beta, alpha=alpha).to(dtype=dtype) |
| out2_gpu = torch.baddbmm(A, M1, M2, beta=beta, alpha=alpha).cpu() |
| self.assertEqual(out2_cpu, out2_gpu) |
| # test multiply the identity matrix |
| if N == M and M == P: |
| M2_eye = torch.eye(N, device=device, dtype=dtype).expand(batch_size, N, N) |
| out2_eye_gpu = torch.baddbmm(torch.zeros_like(A), M1, M2_eye, beta=beta, alpha=alpha) |
| self.assertEqual(M1_cpu.to(dtype=dtype), out2_eye_gpu.cpu()) |
| |
| # cross comparison |
| self.assertEqual(out1_gpu, out2_gpu[0]) |
| |
| |
| f8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices" |
| |
| if torch.version.hip: |
| e4m3_type = torch.float8_e4m3fnuz |
| e5m2_type = torch.float8_e5m2fnuz |
| E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max |
| E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max |
| else: |
| e4m3_type = torch.float8_e4m3fn |
| e5m2_type = torch.float8_e5m2 |
| E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max |
| E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max |
| |
| # avoid division by zero when calculating scale |
| EPS = 1e-12 |
| |
| def amax_to_scale( |
| amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype |
| ): |
| """ Converts the amax value of a tensor to the fp8 scale. |
| Args: |
| amax: The amax value of the tensor. |
| float8_dtype: the float8 dtype. |
| orig_dtype: The original dtype of the tensor. |
| """ |
| scale = torch.empty_like(amax, dtype=torch.float32) |
| if float8_dtype == e4m3_type: |
| res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) |
| elif float8_dtype == e5m2_type: |
| res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) |
| else: |
| raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") |
| |
| # Ensure the scale is representable in float16, |
| # this helps when amax is small. We are assuming that we don't need |
| # to care about this for float32/bfloat16 |
| if orig_dtype is torch.float16: |
| res = torch.clamp(res, max=torch.finfo(torch.float16).max) |
| |
| scale.copy_(res) |
| return scale |
| |
| def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None): |
| if dim is None: |
| amax = torch.max(torch.abs(x)) |
| else: |
| amax = torch.max(torch.abs(x), dim=dim, keepdim=True).values |
| |
| return amax_to_scale(amax, float8_dtype, x.dtype) |
| |
| def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor: |
| # naive implementation: dq -> op -> q |
| x_fp32 = x.to(torch.float) / x_scale |
| y_fp32 = y.to(torch.float) / y_scale |
| out_fp32 = torch.mm(x_fp32, y_fp32) |
| |
| return out_fp32.to(out_dtype) |
| |
| def addmm_float8_unwrapped( |
| a_data: torch.Tensor, |
| a_scale: torch.Tensor, |
| b_data: torch.Tensor, |
| b_scale: torch.tensor, |
| output_dtype: torch.dtype, |
| output_scale: Optional[torch.Tensor], |
| bias: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| a_inverse_scale = a_scale.reciprocal() |
| b_inverse_scale = b_scale.reciprocal() |
| if output_dtype == torch.float32 and bias is not None: |
| # Bias is not supported by _scaled_mm when output is fp32 |
| output = torch._scaled_mm( |
| a_data, |
| b_data, |
| scale_a=a_inverse_scale, |
| scale_b=b_inverse_scale, |
| scale_result=output_scale, |
| out_dtype=output_dtype, |
| ) |
| output += bias |
| return output |
| output = torch._scaled_mm( |
| a_data, |
| b_data, |
| bias=bias, |
| scale_a=a_inverse_scale, |
| scale_b=b_inverse_scale, |
| scale_result=output_scale, |
| out_dtype=output_dtype, |
| ) |
| return output |
| |
| def mm_float8( |
| a: torch.Tensor, |
| b: torch.Tensor, |
| a_scale: torch.Tensor, |
| b_scale: torch.Tensor, |
| output_dtype: torch.dtype, # output dtype |
| output_scale: Optional[torch.Tensor] = None, # output scale, precomputed |
| ) -> torch.Tensor: |
| return addmm_float8_unwrapped( |
| a, a_scale, b, b_scale, output_dtype, output_scale |
| ) |
| |
| def to_fp8_saturated( |
| x: torch.Tensor, |
| fp8_dtype: torch.dtype |
| ): |
| if fp8_dtype == e4m3_type: |
| x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) |
| elif fp8_dtype == e5m2_type: |
| x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS) |
| else: |
| raise ValueError(f"to_fp8_saturated(): Unsupported fp8_dtype: {fp8_dtype}") |
| |
| return x.to(fp8_dtype) |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "CUDA not found") |
| class TestFP8MatmulCuda(TestCase): |
| |
| @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) |
| def _test_tautological_mm(self, device: str = "cuda", |
| x_dtype: torch.dtype = e4m3_type, |
| y_dtype: torch.dtype = e4m3_type, |
| out_dtype: Optional[torch.dtype] = None, |
| size: int = 16) -> None: |
| x_fp8 = torch.rand(size, size, device=device).to(x_dtype) |
| y_fp8 = torch.eye(size, device=device, dtype=y_dtype).t() |
| out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float)) |
| scale_a = torch.tensor(1.0, device=device) |
| scale_b = torch.tensor(1.0, device=device) |
| out_fp8 = torch._scaled_mm(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype) |
| if out_dtype is not None: |
| self.assertEqual(out_dtype, out_fp8.dtype) |
| self.assertEqual(out_fp32, out_fp8.to(torch.float)) |
| |
| @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) |
| def test_float8_basics(self, device) -> None: |
| self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16) |
| # hipblaslt does not yet support mixed e4m3_type input |
| if torch.version.hip is None: |
| self._test_tautological_mm(device, e4m3_type, e5m2_type, size=32) |
| self._test_tautological_mm(device, e5m2_type, e4m3_type, size=48) |
| # According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported |
| with self.assertRaises(RuntimeError): |
| self._test_tautological_mm(device, e5m2_type, e5m2_type) |
| |
| self._test_tautological_mm(device, size=64, out_dtype=torch.float16) |
| self._test_tautological_mm(device, size=96, out_dtype=torch.float32) |
| # hipblaslt does not yet support bfloat16 output |
| if torch.version.hip is None: |
| self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16) |
| with self.assertRaises(RuntimeError): |
| self._test_tautological_mm(device, out_dtype=e5m2_type) |
| |
| @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) |
| def test_float8_scale(self, device) -> None: |
| size = (16, 16) |
| x = torch.full(size, .5, device=device, dtype=e4m3_type) |
| # hipblaslt does not yet support mixed e4m3_type input |
| y_type = e4m3_type if torch.version.hip else e5m2_type |
| y = torch.full(size, .5, device=device, dtype=y_type).t() |
| scale_a = torch.tensor(1.5, device=device) |
| scale_b = torch.tensor(0.66, device=device) |
| out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b) |
| self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device)) |
| out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b) |
| self.assertEqual(out_fp8, out_fp8_s) |
| |
| @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) |
| @parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32]) |
| def test_scaled_mm_vs_emulated(self, base_dtype): |
| torch.manual_seed(42) |
| input_dtype = e4m3_type |
| output_dtype = base_dtype |
| compare_type = torch.float32 |
| |
| x = torch.randn(16, 16, device="cuda", dtype=base_dtype) |
| y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t() |
| |
| x_scale = tensor_to_scale(x, input_dtype).float() |
| y_scale = tensor_to_scale(y, input_dtype).float() |
| |
| x_fp8 = to_fp8_saturated(x * x_scale, input_dtype) |
| y_fp8 = to_fp8_saturated(y * y_scale, input_dtype) |
| |
| # Calculate actual F8 mm |
| out_scaled_mm = mm_float8( |
| x_fp8, |
| y_fp8, |
| a_scale=x_scale, |
| b_scale=y_scale, |
| output_dtype=output_dtype |
| ) |
| |
| # Calculate emulated F8 mm |
| out_emulated = mm_float8_emulated( |
| x_fp8, |
| x_scale, |
| y_fp8, |
| y_scale, |
| output_dtype |
| ) |
| |
| if output_dtype != base_dtype: |
| out_scaled_mm = out_scaled_mm.to(compare_type) |
| out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype) |
| |
| out_emulated = out_emulated.to(compare_type) |
| out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype) |
| |
| if base_dtype in {torch.bfloat16, torch.float16}: |
| atol, rtol = 7e-2, 7e-2 |
| else: |
| atol, rtol = 3e-3, 3e-3 |
| |
| torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) |
| |
| @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) |
| @parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32]) |
| def test_scaled_mm_change_stride(self, base_dtype): |
| torch.manual_seed(42) |
| input_dtype = e4m3_type |
| output_dtype = base_dtype |
| compare_type = torch.float32 |
| |
| x = torch.empty_strided((16, 16), (16, 1), device="cuda", dtype=base_dtype) |
| y = torch.empty_strided((16, 32), (1, 64), device="cuda", dtype=base_dtype) |
| |
| x_scale = tensor_to_scale(x, input_dtype).float() |
| y_scale = tensor_to_scale(y, input_dtype).float() |
| |
| x_fp8 = to_fp8_saturated(x * x_scale, input_dtype) |
| y_fp8 = to_fp8_saturated(y * y_scale, input_dtype) |
| |
| # Calculate actual F8 mm |
| out_scaled_mm = mm_float8( |
| x_fp8, |
| y_fp8, |
| a_scale=x_scale, |
| b_scale=y_scale, |
| output_dtype=output_dtype |
| ) |
| |
| # Calculate emulated F8 mm |
| out_emulated = mm_float8_emulated( |
| x_fp8, |
| x_scale, |
| y_fp8, |
| y_scale, |
| output_dtype |
| ) |
| |
| if output_dtype != base_dtype: |
| out_scaled_mm = out_scaled_mm.to(compare_type) |
| out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype) |
| |
| out_emulated = out_emulated.to(compare_type) |
| out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype) |
| |
| if base_dtype in {torch.bfloat16, torch.float16}: |
| atol, rtol = 7e-2, 7e-2 |
| else: |
| atol, rtol = 3e-3, 3e-3 |
| |
| torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) |
| |
| @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) |
| def test_float8_bias(self, device) -> None: |
| (k, l, m) = (16, 48, 32) |
| x = torch.ones((k, l), device=device).to(e4m3_type) |
| y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t() |
| bias = torch.full((m,), 4.0, device=device, dtype=torch.half) |
| scale_a = torch.tensor(1.0, device=device) |
| scale_b = torch.tensor(1.0, device=device) |
| out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b) |
| outb_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, bias=bias) |
| # this fails on ROCm currently because hipblaslt doesn't have amax op |
| out_fp32 = out_fp8.to(torch.float32) |
| outb_fp32 = outb_fp8.to(torch.float32) |
| difference = torch.abs(out_fp32 - outb_fp32) |
| self.assertEqual(difference, torch.tensor(4.0, device=device).expand_as(out_fp32)) |
| |
| @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) |
| @parametrize("bias", [True, False]) |
| def test_non_divisible_leading_dim(self, device, bias: bool) -> None: |
| x = torch.rand((17, 16), device=device).to(e4m3_type) |
| y = torch.rand((16, 16), device=device).to(e4m3_type).t() |
| scale_a = torch.tensor(1.0, device=device) |
| scale_b = torch.tensor(1.0, device=device) |
| input_bias = None |
| if bias: |
| input_bias = torch.rand((16,), device=device).to(torch.half) |
| _ = torch._scaled_mm(x, y, scale_a, scale_b, bias=input_bias) |
| |
| @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) |
| def test_float8_bias_relu_edgecase(self, device) -> None: |
| (k, l, m) = (16, 48, 32) |
| x = torch.full((k, l), 0.0, device=device).to(e4m3_type) |
| y = torch.full((m, l), 1.0, device=device, dtype=e4m3_type).t() |
| bias = torch.full((m,), -3.0, device=device, dtype=torch.half) |
| scale_a = torch.tensor(1.0, device=device) |
| scale_b = torch.tensor(1.0, device=device) |
| outb_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, bias=bias) |
| outb_fp32 = outb_fp8.to(torch.float32) |
| self.assertEqual(outb_fp32, torch.tensor(-3.0, device=device).expand_as(outb_fp32)) |
| |
| @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) |
| def test_float32_output_errors_with_bias(self, device) -> None: |
| (k, l, m) = (16, 48, 32) |
| x = torch.rand((k, l), device=device).to(e4m3_type) |
| y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t() |
| scale_a = torch.tensor(1.0, device=device) |
| scale_b = torch.tensor(1.0, device=device) |
| bias = torch.full((m,), 4.0, device=device, dtype=torch.bfloat16) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "Bias is not supported when out_dtype is set to Float32", |
| lambda: torch._scaled_mm(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32), |
| ) |
| |
| @unittest.skipIf(PLATFORM_SUPPORTS_FP8, |
| "This test is only for devices with compute capability < 8.9") |
| def test_error_message_fp8_pre_sm89(self, device) -> None: |
| (k, l, m) = (16, 48, 32) |
| x = torch.rand((k, l), device=device).to(e4m3_type) |
| y = torch.rand((m, l), device=device).to(e4m3_type).t() |
| scale_a = torch.tensor(1.0, device=device) |
| scale_b = torch.tensor(1.0, device=device) |
| self.assertRaisesRegex( |
| RuntimeError, |
| r"torch\.\_scaled\_mm is only supported on CUDA devices with compute capability \>\= 9\.0 or 8\.9, or ROCm MI300\+", |
| lambda: torch._scaled_mm(x, y, scale_a, scale_b, out_dtype=torch.float32), |
| ) |
| |
| @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) |
| def test_float8_scale_fast_accum(self, device) -> None: |
| size = (16, 16) |
| x = torch.full(size, .5, device=device, dtype=e4m3_type) |
| # hipblaslt does not yet support mixed e4m3_type input |
| y_type = e4m3_type if torch.version.hip else e5m2_type |
| y = torch.full(size, .5, device=device, dtype=y_type).t() |
| scale_a = torch.tensor(1.5, device=device) |
| scale_b = torch.tensor(0.66, device=device) |
| out_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, use_fast_accum=True) |
| self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device)) |
| out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True) |
| self.assertEqual(out_fp8, out_fp8_s) |
| |
| @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) |
| @skipIfRocm() |
| @parametrize("use_fast_accum", [True, False]) |
| def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None: |
| M, K, N = (1024, 512, 2048) |
| fill_value = 0.5 |
| x = torch.full((M, K), fill_value, device=device) |
| y = torch.full((N, K), fill_value, device=device) |
| |
| x_scales = torch.ones((x.shape[0], 1), device=device, dtype=torch.float32) |
| y_scales = torch.ones((1, y.shape[0]), device=device, dtype=torch.float32) |
| |
| x_fp8 = x.to(torch.float8_e4m3fn) |
| y_fp8 = y.to(torch.float8_e4m3fn).t() |
| |
| out_fp8 = torch._scaled_mm( |
| x_fp8, |
| y_fp8, |
| scale_a=x_scales, |
| scale_b=y_scales, |
| out_dtype=torch.bfloat16, |
| use_fast_accum=use_fast_accum, |
| ) |
| self.assertEqual( |
| out_fp8.to(torch.float32), torch.full((M, N), K * (fill_value**2), device=device) |
| ) |
| |
| @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) |
| @skipIfRocm() |
| def test_float8_error_messages(self, device) -> None: |
| M, K, N = (1024, 512, 2048) |
| fill_value = 0.5 |
| x = torch.full((M, K), fill_value, device=device) |
| y = torch.full((N, K), fill_value, device=device) |
| |
| x_fp8 = x.to(torch.float8_e4m3fn) |
| y_fp8 = y.to(torch.float8_e4m3fn).t() |
| |
| with self.assertRaisesRegex( |
| RuntimeError, |
| re.escape( |
| "For RowWise scaling, scale_a should be (1024, 1) and scale_b " |
| "should be (1, 2048). Got scale_a.size()=(1, 1) and scale_b.size()=(1, 2)" |
| ), |
| ): |
| torch._scaled_mm( |
| x_fp8, |
| y_fp8, |
| scale_a=torch.ones((1, 1), device="cuda"), |
| scale_b=torch.ones((1, 2), device="cuda"), |
| out_dtype=torch.bfloat16, |
| ) |
| |
| with self.assertRaisesRegex( |
| RuntimeError, |
| re.escape( |
| " For RowWise scaling, scale_a should be (1024, 1) and scale_b " |
| "should be (1, 2048). Got scale_a.size()=(1024, 1) and scale_b.size()=(1, 2049)" |
| ), |
| ): |
| torch._scaled_mm( |
| x_fp8, |
| y_fp8, |
| scale_a=torch.ones((M, 1), device="cuda"), |
| scale_b=torch.ones((1, N + 1), device="cuda"), |
| out_dtype=torch.bfloat16, |
| ) |
| with self.assertRaisesRegex( |
| RuntimeError, |
| re.escape("For non-TensorWise scaling, scale tensors must be 2-dimensional"), |
| ): |
| torch._scaled_mm( |
| x_fp8, |
| y_fp8, |
| scale_a=torch.ones((M), device="cuda"), |
| scale_b=torch.ones((N, N), device="cuda"), |
| out_dtype=torch.bfloat16, |
| ) |
| |
| with self.assertRaisesRegex( |
| RuntimeError, |
| re.escape( |
| "Both scale_a and scale_b must be contiguous for RowWise scaling." |
| ), |
| ): |
| torch._scaled_mm( |
| x_fp8, |
| y_fp8, |
| scale_a=torch.ones((M, 1), device="cuda"), |
| scale_b=torch.ones((1, N * 2), device="cuda")[:, ::2], |
| out_dtype=torch.bfloat16, |
| ) |
| |
| with self.assertRaisesRegex( |
| RuntimeError, |
| re.escape("For RowWise scaling the second input is required to be a float8_e4m3fn dtype."), |
| ): |
| torch._scaled_mm( |
| x_fp8, |
| y_fp8.to(torch.float8_e5m2), |
| scale_a=torch.ones((M, 1), device="cuda"), |
| scale_b=torch.ones((1, N), device="cuda"), |
| out_dtype=torch.bfloat16, |
| ) |
| |
| @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) |
| @unittest.skipIf(not SM90OrLater, "rowwise implementation is currently sm90 specific") |
| @skipIfRocm() |
| @parametrize("base_dtype", [torch.bfloat16]) |
| def test_scaled_mm_vs_emulated_row_wise(self, base_dtype): |
| torch.manual_seed(42) |
| input_dtype = e4m3_type |
| output_dtype = base_dtype |
| |
| x = torch.randn(16, 16, device="cuda", dtype=base_dtype) |
| y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t() |
| |
| x_scales = tensor_to_scale(x, input_dtype, dim=1).float() |
| y_scales = tensor_to_scale(y, input_dtype, dim=0).float() |
| |
| x_fp8 = to_fp8_saturated(x * x_scales, e4m3_type) |
| y_fp8 = to_fp8_saturated(y * y_scales, e4m3_type) |
| |
| # Calculate actual F8 mm |
| out_scaled_mm = mm_float8( |
| x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype |
| ) |
| |
| # Calculate emulated F8 mm |
| out_emulated = mm_float8_emulated( |
| x_fp8, x_scales, y_fp8, y_scales, output_dtype |
| ) |
| |
| if base_dtype in {torch.bfloat16, torch.float16}: |
| atol, rtol = 7e-2, 7e-2 |
| else: |
| atol, rtol = 2e-3, 2e-3 |
| |
| torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) |
| |
| |
| @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") |
| @unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions") |
| @unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x") |
| class TestMixedDtypesLinearCuda(TestCase): |
| @dtypes(torch.float16, torch.bfloat16) |
| def test_mixed_dtypes_linear(self, dtype: torch.dtype, device: str = "cuda"): |
| version = _get_torch_cuda_version() |
| if version < (11, 8): |
| self.skipTest("_mixed_dtypes_linear only compiled for CUDA 11.8+") |
| |
| def run_test( |
| batch_shape, |
| m, |
| n, |
| k, |
| add_bias, |
| activation, |
| dtype, |
| dtypeq, |
| device, |
| rtol, |
| atol, |
| ): |
| if not add_bias and activation != "none": |
| return |
| |
| val_lo, val_hi = -1, 1 |
| valq_lo, valq_hi = -2, 2 |
| input = make_tensor( |
| *batch_shape, m, k, low=val_lo, high=val_hi, dtype=dtype, device=device |
| ) |
| weight = make_tensor( |
| n, k, low=valq_lo, high=valq_hi, dtype=torch.int8, device=device |
| ) |
| scale = make_tensor( |
| (n,), low=val_lo, high=val_hi, dtype=input.dtype, device=device |
| ) |
| bias = ( |
| make_tensor( |
| (n,), low=val_lo, high=val_hi, dtype=input.dtype, device=device |
| ) |
| if add_bias |
| else None |
| ) |
| |
| input_ref = input.reshape(-1, input.shape[-1]) |
| |
| # First, test plain multiplication. |
| weight_ref = weight.T.to(input.dtype) * scale.view(1, n) |
| weightq = ( |
| pack_int4_to_int8(weight.T) if dtypeq == torch.quint4x2 else weight.T |
| ) |
| output_ref = torch.mm(input_ref, weight_ref).reshape(*input.shape[:-1], n) |
| output = torch.ops.aten._mixed_dtypes_linear( |
| input, |
| quantized_weight_reorder_for_mixed_dtypes_linear_cutlass( |
| weightq, dtypeq, transpose=False |
| ), |
| scale, |
| ) |
| torch.testing.assert_close(output, output_ref, rtol=rtol, atol=atol) |
| |
| # Second, test the linear operator itself. |
| weight_ref = weight.to(input.dtype) * scale.view(n, 1) |
| weightq = pack_int4_to_int8(weight) if dtypeq == torch.quint4x2 else weight |
| bias_ref = bias.view(1, n) if add_bias else None |
| output_ref = torch.nn.functional.linear( |
| input_ref, weight_ref, bias=bias_ref |
| ).reshape(*input.shape[:-1], n) |
| if activation == "relu": |
| relu = torch.nn.ReLU() |
| output_ref = relu(output_ref) |
| elif activation == "silu": |
| silu = torch.nn.SiLU() |
| output_ref = silu(output_ref) |
| output = torch.ops.aten._mixed_dtypes_linear( |
| input, |
| quantized_weight_reorder_for_mixed_dtypes_linear_cutlass( |
| weightq, dtypeq, transpose=True |
| ), |
| scale, |
| bias=bias, |
| activation=activation, |
| ) |
| torch.testing.assert_close(output, output_ref, rtol=rtol, atol=atol) |
| |
| dtypeqs = [torch.int8, torch.quint4x2] |
| batch_shapes = [[], [2], [2, 1]] |
| shapes = [ |
| [8, 64, 64], |
| [8, 64, 128], |
| [8, 128, 64], |
| [8, 128, 128], |
| [8, 128, 192], |
| [8, 128, 256], |
| [8, 256, 128], |
| [8, 256, 384], |
| [8, 384, 256], |
| ] |
| activations = [None, "relu", "silu"] |
| rtol, atol = 1e-3, 1e-3 |
| if dtype == torch.bfloat16: |
| rtol, atol = 1e-2, 1e-3 |
| for dtypeq, batch_shape, (m, n, k), add_bias, activation in product( |
| dtypeqs, batch_shapes, shapes, (False, True), activations |
| ): |
| run_test( |
| batch_shape, |
| m, |
| n, |
| k, |
| add_bias, |
| activation, |
| dtype, |
| dtypeq, |
| device, |
| rtol, |
| atol, |
| ) |
| |
| instantiate_device_type_tests(TestMatmulCuda, globals(), except_for="cpu") |
| instantiate_device_type_tests(TestFP8MatmulCuda, globals(), except_for="cpu") |
| instantiate_device_type_tests(TestMixedDtypesLinearCuda, globals(), except_for="cpu") |
| |
| if __name__ == '__main__': |
| TestCase._default_dtype_check_enabled = True |
| run_tests() |