| tensor_engine = None |
| |
| |
| def unsupported(func): |
| def wrapper(self): |
| return func(self) |
| |
| wrapper.is_supported = False |
| return wrapper |
| |
| |
| def is_supported(method): |
| if hasattr(method, "is_supported"): |
| return method.is_supported |
| return True |
| |
| |
| def set_engine_mode(mode): |
| global tensor_engine |
| if mode == "tf": |
| from . import tf_engine |
| |
| tensor_engine = tf_engine.TensorFlowEngine() |
| elif mode == "pt": |
| from . import pt_engine |
| |
| tensor_engine = pt_engine.TorchTensorEngine() |
| elif mode == "topi": |
| from . import topi_engine |
| |
| tensor_engine = topi_engine.TopiEngine() |
| elif mode == "relay": |
| from . import relay_engine |
| |
| tensor_engine = relay_engine.RelayEngine() |
| elif mode == "nnc": |
| from . import nnc_engine |
| |
| tensor_engine = nnc_engine.NncEngine() |
| else: |
| raise ValueError(f"invalid tensor engine mode: {mode}") |
| tensor_engine.mode = mode |
| |
| |
| def get_engine(): |
| if tensor_engine is None: |
| raise ValueError("use of get_engine, before calling set_engine_mode is illegal") |
| return tensor_engine |