blob: 70235665cb6fef4bded0e27f5785689558a59a4f [file] [log] [blame] [edit]
# Owner(s): ["module: linear algebra"]
import unittest
from itertools import product
from functools import partial
from typing import Optional
import re
import torch
from torch.quantization._quantized_conversions import (
pack_int4_to_int8,
quantized_weight_reorder_for_mixed_dtypes_linear_cutlass,
)
from torch.testing import make_tensor
from torch.testing._internal.common_cuda import (
SM53OrLater,
SM90OrLater,
_get_torch_cuda_version,
PLATFORM_SUPPORTS_FP8
)
from torch.testing._internal.common_device_type import (
dtypes,
instantiate_device_type_tests,
onlyCUDA,
tol as xtol,
toleranceOverride,
)
from torch.testing._internal.common_utils import (
IS_ARM64,
IS_JETSON,
IS_WINDOWS,
parametrize,
run_tests,
skipIfRocmVersionLessThan,
TEST_WITH_ROCM,
skipIfRocm,
TestCase,
)
_IS_SM8X = False
if torch.cuda.is_available():
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
# Protects against includes accidentally setting the default dtype
assert torch.get_default_dtype() is torch.float32
@unittest.skipIf(IS_ARM64, "Issue with numpy version on arm")
class TestMatmulCuda(TestCase):
def setUp(self):
super(self.__class__, self).setUp()
torch.backends.cuda.matmul.allow_tf32 = False
def tearDown(self):
torch.backends.cuda.matmul.allow_tf32 = True
super(self.__class__, self).tearDown()
def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool = False):
#
# Check for catastrophic cuBLAS inaccuracy by measuring the deviation between
# results from the CUDA invocation of torch.addmm and the CPU invocation
# (which does not use CUDA backend).
#
# Get dims
n, m, p = (size + 1, size, size + 2)
# Disable reduced precision reductions in BFloat16 to bypass some kernels
# which fail the threshold check
orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
orig_fp16 = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = reduced_precision
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = reduced_precision
# Make random tensors on CPU (seed set on common_utils.py import)
# (Not using numpy because it does not support bfloat16)
make_arg = partial(make_tensor, dtype=dtype, device="cpu")
m_beta = make_arg(1)
m_input = make_arg((n, p))
m_1 = make_arg((n, m))
m_2 = make_arg((m, p))
# *(B)FLOAT16 Special Handling*
# Backend does not tensorize float16 on CPU,
# and bloat16 may present accuracy issues,
# so convert to float32 for these cases
# (but keep same for other types, e.g. float32 and int*)
if dtype == torch.float16 or dtype == torch.bfloat16:
m_beta = m_beta.to(dtype=torch.float32)
m_input = m_input.to(dtype=torch.float32)
m_1 = m_1.to(dtype=torch.float32)
m_2 = m_2.to(dtype=torch.float32)
# Get CPU result
res_cpu = torch.addmm(m_input, m_1, m_2, beta=m_beta.item())
# *(B)FLOAT16 Special Handling*``
# Convert back to (b)float16
if dtype == torch.float16 or dtype == torch.bfloat16:
m_beta = m_beta.to(dtype=dtype)
m_input = m_input.to(dtype=dtype)
m_1 = m_1.to(dtype=dtype)
m_2 = m_2.to(dtype=dtype)
res_cpu = res_cpu.to(dtype=dtype)
# Move arg tensors to CUDA
m_beta = m_beta.to("cuda")
m_input = m_input.to("cuda")
m_1 = m_1.to("cuda")
m_2 = m_2.to("cuda")
# Get CUDA result
res_cuda = torch.addmm(m_input, m_1, m_2, beta=m_beta.item())
# Move to CPU for comparison
res_cuda = res_cuda.to("cpu")
# Compare
self.assertEqual(res_cpu, res_cuda)
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig_bf16
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_fp16
@onlyCUDA
@skipIfRocmVersionLessThan((5, 2))
# imported 'tol' as 'xtol' to avoid aliasing in code above
@toleranceOverride({torch.float16: xtol(atol=1e-1, rtol=1e-1),
torch.bfloat16: xtol(atol=1e-1, rtol=1e-1),
torch.float32: xtol(atol=1e-1, rtol=1e-1)})
@dtypes(torch.float16, torch.bfloat16, torch.float32)
@parametrize("size", [100, 1000, 10000])
def test_cublas_addmm(self, size: int, dtype: torch.dtype):
self.cublas_addmm(size, dtype, False)
@onlyCUDA
@skipIfRocmVersionLessThan((5, 2))
# imported 'tol' as 'xtol' to avoid aliasing in code above
@toleranceOverride({torch.float16: xtol(atol=7e-1, rtol=2e-1),
torch.bfloat16: xtol(atol=1e1, rtol=2e-1)})
@dtypes(torch.float16, torch.bfloat16)
@parametrize("size", [100, 1000, 10000])
def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype):
self.cublas_addmm(size, dtype, True)
@onlyCUDA
@toleranceOverride({torch.float16: xtol(atol=1e-3, rtol=2e-3)})
@dtypes(torch.float16)
def test_cublas_addmm_alignment(self, dtype):
device = 'cuda'
# perturb X, A, or B alignment
for idx in range(0, 3):
for offset in range(1, 3):
offsets = [0, 0, 0]
offsets[idx] = offset
x_offset, a_offset, b_offset = offsets
A = torch.rand((5120 * 2560 + a_offset), requires_grad=True, dtype=dtype, device=device)
A = A[a_offset:].reshape(5120, 2560)
X = torch.rand((26 * 2560 + x_offset), requires_grad=True, dtype=dtype, device=device)
X = X[x_offset:].reshape(26, 1, 2560)
B = torch.rand((5120 + b_offset), requires_grad=True, dtype=dtype, device=device)
B = B[b_offset:].reshape(5120)
out = torch.nn.functional.linear(X, A, B)
self.assertEqual(out, torch.matmul(X, A.transpose(1, 0)) + B)
@onlyCUDA
@unittest.skipIf(IS_JETSON, "Too large for Jetson")
@toleranceOverride({torch.float32: xtol(atol=1e-5, rtol=1.1e-5)})
@dtypes(*([torch.float32, torch.float16] +
[torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else []))
@parametrize(
"batch_size, N, M, P",
[(2, 100, 100, 100),
(2, 1000, 1000, 1000),
(1, 10000, 1000, 10000),
(1, 10000, 10000, 10000)],
name_fn=lambda batch_size, N, M, P: f"{batch_size}_{N}_{M}_{P}",
)
@skipIfRocm
def test_cublas_baddbmm_large_input(self, device, batch_size, N, M, P, dtype):
cpu_dtype = dtype
if dtype == torch.float16 or dtype == torch.bfloat16:
cpu_dtype = torch.float32
M1 = torch.rand((N, M), device=device, dtype=dtype)
M2 = torch.rand((M, P), device=device, dtype=dtype)
A = torch.rand((N, P), device=device, dtype=dtype)
def _convert_to_cpu(t):
return t.to(device='cpu', dtype=cpu_dtype)
M1_cpu, M2_cpu, A_cpu = map(_convert_to_cpu, [M1, M2, A])
# linear
out1_cpu = torch.nn.functional.linear(M1_cpu, M2_cpu.t(), A_cpu).to(dtype=dtype)
out1_gpu = torch.nn.functional.linear(M1, M2.t(), A).cpu()
self.assertEqual(out1_cpu, out1_gpu)
# test multiply the identity matrix
if N == M and M == P:
M2_eye = torch.eye(N, device=device, dtype=dtype)
out1_eye_gpu = torch.nn.functional.linear(M1, M2_eye.t(), torch.zeros_like(A))
self.assertEqual(M1_cpu.to(dtype=dtype), out1_eye_gpu.cpu())
# baddbmm
def _expand_to_batch(t: torch.Tensor):
return t.expand((batch_size, ) + t.size())
alpha, beta = 1.0, 1.0
M1, M2, A, M1_cpu, M2_cpu, A_cpu = map(_expand_to_batch, [M1, M2, A, M1_cpu, M2_cpu, A_cpu])
out2_cpu = torch.baddbmm(A_cpu, M1_cpu, M2_cpu, beta=beta, alpha=alpha).to(dtype=dtype)
out2_gpu = torch.baddbmm(A, M1, M2, beta=beta, alpha=alpha).cpu()
self.assertEqual(out2_cpu, out2_gpu)
# test multiply the identity matrix
if N == M and M == P:
M2_eye = torch.eye(N, device=device, dtype=dtype).expand(batch_size, N, N)
out2_eye_gpu = torch.baddbmm(torch.zeros_like(A), M1, M2_eye, beta=beta, alpha=alpha)
self.assertEqual(M1_cpu.to(dtype=dtype), out2_eye_gpu.cpu())
# cross comparison
self.assertEqual(out1_gpu, out2_gpu[0])
f8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices"
if torch.version.hip:
e4m3_type = torch.float8_e4m3fnuz
e5m2_type = torch.float8_e5m2fnuz
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max
else:
e4m3_type = torch.float8_e4m3fn
e5m2_type = torch.float8_e5m2
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
# avoid division by zero when calculating scale
EPS = 1e-12
def amax_to_scale(
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
):
""" Converts the amax value of a tensor to the fp8 scale.
Args:
amax: The amax value of the tensor.
float8_dtype: the float8 dtype.
orig_dtype: The original dtype of the tensor.
"""
scale = torch.empty_like(amax, dtype=torch.float32)
if float8_dtype == e4m3_type:
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
elif float8_dtype == e5m2_type:
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
else:
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
# Ensure the scale is representable in float16,
# this helps when amax is small. We are assuming that we don't need
# to care about this for float32/bfloat16
if orig_dtype is torch.float16:
res = torch.clamp(res, max=torch.finfo(torch.float16).max)
scale.copy_(res)
return scale
def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None):
if dim is None:
amax = torch.max(torch.abs(x))
else:
amax = torch.max(torch.abs(x), dim=dim, keepdim=True).values
return amax_to_scale(amax, float8_dtype, x.dtype)
def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
# naive implementation: dq -> op -> q
x_fp32 = x.to(torch.float) / x_scale
y_fp32 = y.to(torch.float) / y_scale
out_fp32 = torch.mm(x_fp32, y_fp32)
return out_fp32.to(out_dtype)
def addmm_float8_unwrapped(
a_data: torch.Tensor,
a_scale: torch.Tensor,
b_data: torch.Tensor,
b_scale: torch.tensor,
output_dtype: torch.dtype,
output_scale: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
a_inverse_scale = a_scale.reciprocal()
b_inverse_scale = b_scale.reciprocal()
if output_dtype == torch.float32 and bias is not None:
# Bias is not supported by _scaled_mm when output is fp32
output = torch._scaled_mm(
a_data,
b_data,
scale_a=a_inverse_scale,
scale_b=b_inverse_scale,
scale_result=output_scale,
out_dtype=output_dtype,
)
output += bias
return output
output = torch._scaled_mm(
a_data,
b_data,
bias=bias,
scale_a=a_inverse_scale,
scale_b=b_inverse_scale,
scale_result=output_scale,
out_dtype=output_dtype,
)
return output
def mm_float8(
a: torch.Tensor,
b: torch.Tensor,
a_scale: torch.Tensor,
b_scale: torch.Tensor,
output_dtype: torch.dtype, # output dtype
output_scale: Optional[torch.Tensor] = None, # output scale, precomputed
) -> torch.Tensor:
return addmm_float8_unwrapped(
a, a_scale, b, b_scale, output_dtype, output_scale
)
def to_fp8_saturated(
x: torch.Tensor,
fp8_dtype: torch.dtype
):
if fp8_dtype == e4m3_type:
x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
elif fp8_dtype == e5m2_type:
x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
else:
raise ValueError(f"to_fp8_saturated(): Unsupported fp8_dtype: {fp8_dtype}")
return x.to(fp8_dtype)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not found")
class TestFP8MatmulCuda(TestCase):
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
def _test_tautological_mm(self, device: str = "cuda",
x_dtype: torch.dtype = e4m3_type,
y_dtype: torch.dtype = e4m3_type,
out_dtype: Optional[torch.dtype] = None,
size: int = 16) -> None:
x_fp8 = torch.rand(size, size, device=device).to(x_dtype)
y_fp8 = torch.eye(size, device=device, dtype=y_dtype).t()
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
scale_a = torch.tensor(1.0, device=device)
scale_b = torch.tensor(1.0, device=device)
out_fp8 = torch._scaled_mm(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
if out_dtype is not None:
self.assertEqual(out_dtype, out_fp8.dtype)
self.assertEqual(out_fp32, out_fp8.to(torch.float))
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
def test_float8_basics(self, device) -> None:
self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16)
# hipblaslt does not yet support mixed e4m3_type input
if torch.version.hip is None:
self._test_tautological_mm(device, e4m3_type, e5m2_type, size=32)
self._test_tautological_mm(device, e5m2_type, e4m3_type, size=48)
# According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported
with self.assertRaises(RuntimeError):
self._test_tautological_mm(device, e5m2_type, e5m2_type)
self._test_tautological_mm(device, size=64, out_dtype=torch.float16)
self._test_tautological_mm(device, size=96, out_dtype=torch.float32)
# hipblaslt does not yet support bfloat16 output
if torch.version.hip is None:
self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16)
with self.assertRaises(RuntimeError):
self._test_tautological_mm(device, out_dtype=e5m2_type)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
def test_float8_scale(self, device) -> None:
size = (16, 16)
x = torch.full(size, .5, device=device, dtype=e4m3_type)
# hipblaslt does not yet support mixed e4m3_type input
y_type = e4m3_type if torch.version.hip else e5m2_type
y = torch.full(size, .5, device=device, dtype=y_type).t()
scale_a = torch.tensor(1.5, device=device)
scale_b = torch.tensor(0.66, device=device)
out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device))
out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
self.assertEqual(out_fp8, out_fp8_s)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32])
def test_scaled_mm_vs_emulated(self, base_dtype):
torch.manual_seed(42)
input_dtype = e4m3_type
output_dtype = base_dtype
compare_type = torch.float32
x = torch.randn(16, 16, device="cuda", dtype=base_dtype)
y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t()
x_scale = tensor_to_scale(x, input_dtype).float()
y_scale = tensor_to_scale(y, input_dtype).float()
x_fp8 = to_fp8_saturated(x * x_scale, input_dtype)
y_fp8 = to_fp8_saturated(y * y_scale, input_dtype)
# Calculate actual F8 mm
out_scaled_mm = mm_float8(
x_fp8,
y_fp8,
a_scale=x_scale,
b_scale=y_scale,
output_dtype=output_dtype
)
# Calculate emulated F8 mm
out_emulated = mm_float8_emulated(
x_fp8,
x_scale,
y_fp8,
y_scale,
output_dtype
)
if output_dtype != base_dtype:
out_scaled_mm = out_scaled_mm.to(compare_type)
out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype)
out_emulated = out_emulated.to(compare_type)
out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype)
if base_dtype in {torch.bfloat16, torch.float16}:
atol, rtol = 7e-2, 7e-2
else:
atol, rtol = 3e-3, 3e-3
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32])
def test_scaled_mm_change_stride(self, base_dtype):
torch.manual_seed(42)
input_dtype = e4m3_type
output_dtype = base_dtype
compare_type = torch.float32
x = torch.empty_strided((16, 16), (16, 1), device="cuda", dtype=base_dtype)
y = torch.empty_strided((16, 32), (1, 64), device="cuda", dtype=base_dtype)
x_scale = tensor_to_scale(x, input_dtype).float()
y_scale = tensor_to_scale(y, input_dtype).float()
x_fp8 = to_fp8_saturated(x * x_scale, input_dtype)
y_fp8 = to_fp8_saturated(y * y_scale, input_dtype)
# Calculate actual F8 mm
out_scaled_mm = mm_float8(
x_fp8,
y_fp8,
a_scale=x_scale,
b_scale=y_scale,
output_dtype=output_dtype
)
# Calculate emulated F8 mm
out_emulated = mm_float8_emulated(
x_fp8,
x_scale,
y_fp8,
y_scale,
output_dtype
)
if output_dtype != base_dtype:
out_scaled_mm = out_scaled_mm.to(compare_type)
out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype)
out_emulated = out_emulated.to(compare_type)
out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype)
if base_dtype in {torch.bfloat16, torch.float16}:
atol, rtol = 7e-2, 7e-2
else:
atol, rtol = 3e-3, 3e-3
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
def test_float8_bias(self, device) -> None:
(k, l, m) = (16, 48, 32)
x = torch.ones((k, l), device=device).to(e4m3_type)
y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t()
bias = torch.full((m,), 4.0, device=device, dtype=torch.half)
scale_a = torch.tensor(1.0, device=device)
scale_b = torch.tensor(1.0, device=device)
out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
outb_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, bias=bias)
# this fails on ROCm currently because hipblaslt doesn't have amax op
out_fp32 = out_fp8.to(torch.float32)
outb_fp32 = outb_fp8.to(torch.float32)
difference = torch.abs(out_fp32 - outb_fp32)
self.assertEqual(difference, torch.tensor(4.0, device=device).expand_as(out_fp32))
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("bias", [True, False])
def test_non_divisible_leading_dim(self, device, bias: bool) -> None:
x = torch.rand((17, 16), device=device).to(e4m3_type)
y = torch.rand((16, 16), device=device).to(e4m3_type).t()
scale_a = torch.tensor(1.0, device=device)
scale_b = torch.tensor(1.0, device=device)
input_bias = None
if bias:
input_bias = torch.rand((16,), device=device).to(torch.half)
_ = torch._scaled_mm(x, y, scale_a, scale_b, bias=input_bias)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
def test_float8_bias_relu_edgecase(self, device) -> None:
(k, l, m) = (16, 48, 32)
x = torch.full((k, l), 0.0, device=device).to(e4m3_type)
y = torch.full((m, l), 1.0, device=device, dtype=e4m3_type).t()
bias = torch.full((m,), -3.0, device=device, dtype=torch.half)
scale_a = torch.tensor(1.0, device=device)
scale_b = torch.tensor(1.0, device=device)
outb_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, bias=bias)
outb_fp32 = outb_fp8.to(torch.float32)
self.assertEqual(outb_fp32, torch.tensor(-3.0, device=device).expand_as(outb_fp32))
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
def test_float32_output_errors_with_bias(self, device) -> None:
(k, l, m) = (16, 48, 32)
x = torch.rand((k, l), device=device).to(e4m3_type)
y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t()
scale_a = torch.tensor(1.0, device=device)
scale_b = torch.tensor(1.0, device=device)
bias = torch.full((m,), 4.0, device=device, dtype=torch.bfloat16)
self.assertRaisesRegex(
RuntimeError,
"Bias is not supported when out_dtype is set to Float32",
lambda: torch._scaled_mm(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32),
)
@unittest.skipIf(PLATFORM_SUPPORTS_FP8,
"This test is only for devices with compute capability < 8.9")
def test_error_message_fp8_pre_sm89(self, device) -> None:
(k, l, m) = (16, 48, 32)
x = torch.rand((k, l), device=device).to(e4m3_type)
y = torch.rand((m, l), device=device).to(e4m3_type).t()
scale_a = torch.tensor(1.0, device=device)
scale_b = torch.tensor(1.0, device=device)
self.assertRaisesRegex(
RuntimeError,
r"torch\.\_scaled\_mm is only supported on CUDA devices with compute capability \>\= 9\.0 or 8\.9, or ROCm MI300\+",
lambda: torch._scaled_mm(x, y, scale_a, scale_b, out_dtype=torch.float32),
)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
def test_float8_scale_fast_accum(self, device) -> None:
size = (16, 16)
x = torch.full(size, .5, device=device, dtype=e4m3_type)
# hipblaslt does not yet support mixed e4m3_type input
y_type = e4m3_type if torch.version.hip else e5m2_type
y = torch.full(size, .5, device=device, dtype=y_type).t()
scale_a = torch.tensor(1.5, device=device)
scale_b = torch.tensor(0.66, device=device)
out_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, use_fast_accum=True)
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device))
out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True)
self.assertEqual(out_fp8, out_fp8_s)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@skipIfRocm()
@parametrize("use_fast_accum", [True, False])
def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None:
M, K, N = (1024, 512, 2048)
fill_value = 0.5
x = torch.full((M, K), fill_value, device=device)
y = torch.full((N, K), fill_value, device=device)
x_scales = torch.ones((x.shape[0], 1), device=device, dtype=torch.float32)
y_scales = torch.ones((1, y.shape[0]), device=device, dtype=torch.float32)
x_fp8 = x.to(torch.float8_e4m3fn)
y_fp8 = y.to(torch.float8_e4m3fn).t()
out_fp8 = torch._scaled_mm(
x_fp8,
y_fp8,
scale_a=x_scales,
scale_b=y_scales,
out_dtype=torch.bfloat16,
use_fast_accum=use_fast_accum,
)
self.assertEqual(
out_fp8.to(torch.float32), torch.full((M, N), K * (fill_value**2), device=device)
)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@skipIfRocm()
def test_float8_error_messages(self, device) -> None:
M, K, N = (1024, 512, 2048)
fill_value = 0.5
x = torch.full((M, K), fill_value, device=device)
y = torch.full((N, K), fill_value, device=device)
x_fp8 = x.to(torch.float8_e4m3fn)
y_fp8 = y.to(torch.float8_e4m3fn).t()
with self.assertRaisesRegex(
RuntimeError,
re.escape(
"For RowWise scaling, scale_a should be (1024, 1) and scale_b "
"should be (1, 2048). Got scale_a.size()=(1, 1) and scale_b.size()=(1, 2)"
),
):
torch._scaled_mm(
x_fp8,
y_fp8,
scale_a=torch.ones((1, 1), device="cuda"),
scale_b=torch.ones((1, 2), device="cuda"),
out_dtype=torch.bfloat16,
)
with self.assertRaisesRegex(
RuntimeError,
re.escape(
" For RowWise scaling, scale_a should be (1024, 1) and scale_b "
"should be (1, 2048). Got scale_a.size()=(1024, 1) and scale_b.size()=(1, 2049)"
),
):
torch._scaled_mm(
x_fp8,
y_fp8,
scale_a=torch.ones((M, 1), device="cuda"),
scale_b=torch.ones((1, N + 1), device="cuda"),
out_dtype=torch.bfloat16,
)
with self.assertRaisesRegex(
RuntimeError,
re.escape("For non-TensorWise scaling, scale tensors must be 2-dimensional"),
):
torch._scaled_mm(
x_fp8,
y_fp8,
scale_a=torch.ones((M), device="cuda"),
scale_b=torch.ones((N, N), device="cuda"),
out_dtype=torch.bfloat16,
)
with self.assertRaisesRegex(
RuntimeError,
re.escape(
"Both scale_a and scale_b must be contiguous for RowWise scaling."
),
):
torch._scaled_mm(
x_fp8,
y_fp8,
scale_a=torch.ones((M, 1), device="cuda"),
scale_b=torch.ones((1, N * 2), device="cuda")[:, ::2],
out_dtype=torch.bfloat16,
)
with self.assertRaisesRegex(
RuntimeError,
re.escape("For RowWise scaling the second input is required to be a float8_e4m3fn dtype."),
):
torch._scaled_mm(
x_fp8,
y_fp8.to(torch.float8_e5m2),
scale_a=torch.ones((M, 1), device="cuda"),
scale_b=torch.ones((1, N), device="cuda"),
out_dtype=torch.bfloat16,
)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@unittest.skipIf(not SM90OrLater, "rowwise implementation is currently sm90 specific")
@skipIfRocm()
@parametrize("base_dtype", [torch.bfloat16])
def test_scaled_mm_vs_emulated_row_wise(self, base_dtype):
torch.manual_seed(42)
input_dtype = e4m3_type
output_dtype = base_dtype
x = torch.randn(16, 16, device="cuda", dtype=base_dtype)
y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t()
x_scales = tensor_to_scale(x, input_dtype, dim=1).float()
y_scales = tensor_to_scale(y, input_dtype, dim=0).float()
x_fp8 = to_fp8_saturated(x * x_scales, e4m3_type)
y_fp8 = to_fp8_saturated(y * y_scales, e4m3_type)
# Calculate actual F8 mm
out_scaled_mm = mm_float8(
x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype
)
# Calculate emulated F8 mm
out_emulated = mm_float8_emulated(
x_fp8, x_scales, y_fp8, y_scales, output_dtype
)
if base_dtype in {torch.bfloat16, torch.float16}:
atol, rtol = 7e-2, 7e-2
else:
atol, rtol = 2e-3, 2e-3
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")
@unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x")
class TestMixedDtypesLinearCuda(TestCase):
@dtypes(torch.float16, torch.bfloat16)
def test_mixed_dtypes_linear(self, dtype: torch.dtype, device: str = "cuda"):
version = _get_torch_cuda_version()
if version < (11, 8):
self.skipTest("_mixed_dtypes_linear only compiled for CUDA 11.8+")
def run_test(
batch_shape,
m,
n,
k,
add_bias,
activation,
dtype,
dtypeq,
device,
rtol,
atol,
):
if not add_bias and activation != "none":
return
val_lo, val_hi = -1, 1
valq_lo, valq_hi = -2, 2
input = make_tensor(
*batch_shape, m, k, low=val_lo, high=val_hi, dtype=dtype, device=device
)
weight = make_tensor(
n, k, low=valq_lo, high=valq_hi, dtype=torch.int8, device=device
)
scale = make_tensor(
(n,), low=val_lo, high=val_hi, dtype=input.dtype, device=device
)
bias = (
make_tensor(
(n,), low=val_lo, high=val_hi, dtype=input.dtype, device=device
)
if add_bias
else None
)
input_ref = input.reshape(-1, input.shape[-1])
# First, test plain multiplication.
weight_ref = weight.T.to(input.dtype) * scale.view(1, n)
weightq = (
pack_int4_to_int8(weight.T) if dtypeq == torch.quint4x2 else weight.T
)
output_ref = torch.mm(input_ref, weight_ref).reshape(*input.shape[:-1], n)
output = torch.ops.aten._mixed_dtypes_linear(
input,
quantized_weight_reorder_for_mixed_dtypes_linear_cutlass(
weightq, dtypeq, transpose=False
),
scale,
)
torch.testing.assert_close(output, output_ref, rtol=rtol, atol=atol)
# Second, test the linear operator itself.
weight_ref = weight.to(input.dtype) * scale.view(n, 1)
weightq = pack_int4_to_int8(weight) if dtypeq == torch.quint4x2 else weight
bias_ref = bias.view(1, n) if add_bias else None
output_ref = torch.nn.functional.linear(
input_ref, weight_ref, bias=bias_ref
).reshape(*input.shape[:-1], n)
if activation == "relu":
relu = torch.nn.ReLU()
output_ref = relu(output_ref)
elif activation == "silu":
silu = torch.nn.SiLU()
output_ref = silu(output_ref)
output = torch.ops.aten._mixed_dtypes_linear(
input,
quantized_weight_reorder_for_mixed_dtypes_linear_cutlass(
weightq, dtypeq, transpose=True
),
scale,
bias=bias,
activation=activation,
)
torch.testing.assert_close(output, output_ref, rtol=rtol, atol=atol)
dtypeqs = [torch.int8, torch.quint4x2]
batch_shapes = [[], [2], [2, 1]]
shapes = [
[8, 64, 64],
[8, 64, 128],
[8, 128, 64],
[8, 128, 128],
[8, 128, 192],
[8, 128, 256],
[8, 256, 128],
[8, 256, 384],
[8, 384, 256],
]
activations = [None, "relu", "silu"]
rtol, atol = 1e-3, 1e-3
if dtype == torch.bfloat16:
rtol, atol = 1e-2, 1e-3
for dtypeq, batch_shape, (m, n, k), add_bias, activation in product(
dtypeqs, batch_shapes, shapes, (False, True), activations
):
run_test(
batch_shape,
m,
n,
k,
add_bias,
activation,
dtype,
dtypeq,
device,
rtol,
atol,
)
instantiate_device_type_tests(TestMatmulCuda, globals(), except_for="cpu")
instantiate_device_type_tests(TestFP8MatmulCuda, globals(), except_for="cpu")
instantiate_device_type_tests(TestMixedDtypesLinearCuda, globals(), except_for="cpu")
if __name__ == '__main__':
TestCase._default_dtype_check_enabled = True
run_tests()