| #pragma once |
| |
| #include <memory> |
| #include <unordered_set> |
| |
| #include "caffe2/core/logging.h" |
| |
| namespace caffe2 { |
| |
| /** |
| * Use this to implement a Observer using the Observer Pattern template. |
| */ |
| |
| template <class T> |
| class ObserverBase { |
| public: |
| explicit ObserverBase(T* subject) : subject_(subject) {} |
| |
| virtual void Start() {} |
| virtual void Stop() {} |
| |
| virtual std::string debugInfo() { |
| return "Not implemented."; |
| } |
| |
| virtual ~ObserverBase() noexcept {}; |
| |
| T* subject() const { |
| return subject_; |
| } |
| |
| virtual std::unique_ptr<ObserverBase<T>> rnnCopy(T* subject, int rnn_order) |
| const { |
| return nullptr; |
| }; |
| |
| protected: |
| T* subject_; |
| }; |
| |
| /** |
| * Inherit to make your class observable. |
| */ |
| template <class T> |
| class Observable { |
| public: |
| Observable() = default; |
| |
| Observable(Observable&&) = default; |
| Observable& operator =(Observable&&) = default; |
| |
| virtual ~Observable() = default; |
| |
| C10_DISABLE_COPY_AND_ASSIGN(Observable); |
| |
| using Observer = ObserverBase<T>; |
| |
| /* Returns a reference to the observer after addition. */ |
| const Observer* AttachObserver(std::unique_ptr<Observer> observer) { |
| CAFFE_ENFORCE(observer, "Couldn't attach a null observer."); |
| std::unordered_set<const Observer*> observers; |
| for (auto& ob : observers_list_) { |
| observers.insert(ob.get()); |
| } |
| |
| const auto* observer_ptr = observer.get(); |
| if (observers.count(observer_ptr)) { |
| return observer_ptr; |
| } |
| observers_list_.push_back(std::move(observer)); |
| UpdateCache(); |
| |
| return observer_ptr; |
| } |
| |
| /** |
| * Returns a unique_ptr to the removed observer. If not found, return a |
| * nullptr |
| */ |
| std::unique_ptr<Observer> DetachObserver(const Observer* observer_ptr) { |
| for (auto it = observers_list_.begin(); it != observers_list_.end(); ++it) { |
| if (it->get() == observer_ptr) { |
| auto res = std::move(*it); |
| observers_list_.erase(it); |
| UpdateCache(); |
| return res; |
| } |
| } |
| return nullptr; |
| } |
| |
| virtual size_t NumObservers() { |
| return num_observers_; |
| } |
| |
| private: |
| inline static void StartObserver(Observer* observer) { |
| try { |
| observer->Start(); |
| } catch (const std::exception& e) { |
| LOG(ERROR) << "Exception from observer: " << e.what(); |
| } catch (...) { |
| LOG(ERROR) << "Exception from observer: unknown"; |
| } |
| } |
| |
| inline static void StopObserver(Observer* observer) { |
| try { |
| observer->Stop(); |
| } catch (const std::exception& e) { |
| LOG(ERROR) << "Exception from observer: " << e.what(); |
| } catch (...) { |
| LOG(ERROR) << "Exception from observer: unknown"; |
| } |
| } |
| |
| void UpdateCache() { |
| num_observers_ = observers_list_.size(); |
| if (num_observers_ != 1) { |
| // we cannot take advantage of the cache |
| return; |
| } |
| observer_cache_ = observers_list_[0].get(); |
| } |
| |
| public: |
| void StartAllObservers() { |
| // do not access observers_list_ unless necessary |
| if (num_observers_ == 0) { |
| return; |
| } else if (num_observers_ == 1) { |
| StartObserver(observer_cache_); |
| } else { |
| for (auto& observer : observers_list_) { |
| StartObserver(observer.get()); |
| } |
| } |
| } |
| |
| void StopAllObservers() { |
| // do not access observers_list_ unless necessary |
| if (num_observers_ == 0) { |
| return; |
| } else if (num_observers_ == 1) { |
| StopObserver(observer_cache_); |
| } else { |
| for (auto& observer : observers_list_) { |
| StopObserver(observer.get()); |
| } |
| } |
| } |
| |
| private: |
| // an on-stack cache for fast iteration; |
| // ideally, inside StartAllObservers and StopAllObservers, |
| // we should never access observers_list_ |
| Observer* observer_cache_; |
| size_t num_observers_ = 0; |
| |
| protected: |
| std::vector<std::unique_ptr<Observer>> observers_list_; |
| }; |
| |
| } // namespace caffe2 |