| import copy |
| import functools |
| import io |
| import logging |
| import os |
| import subprocess |
| import tempfile |
| |
| import numpy as np |
| |
| import torch |
| |
| from ..utils import identity |
| from .subgraph import SubGraph |
| |
| log = logging.getLogger(__name__) |
| BACKENDS = dict() |
| _NP_DTYPE = { |
| torch.float16: np.float16, |
| torch.float32: np.float32, |
| torch.float64: np.float64, |
| torch.uint8: np.uint8, |
| torch.int8: np.int8, |
| torch.int16: np.int16, |
| torch.int32: np.int32, |
| torch.int64: np.longlong, |
| torch.bool: np.bool_, |
| } |
| |
| |
| def register_backend(fn): |
| @functools.wraps(fn) |
| def inner(gm, example_inputs, **kwargs): |
| return fn(gm, example_inputs, **kwargs) |
| |
| BACKENDS[fn.__name__] = inner |
| return inner |
| |
| |
| def create_backend(fn): |
| @functools.wraps(fn) |
| def inner(model, example_inputs=None, **kwargs): |
| if model is None: |
| return None |
| |
| if not isinstance(model, SubGraph): |
| with tempfile.TemporaryDirectory() as tmp: |
| return inner(SubGraph(model, example_inputs, tmp), **kwargs) |
| else: |
| assert example_inputs is None |
| |
| try: |
| return fn(model, **kwargs) |
| except KeyboardInterrupt: |
| raise |
| except Exception: |
| log.exception(f"{fn.__name__} error") |
| return None |
| |
| BACKENDS[fn.__name__] = inner |
| return inner |
| |
| |
| @create_backend |
| def eager(subgraph): |
| return subgraph.model |
| |
| |
| @create_backend |
| def ts(subgraph): |
| return subgraph.scripted |
| |
| |
| def reload_jit_model(subgraph, opt_fn=identity): |
| tmp = io.BytesIO() |
| torch.jit.save(subgraph.scripted, tmp) |
| tmp.seek(0) |
| model = torch.jit.load(tmp) |
| model = opt_fn(model) |
| # populate cache |
| for _ in range(3): |
| model(*subgraph.example_inputs) |
| return model |
| |
| |
| def reload_jit_model_ofi(subgraph): |
| return reload_jit_model(subgraph, torch.jit.optimize_for_inference) |
| |
| |
| @create_backend |
| def nnc(subgraph): |
| with torch.jit.fuser("fuser1"): |
| return reload_jit_model(subgraph) |
| |
| |
| @create_backend |
| def nnc_ofi(subgraph): |
| with torch.jit.fuser("fuser1"): |
| return reload_jit_model_ofi(subgraph) |
| |
| |
| @create_backend |
| def nvfuser(subgraph): |
| with torch.jit.fuser("fuser2"): |
| return reload_jit_model(subgraph) |
| |
| |
| @create_backend |
| def nvfuser_ofi(subgraph): |
| with torch.jit.fuser("fuser2"): |
| return reload_jit_model_ofi(subgraph) |
| |
| |
| @create_backend |
| def onednn(subgraph): |
| with torch.jit.fuser("fuser3"): |
| return reload_jit_model(subgraph) |
| |
| |
| @create_backend |
| def ofi(subgraph): |
| return torch.jit.optimize_for_inference(subgraph.scripted) |
| |
| |
| @create_backend |
| def static_runtime(subgraph): |
| scripted = subgraph.scripted |
| if hasattr(scripted, "_c"): |
| static_module = torch._C._jit_to_static_module(scripted._c) |
| else: |
| static_module = torch._C._jit_to_static_module(scripted.graph) |
| return subgraph.wrap_returns(static_module) |
| |
| |
| def onnxrt_common(subgraph, provider, onnx_filename=None): |
| import onnxruntime |
| |
| assert provider in onnxruntime.get_available_providers() |
| session = onnxruntime.InferenceSession( |
| onnx_filename or subgraph.onnx_filename, providers=[provider] |
| ) |
| input_names = subgraph.input_names |
| output_names = subgraph.output_names |
| create_outputs = subgraph.empty_outputs_factory() |
| is_cpu = subgraph.is_cpu |
| |
| def _call(*args): |
| binding = session.io_binding() |
| args = [a.contiguous() for a in args] |
| for name, value in zip(input_names, args): |
| dev = value.device |
| binding.bind_input( |
| name, |
| dev.type, |
| dev.index or 0, |
| _NP_DTYPE[value.dtype], |
| value.size(), |
| value.data_ptr(), |
| ) |
| outputs = create_outputs() |
| for name, value in zip(output_names, outputs): |
| dev = value.device |
| binding.bind_output( |
| name, |
| dev.type, |
| dev.index or 0, |
| _NP_DTYPE[value.dtype], |
| value.size(), |
| value.data_ptr(), |
| ) |
| session.run_with_iobinding(binding) |
| if is_cpu: |
| binding.copy_outputs_to_cpu() |
| return outputs |
| |
| return subgraph.wrap_returns(_call) |
| |
| |
| @create_backend |
| def onnxrt_cpu(subgraph): |
| return onnxrt_common(subgraph, provider="CPUExecutionProvider") |
| |
| |
| @create_backend |
| def onnxrt_cuda(subgraph): |
| return onnxrt_common(subgraph, provider="CUDAExecutionProvider") |
| |
| |
| @create_backend |
| def onnx2tensorrt(subgraph): |
| if subgraph.will_tensorrt_barf(): |
| # TensorRT fails violently with an abort() on this |
| return None |
| |
| return onnxrt_common(subgraph, provider="TensorrtExecutionProvider") |
| |
| |
| @create_backend |
| def onnxrt_cpu_numpy(subgraph, provider="CPUExecutionProvider"): |
| """Alternate version that integrates via numpy""" |
| import onnxruntime |
| |
| assert provider in onnxruntime.get_available_providers() |
| ort_session = onnxruntime.InferenceSession( |
| subgraph.onnx_filename, providers=[provider] |
| ) |
| |
| def to_numpy(x): |
| try: |
| return x.numpy() |
| except RuntimeError: |
| return x.detach().numpy() |
| |
| def _call(*args): |
| res = ort_session.run( |
| None, {f"i{i}": to_numpy(arg) for i, arg in enumerate(args)} |
| ) |
| res = [torch.from_numpy(x) for x in res] |
| return res |
| |
| return subgraph.wrap_returns(_call) |
| |
| |
| @create_backend |
| def onnxrt(subgraph): |
| if subgraph.is_cuda: |
| return onnxrt_cuda(subgraph) |
| else: |
| return onnxrt_cpu(subgraph) |
| |
| |
| @functools.lru_cache(None) |
| def _init_tensorflow(): |
| import tensorflow as tf |
| |
| # prevent tensorflow from eating all the GPU memory |
| gpus = tf.config.list_physical_devices("GPU") |
| for gpu in gpus: |
| tf.config.experimental.set_memory_growth(gpu, True) |
| return tf |
| |
| |
| @create_backend |
| def onnx2tf(subgraph): |
| import onnx |
| from onnx_tf.backend import prepare |
| |
| tf = _init_tensorflow() |
| filename = subgraph.filename("tensorflow") |
| input_names = subgraph.input_names |
| output_names = subgraph.output_names |
| device = "/CPU:0" if subgraph.is_cpu else f"/GPU:{subgraph.device_index}" |
| with tf.device(device): |
| if not os.path.exists(filename): |
| prepare(onnx.load(subgraph.onnx_filename)).export_graph(filename) |
| tf_module = tf.saved_model.load(filename) |
| tf_module = tf.function(tf_module, jit_compile=True) |
| |
| def run(*args): |
| args = [a.contiguous() for a in args] |
| with tf.device(device): |
| outs = tf_module( |
| **{ |
| name: tf.experimental.dlpack.from_dlpack( |
| torch.utils.dlpack.to_dlpack(args[idx]) |
| ) |
| for idx, name in enumerate(input_names) |
| } |
| ) |
| return [ |
| torch.utils.dlpack.from_dlpack( |
| tf.experimental.dlpack.to_dlpack(outs[name]) |
| ) |
| for name in output_names |
| ] |
| |
| return subgraph.wrap_returns(run) |
| |
| |
| @create_backend |
| def taso(subgraph): |
| taso_filename = subgraph.filename("taso") |
| subprocess.check_call( |
| [ |
| os.path.expanduser("~/conda/envs/taso/bin/python"), |
| "-c", |
| "import taso,onnx; onnx.save(taso.export_onnx(taso.optimize(" |
| f"taso.load_onnx('{subgraph.onnx_filename}'))), '{taso_filename}')", |
| ] |
| ) |
| return onnxrt_common( |
| subgraph, provider="CUDAExecutionProvider", onnx_filename=taso_filename |
| ) |
| |
| |
| @create_backend |
| def ipex(subgraph, **kwargs): |
| import intel_extension_for_pytorch as ipex |
| |
| inputs = subgraph.example_inputs |
| model = subgraph.model |
| with torch.no_grad(): |
| model.eval() |
| if kwargs["datatype"] == "bf16": |
| model = ipex.optimize(model, dtype=torch.bfloat16) |
| else: |
| model = ipex.optimize(model, dtype=torch.float32) |
| try: |
| traced_model = torch.jit.trace(model, inputs).eval() |
| traced_model = torch.jit.freeze(traced_model) |
| return traced_model |
| except Exception: |
| log.warning("JIT trace failed during the 'ipex' optimize process.") |
| return model |
| |
| |
| def _raise_timeout(signum, frame): |
| raise TimeoutError() |
| |
| |
| @create_backend |
| def fx2trt(subgraph, **kwargs): |
| if subgraph.will_tensorrt_barf(): |
| # TensorRT fails violently with an abort() on this |
| return None |
| |
| from torch_tensorrt.fx.fx2trt import InputTensorSpec, TRTInterpreter |
| from torch_tensorrt.fx.passes.lower_basic_pass import transform_setitem |
| from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting |
| from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer |
| from torch_tensorrt.fx.trt_module import TRTModule |
| from torch_tensorrt.fx.utils import LowerPrecision |
| |
| from .normalize import normalize_ir |
| |
| try: |
| model = subgraph.model |
| inputs = subgraph.example_inputs |
| # normalize |
| model = normalize_ir(model, inputs) |
| # pass rewrite |
| model = transform_setitem(model, inputs) |
| acc_model = acc_tracer.trace(model, inputs) |
| # Split out unsupported ops |
| splitter_setting = TRTSplitterSetting() |
| splitter_setting.use_implicit_batch_dim = False |
| splitter = TRTSplitter(acc_model, inputs, settings=splitter_setting) |
| splitter.node_support_preview() |
| split_mod = splitter() |
| num_piece = 0 |
| for name, _ in split_mod.named_children(): |
| print(f"graph is split into {name}") |
| num_piece += 1 |
| |
| # if the graph module is split into pieces larger than 8, we consider its perf |
| # is not good and fall back to non-TRT |
| if num_piece > 8: |
| print( |
| f"The graph module is split into {num_piece} which is large than the \ |
| threshold=8. Fall back to non-TRT module." |
| ) |
| return None |
| |
| if "fp16_mode" in kwargs and kwargs["fp16_mode"]: |
| precision = LowerPrecision.FP16 |
| else: |
| precision = LowerPrecision.FP32 |
| |
| def get_submod_inputs(mod, submod, inputs): |
| acc_inputs = None |
| |
| def get_input(self, inputs): |
| nonlocal acc_inputs |
| acc_inputs = inputs |
| |
| handle = submod.register_forward_pre_hook(get_input) |
| mod(*inputs) |
| handle.remove() |
| return acc_inputs |
| |
| for name, _ in split_mod.named_children(): |
| if "_run_on_acc" in name: |
| submod = getattr(split_mod, name) |
| # print("acc=",submod.code) |
| # Get submodule inputs for fx2trt |
| acc_inputs = get_submod_inputs(split_mod, submod, inputs) |
| |
| # fx2trt replacement |
| interp = TRTInterpreter( |
| submod, |
| InputTensorSpec.from_tensors(acc_inputs), |
| explicit_batch_dimension=True, |
| ) |
| r = interp.run( |
| max_workspace_size=20 << 30, |
| lower_precision=precision, |
| # profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile |
| ) |
| # For profile |
| # from fx2trt_oss.fx.tools.trt_profiler_sorted import profile_trt_module |
| # profile_trt_module("", trt_mod, acc_inputs) |
| trt_mod = TRTModule(*r) |
| |
| setattr(split_mod, name, trt_mod) |
| else: |
| submod = getattr(split_mod, name) |
| # print("gpu=",submod.code) |
| return subgraph.wrap_returns(split_mod) |
| except Exception: |
| log.exception("FX2TRT conversion error") |
| return None |
| |
| |
| @create_backend |
| def torch2trt(subgraph): |
| if subgraph.will_tensorrt_barf(): |
| # TensorRT fails violently with an abort() on this |
| return None |
| |
| from torch2trt import torch2trt |
| |
| inputs = subgraph.example_inputs |
| trt_mod = torch2trt( |
| subgraph.model, |
| inputs, |
| max_batch_size=len(inputs[0]), |
| strict_type_constraints=True, |
| ) |
| return subgraph.wrap_returns(trt_mod) |
| |
| |
| @create_backend |
| def tensorrt(subgraph): |
| if subgraph.will_tensorrt_barf(): |
| # TensorRT fails violently with an abort() on this |
| return None |
| |
| model = onnx2tensorrt(subgraph) |
| if model is None: |
| model = torch2trt(subgraph) |
| return model |
| |
| |
| @create_backend |
| def onnx2tensorrt_alt(subgraph): |
| if subgraph.will_tensorrt_barf(): |
| # TensorRT fails violently with an abort() on this |
| return None |
| |
| import tensorrt as trt |
| |
| from torch.fx.experimental.fx2trt.trt_module import TRTModule |
| |
| inputs = subgraph.example_inputs |
| |
| logger = trt.Logger(trt.Logger.ERROR) |
| builder = trt.Builder(logger) |
| config = builder.create_builder_config() |
| assert isinstance(inputs, (list, tuple)) |
| inputs = tuple(inputs) |
| input_names = subgraph.input_names |
| output_names = subgraph.output_names |
| network = builder.create_network( |
| 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) |
| ) |
| parser = trt.OnnxParser(network, logger) |
| success = parser.parse(open(subgraph.onnx_filename, "rb").read()) |
| for idx in range(parser.num_errors): |
| print(parser.get_error(idx)) |
| assert success |
| |
| config.max_workspace_size = 1 << 25 |
| config.set_flag(trt.BuilderFlag.STRICT_TYPES) |
| builder.max_batch_size = len(inputs[0]) |
| |
| engine = builder.build_engine(network, config) |
| assert engine |
| |
| trt_mod = TRTModule(engine, input_names, output_names) |
| return subgraph.wrap_returns(trt_mod) |
| |
| |
| @create_backend |
| def cudagraphs(subgraph): |
| model = subgraph.model |
| inputs = subgraph.example_inputs |
| assert subgraph.is_cuda |
| return subgraph.wrap_returns(cudagraphs_inner(model, inputs)) |
| |
| |
| @create_backend |
| def cudagraphs_ts(subgraph): |
| assert subgraph.is_cuda |
| model = subgraph.scripted |
| inputs = subgraph.example_inputs |
| |
| # warmup |
| for _ in range(3): |
| model(*inputs) |
| |
| return subgraph.wrap_returns(cudagraphs_inner(model, inputs)) |
| |
| |
| @create_backend |
| def cudagraphs_ts_ofi(subgraph): |
| assert subgraph.is_cuda |
| model = torch.jit.optimize_for_inference(torch.jit.freeze(subgraph.scripted)) |
| inputs = subgraph.example_inputs |
| |
| # warmup |
| for _ in range(3): |
| model(*inputs) |
| |
| return subgraph.wrap_returns(cudagraphs_inner(model, inputs)) |
| |
| |
| def cudagraphs_inner(model, inputs, copy_outputs=True): |
| assert isinstance(inputs, (list, tuple)) |
| static_inputs = [torch.zeros_like(x) for x in inputs] |
| |
| # warmup |
| torch.cuda.synchronize() |
| stream = torch.cuda.Stream() |
| stream.wait_stream(torch.cuda.current_stream()) |
| with torch.cuda.stream(stream): |
| model(*inputs) |
| stream.synchronize() |
| torch.cuda.current_stream().wait_stream(stream) |
| torch.cuda.synchronize() |
| |
| # record |
| graph = torch.cuda.CUDAGraph() |
| with torch.cuda.graph(graph, stream=stream): |
| static_outputs = model(*static_inputs) |
| if not isinstance(static_outputs, (list, tuple)): |
| static_outputs = (static_outputs,) |
| |
| def run(*new_inputs): |
| assert len(static_inputs) == len(new_inputs) |
| for dst, src in zip(static_inputs, new_inputs): |
| dst.copy_(src) |
| graph.replay() |
| if copy_outputs: |
| return [x.clone() for x in static_outputs] |
| else: |
| return static_outputs |
| |
| return run |
| |
| |
| @create_backend |
| def aot_autograd(subgraph, **kwargs): |
| def _wrapped_bw_compiler(*args, **kwargs): |
| # stop TorchDynamo from trying to compile our generated backwards pass |
| return disable(bw_compiler(*args, **kwargs)) |
| |
| bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"] |
| kwargs["bw_compiler"] = _wrapped_bw_compiler |
| |
| from functorch.compile import aot_module_simplified |
| |
| from .. import disable |
| |
| return aot_module_simplified(subgraph.model, **kwargs) |
| |
| |
| def tvm_compile(jit_mod, example_inputs, log_file=None, **kwargs): |
| if jit_mod is None: |
| return None |
| try: |
| return tvm_compile_inner(jit_mod, example_inputs, None, log_file, **kwargs) |
| except Exception as e: |
| if log_file and os.path.exists(log_file): |
| os.unlink(log_file) |
| if isinstance(e, KeyboardInterrupt): |
| raise |
| log.exception("tvm error") |
| return None |
| |
| |
| @create_backend |
| def tvm(subgraph): |
| return subgraph.wrap_returns( |
| tvm_compile_inner( |
| subgraph.scripted, |
| subgraph.example_inputs, |
| tuning_option=None, |
| cuda=subgraph.is_cuda, |
| ) |
| ) |
| |
| |
| @create_backend |
| def ansor(subgraph): |
| """ |
| WARNING: this backend takes hours or days to train and |
| often produces a slower result than the default schedule. |
| """ |
| return subgraph.wrap_returns( |
| tvm_compile_inner( |
| subgraph.scripted, |
| subgraph.example_inputs, |
| tuning_option="auto_scheduler", |
| log_file=subgraph.filename("ansor"), |
| cuda=subgraph.is_cuda, |
| ) |
| ) |
| |
| |
| @create_backend |
| def tvm_meta_schedule(subgraph): |
| return subgraph.wrap_returns( |
| tvm_compile_inner( |
| subgraph.scripted, |
| subgraph.example_inputs, |
| tuning_option="meta_schedule", |
| trials=20000, |
| cuda=subgraph.is_cuda, |
| ) |
| ) |
| |
| |
| @functools.lru_cache(None) |
| def llvm_target(): |
| if "avx512" in open("/proc/cpuinfo").read(): |
| return "llvm -mcpu=skylake-avx512" |
| return "llvm -mcpu=core-avx2" |
| |
| |
| def tvm_compile_inner( |
| jit_mod, example_inputs, tuning_option=None, log_file=None, trials=20000, cuda=False |
| ): |
| try: |
| import tvm |
| from tvm import relay |
| from tvm.contrib import graph_executor |
| |
| shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)] |
| mod, params = relay.frontend.from_pytorch(jit_mod, shape_list) |
| if cuda: |
| dev = tvm.cuda(0) |
| target = tvm.target.cuda() |
| else: |
| dev = tvm.cpu(0) |
| target = tvm.target.Target(llvm_target()) |
| |
| if tuning_option == "auto_scheduler": |
| from tvm import auto_scheduler |
| |
| if log_file is None: |
| log_file = tempfile.NamedTemporaryFile() |
| if not os.path.exists(log_file): |
| tasks, task_weights = auto_scheduler.extract_tasks( |
| mod["main"], params, target |
| ) |
| for task in tasks: |
| print(task.compute_dag) |
| else: |
| print("No tasks") |
| if len(tasks) != 0: |
| tuner = auto_scheduler.TaskScheduler(tasks, task_weights) |
| if not os.path.exists(log_file): |
| assert trials > 0 |
| tune_option = auto_scheduler.TuningOptions( |
| num_measure_trials=trials, |
| measure_callbacks=[auto_scheduler.RecordToFile(log_file)], |
| early_stopping=2000, |
| ) |
| try: |
| tuner.tune(tune_option) |
| except Exception: |
| if os.path.exists(log_file): |
| os.unlink(log_file) |
| raise |
| |
| with auto_scheduler.ApplyHistoryBest(log_file): |
| with tvm.transform.PassContext( |
| opt_level=3, config={"relay.backend.use_auto_scheduler": True} |
| ): |
| lib = relay.build(mod, target=target, params=params) |
| elif tuning_option == "meta_schedule": |
| from os import path as osp |
| |
| from tvm.contrib.torch import optimize_torch |
| |
| with tempfile.TemporaryDirectory() as work_dir: |
| if log_file is not None: |
| assert osp.isdir( |
| log_file |
| ), "TVM's meta_schedule requires a directory for storing log files." |
| work_dir = log_file |
| |
| lib = optimize_torch( |
| jit_mod, |
| example_inputs, |
| max_trials_global=20000, |
| work_dir=work_dir, |
| target=target, |
| max_trials_per_task=64, |
| ) |
| |
| elif tuning_option is None: |
| # no autotuning (for debugging) |
| with tvm.transform.PassContext(opt_level=10): |
| lib = relay.build(mod, target=target, params=params) |
| else: |
| raise NotImplementedError( |
| "This tuning option is invalid/not implemented for torchdynamo's TVM-related backend. " |
| "There are three available options including None, auto_scheduler and meta_schedule." |
| ) |
| if tune_option != "meta_schedule": |
| m = graph_executor.GraphModule(lib["default"](dev)) |
| |
| def to_torch_tensor(nd_tensor): |
| """A helper function to transfer a NDArray to torch.tensor.""" |
| if nd_tensor.dtype == "bool": |
| # DLPack does not support boolean so it can't be handled by |
| # torch.utils.dlpack.from_pack. Workaround by going through |
| # numpy, although this brings additional data copy overhead. |
| return torch.from_numpy(nd_tensor.numpy()) |
| return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack()) |
| |
| def exec_tvm(*args): |
| args = [a.contiguous() for a in args] |
| for idx, arg in enumerate(args, 0): |
| if arg.dim() != 0: |
| if arg.requires_grad: |
| arg = arg.detach() |
| m.set_input( |
| f"inp_{idx}", |
| tvm.nd.array(arg.numpy(), dev), |
| ) |
| m.run() |
| return [ |
| to_torch_tensor(m.get_output(i)) for i in range(m.get_num_outputs()) |
| ] |
| |
| else: |
| |
| def exec_tvm(*args): |
| args = [a.contiguous() for a in args] |
| return lib(*args) |
| |
| return exec_tvm |
| |
| except Exception: |
| log.exception("tvm error") |
| return jit_mod # explicit fall back to eager |
| |
| |
| @functools.lru_cache(None) |
| def _init_ltc(): |
| try: |
| import torch._lazy.extract_compiled_graph |
| from torch._lazy.ts_backend import init as init_ts_backend |
| |
| # hopefully changing this line to sth like _ltc_init_xla_backend in future |
| # will enable XLA |
| init_ts_backend() |
| |
| return torch._lazy |
| except ModuleNotFoundError as e: |
| print(f"ltc backend fails. Can not import {e.name}") |
| raise |
| |
| |
| def ltc_reuse_graph(gm: torch.fx.GraphModule, example_inputs): |
| ltc = _init_ltc() |
| return ltc.extract_compiled_graph.extract_compiled_graph(gm, example_inputs) |
| |
| |
| def ltc_trivial(gm: torch.fx.GraphModule, example_inputs): |
| ltc = _init_ltc() |
| lazy_model = copy.deepcopy(gm).to(device="lazy") |
| ltc.extract_compiled_graph.force_lazy_device(lazy_model) |
| |
| def ltc_model(*inputs): |
| orig_device = inputs[0].device if len(inputs) > 0 else "cuda" |
| lazy_inputs = tuple(inp.to(device="lazy") for inp in inputs) |
| |
| lazy_out = lazy_model(*lazy_inputs) |
| out = tuple(out.to(device=orig_device) for out in lazy_out) |
| return out |
| |
| return ltc_model |
| |
| |
| def ipex_fp32(gm: torch.fx.GraphModule, example_inputs): |
| kwargs_ipex = {"datatype": "fp32"} |
| return BACKENDS["ipex"](gm, example_inputs, **kwargs_ipex) |
| |
| |
| def ipex_bf16(gm: torch.fx.GraphModule, example_inputs): |
| kwargs_ipex = {"datatype": "bf16"} |
| return BACKENDS["ipex"](gm, example_inputs, **kwargs_ipex) |
| |
| |
| def fx2trt_compiler_fp16(gm: torch.fx.GraphModule, example_inputs): |
| kwargs_fx2trt = {"fp16_mode": True} |
| trt_compiled = BACKENDS["fx2trt"](gm, example_inputs, **kwargs_fx2trt) |
| if trt_compiled is not None: |
| return trt_compiled |
| else: |
| print( |
| "FX2TRT conversion failed on the subgraph. Return GraphModule forward instead" |
| ) |
| return gm.forward |
| |
| |
| def fx2trt_compiler(gm: torch.fx.GraphModule, example_inputs): |
| kwargs_fx2trt = {"fp16_mode": False} |
| trt_compiled = BACKENDS["fx2trt"](gm, example_inputs, **kwargs_fx2trt) |
| if trt_compiled is not None: |
| return trt_compiled |
| else: |
| print( |
| "FX2TRT conversion failed on the subgraph. Return GraphModule forward instead" |
| ) |
| return gm.forward |