| ## @package text_file_reader |
| # Module caffe2.python.text_file_reader |
| |
| |
| |
| |
| from caffe2.python import core |
| from caffe2.python.dataio import Reader |
| from caffe2.python.schema import Scalar, Struct, data_type_for_dtype |
| |
| |
| class TextFileReader(Reader): |
| """ |
| Wrapper around operators for reading from text files. |
| """ |
| def __init__(self, init_net, filename, schema, num_passes=1, batch_size=1): |
| """ |
| Create op for building a TextFileReader instance in the workspace. |
| |
| Args: |
| init_net : Net that will be run only once at startup. |
| filename : Path to file to read from. |
| schema : schema.Struct representing the schema of the data. |
| Currently, only support Struct of strings and float32. |
| num_passes : Number of passes over the data. |
| batch_size : Number of rows to read at a time. |
| """ |
| assert isinstance(schema, Struct), 'Schema must be a schema.Struct' |
| for name, child in schema.get_children(): |
| assert isinstance(child, Scalar), ( |
| 'Only scalar fields are supported in TextFileReader.') |
| field_types = [ |
| data_type_for_dtype(dtype) for dtype in schema.field_types()] |
| Reader.__init__(self, schema) |
| self._reader = init_net.CreateTextFileReader( |
| [], |
| filename=filename, |
| num_passes=num_passes, |
| field_types=field_types) |
| self._batch_size = batch_size |
| |
| def read(self, net): |
| """ |
| Create op for reading a batch of rows. |
| """ |
| blobs = net.TextFileReaderRead( |
| [self._reader], |
| len(self.schema().field_names()), |
| batch_size=self._batch_size) |
| if type(blobs) is core.BlobReference: |
| blobs = [blobs] |
| |
| is_empty = net.IsEmpty( |
| [blobs[0]], |
| core.ScopedBlobReference(net.NextName('should_stop')) |
| ) |
| |
| return (is_empty, blobs) |