blob: a8070246dc1d3a01ef74af9fabefeac51c37754c [file] [log] [blame]
#include <c10/util/irange.h>
#include "StoreTestCommon.hpp"
#include <cstdlib>
#include <future>
#include <iostream>
#include <system_error>
#include <thread>
#include <gtest/gtest.h>
#include <c10d/PrefixStore.hpp>
#include <c10d/TCPStore.hpp>
constexpr int64_t kShortStoreTimeoutMillis = 100;
constexpr int64_t kStoreCallbackTimeoutMillis = 5000;
constexpr int defaultTimeout = 20;
c10::intrusive_ptr<c10d::TCPStore> _createServer(
int numWorkers = 1,
int timeout = defaultTimeout) {
return c10::make_intrusive<c10d::TCPStore>(
"127.0.0.1",
c10d::TCPStoreOptions{
/* port */ 0,
/* isServer */ true,
numWorkers,
/* waitWorkers */ false,
/* timeout */ std::chrono::seconds(timeout)});
}
// Different ports for different tests.
void testHelper(const std::string& prefix = "") {
constexpr auto numThreads = 16;
constexpr auto numWorkers = numThreads + 1;
auto serverTCPStore = _createServer(numWorkers);
auto serverStore =
c10::make_intrusive<c10d::PrefixStore>(prefix, serverTCPStore);
// server store
auto serverThread = std::thread([&serverStore, &serverTCPStore] {
// Wait for all workers to join.
serverTCPStore->waitForWorkers();
// Basic set/get on the server store
c10d::test::set(*serverStore, "key0", "value0");
c10d::test::set(*serverStore, "key1", "value1");
c10d::test::set(*serverStore, "key2", "value2");
c10d::test::check(*serverStore, "key0", "value0");
c10d::test::check(*serverStore, "key1", "value1");
c10d::test::check(*serverStore, "key2", "value2");
serverStore->add("counter", 1);
auto numKeys = serverStore->getNumKeys();
// We expect 5 keys since 3 are added above, 'counter' is added by the
// helper thread, and the init key to coordinate workers.
EXPECT_EQ(numKeys, 5);
// Check compareSet, does not check return value
c10d::test::compareSet(
*serverStore, "key0", "wrongExpectedValue", "newValue");
c10d::test::check(*serverStore, "key0", "value0");
c10d::test::compareSet(*serverStore, "key0", "value0", "newValue");
c10d::test::check(*serverStore, "key0", "newValue");
auto delSuccess = serverStore->deleteKey("key0");
// Ensure that the key was successfully deleted
EXPECT_TRUE(delSuccess);
auto delFailure = serverStore->deleteKey("badKeyName");
// The key was not in the store so the delete operation should have failed
// and returned false.
EXPECT_FALSE(delFailure);
numKeys = serverStore->getNumKeys();
EXPECT_EQ(numKeys, 4);
auto timeout = std::chrono::milliseconds(kShortStoreTimeoutMillis);
serverStore->setTimeout(timeout);
EXPECT_THROW(serverStore->get("key0"), c10::Error);
});
// Hammer on TCPStore
std::vector<std::thread> threads;
constexpr auto numIterations = 1000;
c10d::test::Semaphore sem1, sem2;
c10d::TCPStoreOptions opts{};
opts.port = serverTCPStore->getPort();
opts.numWorkers = numWorkers;
// Each thread will have a client store to send/recv data
std::vector<c10::intrusive_ptr<c10d::TCPStore>> clientTCPStores;
std::vector<c10::intrusive_ptr<c10d::PrefixStore>> clientStores;
for (const auto i : c10::irange(numThreads)) {
clientTCPStores.push_back(
c10::make_intrusive<c10d::TCPStore>("127.0.0.1", opts));
clientStores.push_back(
c10::make_intrusive<c10d::PrefixStore>(prefix, clientTCPStores[i]));
}
std::string expectedCounterRes =
std::to_string(numThreads * numIterations + 1);
for (const auto i : c10::irange(numThreads)) {
threads.emplace_back(
std::thread([=, &sem1, &sem2, &clientStores, &expectedCounterRes] {
for (C10_UNUSED const auto j : c10::irange(numIterations)) {
clientStores[i]->add("counter", 1);
}
// Let each thread set and get key on its client store
std::string key = "thread_" + std::to_string(i);
for (const auto j : c10::irange(numIterations)) {
std::string val = "thread_val_" + std::to_string(j);
c10d::test::set(*clientStores[i], key, val);
c10d::test::check(*clientStores[i], key, val);
}
sem1.post();
sem2.wait();
// Check the counter results
c10d::test::check(*clientStores[i], "counter", expectedCounterRes);
// Now check other threads' written data
for (const auto j : c10::irange(numThreads)) {
if (j == i) {
continue;
}
std::string key = "thread_" + std::to_string(i);
std::string val = "thread_val_" + std::to_string(numIterations - 1);
c10d::test::check(*clientStores[i], key, val);
}
}));
}
sem1.wait(numThreads);
sem2.post(numThreads);
for (auto& thread : threads) {
thread.join();
}
serverThread.join();
// Clear the store to test that client disconnect won't shutdown the store
clientStores.clear();
clientTCPStores.clear();
// Check that the counter has the expected value
c10d::test::check(*serverStore, "counter", expectedCounterRes);
// Check that each threads' written data from the main thread
for (const auto i : c10::irange(numThreads)) {
std::string key = "thread_" + std::to_string(i);
std::string val = "thread_val_" + std::to_string(numIterations - 1);
c10d::test::check(*serverStore, key, val);
}
}
void testWatchKeyCallback(const std::string& prefix = "") {
// Callback function increments counter of the total number of callbacks that
// were run
std::promise<int> numCallbacksExecutedPromise;
std::atomic<int> numCallbacksExecuted{0};
constexpr int numThreads = 16;
constexpr int keyChangeOperation = 3;
c10d::WatchKeyCallback callback =
[=, &numCallbacksExecuted, &numCallbacksExecutedPromise](
c10::optional<std::string> /* unused */,
c10::optional<std::string> /* unused */) {
numCallbacksExecuted++;
if (numCallbacksExecuted == numThreads * keyChangeOperation * 2) {
numCallbacksExecutedPromise.set_value(numCallbacksExecuted);
}
};
const int numWorkers = numThreads + 1;
auto serverTCPStore = _createServer(numWorkers);
auto serverStore =
c10::make_intrusive<c10d::PrefixStore>(prefix, serverTCPStore);
c10d::TCPStoreOptions opts{};
opts.port = serverTCPStore->getPort();
opts.numWorkers = numWorkers;
// Each thread will have a client store to send/recv data
std::vector<c10::intrusive_ptr<c10d::TCPStore>> clientTCPStores;
std::vector<c10::intrusive_ptr<c10d::PrefixStore>> clientStores;
for (const auto i : c10::irange(numThreads)) {
clientTCPStores.push_back(
c10::make_intrusive<c10d::TCPStore>("127.0.0.1", opts));
clientStores.push_back(
c10::make_intrusive<c10d::PrefixStore>(prefix, clientTCPStores[i]));
}
// Start watching key on server and client stores
std::string internalKey = "internalKey";
std::string internalKeyCount = "internalKeyCount";
for (const auto i : c10::irange(numThreads)) {
serverStore->watchKey(internalKey + std::to_string(i), callback);
serverStore->watchKey(internalKeyCount + std::to_string(i), callback);
clientStores[i]->watchKey(internalKey + std::to_string(i), callback);
clientStores[i]->watchKey(internalKeyCount + std::to_string(i), callback);
}
std::vector<std::thread> threads;
std::atomic<int> keyChangeOperationCount{0};
for (const auto i : c10::irange(numThreads)) {
threads.emplace_back(std::thread([=,
&clientStores,
&internalKey,
&internalKeyCount,
&keyChangeOperationCount] {
// Let each thread set and get key on its client store
std::string key = internalKey + std::to_string(i);
std::string keyCounter = internalKeyCount + std::to_string(i);
std::string val = "thread_val_" + std::to_string(i);
// The set, compareSet, add methods count as key change operations
c10d::test::set(*clientStores[i], key, val);
c10d::test::compareSet(*clientStores[i], key, val, "newValue");
clientStores[i]->add(keyCounter, i);
keyChangeOperationCount += keyChangeOperation * 2;
c10d::test::check(*clientStores[i], key, "newValue");
c10d::test::check(*clientStores[i], keyCounter, std::to_string(i));
}));
}
// Ensures that internal_key has been "set" and "get"
for (auto& thread : threads) {
thread.join();
}
std::future<int> numCallbacksExecutedFuture =
numCallbacksExecutedPromise.get_future();
std::chrono::milliseconds span(kStoreCallbackTimeoutMillis);
if (numCallbacksExecutedFuture.wait_for(span) == std::future_status::timeout)
TORCH_CHECK(false, "Callback execution timed out.");
// Check number of callbacks executed equal to number of key change operations
// Wait for all callbacks to be triggered
EXPECT_EQ(keyChangeOperationCount, numCallbacksExecutedFuture.get());
}
TEST(TCPStoreTest, testHelper) {
testHelper();
}
TEST(TCPStoreTest, testHelperPrefix) {
testHelper("testPrefix");
}
TEST(TCPStoreTest, testWatchKeyCallback) {
testWatchKeyCallback();
}
TEST(TCPStoreTest, testWatchKeyCallbackWithPrefix) {
testWatchKeyCallback("testPrefix");
}
// Helper function to create a key on the store, watch it, and run the callback
void testKeyChangeHelper(
c10d::Store& store,
std::string key,
const c10::optional<std::string>& expectedOldValue,
const c10::optional<std::string>& expectedNewValue) {
std::exception_ptr eptr = nullptr;
std::promise<bool> callbackPromise;
// Test the correctness of new_value and old_value
c10d::WatchKeyCallback callback = [expectedOldValue,
expectedNewValue,
&callbackPromise,
&eptr](
c10::optional<std::string> oldValue,
c10::optional<std::string> newValue) {
try {
EXPECT_EQ(expectedOldValue.value_or("NONE"), oldValue.value_or("NONE"));
EXPECT_EQ(expectedNewValue.value_or("NONE"), newValue.value_or("NONE"));
} catch (...) {
eptr = std::current_exception();
}
callbackPromise.set_value(true);
};
store.watchKey(key, callback);
// Perform the specified update according to key
if (key == "testEmptyKeyValue" || key == "testRegularKeyValue" ||
key == "testWatchKeyCreate") {
c10d::test::set(store, key, expectedNewValue.value());
} else if (key == "testWatchKeyAdd") {
store.add(key, std::stoi(expectedNewValue.value()));
} else if (key == "testWatchKeyDelete") {
store.deleteKey(key);
}
// Test that the callback is fired and the expected values are correct
std::future<bool> callbackFuture = callbackPromise.get_future();
std::chrono::milliseconds span(kStoreCallbackTimeoutMillis);
if (callbackFuture.wait_for(span) == std::future_status::timeout)
TORCH_CHECK(false, "Callback execution timed out.");
// Any exceptions raised from asserts should be rethrown
if (eptr)
std::rethrow_exception(eptr);
}
TEST(TCPStoreTest, testKeyEmptyUpdate) {
auto store = _createServer();
std::string key = "testEmptyKeyValue";
c10d::test::set(*store, key, "");
store->get(key);
testKeyChangeHelper(*store, key, "", "2");
}
TEST(TCPStoreTest, testKeyUpdate) {
auto store = _createServer();
std::string key = "testRegularKeyValue";
c10d::test::set(*store, key, "1");
store->get(key);
testKeyChangeHelper(*store, key, "1", "2");
}
TEST(TCPStoreTest, testKeyCreate) {
auto store = _createServer();
std::string key = "testWatchKeyCreate";
testKeyChangeHelper(*store, key, c10::nullopt, "2");
}
TEST(TCPStoreTest, testKeyAdd) {
auto store = _createServer();
std::string key = "testWatchKeyAdd";
testKeyChangeHelper(*store, key, c10::nullopt, "2");
}
TEST(TCPStoreTest, testKeyDelete) {
auto store = _createServer();
std::string key = "testWatchKeyDelete";
c10d::test::set(*store, key, "1");
store->get(key);
testKeyChangeHelper(*store, key, "1", c10::nullopt);
}
TEST(TCPStoreTest, testCleanShutdown) {
int numWorkers = 2;
auto serverTCPStore = std::make_unique<c10d::TCPStore>(
"127.0.0.1",
0,
numWorkers,
true,
std::chrono::seconds(defaultTimeout),
/* wait */ false);
c10d::test::set(*serverTCPStore, "key", "val");
auto clientTCPStore = c10::make_intrusive<c10d::TCPStore>(
"127.0.0.1",
c10d::TCPStoreOptions{
/* port */ serverTCPStore->getPort(),
/* isServer */ false,
numWorkers,
/* waitWorkers */ false,
/* timeout */ std::chrono::seconds(defaultTimeout)});
clientTCPStore->get("key");
auto clientThread = std::thread([&clientTCPStore] {
EXPECT_THROW(clientTCPStore->get("invalid_key"), std::system_error);
});
// start server shutdown during a client request
serverTCPStore = nullptr;
clientThread.join();
}
TEST(TCPStoreTest, testMultiTenantStores) {
c10d::TCPStoreOptions opts{};
opts.isServer = true;
opts.multiTenant = true;
// Construct two server stores on the same port.
auto store1 = c10::make_intrusive<c10d::TCPStore>("localhost", opts);
auto store2 = c10::make_intrusive<c10d::TCPStore>("localhost", opts);
// Assert that the two stores share the same server.
c10d::test::set(*store1, "key0", "value0");
c10d::test::check(*store2, "key0", "value0");
// Dispose the second instance and assert that the server is still alive.
store2.reset();
c10d::test::set(*store1, "key0", "value0");
c10d::test::check(*store1, "key0", "value0");
}