blob: b2f8135f5da2dcab47b0018980525ee74a9566b7 [file] [log] [blame]
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// This script converts an image dataset to a database.
//
// FLAGS_input_folder is the root folder that holds all the images
//
// FLAGS_list_file is the path to a file containing a list of files
// and their labels, as follows:
//
// subfolder1/file1.JPEG 7
// subfolder1/file2.JPEG 7
// subfolder2/file1.JPEG 8
// ...
//
#include <opencv2/opencv.hpp>
#include <algorithm>
#include <fstream>
#include <queue>
#include <random>
#include <string>
#include <thread>
#include "caffe2/core/common.h"
#include "caffe2/core/db.h"
#include "caffe2/core/init.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/core/logging.h"
C10_DEFINE_bool(
shuffle,
false,
"Randomly shuffle the order of images and their labels");
C10_DEFINE_string(input_folder, "", "The input image file name.");
C10_DEFINE_string(
list_file,
"",
"The text file containing the list of images.");
C10_DEFINE_string(output_db_name, "", "The output training leveldb name.");
C10_DEFINE_string(db, "leveldb", "The db type.");
C10_DEFINE_bool(
raw,
false,
"If set, we pre-read the images and store the raw buffer.");
C10_DEFINE_bool(color, true, "If set, load images in color.");
C10_DEFINE_int(
scale,
256,
"If FLAGS_raw is set, scale the shorter edge to the given value.");
C10_DEFINE_bool(warp, false, "If warp is set, warp the images to square.");
C10_DEFINE_int(
num_threads,
-1,
"Number of image parsing and conversion threads.");
namespace caffe2 {
class Converter {
public:
explicit Converter() {
data_ = protos_.add_protos();
label_ = protos_.add_protos();
if (FLAGS_raw) {
data_->set_data_type(TensorProto::BYTE);
data_->add_dims(0);
data_->add_dims(0);
if (FLAGS_color) {
data_->add_dims(3);
}
} else {
data_->set_data_type(TensorProto::STRING);
data_->add_dims(1);
data_->add_string_data("");
}
label_->set_data_type(TensorProto::INT32);
label_->add_dims(1);
label_->add_int32_data(0);
}
~Converter() {
if (thread_.joinable()) {
thread_.join();
}
}
void queue(const std::pair<std::string, int>& pair) {
in_.push(pair);
}
void start() {
thread_ = std::thread(&Converter::run, this);
}
std::string get() {
std::unique_lock<std::mutex> lock(mutex_);
while (out_.empty()) {
cv_.wait(lock);
}
auto value = out_.front();
out_.pop();
cv_.notify_one();
return value;
}
void run() {
const auto& input_folder = FLAGS_input_folder;
std::unique_lock<std::mutex> lock(mutex_);
std::string value;
while (!in_.empty()) {
auto pair = in_.front();
in_.pop();
lock.unlock();
label_->set_int32_data(0, pair.second);
// Add raw file contents to DB if !raw
if (!FLAGS_raw) {
std::ifstream image_file_stream(input_folder + pair.first);
if (!image_file_stream) {
LOG(ERROR) << "Cannot open " << input_folder << pair.first
<< ". Skipping.";
} else {
data_->mutable_string_data(0)->assign(
std::istreambuf_iterator<char>(image_file_stream),
std::istreambuf_iterator<char>());
}
} else {
// Load image
cv::Mat img = cv::imread(
input_folder + pair.first,
FLAGS_color ? cv::IMREAD_COLOR : cv::IMREAD_GRAYSCALE);
// Resize image
cv::Mat resized_img;
int scaled_width, scaled_height;
if (FLAGS_warp) {
scaled_width = FLAGS_scale;
scaled_height = FLAGS_scale;
} else if (img.rows > img.cols) {
scaled_width = FLAGS_scale;
scaled_height = static_cast<float>(img.rows) * FLAGS_scale / img.cols;
} else {
scaled_height = FLAGS_scale;
scaled_width = static_cast<float>(img.cols) * FLAGS_scale / img.rows;
}
cv::resize(
img,
resized_img,
cv::Size(scaled_width, scaled_height),
0,
0,
cv::INTER_LINEAR);
data_->set_dims(0, scaled_height);
data_->set_dims(1, scaled_width);
// Assert we don't have to deal with alignment
DCHECK(resized_img.isContinuous());
auto nbytes = resized_img.total() * resized_img.elemSize();
data_->set_byte_data(resized_img.ptr(), nbytes);
}
protos_.SerializeToString(&value);
// Add serialized proto to out queue or wait if it is not empty
lock.lock();
while (!out_.empty()) {
cv_.wait(lock);
}
out_.push(value);
cv_.notify_one();
}
}
protected:
TensorProtos protos_;
TensorProto* data_;
TensorProto* label_;
std::queue<std::pair<std::string, int>> in_;
std::queue<std::string> out_;
std::mutex mutex_;
std::condition_variable cv_;
std::thread thread_;
};
void ConvertImageDataset(
const string& input_folder,
const string& list_filename,
const string& output_db_name,
const bool /*shuffle*/) {
std::ifstream list_file(list_filename);
std::vector<std::pair<std::string, int> > lines;
std::string filename;
int file_label;
while (list_file >> filename >> file_label) {
lines.push_back(std::make_pair(filename, file_label));
}
if (FLAGS_shuffle) {
LOG(INFO) << "Shuffling data";
std::shuffle(lines.begin(), lines.end(), std::default_random_engine(1701));
}
auto num_threads = FLAGS_num_threads;
if (num_threads < 1) {
num_threads = std::thread::hardware_concurrency();
}
LOG(INFO) << "Processing " << lines.size() << " images...";
LOG(INFO) << "Opening DB " << output_db_name;
auto db = db::CreateDB(FLAGS_db, output_db_name, db::NEW);
auto transaction = db->NewTransaction();
LOG(INFO) << "Using " << num_threads << " processing threads...";
std::vector<Converter> converters(num_threads);
// Queue entries across converters
for (auto i = 0; i < lines.size(); i++) {
converters[i % converters.size()].queue(lines[i]);
}
// Start all converters
for (auto& converter : converters) {
converter.start();
}
constexpr auto key_max_length = 256;
char key_cstr[key_max_length];
int count = 0;
for (auto i = 0; i < lines.size(); i++) {
// Get serialized proto for this entry
auto value = converters[i % converters.size()].get();
// Synthesize key for this entry
auto key_len = snprintf(
key_cstr, sizeof(key_cstr), "%08d_%s", i, lines[i].first.c_str());
DCHECK_LE(key_len, sizeof(key_cstr));
// Put in db
transaction->Put(string(key_cstr), std::move(value));
if (++count % 1000 == 0) {
// Commit the current writes.
transaction->Commit();
LOG(INFO) << "Processed " << count << " files.";
}
}
// Commit final transaction
transaction->Commit();
LOG(INFO) << "Processed " << count << " files.";
}
} // namespace caffe2
int main(int argc, char** argv) {
caffe2::GlobalInit(&argc, &argv);
caffe2::ConvertImageDataset(
FLAGS_input_folder, FLAGS_list_file, FLAGS_output_db_name, FLAGS_shuffle);
return 0;
}