| |
| |
| |
| |
| |
| import abc |
| |
| |
| class NetModifier(metaclass=abc.ABCMeta): |
| """ |
| An abstraction class for supporting modifying a generated net. |
| Inherited classes should implement the modify_net method where |
| related operators are added to the net. |
| |
| Example usage: |
| modifier = SomeNetModifier(opts) |
| modifier(net) |
| """ |
| |
| def __init__(self): |
| pass |
| |
| @abc.abstractmethod |
| def modify_net(self, net, init_net=None, grad_map=None, blob_to_device=None): |
| pass |
| |
| def __call__(self, net, init_net=None, grad_map=None, blob_to_device=None, |
| modify_output_record=False): |
| self.modify_net( |
| net, |
| init_net=init_net, |
| grad_map=grad_map, |
| blob_to_device=blob_to_device, |
| modify_output_record=modify_output_record) |