| |
| #pragma once |
| |
| #include <chrono> |
| #include <string> |
| |
| #include "caffe2/core/db.h" |
| #include "caffe2/core/logging.h" |
| #include "caffe2/core/stats.h" |
| #include "caffe2/queue/blobs_queue.h" |
| |
| namespace caffe2 { |
| namespace db { |
| |
| namespace { |
| const std::string& GetStringFromBlob(Blob* blob) { |
| if (blob->template IsType<string>()) { |
| return blob->template Get<string>(); |
| } else if (blob->template IsType<Tensor>()) { |
| return *blob->template Get<Tensor>().template data<string>(); |
| } else { |
| CAFFE_THROW("Unsupported Blob type"); |
| } |
| } |
| } // namespace |
| |
| class BlobsQueueDBCursor : public Cursor { |
| public: |
| explicit BlobsQueueDBCursor( |
| std::shared_ptr<BlobsQueue> queue, |
| int key_blob_index, |
| int value_blob_index, |
| float timeout_secs) |
| : queue_(queue), |
| key_blob_index_(key_blob_index), |
| value_blob_index_(value_blob_index), |
| timeout_secs_(timeout_secs), |
| inited_(false), |
| valid_(false) { |
| LOG(INFO) << "BlobsQueueDBCursor constructed"; |
| CAFFE_ENFORCE(queue_ != nullptr, "queue is null"); |
| CAFFE_ENFORCE(value_blob_index_ >= 0, "value_blob_index < 0"); |
| } |
| |
| virtual ~BlobsQueueDBCursor() {} |
| |
| void Seek(const string& /* unused */) override { |
| CAFFE_THROW("Seek is not supported."); |
| } |
| |
| bool SupportsSeek() override { |
| return false; |
| } |
| |
| void SeekToFirst() override { |
| // not applicable |
| } |
| |
| void Next() override { |
| unique_ptr<Blob> blob = make_unique<Blob>(); |
| vector<Blob*> blob_vector{blob.get()}; |
| auto success = queue_->blockingRead(blob_vector, timeout_secs_); |
| if (!success) { |
| LOG(ERROR) << "Timed out reading from BlobsQueue or it is closed"; |
| valid_ = false; |
| return; |
| } |
| |
| if (key_blob_index_ >= 0) { |
| key_ = GetStringFromBlob(blob_vector[key_blob_index_]); |
| } |
| value_ = GetStringFromBlob(blob_vector[value_blob_index_]); |
| valid_ = true; |
| } |
| |
| string key() override { |
| if (!inited_) { |
| Next(); |
| inited_ = true; |
| } |
| return key_; |
| } |
| |
| string value() override { |
| if (!inited_) { |
| Next(); |
| inited_ = true; |
| } |
| return value_; |
| } |
| |
| bool Valid() override { |
| return valid_; |
| } |
| |
| private: |
| std::shared_ptr<BlobsQueue> queue_; |
| int key_blob_index_; |
| int value_blob_index_; |
| float timeout_secs_; |
| bool inited_; |
| string key_; |
| string value_; |
| bool valid_; |
| }; |
| |
| class BlobsQueueDB : public DB { |
| public: |
| BlobsQueueDB( |
| const string& source, |
| Mode mode, |
| std::shared_ptr<BlobsQueue> queue, |
| int key_blob_index = -1, |
| int value_blob_index = 0, |
| float timeout_secs = 0.0) |
| : DB(source, mode), |
| queue_(queue), |
| key_blob_index_(key_blob_index), |
| value_blob_index_(value_blob_index), |
| timeout_secs_(timeout_secs) { |
| LOG(INFO) << "BlobsQueueDB constructed"; |
| } |
| |
| virtual ~BlobsQueueDB() { |
| Close(); |
| } |
| |
| void Close() override {} |
| unique_ptr<Cursor> NewCursor() override { |
| return make_unique<BlobsQueueDBCursor>( |
| queue_, key_blob_index_, value_blob_index_, timeout_secs_); |
| } |
| |
| unique_ptr<Transaction> NewTransaction() override { |
| CAFFE_THROW("Not implemented."); |
| } |
| |
| private: |
| std::shared_ptr<BlobsQueue> queue_; |
| int key_blob_index_; |
| int value_blob_index_; |
| float timeout_secs_; |
| }; |
| } // namespace db |
| } // namespace caffe2 |