| import importlib |
| import logging |
| |
| import torch |
| from torch._dynamo import register_backend |
| from .common import fake_tensor_unsupported |
| |
| log = logging.getLogger(__name__) |
| |
| |
| @register_backend |
| @fake_tensor_unsupported |
| def ipex(model, inputs): |
| try: |
| import intel_extension_for_pytorch # type: ignore[import] # noqa: F401 |
| except ImportError: |
| log.exception( |
| "Unable to import Intel Extension for PyTorch (IPEX). " |
| "Please install the right version of IPEX that matches the PyTorch version being used. " |
| "Refer to https://github.com/intel/intel-extension-for-pytorch for details." |
| ) |
| raise |
| |
| try: |
| with torch.no_grad(): |
| traced_model = torch.jit.trace(model.eval(), inputs) |
| traced_model = torch.jit.freeze(traced_model) |
| return traced_model |
| except Exception: |
| log.warning("JIT trace failed during the 'ipex' optimize process.") |
| return model |
| |
| |
| def has_ipex(): |
| try: |
| importlib.import_module("intel_extension_for_pytorch") |
| return True |
| except ImportError: |
| return False |