blob: 5ef83270a2a4b38061975cd9ccee0831a4bc7d63 [file] [log] [blame]
#pragma once
#include <bitset>
#include <mutex>
#include <sstream>
#include <unordered_map>
#include <vector>
#include <c10/macros/Macros.h>
#include <torch/csrc/monitor/events.h>
namespace torch {
namespace monitor {
constexpr int NUM_AGGREGATIONS = 7;
// Aggregation is the list of possible aggregations for Stats.
// These use bitwise flags so they can be efficiently stored.
enum class C10_API_ENUM Aggregation {
// NONE means no aggregations are set.
NONE = 0,
// VALUE exports the most recently set value.
VALUE = 1,
// MEAN computes the mean of the set values within the window. Zero if no
// values.
MEAN = 2,
// COUNT tracks the number of times a value is set within the window.
COUNT = 3,
// SUM computes the sum of the values set within the window.
SUM = 4,
// MIN computes the minimum of the values set within the window. Zero if no
// values.
MAX = 5,
// MAX computes the maximum of the values set within the window. Zero if no
// values.
MIN = 6,
};
struct TORCH_API AggregationHash {
template <typename T>
std::size_t operator()(T t) const {
return static_cast<std::size_t>(t);
}
};
// aggregationName returns the human readable name corresponding to the
// aggregation.
TORCH_API const char* aggregationName(Aggregation agg);
template <typename T>
class Stat;
namespace {
template <typename T>
inline std::bitset<NUM_AGGREGATIONS> merge(T& list) {
std::bitset<NUM_AGGREGATIONS> a;
for (Aggregation b : list) {
a.set(static_cast<int>(b));
}
return a;
}
} // namespace
namespace detail {
void TORCH_API registerStat(Stat<double>* stat);
void TORCH_API registerStat(Stat<int64_t>* stat);
void TORCH_API unregisterStat(Stat<double>* stat);
void TORCH_API unregisterStat(Stat<int64_t>* stat);
} // namespace detail
// Stat is used to compute summary statistics in a performant way over fixed
// intervals. Stat logs the statistics as an Event once every `windowSize`
// duration. When the window closes the stats are logged via the event handlers
// as a `torch.monitor.Stat` event.
//
// `windowSize` should be set to something relatively high to avoid a huge
// number of events being logged. Ex: 60s. Stat uses millisecond precision.
//
// If maxSamples is set, the stat will cap the number of samples per window by
// discarding `add` calls once `maxSamples` adds have occurred. If it's not set,
// all `add` calls during the window will be included.
// This is an optional field to make aggregations more directly comparable
// across windows when the number of samples might vary.
//
// Stats support double and int64_t data types depending on what needs to be
// logged and needs to be templatized with one of them.
//
// When the Stat is destructed it will log any remaining data even if the window
// hasn't elapsed.
template <typename T>
class Stat {
private:
struct Values {
T value{0};
T sum{0};
T min{0};
T max{0};
int64_t count{0};
};
public:
Stat(
std::string name,
std::initializer_list<Aggregation> aggregations,
std::chrono::milliseconds windowSize,
int64_t maxSamples = std::numeric_limits<int64_t>::max())
: name_(std::move(name)),
aggregations_(merge(aggregations)),
windowSize_(windowSize),
maxSamples_(maxSamples) {
detail::registerStat(this);
}
Stat(
std::string name,
std::vector<Aggregation> aggregations,
std::chrono::milliseconds windowSize,
int64_t maxSamples = std::numeric_limits<int64_t>::max())
: name_(std::move(name)),
aggregations_(merge(aggregations)),
windowSize_(windowSize),
maxSamples_(maxSamples) {
detail::registerStat(this);
}
virtual ~Stat() {
{
// on destruction log if there's unlogged data
std::lock_guard<std::mutex> guard(mu_);
logLocked();
}
detail::unregisterStat(this);
}
// add adds the value v to the current window.
void add(T v) {
std::lock_guard<std::mutex> guard(mu_);
maybeLogLocked();
if (alreadyLogged()) {
return;
}
if (aggregations_.test(static_cast<int>(Aggregation::VALUE))) {
current_.value = v;
}
if (aggregations_.test(static_cast<int>(Aggregation::MEAN)) ||
aggregations_.test(static_cast<int>(Aggregation::SUM))) {
current_.sum += v;
}
if (aggregations_.test(static_cast<int>(Aggregation::MAX))) {
if (current_.max < v || current_.count == 0) {
current_.max = v;
}
}
if (aggregations_.test(static_cast<int>(Aggregation::MIN))) {
if (current_.min > v || current_.count == 0) {
current_.min = v;
}
}
current_.count += 1;
maybeLogLocked();
}
const std::string& name() const noexcept {
return name_;
}
// count returns the number of items in the current open window.
int64_t count() noexcept {
std::lock_guard<std::mutex> guard(mu_);
return current_.count;
}
std::unordered_map<Aggregation, T, AggregationHash> get() noexcept {
std::lock_guard<std::mutex> guard(mu_);
return getLocked();
}
protected:
virtual uint64_t currentWindowId() const {
std::chrono::milliseconds now =
std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now().time_since_epoch());
// always returns a currentWindowId of at least 1 to avoid 0 window issues
return (now / windowSize_) + 1;
}
private:
bool alreadyLogged() {
return lastLoggedWindowId_ == currentWindowId();
}
void maybeLogLocked() {
auto windowId = currentWindowId();
bool shouldLog = windowId_ != windowId || current_.count >= maxSamples_;
if (shouldLog && !alreadyLogged()) {
logLocked();
lastLoggedWindowId_ = windowId_;
windowId_ = windowId;
}
}
void logLocked() {
prev_ = current_;
current_ = Values();
// don't log event if there's no data
if (prev_.count == 0) {
return;
}
Event e;
e.name = "torch.monitor.Stat";
e.timestamp = std::chrono::system_clock::now();
auto stats = getLocked();
e.data.reserve(stats.size());
for (auto& kv : stats) {
std::stringstream key;
key << name_;
key << ".";
key << aggregationName(kv.first);
e.data[key.str()] = kv.second;
}
logEvent(e);
}
std::unordered_map<Aggregation, T, AggregationHash> getLocked()
const noexcept {
std::unordered_map<Aggregation, T, AggregationHash> out;
out.reserve(aggregations_.count());
if (aggregations_.test(static_cast<int>(Aggregation::VALUE))) {
out.emplace(Aggregation::VALUE, prev_.value);
}
if (aggregations_.test(static_cast<int>(Aggregation::MEAN))) {
if (prev_.count == 0) {
out.emplace(Aggregation::MEAN, 0);
} else {
out.emplace(Aggregation::MEAN, prev_.sum / prev_.count);
}
}
if (aggregations_.test(static_cast<int>(Aggregation::COUNT))) {
out.emplace(Aggregation::COUNT, prev_.count);
}
if (aggregations_.test(static_cast<int>(Aggregation::SUM))) {
out.emplace(Aggregation::SUM, prev_.sum);
}
if (aggregations_.test(static_cast<int>(Aggregation::MAX))) {
out.emplace(Aggregation::MAX, prev_.max);
}
if (aggregations_.test(static_cast<int>(Aggregation::MIN))) {
out.emplace(Aggregation::MIN, prev_.min);
}
return out;
}
const std::string name_;
const std::bitset<NUM_AGGREGATIONS> aggregations_;
std::mutex mu_;
Values current_;
Values prev_;
uint64_t windowId_{0};
uint64_t lastLoggedWindowId_{0};
const std::chrono::milliseconds windowSize_;
const int64_t maxSamples_;
};
} // namespace monitor
} // namespace torch