| ## @package queue_util |
| # Module caffe2.python.queue_util |
| |
| |
| |
| |
| |
| from caffe2.python import core, dataio |
| from caffe2.python.task import TaskGroup |
| |
| import logging |
| |
| |
| logger = logging.getLogger(__name__) |
| |
| |
| class _QueueReader(dataio.Reader): |
| def __init__(self, wrapper, num_dequeue_records=1): |
| assert wrapper.schema is not None, ( |
| 'Queue needs a schema in order to be read from.') |
| dataio.Reader.__init__(self, wrapper.schema()) |
| self._wrapper = wrapper |
| self._num_dequeue_records = num_dequeue_records |
| |
| def setup_ex(self, init_net, exit_net): |
| exit_net.CloseBlobsQueue([self._wrapper.queue()], 0) |
| |
| def read_ex(self, local_init_net, local_finish_net): |
| self._wrapper._new_reader(local_init_net) |
| dequeue_net = core.Net('dequeue') |
| fields, status_blob = dequeue( |
| dequeue_net, |
| self._wrapper.queue(), |
| len(self.schema().field_names()), |
| field_names=self.schema().field_names(), |
| num_records=self._num_dequeue_records) |
| return [dequeue_net], status_blob, fields |
| |
| def read(self, net): |
| net, _, fields = self.read_ex(net, None) |
| return net, fields |
| |
| |
| class _QueueWriter(dataio.Writer): |
| def __init__(self, wrapper): |
| self._wrapper = wrapper |
| |
| def setup_ex(self, init_net, exit_net): |
| exit_net.CloseBlobsQueue([self._wrapper.queue()], 0) |
| |
| def write_ex(self, fields, local_init_net, local_finish_net, status): |
| self._wrapper._new_writer(self.schema(), local_init_net) |
| enqueue_net = core.Net('enqueue') |
| enqueue(enqueue_net, self._wrapper.queue(), fields, status) |
| return [enqueue_net] |
| |
| |
| class QueueWrapper(dataio.Pipe): |
| def __init__(self, handler, schema=None, num_dequeue_records=1): |
| dataio.Pipe.__init__(self, schema, TaskGroup.LOCAL_SETUP) |
| self._queue = handler |
| self._num_dequeue_records = num_dequeue_records |
| |
| def reader(self): |
| return _QueueReader( |
| self, num_dequeue_records=self._num_dequeue_records) |
| |
| def writer(self): |
| return _QueueWriter(self) |
| |
| def queue(self): |
| return self._queue |
| |
| |
| class Queue(QueueWrapper): |
| def __init__(self, capacity, schema=None, name='queue', |
| num_dequeue_records=1): |
| # find a unique blob name for the queue |
| net = core.Net(name) |
| queue_blob = net.AddExternalInput(net.NextName('handler')) |
| QueueWrapper.__init__( |
| self, queue_blob, schema, num_dequeue_records=num_dequeue_records) |
| self.capacity = capacity |
| self._setup_done = False |
| |
| def setup(self, global_init_net): |
| assert self._schema, 'This queue does not have a schema.' |
| self._setup_done = True |
| global_init_net.CreateBlobsQueue( |
| [], |
| [self._queue], |
| capacity=self.capacity, |
| num_blobs=len(self._schema.field_names()), |
| field_names=self._schema.field_names()) |
| |
| |
| def enqueue(net, queue, data_blobs, status=None): |
| if status is None: |
| status = net.NextName('status') |
| # Enqueueing moved the data into the queue; |
| # duplication will result in data corruption |
| queue_blobs = [] |
| for blob in data_blobs: |
| if blob not in queue_blobs: |
| queue_blobs.append(blob) |
| else: |
| logger.warning("Need to copy blob {} to enqueue".format(blob)) |
| queue_blobs.append(net.Copy(blob)) |
| results = net.SafeEnqueueBlobs([queue] + queue_blobs, queue_blobs + [status]) |
| return results[-1] |
| |
| |
| def dequeue(net, queue, num_blobs, status=None, field_names=None, |
| num_records=1): |
| if field_names is not None: |
| assert len(field_names) == num_blobs |
| data_names = [net.NextName(name) for name in field_names] |
| else: |
| data_names = [net.NextName('data', i) for i in range(num_blobs)] |
| if status is None: |
| status = net.NextName('status') |
| results = net.SafeDequeueBlobs( |
| queue, data_names + [status], num_records=num_records) |
| results = list(results) |
| status_blob = results.pop(-1) |
| return results, status_blob |
| |
| |
| def close_queue(step, *queues): |
| close_net = core.Net("close_queue_net") |
| for queue in queues: |
| close_net.CloseBlobsQueue([queue], 0) |
| close_step = core.execution_step("%s_step" % str(close_net), close_net) |
| return core.execution_step( |
| "%s_wraper_step" % str(close_net), |
| [step, close_step]) |