| #ifndef CAFFE2_OPERATORS_COUNTER_OPS_H |
| #define CAFFE2_OPERATORS_COUNTER_OPS_H |
| |
| #include <atomic> |
| |
| #include "caffe2/core/context.h" |
| #include "caffe2/core/logging.h" |
| #include "caffe2/core/operator.h" |
| |
| namespace caffe2 { |
| template <typename T> |
| class TORCH_API Counter { |
| public: |
| explicit Counter(T count) : count_(count) {} |
| bool countDown() { |
| if (count_-- > 0) { |
| return false; |
| } |
| return true; |
| } |
| |
| T countUp() { |
| return count_++; |
| } |
| |
| T retrieve() const { |
| return count_.load(); |
| } |
| |
| T checkIfDone() const { |
| return (count_.load() <= 0); |
| } |
| |
| T reset(T init_count) { |
| return count_.exchange(init_count); |
| } |
| |
| private: |
| std::atomic<T> count_; |
| }; |
| |
| // TODO(jiayq): deprecate these ops & consolidate them with IterOp/AtomicIterOp |
| |
| template <typename T, class Context> |
| class CreateCounterOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| template <class... Args> |
| explicit CreateCounterOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...), |
| init_count_(this->template GetSingleArgument<T>("init_count", 0)) { |
| CAFFE_ENFORCE_LE(0, init_count_, "negative init_count is not permitted."); |
| } |
| |
| bool RunOnDevice() override { |
| *this->template Output<std::unique_ptr<Counter<T>>>(0) = |
| std::unique_ptr<Counter<T>>(new Counter<T>(init_count_)); |
| return true; |
| } |
| |
| private: |
| T init_count_ = 0; |
| }; |
| |
| template <typename T, class Context> |
| class ResetCounterOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| template <class... Args> |
| explicit ResetCounterOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...), |
| init_count_(this->template GetSingleArgument<T>("init_count", 0)) { |
| CAFFE_ENFORCE_LE(0, init_count_, "negative init_count is not permitted."); |
| } |
| |
| bool RunOnDevice() override { |
| auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0); |
| auto previous = counterPtr->reset(init_count_); |
| if (OutputSize() == 1) { |
| auto* output = Output(0); |
| output->Resize(); |
| *output->template mutable_data<T>() = previous; |
| } |
| return true; |
| } |
| |
| private: |
| T init_count_; |
| }; |
| |
| // Will always use TensorCPU regardless the Context |
| template <typename T, class Context> |
| class CountDownOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| template <class... Args> |
| explicit CountDownOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...) {} |
| |
| bool RunOnDevice() override { |
| auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0); |
| auto* output = Output(0); |
| output->Resize(std::vector<int>{}); |
| *output->template mutable_data<bool>() = counterPtr->countDown(); |
| return true; |
| } |
| }; |
| |
| // Will always use TensorCPU regardless the Context |
| template <typename T, class Context> |
| class CheckCounterDoneOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| template <class... Args> |
| explicit CheckCounterDoneOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...) {} |
| |
| bool RunOnDevice() override { |
| auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0); |
| auto* output = Output(0); |
| output->Resize(std::vector<int>{}); |
| *output->template mutable_data<bool>() = counterPtr->checkIfDone(); |
| return true; |
| } |
| }; |
| |
| // Will always use TensorCPU regardless the Context |
| template <typename T, class Context> |
| class CountUpOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| template <class... Args> |
| explicit CountUpOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...) {} |
| |
| bool RunOnDevice() override { |
| auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0); |
| auto* output = Output(0); |
| output->Resize(std::vector<int>{}); |
| *output->template mutable_data<T>() = counterPtr->countUp(); |
| return true; |
| } |
| }; |
| |
| // Will always use TensorCPU regardless the Context |
| template <typename T, class Context> |
| class RetrieveCountOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| template <class... Args> |
| explicit RetrieveCountOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...) {} |
| |
| bool RunOnDevice() override { |
| auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0); |
| auto* output = Output(0); |
| output->Resize(std::vector<int>{}); |
| *output->template mutable_data<T>() = counterPtr->retrieve(); |
| return true; |
| } |
| }; |
| |
| } // namespace caffe2 |
| #endif // CAFFE2_OPERATORS_COUNTER_OPS_H_ |