blob: 9104e9bb251bad6244d283d64e821b0c1f366968 [file] [log] [blame]
#include <gtest/gtest.h>
#include <thread>
#include <torch/csrc/monitor/counters.h>
#include <torch/csrc/monitor/events.h>
using namespace torch::monitor;
TEST(MonitorTest, CounterDouble) {
Stat<double> a{
"a",
{Aggregation::MEAN, Aggregation::COUNT},
std::chrono::milliseconds(100000),
2,
};
a.add(5.0);
ASSERT_EQ(a.count(), 1);
a.add(6.0);
ASSERT_EQ(a.count(), 0);
auto stats = a.get();
std::unordered_map<Aggregation, double, AggregationHash> want = {
{Aggregation::MEAN, 5.5},
{Aggregation::COUNT, 2.0},
};
ASSERT_EQ(stats, want);
}
TEST(MonitorTest, CounterInt64Sum) {
Stat<int64_t> a{
"a",
{Aggregation::SUM},
std::chrono::milliseconds(100000),
2,
};
a.add(5);
a.add(6);
auto stats = a.get();
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
{Aggregation::SUM, 11},
};
ASSERT_EQ(stats, want);
}
TEST(MonitorTest, CounterInt64Value) {
Stat<int64_t> a{
"a",
{Aggregation::VALUE},
std::chrono::milliseconds(100000),
2,
};
a.add(5);
a.add(6);
auto stats = a.get();
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
{Aggregation::VALUE, 6},
};
ASSERT_EQ(stats, want);
}
TEST(MonitorTest, CounterInt64Mean) {
Stat<int64_t> a{
"a",
{Aggregation::MEAN},
std::chrono::milliseconds(100000),
2,
};
{
// zero samples case
auto stats = a.get();
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
{Aggregation::MEAN, 0},
};
ASSERT_EQ(stats, want);
}
a.add(0);
a.add(10);
{
auto stats = a.get();
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
{Aggregation::MEAN, 5},
};
ASSERT_EQ(stats, want);
}
}
TEST(MonitorTest, CounterInt64Count) {
Stat<int64_t> a{
"a",
{Aggregation::COUNT},
std::chrono::milliseconds(100000),
2,
};
ASSERT_EQ(a.count(), 0);
a.add(0);
ASSERT_EQ(a.count(), 1);
a.add(10);
ASSERT_EQ(a.count(), 0);
auto stats = a.get();
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
{Aggregation::COUNT, 2},
};
ASSERT_EQ(stats, want);
}
TEST(MonitorTest, CounterInt64MinMax) {
Stat<int64_t> a{
"a",
{Aggregation::MIN, Aggregation::MAX},
std::chrono::milliseconds(100000),
6,
};
{
auto stats = a.get();
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
{Aggregation::MAX, 0},
{Aggregation::MIN, 0},
};
ASSERT_EQ(stats, want);
}
a.add(0);
a.add(5);
a.add(-5);
a.add(-6);
a.add(9);
a.add(2);
{
auto stats = a.get();
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
{Aggregation::MAX, 9},
{Aggregation::MIN, -6},
};
ASSERT_EQ(stats, want);
}
}
TEST(MonitorTest, CounterInt64WindowSize) {
Stat<int64_t> a{
"a",
{Aggregation::COUNT, Aggregation::SUM},
std::chrono::milliseconds(100000),
/*windowSize=*/3,
};
a.add(1);
a.add(2);
ASSERT_EQ(a.count(), 2);
a.add(3);
ASSERT_EQ(a.count(), 0);
// after logging max for window, should be zero
a.add(4);
ASSERT_EQ(a.count(), 0);
auto stats = a.get();
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
{Aggregation::COUNT, 3},
{Aggregation::SUM, 6},
};
ASSERT_EQ(stats, want);
}
TEST(MonitorTest, CounterInt64WindowSizeHuge) {
Stat<int64_t> a{
"a",
{Aggregation::COUNT, Aggregation::SUM},
std::chrono::hours(24 * 365 * 10), // 10 years
/*windowSize=*/3,
};
a.add(1);
a.add(2);
ASSERT_EQ(a.count(), 2);
a.add(3);
ASSERT_EQ(a.count(), 0);
// after logging max for window, should be zero
a.add(4);
ASSERT_EQ(a.count(), 0);
auto stats = a.get();
std::unordered_map<Aggregation, int64_t, AggregationHash> want = {
{Aggregation::COUNT, 3},
{Aggregation::SUM, 6},
};
ASSERT_EQ(stats, want);
}
template <typename T>
struct TestStat : public Stat<T> {
uint64_t mockWindowId{1};
TestStat(
std::string name,
std::initializer_list<Aggregation> aggregations,
std::chrono::milliseconds windowSize,
int64_t maxSamples = std::numeric_limits<int64_t>::max())
: Stat<T>(name, aggregations, windowSize, maxSamples) {}
uint64_t currentWindowId() const override {
return mockWindowId;
}
};
struct AggregatingEventHandler : public EventHandler {
std::vector<Event> events;
void handle(const Event& e) override {
events.emplace_back(e);
}
};
template <typename T>
struct HandlerGuard {
std::shared_ptr<T> handler;
HandlerGuard() : handler(std::make_shared<T>()) {
registerEventHandler(handler);
}
~HandlerGuard() {
unregisterEventHandler(handler);
}
};
TEST(MonitorTest, Stat) {
HandlerGuard<AggregatingEventHandler> guard;
Stat<int64_t> a{
"a",
{Aggregation::COUNT, Aggregation::SUM},
std::chrono::milliseconds(1),
};
ASSERT_EQ(guard.handler->events.size(), 0);
a.add(1);
ASSERT_LE(a.count(), 1);
std::this_thread::sleep_for(std::chrono::milliseconds(2));
a.add(2);
ASSERT_LE(a.count(), 1);
ASSERT_GE(guard.handler->events.size(), 1);
ASSERT_LE(guard.handler->events.size(), 2);
}
TEST(MonitorTest, StatEvent) {
HandlerGuard<AggregatingEventHandler> guard;
TestStat<int64_t> a{
"a",
{Aggregation::COUNT, Aggregation::SUM},
std::chrono::milliseconds(1),
};
ASSERT_EQ(guard.handler->events.size(), 0);
a.add(1);
ASSERT_EQ(a.count(), 1);
a.add(2);
ASSERT_EQ(a.count(), 2);
ASSERT_EQ(guard.handler->events.size(), 0);
a.mockWindowId = 100;
a.add(3);
ASSERT_LE(a.count(), 1);
ASSERT_EQ(guard.handler->events.size(), 1);
Event e = guard.handler->events.at(0);
ASSERT_EQ(e.name, "torch.monitor.Stat");
ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{});
std::unordered_map<std::string, data_value_t> data{
{"a.sum", 3L},
{"a.count", 2L},
};
ASSERT_EQ(e.data, data);
}
TEST(MonitorTest, StatEventDestruction) {
HandlerGuard<AggregatingEventHandler> guard;
{
TestStat<int64_t> a{
"a",
{Aggregation::COUNT, Aggregation::SUM},
std::chrono::hours(10),
};
a.add(1);
ASSERT_EQ(a.count(), 1);
ASSERT_EQ(guard.handler->events.size(), 0);
}
ASSERT_EQ(guard.handler->events.size(), 1);
Event e = guard.handler->events.at(0);
ASSERT_EQ(e.name, "torch.monitor.Stat");
ASSERT_NE(e.timestamp, std::chrono::system_clock::time_point{});
std::unordered_map<std::string, data_value_t> data{
{"a.sum", 1L},
{"a.count", 1L},
};
ASSERT_EQ(e.data, data);
}