| ## @package cnn |
| # Module caffe2.python.cnn |
| |
| |
| |
| |
| |
| from caffe2.python import brew, workspace |
| from caffe2.python.model_helper import ModelHelper |
| from caffe2.proto import caffe2_pb2 |
| import logging |
| |
| |
| class CNNModelHelper(ModelHelper): |
| """A helper model so we can write CNN models more easily, without having to |
| manually define parameter initializations and operators separately. |
| """ |
| |
| def __init__(self, order="NCHW", name=None, |
| use_cudnn=True, cudnn_exhaustive_search=False, |
| ws_nbytes_limit=None, init_params=True, |
| skip_sparse_optim=False, |
| param_model=None): |
| logging.warning( |
| "[====DEPRECATE WARNING====]: you are creating an " |
| "object from CNNModelHelper class which will be deprecated soon. " |
| "Please use ModelHelper object with brew module. For more " |
| "information, please refer to caffe2.ai and python/brew.py, " |
| "python/brew_test.py for more information." |
| ) |
| |
| cnn_arg_scope = { |
| 'order': order, |
| 'use_cudnn': use_cudnn, |
| 'cudnn_exhaustive_search': cudnn_exhaustive_search, |
| } |
| if ws_nbytes_limit: |
| cnn_arg_scope['ws_nbytes_limit'] = ws_nbytes_limit |
| super(CNNModelHelper, self).__init__( |
| skip_sparse_optim=skip_sparse_optim, |
| name="CNN" if name is None else name, |
| init_params=init_params, |
| param_model=param_model, |
| arg_scope=cnn_arg_scope, |
| ) |
| |
| self.order = order |
| self.use_cudnn = use_cudnn |
| self.cudnn_exhaustive_search = cudnn_exhaustive_search |
| self.ws_nbytes_limit = ws_nbytes_limit |
| if self.order != "NHWC" and self.order != "NCHW": |
| raise ValueError( |
| "Cannot understand the CNN storage order %s." % self.order |
| ) |
| |
| def ImageInput(self, blob_in, blob_out, use_gpu_transform=False, **kwargs): |
| return brew.image_input( |
| self, |
| blob_in, |
| blob_out, |
| order=self.order, |
| use_gpu_transform=use_gpu_transform, |
| **kwargs |
| ) |
| |
| def VideoInput(self, blob_in, blob_out, **kwargs): |
| return brew.video_input( |
| self, |
| blob_in, |
| blob_out, |
| **kwargs |
| ) |
| |
| def PadImage(self, blob_in, blob_out, **kwargs): |
| # TODO(wyiming): remove this dummy helper later |
| self.net.PadImage(blob_in, blob_out, **kwargs) |
| |
| def ConvNd(self, *args, **kwargs): |
| return brew.conv_nd( |
| self, |
| *args, |
| use_cudnn=self.use_cudnn, |
| order=self.order, |
| cudnn_exhaustive_search=self.cudnn_exhaustive_search, |
| ws_nbytes_limit=self.ws_nbytes_limit, |
| **kwargs |
| ) |
| |
| def Conv(self, *args, **kwargs): |
| return brew.conv( |
| self, |
| *args, |
| use_cudnn=self.use_cudnn, |
| order=self.order, |
| cudnn_exhaustive_search=self.cudnn_exhaustive_search, |
| ws_nbytes_limit=self.ws_nbytes_limit, |
| **kwargs |
| ) |
| |
| def ConvTranspose(self, *args, **kwargs): |
| return brew.conv_transpose( |
| self, |
| *args, |
| use_cudnn=self.use_cudnn, |
| order=self.order, |
| cudnn_exhaustive_search=self.cudnn_exhaustive_search, |
| ws_nbytes_limit=self.ws_nbytes_limit, |
| **kwargs |
| ) |
| |
| def GroupConv(self, *args, **kwargs): |
| return brew.group_conv( |
| self, |
| *args, |
| use_cudnn=self.use_cudnn, |
| order=self.order, |
| cudnn_exhaustive_search=self.cudnn_exhaustive_search, |
| ws_nbytes_limit=self.ws_nbytes_limit, |
| **kwargs |
| ) |
| |
| def GroupConv_Deprecated(self, *args, **kwargs): |
| return brew.group_conv_deprecated( |
| self, |
| *args, |
| use_cudnn=self.use_cudnn, |
| order=self.order, |
| cudnn_exhaustive_search=self.cudnn_exhaustive_search, |
| ws_nbytes_limit=self.ws_nbytes_limit, |
| **kwargs |
| ) |
| |
| def FC(self, *args, **kwargs): |
| return brew.fc(self, *args, **kwargs) |
| |
| def PackedFC(self, *args, **kwargs): |
| return brew.packed_fc(self, *args, **kwargs) |
| |
| def FC_Prune(self, *args, **kwargs): |
| return brew.fc_prune(self, *args, **kwargs) |
| |
| def FC_Decomp(self, *args, **kwargs): |
| return brew.fc_decomp(self, *args, **kwargs) |
| |
| def FC_Sparse(self, *args, **kwargs): |
| return brew.fc_sparse(self, *args, **kwargs) |
| |
| def Dropout(self, *args, **kwargs): |
| return brew.dropout( |
| self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs |
| ) |
| |
| def LRN(self, *args, **kwargs): |
| return brew.lrn( |
| self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs |
| ) |
| |
| def Softmax(self, *args, **kwargs): |
| return brew.softmax(self, *args, use_cudnn=self.use_cudnn, **kwargs) |
| |
| def SpatialBN(self, *args, **kwargs): |
| return brew.spatial_bn(self, *args, order=self.order, **kwargs) |
| |
| def SpatialGN(self, *args, **kwargs): |
| return brew.spatial_gn(self, *args, order=self.order, **kwargs) |
| |
| def InstanceNorm(self, *args, **kwargs): |
| return brew.instance_norm(self, *args, order=self.order, **kwargs) |
| |
| def Relu(self, *args, **kwargs): |
| return brew.relu( |
| self, *args, order=self.order, use_cudnn=self.use_cudnn, **kwargs |
| ) |
| |
| def PRelu(self, *args, **kwargs): |
| return brew.prelu(self, *args, **kwargs) |
| |
| def Concat(self, *args, **kwargs): |
| return brew.concat(self, *args, order=self.order, **kwargs) |
| |
| def DepthConcat(self, *args, **kwargs): |
| """The old depth concat function - we should move to use concat.""" |
| print("DepthConcat is deprecated. use Concat instead.") |
| return self.Concat(*args, **kwargs) |
| |
| def Sum(self, *args, **kwargs): |
| return brew.sum(self, *args, **kwargs) |
| |
| def Transpose(self, *args, **kwargs): |
| return brew.transpose(self, *args, use_cudnn=self.use_cudnn, **kwargs) |
| |
| def Iter(self, *args, **kwargs): |
| return brew.iter(self, *args, **kwargs) |
| |
| def Accuracy(self, *args, **kwargs): |
| return brew.accuracy(self, *args, **kwargs) |
| |
| def MaxPool(self, *args, **kwargs): |
| return brew.max_pool( |
| self, *args, use_cudnn=self.use_cudnn, order=self.order, **kwargs |
| ) |
| |
| def MaxPoolWithIndex(self, *args, **kwargs): |
| return brew.max_pool_with_index(self, *args, order=self.order, **kwargs) |
| |
| def AveragePool(self, *args, **kwargs): |
| return brew.average_pool( |
| self, *args, use_cudnn=self.use_cudnn, order=self.order, **kwargs |
| ) |
| |
| @property |
| def XavierInit(self): |
| return ('XavierFill', {}) |
| |
| def ConstantInit(self, value): |
| return ('ConstantFill', dict(value=value)) |
| |
| @property |
| def MSRAInit(self): |
| return ('MSRAFill', {}) |
| |
| @property |
| def ZeroInit(self): |
| return ('ConstantFill', {}) |
| |
| def AddWeightDecay(self, weight_decay): |
| return brew.add_weight_decay(self, weight_decay) |
| |
| @property |
| def CPU(self): |
| device_option = caffe2_pb2.DeviceOption() |
| device_option.device_type = caffe2_pb2.CPU |
| return device_option |
| |
| @property |
| def GPU(self, gpu_id=0): |
| device_option = caffe2_pb2.DeviceOption() |
| device_option.device_type = workspace.GpuDeviceType |
| device_option.device_id = gpu_id |
| return device_option |