blob: ec09fbd2b489b5bae440063e01d3bd9b165016cf [file] [log] [blame] [edit]
import os
import re
import subprocess
import sys
import traceback
import warnings
MIN_CUDA_VERSION = "11.6"
MIN_ROCM_VERSION = "5.4"
MIN_PYTHON_VERSION = (3, 8)
class VerifyDynamoError(BaseException):
pass
def check_python():
if sys.version_info < MIN_PYTHON_VERSION:
raise VerifyDynamoError(
f"Python version not supported: {sys.version_info} "
f"- minimum requirement: {MIN_PYTHON_VERSION}"
)
return sys.version_info
def check_torch():
import torch
return torch.__version__
# based on torch/utils/cpp_extension.py
def get_cuda_version():
from torch.torch_version import TorchVersion
from torch.utils import cpp_extension
CUDA_HOME = cpp_extension._find_cuda_home()
if not CUDA_HOME:
raise VerifyDynamoError(cpp_extension.CUDA_NOT_FOUND_MESSAGE)
nvcc = os.path.join(CUDA_HOME, "bin", "nvcc")
cuda_version_str = (
subprocess.check_output([nvcc, "--version"])
.strip()
.decode(*cpp_extension.SUBPROCESS_DECODE_ARGS)
)
cuda_version = re.search(r"release (\d+[.]\d+)", cuda_version_str)
if cuda_version is None:
raise VerifyDynamoError("CUDA version not found in `nvcc --version` output")
cuda_str_version = cuda_version.group(1)
return TorchVersion(cuda_str_version)
def get_rocm_version():
from torch.torch_version import TorchVersion
from torch.utils import cpp_extension
ROCM_HOME = cpp_extension._find_rocm_home()
if not ROCM_HOME:
raise VerifyDynamoError(
"ROCM was not found on the system, please set ROCM_HOME environment variable"
)
hipcc = os.path.join(ROCM_HOME, "bin", "hipcc")
hip_version_str = (
subprocess.check_output([hipcc, "--version"])
.strip()
.decode(*cpp_extension.SUBPROCESS_DECODE_ARGS)
)
hip_version = re.search(r"HIP version: (\d+[.]\d+)", hip_version_str)
if hip_version is None:
raise VerifyDynamoError("HIP version not found in `hipcc --version` output")
hip_str_version = hip_version.group(1)
return TorchVersion(hip_str_version)
def check_cuda():
import torch
from torch.torch_version import TorchVersion
if not torch.cuda.is_available() or torch.version.hip is not None:
return None
torch_cuda_ver = TorchVersion(torch.version.cuda)
# check if torch cuda version matches system cuda version
cuda_ver = get_cuda_version()
if cuda_ver != torch_cuda_ver:
# raise VerifyDynamoError(
warnings.warn(
f"CUDA version mismatch, `torch` version: {torch_cuda_ver}, env version: {cuda_ver}"
)
if torch_cuda_ver < MIN_CUDA_VERSION:
# raise VerifyDynamoError(
warnings.warn(
f"(`torch`) CUDA version not supported: {torch_cuda_ver} "
f"- minimum requirement: {MIN_CUDA_VERSION}"
)
if cuda_ver < MIN_CUDA_VERSION:
# raise VerifyDynamoError(
warnings.warn(
f"(env) CUDA version not supported: {cuda_ver} "
f"- minimum requirement: {MIN_CUDA_VERSION}"
)
return cuda_ver if torch.version.hip is None else "None"
def check_rocm():
import torch
from torch.torch_version import TorchVersion
if not torch.cuda.is_available() or torch.version.hip is None:
return None
# Extracts main ROCm version from full string
torch_rocm_ver = TorchVersion(".".join(list(torch.version.hip.split(".")[0:2])))
# check if torch rocm version matches system rocm version
rocm_ver = get_rocm_version()
if rocm_ver != torch_rocm_ver:
warnings.warn(
f"ROCm version mismatch, `torch` version: {torch_rocm_ver}, env version: {rocm_ver}"
)
if torch_rocm_ver < MIN_ROCM_VERSION:
warnings.warn(
f"(`torch`) ROCm version not supported: {torch_rocm_ver} "
f"- minimum requirement: {MIN_ROCM_VERSION}"
)
if rocm_ver < MIN_ROCM_VERSION:
warnings.warn(
f"(env) ROCm version not supported: {rocm_ver} "
f"- minimum requirement: {MIN_ROCM_VERSION}"
)
return rocm_ver if torch.version.hip else "None"
def check_dynamo(backend, device, err_msg) -> None:
import torch
if device == "cuda" and not torch.cuda.is_available():
print(f"CUDA not available -- skipping CUDA check on {backend} backend\n")
return
try:
import torch._dynamo as dynamo
if device == "cuda":
from torch.utils._triton import has_triton
if not has_triton():
print(
f"WARNING: CUDA available but triton cannot be used. "
f"Your GPU may not be supported. "
f"Skipping CUDA check on {backend} backend\n"
)
return
dynamo.reset()
@dynamo.optimize(backend, nopython=True)
def fn(x):
return x + x
class Module(torch.nn.Module):
def forward(self, x):
return x + x
mod = Module()
opt_mod = dynamo.optimize(backend, nopython=True)(mod)
for f in (fn, opt_mod):
x = torch.randn(10, 10).to(device)
x.requires_grad = True
y = f(x)
torch.testing.assert_close(y, x + x)
z = y.sum()
z.backward()
torch.testing.assert_close(x.grad, 2 * torch.ones_like(x))
except Exception:
sys.stderr.write(traceback.format_exc() + "\n" + err_msg + "\n\n")
sys.exit(1)
_SANITY_CHECK_ARGS = (
("eager", "cpu", "CPU eager sanity check failed"),
("eager", "cuda", "CUDA eager sanity check failed"),
("aot_eager", "cpu", "CPU aot_eager sanity check failed"),
("aot_eager", "cuda", "CUDA aot_eager sanity check failed"),
("inductor", "cpu", "CPU inductor sanity check failed"),
(
"inductor",
"cuda",
"CUDA inductor sanity check failed\n"
+ "NOTE: Please check that you installed the correct hash/version of `triton`",
),
)
def main() -> None:
python_ver = check_python()
torch_ver = check_torch()
cuda_ver = check_cuda()
rocm_ver = check_rocm()
print(
f"Python version: {python_ver.major}.{python_ver.minor}.{python_ver.micro}\n"
f"`torch` version: {torch_ver}\n"
f"CUDA version: {cuda_ver}\n"
f"ROCM version: {rocm_ver}\n"
)
for args in _SANITY_CHECK_ARGS:
if sys.version_info >= (3, 13):
warnings.warn("Dynamo not yet supported in Python 3.13. Skipping check.")
continue
check_dynamo(*args)
print("All required checks passed")
if __name__ == "__main__":
main()