Fallback to eager for float8 ops in inductor (#108293)
# Summary
As a stop gap to supporting the FP8 Dtype within inductor we would like to fallback to eager. Currently there are 3 ops that are needed for this:
`_scaled_mm` ( matmul for fp8 types)
`clone` (for creating new copies of fp8 tensors)
`to` ( for converting to and from fp8 types).
This PR registers a fallback for _scaled_mm. And adds fp8 to trigger `unsupported_input_tensor`
Prior to these changes this was failing with:
``` Shell
File "/home/drisspg/meta/pytorch/torch/_inductor/codegen/triton_utils.py", line 11, in signature_of
tye = JITFunction._type_of(arg.dtype)
File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/site-packages/triton/runtime/jit.py", line 229, in _type_of
return key if isinstance(key, str) else f"*{tys[dtype_str]}"
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
KeyError: 'float8_e4m3fn'
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108293
Approved by: https://github.com/peterbell10
diff --git a/test/inductor/test_fp8.py b/test/inductor/test_fp8.py
new file mode 100644
index 0000000..5faaced
--- /dev/null
+++ b/test/inductor/test_fp8.py
@@ -0,0 +1,66 @@
+# Owner(s): ["module: inductor"]
+
+import unittest
+
+import torch
+from torch._dynamo.test_case import run_tests, TestCase
+from torch.testing._internal.common_utils import (
+ instantiate_parametrized_tests,
+ parametrize,
+ TEST_WITH_ROCM,
+)
+from torch.testing._internal.inductor_utils import HAS_CUDA
+
+isSM90orLaterDevice = (
+ torch.cuda.is_available()
+ and torch.cuda.get_device_capability()
+ >= (
+ 9,
+ 0,
+ )
+)
+
+torch.set_float32_matmul_precision("high")
+
+
+@instantiate_parametrized_tests
+class TestFP8Types(TestCase):
+ @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
+ @unittest.skipIf(
+ not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0),
+ "FP8 is only supported on H100+",
+ )
+ @parametrize("dtype", (torch.float16, torch.bfloat16))
+ def test_eager_fallback(self, dtype: torch.dtype):
+ x_shape = (16, 16)
+ weight_shape = (32, 16)
+
+ def fp8_matmul_unwrapped(x):
+ a_scale = torch.Tensor([1.0]).to(device="cuda")
+ b_scale = torch.Tensor([1.0]).to(device="cuda")
+ output_scale = None
+ input_bias = torch.rand(32, device="cuda", dtype=dtype)
+ weight = torch.rand(*weight_shape, device="cuda", dtype=dtype).T.to(
+ torch.float8_e4m3fn
+ )
+ a_inverse_scale = 1 / a_scale
+ b_inverse_scale = 1 / b_scale
+ output, updated_amax = torch._scaled_mm(
+ x,
+ weight,
+ bias=input_bias,
+ out_dtype=dtype,
+ scale_a=a_inverse_scale,
+ scale_b=b_inverse_scale,
+ scale_result=output_scale,
+ )
+ return output
+
+ x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(torch.float8_e4m3fn)
+ compiled_fp8_matmul = torch.compile(fp8_matmul_unwrapped, backend="inductor")
+ y_fp8 = compiled_fp8_matmul(x)
+
+
+if __name__ == "__main__":
+ if HAS_CUDA:
+ run_tests()
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index ed1b824..3caab47 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -1475,6 +1475,13 @@
)
[email protected]_cache(None)
+def _warn_float8_not_supported():
+ warnings.warn(
+ "Torchinductor does not support code generation for float8 operators. Performance may be worse than eager."
+ )
+
+
# There are some types (CPU) which we accept as input but not as
# output.
def unsupported_input_tensor(t: torch._subclasses.FakeTensor):
@@ -1482,6 +1489,10 @@
if t.is_complex():
_warn_complex_not_supported()
return True
+ # FP8 Tensors are currently not supported
+ if t.dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
+ _warn_float8_not_supported()
+ return True
return False
@@ -1869,6 +1880,7 @@
make_fallback(aten._thnn_fused_lstm_cell, require_dense)
make_fallback(aten.topk)
make_fallback(aten.upsample_bicubic2d_backward, require_contiguous)
+make_fallback(aten._scaled_mm.default)
make_fallback(aten.view_as_complex, require_contiguous)