| import io |
| |
| import onnx |
| |
| import torch.onnx |
| from caffe2.python.core import BlobReference, Net |
| from caffe2.python.onnx.backend import Caffe2Backend |
| |
| _next_idx = 0 |
| # Clone net takes a dict instead of a lambda |
| # It should probably take a lambda, it is more flexible |
| # We fake dict here |
| |
| |
| class _FakeDict: |
| def __init__(self, fn): |
| self.fn = fn |
| |
| def get(self, name, _): |
| return self.fn(name) |
| |
| |
| def PyTorchModule(helper, model, sample_arguments, caffe2_inputs, prefix_name=None): |
| """ |
| Embed an ONNX-exportable PyTorch Model into a Caffe2 model being built. |
| |
| Args: |
| helper (caffe2.python.core.ModelHelder): the model helper where |
| this imported network should be inserted |
| model (torch.nn.Module): the model to be exported |
| sample_arguments (tuple of arguments): the inputs to |
| the model, e.g., such that ``model(*args)`` is a valid |
| invocation of the model. Any non-Variable arguments will |
| be hard-coded into the exported model; any Variable arguments |
| will become inputs of the exported model, in the order they |
| occur in args. If args is a Variable, this is equivalent |
| to having called it with a 1-ary tuple of that Variable. |
| (Note: passing keyword arguments to the model is not currently |
| supported. Give us a shout if you need it.) |
| caffe2_inputs (list of str or caffe2.python.core.BlobReference): the |
| caffe2 Blobs that should be inputs to this network. Must be |
| the same length as sample_arguments |
| prefix_name: prefix name to add to each member of the blob, if None then |
| a fresh prefix pytorch_input_N/ is used |
| Returns: |
| A tuple of caffe2.python.core.BlobReference objects referring to the |
| models outputs, or a single BlobReference when the model returns a single |
| value. |
| """ |
| if prefix_name is None: |
| global _next_idx |
| prefix_name = "pytorch_import_" + str(_next_idx) + "/" |
| _next_idx += 1 |
| |
| # TODO: handle the case where model cannot be exported |
| # and embed as a Python op in Caffe2 |
| f = io.BytesIO() |
| torch.onnx.export(model, sample_arguments, f, export_params=True) |
| onnx_model = onnx.load(io.BytesIO(f.getvalue())) |
| init_net, predict_net = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model) |
| |
| initialized = {x.name for x in onnx_model.graph.initializer} |
| uninitialized_inputs = { |
| x.name: i |
| for i, x in enumerate(onnx_model.graph.input) |
| if x.name not in initialized |
| } |
| |
| if len(uninitialized_inputs) != len(caffe2_inputs): |
| raise ValueError( |
| f"Expected {len(uninitialized_inputs)} inputs but found {len(caffe2_inputs)}" |
| ) |
| |
| def remap_blob_name(name): |
| if name in uninitialized_inputs: |
| idx = uninitialized_inputs[name] |
| return str(caffe2_inputs[idx]) |
| return prefix_name + name |
| |
| predict_net = Net(predict_net).Clone("anon", _FakeDict(remap_blob_name)) |
| helper.net.AppendNet(predict_net) |
| |
| init_net = Net(init_net).Clone("anon", _FakeDict(remap_blob_name)) |
| helper.param_init_net.AppendNet(init_net) |
| |
| results = tuple( |
| BlobReference(remap_blob_name(x.name), helper.net) |
| for x in onnx_model.graph.output |
| ) |
| return results |