blob: 21b9c44f2a79ec341ca69db6febbf3a8671d478f [file] [log] [blame]
## @package reservoir_sampling
# Module caffe2.python.layers.reservoir_sampling
from caffe2.python import core, schema
from caffe2.python.layers.layers import ModelLayer
class ReservoirSampling(ModelLayer):
"""
Collect samples from input record w/ reservoir sampling. If you have complex
data, use PackRecords to pack it before using this layer.
This layer is not thread safe.
"""
def __init__(self, model, input_record, num_to_collect,
name='reservoir_sampling', **kwargs):
super(ReservoirSampling, self).__init__(
model, name, input_record, **kwargs)
assert num_to_collect > 0
self.num_to_collect = num_to_collect
self.reservoir = self.create_param(
param_name='reservoir',
shape=[0],
initializer=('ConstantFill',),
optimizer=model.NoOptim,
)
self.num_visited_blob = self.create_param(
param_name='num_visited',
shape=[],
initializer=('ConstantFill', {
'value': 0,
'dtype': core.DataType.INT64,
}),
optimizer=model.NoOptim,
)
self.mutex = self.create_param(
param_name='mutex',
shape=[],
initializer=('CreateMutex',),
optimizer=model.NoOptim,
)
self.extra_input_blobs = []
self.extra_output_blobs = []
if 'object_id' in input_record:
object_to_pos = self.create_param(
param_name='object_to_pos',
shape=None,
initializer=('CreateMap', {
'key_dtype': core.DataType.INT64,
'valued_dtype': core.DataType.INT32,
}),
optimizer=model.NoOptim,
)
pos_to_object = self.create_param(
param_name='pos_to_object',
shape=[0],
initializer=('ConstantFill', {
'value': 0,
'dtype': core.DataType.INT64,
}),
optimizer=model.NoOptim,
)
self.extra_input_blobs.append(input_record.object_id())
self.extra_input_blobs.extend([object_to_pos, pos_to_object])
self.extra_output_blobs.extend([object_to_pos, pos_to_object])
self.output_schema = schema.Struct(
(
'reservoir',
schema.from_blob_list(input_record.data, [self.reservoir])
),
('num_visited', schema.Scalar(blob=self.num_visited_blob)),
('mutex', schema.Scalar(blob=self.mutex)),
)
def add_ops(self, net):
net.ReservoirSampling(
[self.reservoir, self.num_visited_blob, self.input_record.data(),
self.mutex] + self.extra_input_blobs,
[self.reservoir, self.num_visited_blob] + self.extra_output_blobs,
num_to_collect=self.num_to_collect,
)