| ## @package reservoir_sampling |
| # Module caffe2.python.layers.reservoir_sampling |
| |
| |
| |
| |
| |
| from caffe2.python import core, schema |
| from caffe2.python.layers.layers import ModelLayer |
| |
| |
| class ReservoirSampling(ModelLayer): |
| """ |
| Collect samples from input record w/ reservoir sampling. If you have complex |
| data, use PackRecords to pack it before using this layer. |
| |
| This layer is not thread safe. |
| """ |
| |
| def __init__(self, model, input_record, num_to_collect, |
| name='reservoir_sampling', **kwargs): |
| super(ReservoirSampling, self).__init__( |
| model, name, input_record, **kwargs) |
| assert num_to_collect > 0 |
| self.num_to_collect = num_to_collect |
| |
| self.reservoir = self.create_param( |
| param_name='reservoir', |
| shape=[0], |
| initializer=('ConstantFill',), |
| optimizer=model.NoOptim, |
| ) |
| self.num_visited_blob = self.create_param( |
| param_name='num_visited', |
| shape=[], |
| initializer=('ConstantFill', { |
| 'value': 0, |
| 'dtype': core.DataType.INT64, |
| }), |
| optimizer=model.NoOptim, |
| ) |
| self.mutex = self.create_param( |
| param_name='mutex', |
| shape=[], |
| initializer=('CreateMutex',), |
| optimizer=model.NoOptim, |
| ) |
| |
| self.extra_input_blobs = [] |
| self.extra_output_blobs = [] |
| if 'object_id' in input_record: |
| object_to_pos = self.create_param( |
| param_name='object_to_pos', |
| shape=None, |
| initializer=('CreateMap', { |
| 'key_dtype': core.DataType.INT64, |
| 'valued_dtype': core.DataType.INT32, |
| }), |
| optimizer=model.NoOptim, |
| ) |
| pos_to_object = self.create_param( |
| param_name='pos_to_object', |
| shape=[0], |
| initializer=('ConstantFill', { |
| 'value': 0, |
| 'dtype': core.DataType.INT64, |
| }), |
| optimizer=model.NoOptim, |
| ) |
| self.extra_input_blobs.append(input_record.object_id()) |
| self.extra_input_blobs.extend([object_to_pos, pos_to_object]) |
| self.extra_output_blobs.extend([object_to_pos, pos_to_object]) |
| |
| self.output_schema = schema.Struct( |
| ( |
| 'reservoir', |
| schema.from_blob_list(input_record.data, [self.reservoir]) |
| ), |
| ('num_visited', schema.Scalar(blob=self.num_visited_blob)), |
| ('mutex', schema.Scalar(blob=self.mutex)), |
| ) |
| |
| def add_ops(self, net): |
| net.ReservoirSampling( |
| [self.reservoir, self.num_visited_blob, self.input_record.data(), |
| self.mutex] + self.extra_input_blobs, |
| [self.reservoir, self.num_visited_blob] + self.extra_output_blobs, |
| num_to_collect=self.num_to_collect, |
| ) |