| ## @package sampling_trainable_mixin |
| # Module caffe2.python.layers.sampling_trainable_mixin |
| |
| |
| |
| |
| |
| import abc |
| |
| |
| class SamplingTrainableMixin(metaclass=abc.ABCMeta): |
| |
| def __init__(self, *args, **kwargs): |
| super(SamplingTrainableMixin, self).__init__(*args, **kwargs) |
| self._train_param_blobs = None |
| self._train_param_blobs_frozen = False |
| |
| @property |
| @abc.abstractmethod |
| def param_blobs(self): |
| """ |
| List of parameter blobs for prediction net |
| """ |
| pass |
| |
| @property |
| def train_param_blobs(self): |
| """ |
| If train_param_blobs is not set before used, default to param_blobs |
| """ |
| if self._train_param_blobs is None: |
| self.train_param_blobs = self.param_blobs |
| return self._train_param_blobs |
| |
| @train_param_blobs.setter |
| def train_param_blobs(self, blobs): |
| assert not self._train_param_blobs_frozen |
| assert blobs is not None |
| self._train_param_blobs_frozen = True |
| self._train_param_blobs = blobs |
| |
| @abc.abstractmethod |
| def _add_ops(self, net, param_blobs): |
| """ |
| Add ops to the given net, using the given param_blobs |
| """ |
| pass |
| |
| def add_ops(self, net): |
| self._add_ops(net, self.param_blobs) |
| |
| def add_train_ops(self, net): |
| self._add_ops(net, self.train_param_blobs) |