blob: 96250cdd25812fb5ff42904e28889c26e842bdd0 [file] [log] [blame]
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/tensor.h"
#include "caffe2/core/types.h"
#include "caffe2/operators/text_file_reader_utils.h"
#include "caffe2/utils/string_utils.h"
namespace caffe2 {
struct TextFileReaderInstance {
TextFileReaderInstance(
const std::vector<char>& delims,
char escape,
const std::string& filename,
int numPasses,
// NOLINTNEXTLINE(modernize-pass-by-value)
const std::vector<int>& types)
: fileReader(filename),
tokenizer(Tokenizer(delims, escape), &fileReader, numPasses),
fieldTypes(types) {
for (const auto dt : fieldTypes) {
fieldMetas.push_back(
DataTypeToTypeMeta(static_cast<TensorProto_DataType>(dt)));
fieldByteSizes.push_back(fieldMetas.back().itemsize());
}
}
FileReader fileReader;
BufferedTokenizer tokenizer;
std::vector<int> fieldTypes;
std::vector<TypeMeta> fieldMetas;
std::vector<size_t> fieldByteSizes;
size_t rowsRead{0};
// hack to guarantee thread-safeness of the read op
// TODO(azzolini): support multi-threaded reading.
std::mutex globalMutex_;
};
class CreateTextFileReaderOp : public Operator<CPUContext> {
public:
template <class... Args>
explicit CreateTextFileReaderOp(Args&&... args)
: Operator<CPUContext>(std::forward<Args>(args)...),
filename_(GetSingleArgument<string>("filename", "")),
numPasses_(GetSingleArgument<int>("num_passes", 1)),
fieldTypes_(GetRepeatedArgument<int>("field_types")) {
CAFFE_ENFORCE(fieldTypes_.size() > 0, "field_types arg must be non-empty");
}
bool RunOnDevice() override {
*OperatorBase::Output<std::unique_ptr<TextFileReaderInstance>>(0) =
// NOLINTNEXTLINE(modernize-make-unique)
std::unique_ptr<TextFileReaderInstance>(new TextFileReaderInstance(
{'\n', '\t'}, '\0', filename_, numPasses_, fieldTypes_));
return true;
}
private:
std::string filename_;
int numPasses_;
std::vector<int> fieldTypes_;
};
inline void convert(
TensorProto_DataType dst_type,
const char* src_start,
const char* src_end,
void* dst) {
switch (dst_type) {
case TensorProto_DataType_STRING: {
static_cast<std::string*>(dst)->assign(src_start, src_end);
} break;
case TensorProto_DataType_FLOAT: {
// TODO(azzolini): avoid copy, use faster conversion
std::string str_copy(src_start, src_end);
const char* src_copy = str_copy.c_str();
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
char* src_copy_end;
float val = strtof(src_copy, &src_copy_end);
if (src_copy == src_copy_end) {
throw std::runtime_error("Invalid float: " + str_copy);
}
*static_cast<float*>(dst) = val;
} break;
default:
throw std::runtime_error("Unsupported type.");
}
}
class TextFileReaderReadOp : public Operator<CPUContext> {
public:
template <class... Args>
explicit TextFileReaderReadOp(Args&&... args)
: Operator<CPUContext>(std::forward<Args>(args)...),
batchSize_(GetSingleArgument<int>("batch_size", 1)) {}
bool RunOnDevice() override {
const int numFields = OutputSize();
CAFFE_ENFORCE(numFields > 0, "Expected at least one output.");
auto instance =
OperatorBase::Input<std::unique_ptr<TextFileReaderInstance>>(0).get();
CAFFE_ENFORCE(
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
instance->fieldTypes.size() == numFields,
"Invalid number of outputs. Expected " +
to_string(instance->fieldTypes.size()) + " got " +
to_string(numFields));
// char* datas[numFields];
// MSVC does not allow using const int, so we will need to dynamically allocate
// it.
std::vector<char*> datas(numFields);
for (int i = 0; i < numFields; ++i) {
Output(i)->Resize(batchSize_);
datas[i] = (char*)Output(i)->raw_mutable_data(instance->fieldMetas[i]);
}
int rowsRead = 0;
{
// TODO(azzolini): support multi-threaded reading
std::lock_guard<std::mutex> guard(instance->globalMutex_);
bool finished = false;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
Token token;
while (!finished && (rowsRead < batchSize_)) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int field;
for (field = 0; field < numFields; ++field) {
finished = !instance->tokenizer.next(token);
if (finished) {
CAFFE_ENFORCE(
field == 0, "Invalid number of fields at end of file.");
break;
}
CAFFE_ENFORCE(
(field == 0 && token.startDelimId == 0) ||
(field > 0 && token.startDelimId == 1),
"Invalid number of columns at row ",
instance->rowsRead + rowsRead + 1);
char*& data = datas[field];
convert(
(TensorProto_DataType)instance->fieldTypes[field],
token.start,
token.end,
data);
data += instance->fieldByteSizes[field];
}
if (!finished) {
++rowsRead;
}
}
instance->rowsRead += rowsRead;
}
for (int i = 0; i < numFields; ++i) {
Output(i)->ShrinkTo(rowsRead);
}
return true;
}
private:
int64_t batchSize_;
};
CAFFE_KNOWN_TYPE(std::unique_ptr<TextFileReaderInstance>);
REGISTER_CPU_OPERATOR(CreateTextFileReader, CreateTextFileReaderOp);
REGISTER_CPU_OPERATOR(TextFileReaderRead, TextFileReaderReadOp);
OPERATOR_SCHEMA(CreateTextFileReader)
.NumInputs(0)
.NumOutputs(1)
.ScalarType(TensorProto::UNDEFINED)
.SetDoc("Create a text file reader. Fields are delimited by <TAB>.")
.Arg("filename", "Path to the file.")
.Arg("num_passes", "Number of passes over the file.")
.Arg(
"field_types",
"List with type of each field. Type enum is found at core.DataType.")
.Output(0, "handler", "Pointer to the created TextFileReaderInstance.");
OPERATOR_SCHEMA(TextFileReaderRead)
.NumInputs(1)
.NumOutputs(1, INT_MAX)
.SetDoc(
"Read a batch of rows from the given text file reader instance. "
"Expects the number of fields to be equal to the number of outputs. "
"Each output is a 1D tensor containing the values for the given field "
"for each row. When end of file is reached, returns empty tensors.")
.Input(0, "handler", "Pointer to an existing TextFileReaderInstance.")
.Arg("batch_size", "Maximum number of rows to read.");
NO_GRADIENT(CreateTextFileReader);
NO_GRADIENT(TextFileReaderRead);
} // namespace caffe2