| #ifndef CAFFE2_OPERATORS_PREFETCH_OP_H_ |
| #define CAFFE2_OPERATORS_PREFETCH_OP_H_ |
| |
| #include <condition_variable> |
| #include <mutex> |
| #include <thread> // NOLINT |
| |
| #include "caffe2/core/context.h" |
| #include "caffe2/core/operator.h" |
| |
| namespace caffe2 { |
| |
| // PrefetchOperator is an operator that prefetches the next batch. It should |
| // almost always be used to read things from disk, so I am setting the input to |
| // zero blobs. |
| // |
| // For any operator that is derived from PrefetchOperator, it should |
| // explicitly call the Finalize() function in its destructor, so that the |
| // prefetching thread is properly destructed. |
| |
| // Note: We inherit from OperatorBase since we control the |
| // synchronization properties of this operator ourselves (we inform |
| // the waiting producer after we synchronize). This is a special-case |
| // - you should generally inherit from Operator<Context> directly. |
| template <class Context> |
| class PrefetchOperator : public OperatorBase { |
| public: |
| PrefetchOperator(const OperatorDef& operator_def, Workspace* ws) |
| : OperatorBase(operator_def, ws), |
| context_(operator_def.device_option()), |
| prefetched_(false), |
| prefetch_success_(true), |
| finalize_(false), |
| no_prefetch_(GetSingleArgument<bool>("no_prefetch", false)) { |
| context_.SwitchToDevice(); |
| } |
| |
| virtual ~PrefetchOperator() noexcept { |
| CHECK(finalize_ || !prefetch_thread_.get()) |
| << "YOU MADE A PROGRAMING ERROR: derived class of PrefetchOperator " |
| "should call Finalize() in its destructor so the prefetching " |
| "thread is joined. "; |
| } |
| |
| void Finalize() { |
| if (prefetch_thread_.get()) { |
| { |
| std::unique_lock<std::mutex> lock(prefetch_access_mutex_); |
| while (!prefetched_) |
| consumer_.wait(lock); |
| finalize_ = true; |
| prefetched_ = false; |
| } |
| producer_.notify_one(); |
| prefetch_thread_->join(); |
| prefetch_thread_.reset(); |
| } else { |
| // If we never initialized the prefetch thread, just set |
| // finalize anyway. |
| finalize_ = true; |
| } |
| } |
| |
| bool Run(int /* unused */ /*stream_id*/) override { |
| if (no_prefetch_) { |
| context_.SwitchToDevice(); |
| bool result = Prefetch() && CopyPrefetched(); |
| context_.FinishDeviceComputation(); |
| return result; |
| } |
| // Note(jiayq): We only start the prefetch_thread at the Run() function |
| // instead of in the constructor, because the prefetch_thread needs to start |
| // after all derived classes' constructors finish. |
| if (!prefetch_thread_) { |
| prefetch_thread_.reset( |
| new std::thread([this] { this->PrefetchWorker(); })); |
| } |
| context_.SwitchToDevice(); |
| std::unique_lock<std::mutex> lock(prefetch_access_mutex_); |
| while (!prefetched_) |
| consumer_.wait(lock); |
| if (!prefetch_success_) { |
| LOG(ERROR) << "Prefetching failed."; |
| return false; |
| } |
| if (!CopyPrefetched()) { |
| LOG(ERROR) << "Error when copying prefetched data."; |
| return false; |
| } |
| prefetched_ = false; |
| context_.FinishDeviceComputation(); |
| producer_.notify_one(); |
| return true; |
| } |
| |
| void PrefetchWorker() { |
| context_.SwitchToDevice(); |
| std::unique_lock<std::mutex> lock(prefetch_access_mutex_); |
| while (prefetched_) |
| producer_.wait(lock); |
| while (!finalize_) { |
| // We will need to run a FinishDeviceComputation() call because the |
| // prefetcher thread and the main thread are potentially using different |
| // streams (like on GPU). |
| try { |
| prefetch_success_ = Prefetch(); |
| context_.FinishDeviceComputation(); |
| } catch (const std::exception& e) { |
| // TODO: propagate exception_ptr to the caller side |
| LOG(ERROR) << "Prefetching error " << e.what(); |
| prefetch_success_ = false; |
| } |
| prefetched_ = true; |
| consumer_.notify_one(); |
| while (prefetched_) |
| producer_.wait(lock); |
| } |
| } |
| |
| // You will need to implement this instead of the Run function. |
| virtual bool Prefetch() = 0; |
| virtual bool CopyPrefetched() = 0; |
| |
| protected: |
| Context context_; |
| std::mutex prefetch_access_mutex_; |
| std::condition_variable producer_, consumer_; |
| // prefetched_ is used to tell the operator that it is done. |
| std::atomic<bool> prefetched_; |
| // prefetch_success_ is used to see if prefetching failed or not. |
| std::atomic<bool> prefetch_success_; |
| // finalize_ is used to tell the prefetcher to quit. |
| std::atomic<bool> finalize_; |
| unique_ptr<std::thread> prefetch_thread_; |
| |
| // Whether to do prefetching or run this as a normal operator |
| const bool no_prefetch_; |
| }; |
| |
| } // namespace caffe2 |
| |
| #endif // CAFFE2_OPERATORS_PREFETCH_OP_H_ |