| #pragma once |
| |
| #include <mutex> |
| #include <unordered_map> |
| #include <unordered_set> |
| #include "caffe2/core/logging.h" |
| |
| namespace caffe2 { |
| |
| /** |
| * thread_local pointer in C++ is a per thread pointer. However, sometimes |
| * we want to have a thread local state that is per thread and also per |
| * instance. e.g. we have the following class: |
| * class A { |
| * ThreadLocalPtr<int> x; |
| * } |
| * We would like to have a copy of x per thread and also per instance of class A |
| * This can be applied to storing per instance thread local state of some class, |
| * when we could have multiple instances of the class in the same thread. |
| * We implemented a subset of functions in folly::ThreadLocalPtr that's enough |
| * to support BlackBoxPredictor. |
| */ |
| |
| class ThreadLocalPtrImpl; |
| class ThreadLocalHelper; |
| |
| /** |
| * Map of object pointer to instance in each thread |
| * to achieve per thread(using thread_local) per object(using the map) |
| * thread local pointer |
| */ |
| typedef std::unordered_map<ThreadLocalPtrImpl*, std::shared_ptr<void>> |
| UnsafeThreadLocalMap; |
| |
| ThreadLocalHelper* getThreadLocalHelper(); |
| |
| typedef std::vector<ThreadLocalHelper*> UnsafeAllThreadLocalHelperVector; |
| |
| /** |
| * A thread safe vector of all ThreadLocalHelper, this will be used |
| * to encapuslate the locking in the APIs for the changes to the global |
| * AllThreadLocalHelperVector instance. |
| */ |
| class AllThreadLocalHelperVector { |
| public: |
| AllThreadLocalHelperVector() {} |
| |
| // Add a new ThreadLocalHelper to the vector |
| void push_back(ThreadLocalHelper* helper); |
| |
| // Erase a ThreadLocalHelper to the vector |
| void erase(ThreadLocalHelper* helper); |
| |
| // Erase object in all the helpers stored in vector |
| // Called during destructor of a ThreadLocalPtrImpl |
| void erase_tlp(ThreadLocalPtrImpl* ptr); |
| |
| private: |
| UnsafeAllThreadLocalHelperVector vector_; |
| std::mutex mutex_; |
| }; |
| |
| /** |
| * ThreadLocalHelper is per thread |
| */ |
| class ThreadLocalHelper { |
| public: |
| ThreadLocalHelper(); |
| |
| // When the thread dies, we want to clean up *this* |
| // in AllThreadLocalHelperVector |
| ~ThreadLocalHelper(); |
| |
| // Insert a (object, ptr) pair into the thread local map |
| void insert(ThreadLocalPtrImpl* tl_ptr, std::shared_ptr<void> ptr); |
| // Get the ptr by object |
| void* get(ThreadLocalPtrImpl* key); |
| // Erase the ptr associated with the object in the map |
| void erase(ThreadLocalPtrImpl* key); |
| |
| private: |
| // mapping of object -> ptr in each thread |
| UnsafeThreadLocalMap mapping_; |
| std::mutex mutex_; |
| }; // ThreadLocalHelper |
| |
| /** ThreadLocalPtrImpl is per object |
| */ |
| class ThreadLocalPtrImpl { |
| public: |
| ThreadLocalPtrImpl() {} |
| // Delete copy and move constructors |
| ThreadLocalPtrImpl(const ThreadLocalPtrImpl&) = delete; |
| ThreadLocalPtrImpl(ThreadLocalPtrImpl&&) = delete; |
| ThreadLocalPtrImpl& operator=(const ThreadLocalPtrImpl&) = delete; |
| ThreadLocalPtrImpl& operator=(const ThreadLocalPtrImpl&&) = delete; |
| |
| // In the case when object dies first, we want to |
| // clean up the states in all child threads |
| ~ThreadLocalPtrImpl(); |
| |
| template <typename T> |
| T* get() { |
| return static_cast<T*>(getThreadLocalHelper()->get(this)); |
| } |
| |
| template <typename T> |
| void reset(T* newPtr = nullptr) { |
| VLOG(2) << "In Reset(" << newPtr << ")"; |
| auto* wrapper = getThreadLocalHelper(); |
| // Cleaning up the objects(T) stored in the ThreadLocalPtrImpl in the thread |
| wrapper->erase(this); |
| if (newPtr != nullptr) { |
| std::shared_ptr<void> sharedPtr(newPtr); |
| // Deletion of newPtr is handled by shared_ptr |
| // as it implements type erasure |
| wrapper->insert(this, std::move(sharedPtr)); |
| } |
| } |
| |
| }; // ThreadLocalPtrImpl |
| |
| template <typename T> |
| class ThreadLocalPtr { |
| public: |
| auto* operator->() { |
| return get(); |
| } |
| |
| auto& operator*() { |
| return *get(); |
| } |
| |
| auto* get() { |
| return impl_.get<T>(); |
| } |
| |
| auto* operator->() const { |
| return get(); |
| } |
| |
| auto& operator*() const { |
| return *get(); |
| } |
| |
| auto* get() const { |
| return impl_.get<T>(); |
| } |
| |
| void reset(unique_ptr<T> ptr = nullptr) { |
| impl_.reset<T>(ptr.release()); |
| } |
| |
| private: |
| ThreadLocalPtrImpl impl_; |
| }; |
| |
| } // namespace caffe2 |