blob: df0a0ef114d53c0e03ffd6f7d85ffdb8eb6d24bb [file] [log] [blame]
import importlib
import logging
import os
import tempfile
import torch
from .common import device_from_inputs, fake_tensor_unsupported
from .registry import register_backend
try:
import numpy as np
_np_dtype = {
torch.float16: np.float16,
torch.float32: np.float32,
torch.float64: np.float64,
torch.uint8: np.uint8,
torch.int8: np.int8,
torch.int16: np.int16,
torch.int32: np.int32,
torch.int64: np.longlong,
torch.bool: np.bool_,
}
except ImportError:
_np_dtype = None
log = logging.getLogger(__name__)
def default_provider(device_type):
if "ONNXRT_PROVIDER" in os.environ:
return os.environ["ONNXRT_PROVIDER"]
return {
"cpu": "CPUExecutionProvider",
"cuda": "CUDAExecutionProvider",
# "TensorrtExecutionProvider" is another option
}[device_type]
def has_onnxruntime():
try:
importlib.import_module("onnxruntime")
return True
except ImportError:
return False
@register_backend
@fake_tensor_unsupported
def onnxrt(gm, example_inputs, *, filename=None, provider=None):
if filename is None:
with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp:
return onnxrt(gm, example_inputs, filename=tmp.name)
import onnxruntime # type: ignore[import]
assert _np_dtype, "requires numpy"
device_type = device_from_inputs(example_inputs).type
example_outputs = gm(*example_inputs)
output_spec = [
(o.shape, o.dtype, o.layout, o.device, o.requires_grad) for o in example_outputs
]
input_names = [f"i{i}" for i in range(len(example_inputs))]
output_names = [f"o{x}" for x in range(len(example_outputs))]
torch.onnx.export(
torch.jit.script(gm),
example_inputs,
filename,
input_names=input_names,
output_names=output_names,
)
del example_inputs, example_outputs
if provider is None:
provider = default_provider(device_type)
assert provider in onnxruntime.get_available_providers()
session = onnxruntime.InferenceSession(filename, providers=[provider])
def _call(*initial_args):
binding = session.io_binding()
active_inputs = {inp.name for inp in session.get_inputs()}
args = [a.contiguous() for a in initial_args]
for name, value in zip(input_names, args):
if name not in active_inputs:
log.warning(
f"input {name} skipped as not found in onnx inference session"
)
continue
dev = value.device
binding.bind_input(
name,
dev.type,
dev.index or 0,
_np_dtype[value.dtype],
value.size(),
value.data_ptr(),
)
outputs = [
torch.empty(
shape,
dtype=dtype,
layout=layout,
device=device,
requires_grad=requires_grad,
)
for shape, dtype, layout, device, requires_grad in output_spec
]
for name, value in zip(output_names, outputs):
dev = value.device
binding.bind_output(
name,
dev.type,
dev.index or 0,
_np_dtype[value.dtype],
value.size(),
value.data_ptr(),
)
session.run_with_iobinding(binding)
if device_type == "cpu":
binding.copy_outputs_to_cpu()
return outputs
return _call