| #pragma once |
| |
| #include <bitset> |
| #include <mutex> |
| #include <sstream> |
| #include <unordered_map> |
| #include <vector> |
| |
| #include <c10/macros/Macros.h> |
| |
| #include <torch/csrc/monitor/events.h> |
| |
| namespace torch { |
| namespace monitor { |
| |
| constexpr int NUM_AGGREGATIONS = 7; |
| |
| // Aggregation is the list of possible aggregations for Stats. |
| // These use bitwise flags so they can be efficiently stored. |
| enum class C10_API_ENUM Aggregation { |
| // NONE means no aggregations are set. |
| NONE = 0, |
| // VALUE exports the most recently set value. |
| VALUE = 1, |
| // MEAN computes the mean of the set values within the window. Zero if no |
| // values. |
| MEAN = 2, |
| // COUNT tracks the number of times a value is set within the window. |
| COUNT = 3, |
| // SUM computes the sum of the values set within the window. |
| SUM = 4, |
| // MIN computes the minimum of the values set within the window. Zero if no |
| // values. |
| MAX = 5, |
| // MAX computes the maximum of the values set within the window. Zero if no |
| // values. |
| MIN = 6, |
| }; |
| |
| struct TORCH_API AggregationHash { |
| template <typename T> |
| std::size_t operator()(T t) const { |
| return static_cast<std::size_t>(t); |
| } |
| }; |
| |
| // aggregationName returns the human readable name corresponding to the |
| // aggregation. |
| TORCH_API const char* aggregationName(Aggregation agg); |
| |
| template <typename T> |
| class Stat; |
| |
| namespace { |
| template <typename T> |
| inline std::bitset<NUM_AGGREGATIONS> merge(T& list) { |
| std::bitset<NUM_AGGREGATIONS> a; |
| for (Aggregation b : list) { |
| a.set(static_cast<int>(b)); |
| } |
| return a; |
| } |
| } // namespace |
| |
| namespace detail { |
| void TORCH_API registerStat(Stat<double>* stat); |
| void TORCH_API registerStat(Stat<int64_t>* stat); |
| void TORCH_API unregisterStat(Stat<double>* stat); |
| void TORCH_API unregisterStat(Stat<int64_t>* stat); |
| } // namespace detail |
| |
| // Stat is used to compute summary statistics in a performant way over fixed |
| // intervals. Stat logs the statistics as an Event once every `windowSize` |
| // duration. When the window closes the stats are logged via the event handlers |
| // as a `torch.monitor.Stat` event. |
| // |
| // `windowSize` should be set to something relatively high to avoid a huge |
| // number of events being logged. Ex: 60s. Stat uses millisecond precision. |
| // |
| // If maxSamples is set, the stat will cap the number of samples per window by |
| // discarding `add` calls once `maxSamples` adds have occurred. If it's not set, |
| // all `add` calls during the window will be included. |
| // This is an optional field to make aggregations more directly comparable |
| // across windows when the number of samples might vary. |
| // |
| // Stats support double and int64_t data types depending on what needs to be |
| // logged and needs to be templatized with one of them. |
| // |
| // When the Stat is destructed it will log any remaining data even if the window |
| // hasn't elapsed. |
| template <typename T> |
| class Stat { |
| private: |
| struct Values { |
| T value{0}; |
| T sum{0}; |
| T min{0}; |
| T max{0}; |
| int64_t count{0}; |
| }; |
| |
| public: |
| Stat( |
| std::string name, |
| std::initializer_list<Aggregation> aggregations, |
| std::chrono::milliseconds windowSize, |
| int64_t maxSamples = std::numeric_limits<int64_t>::max()) |
| : name_(std::move(name)), |
| aggregations_(merge(aggregations)), |
| windowSize_(windowSize), |
| maxSamples_(maxSamples) { |
| detail::registerStat(this); |
| } |
| |
| Stat( |
| std::string name, |
| std::vector<Aggregation> aggregations, |
| std::chrono::milliseconds windowSize, |
| int64_t maxSamples = std::numeric_limits<int64_t>::max()) |
| : name_(std::move(name)), |
| aggregations_(merge(aggregations)), |
| windowSize_(windowSize), |
| maxSamples_(maxSamples) { |
| detail::registerStat(this); |
| } |
| |
| virtual ~Stat() { |
| { |
| // on destruction log if there's unlogged data |
| std::lock_guard<std::mutex> guard(mu_); |
| logLocked(); |
| } |
| detail::unregisterStat(this); |
| } |
| |
| // add adds the value v to the current window. |
| void add(T v) { |
| std::lock_guard<std::mutex> guard(mu_); |
| maybeLogLocked(); |
| |
| if (alreadyLogged()) { |
| return; |
| } |
| |
| if (aggregations_.test(static_cast<int>(Aggregation::VALUE))) { |
| current_.value = v; |
| } |
| if (aggregations_.test(static_cast<int>(Aggregation::MEAN)) || |
| aggregations_.test(static_cast<int>(Aggregation::SUM))) { |
| current_.sum += v; |
| } |
| |
| if (aggregations_.test(static_cast<int>(Aggregation::MAX))) { |
| if (current_.max < v || current_.count == 0) { |
| current_.max = v; |
| } |
| } |
| if (aggregations_.test(static_cast<int>(Aggregation::MIN))) { |
| if (current_.min > v || current_.count == 0) { |
| current_.min = v; |
| } |
| } |
| |
| current_.count += 1; |
| maybeLogLocked(); |
| } |
| |
| const std::string& name() const noexcept { |
| return name_; |
| } |
| |
| // count returns the number of items in the current open window. |
| int64_t count() noexcept { |
| std::lock_guard<std::mutex> guard(mu_); |
| |
| return current_.count; |
| } |
| |
| std::unordered_map<Aggregation, T, AggregationHash> get() noexcept { |
| std::lock_guard<std::mutex> guard(mu_); |
| return getLocked(); |
| } |
| |
| protected: |
| virtual uint64_t currentWindowId() const { |
| std::chrono::milliseconds now = |
| std::chrono::duration_cast<std::chrono::milliseconds>( |
| std::chrono::steady_clock::now().time_since_epoch()); |
| |
| // always returns a currentWindowId of at least 1 to avoid 0 window issues |
| return (now / windowSize_) + 1; |
| } |
| |
| private: |
| bool alreadyLogged() { |
| return lastLoggedWindowId_ == currentWindowId(); |
| } |
| |
| void maybeLogLocked() { |
| auto windowId = currentWindowId(); |
| bool shouldLog = windowId_ != windowId || current_.count >= maxSamples_; |
| if (shouldLog && !alreadyLogged()) { |
| logLocked(); |
| lastLoggedWindowId_ = windowId_; |
| windowId_ = windowId; |
| } |
| } |
| |
| void logLocked() { |
| prev_ = current_; |
| current_ = Values(); |
| |
| // don't log event if there's no data |
| if (prev_.count == 0) { |
| return; |
| } |
| |
| Event e; |
| e.name = "torch.monitor.Stat"; |
| e.timestamp = std::chrono::system_clock::now(); |
| |
| auto stats = getLocked(); |
| e.data.reserve(stats.size()); |
| for (auto& kv : stats) { |
| std::stringstream key; |
| key << name_; |
| key << "."; |
| key << aggregationName(kv.first); |
| e.data[key.str()] = kv.second; |
| } |
| |
| logEvent(e); |
| } |
| |
| std::unordered_map<Aggregation, T, AggregationHash> getLocked() |
| const noexcept { |
| std::unordered_map<Aggregation, T, AggregationHash> out; |
| out.reserve(aggregations_.count()); |
| |
| if (aggregations_.test(static_cast<int>(Aggregation::VALUE))) { |
| out.emplace(Aggregation::VALUE, prev_.value); |
| } |
| if (aggregations_.test(static_cast<int>(Aggregation::MEAN))) { |
| if (prev_.count == 0) { |
| out.emplace(Aggregation::MEAN, 0); |
| } else { |
| out.emplace(Aggregation::MEAN, prev_.sum / prev_.count); |
| } |
| } |
| if (aggregations_.test(static_cast<int>(Aggregation::COUNT))) { |
| out.emplace(Aggregation::COUNT, prev_.count); |
| } |
| if (aggregations_.test(static_cast<int>(Aggregation::SUM))) { |
| out.emplace(Aggregation::SUM, prev_.sum); |
| } |
| if (aggregations_.test(static_cast<int>(Aggregation::MAX))) { |
| out.emplace(Aggregation::MAX, prev_.max); |
| } |
| if (aggregations_.test(static_cast<int>(Aggregation::MIN))) { |
| out.emplace(Aggregation::MIN, prev_.min); |
| } |
| |
| return out; |
| } |
| |
| const std::string name_; |
| const std::bitset<NUM_AGGREGATIONS> aggregations_; |
| |
| std::mutex mu_; |
| Values current_; |
| Values prev_; |
| |
| uint64_t windowId_{0}; |
| uint64_t lastLoggedWindowId_{0}; |
| const std::chrono::milliseconds windowSize_; |
| const int64_t maxSamples_; |
| }; |
| } // namespace monitor |
| } // namespace torch |