| import sys |
| |
| import caffe2.python.onnx.backend as c2 |
| |
| import onnx |
| import pytorch_test_common |
| import torch |
| import torch.jit |
| from torch.autograd import Variable |
| |
| torch.set_default_tensor_type("torch.FloatTensor") |
| try: |
| import torch |
| except ImportError: |
| print("Cannot import torch, hence caffe2-torch test will not run.") |
| sys.exit(0) |
| |
| |
| def run_embed_params(proto, model, input, state_dict=None, use_gpu=True): |
| """ |
| This is only a helper debug function so we can test embed_params=False |
| case as well on pytorch front |
| This should likely be removed from the release version of the code |
| """ |
| device = "CPU" |
| if use_gpu: |
| device = "CUDA" |
| model_def = onnx.ModelProto.FromString(proto) |
| onnx.checker.check_model(model_def) |
| prepared = c2.prepare(model_def, device=device) |
| |
| if state_dict: |
| parameters = [] |
| # Passed in state_dict may have a different order. Make |
| # sure our order is consistent with the model's order. |
| # TODO: Even better: keyword arguments! |
| for k in model.state_dict(): |
| if k in state_dict: |
| parameters.append(state_dict[k]) |
| else: |
| parameters = list(model.state_dict().values()) |
| |
| W = {} |
| for k, v in zip( |
| model_def.graph.input, pytorch_test_common.flatten((input, parameters)) |
| ): |
| if isinstance(v, Variable): |
| W[k.name] = v.data.cpu().numpy() |
| else: |
| W[k.name] = v.cpu().numpy() |
| |
| caffe2_out = prepared.run(inputs=W) |
| |
| return caffe2_out |