| """Functionality for Python <-> C++ frontend inter-op.""" |
| |
| from torch import nn |
| |
| |
| class OrderedDictWrapper(object): |
| """ |
| A wrapper around a C++ OrderedDict that dynamically evaluates the |
| OrderedDict getter on a bound C++ module, such that new changes on the C++ |
| side are picked up. Otherwise accessing e.g. ``cpp_module._parameters`` just |
| once would get a frozen copy of the parameters at the time of access. |
| ``torch.nn.Module`` accesses ``_parameters`` et al. via ``self.__dict__`` so |
| using properties does not work. |
| """ |
| |
| def __init__(self, cpp_module, attr): |
| self.cpp_module = cpp_module |
| self.attr = attr |
| |
| @property |
| def cpp_dict(self): |
| return getattr(self.cpp_module, self.attr) |
| |
| # Magic methods cannot be assigned dynamically and bypass ``getattr``, so we |
| # must manually override them. |
| |
| def items(self): |
| return self.cpp_dict.items() |
| |
| def keys(self): |
| return self.cpp_dict.keys() |
| |
| def values(self): |
| return self.cpp_dict.values() |
| |
| def __iter__(self): |
| return self.cpp_dict.__iter__() |
| |
| def __len__(self): |
| return self.cpp_dict.__len__() |
| |
| def __contains__(self, key): |
| return self.cpp_dict.__contains__(key) |
| |
| def __getitem__(self, key): |
| return self.cpp_dict.__getitem__(key) |
| |
| |
| class ModuleWrapper(nn.Module): |
| """ |
| A subclass of ``torch.nn.Module`` that wraps a C++ frontend module and |
| delegates all access. |
| """ |
| |
| def __init__(self, cpp_module): |
| # Assign before the super class constructor so ``self.training`` can be |
| # assigned to in the super class constructor. |
| self.cpp_module = cpp_module |
| super(ModuleWrapper, self).__init__() |
| self._parameters = OrderedDictWrapper(cpp_module, "_parameters") # type: ignore[assignment] |
| self._buffers: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_buffers") # type: ignore[assignment] |
| self._modules: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_modules") # type: ignore[assignment] |
| for attr in dir(cpp_module): |
| # Skip magic methods and the three attributes above. |
| if not attr.startswith("_"): |
| setattr(self, attr, getattr(self.cpp_module, attr)) |
| |
| def _apply(self, fn): |
| for param in self.parameters(): |
| # Tensors stored in modules are graph leaves, and we don't |
| # want to create copy nodes, so we have to unpack the data. |
| param.data = fn(param.data) |
| if param._grad is not None: |
| param._grad.data = fn(param._grad.data) |
| |
| for buf in self.buffers(): |
| buf.data = fn(buf.data) |
| |
| return self |
| |
| # nn.Module defines training as a boolean |
| @property # type: ignore[override] |
| def training(self): |
| return self.cpp_module.training |
| |
| @training.setter |
| def training(self, mode): |
| self.cpp_module.train(mode) |
| |
| def __repr__(self): |
| return self.cpp_module.__repr__() |