blob: cd9a18e7c1d74d8a75831f584cbd8326e723546e [file] [log] [blame]
#include <c10/util/irange.h>
#include "StoreTestCommon.hpp"
#include <cstdlib>
#include <future>
#include <iostream>
#include <system_error>
#include <thread>
#include <gtest/gtest.h>
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
#include <torch/csrc/distributed/c10d/TCPStore.hpp>
constexpr int64_t kShortStoreTimeoutMillis = 100;
constexpr int defaultTimeout = 20;
c10::intrusive_ptr<c10d::TCPStore> _createServer(
bool useLibUV,
int numWorkers = 1,
int timeout = defaultTimeout) {
return c10::make_intrusive<c10d::TCPStore>(
"127.0.0.1",
c10d::TCPStoreOptions{
/* port */ 0,
/* isServer */ true,
numWorkers,
/* waitWorkers */ false,
/* timeout */ std::chrono::seconds(timeout),
/* multiTenant */ false,
/* masterListenFd */ c10::nullopt,
/* useLibUV*/ useLibUV});
}
// Different ports for different tests.
void testHelper(bool useLibUV, const std::string& prefix = "") {
constexpr auto numThreads = 16;
constexpr auto numWorkers = numThreads + 1;
auto serverTCPStore = _createServer(useLibUV, numWorkers);
auto serverStore =
c10::make_intrusive<c10d::PrefixStore>(prefix, serverTCPStore);
// server store
auto serverThread = std::thread([&serverStore, &serverTCPStore] {
// Wait for all workers to join.
serverTCPStore->waitForWorkers();
// Basic set/get on the server store
c10d::test::set(*serverStore, "key0", "value0");
c10d::test::set(*serverStore, "key1", "value1");
c10d::test::set(*serverStore, "key2", "value2");
c10d::test::check(*serverStore, "key0", "value0");
c10d::test::check(*serverStore, "key1", "value1");
c10d::test::check(*serverStore, "key2", "value2");
serverStore->add("counter", 1);
auto numKeys = serverStore->getNumKeys();
// We expect 5 keys since 3 are added above, 'counter' is added by the
// helper thread, and the init key to coordinate workers.
EXPECT_EQ(numKeys, 5);
// Check compareSet, does not check return value
c10d::test::compareSet(
*serverStore, "key0", "wrongExpectedValue", "newValue");
c10d::test::check(*serverStore, "key0", "value0");
c10d::test::compareSet(*serverStore, "key0", "value0", "newValue");
c10d::test::check(*serverStore, "key0", "newValue");
auto delSuccess = serverStore->deleteKey("key0");
// Ensure that the key was successfully deleted
EXPECT_TRUE(delSuccess);
auto delFailure = serverStore->deleteKey("badKeyName");
// The key was not in the store so the delete operation should have failed
// and returned false.
EXPECT_FALSE(delFailure);
numKeys = serverStore->getNumKeys();
EXPECT_EQ(numKeys, 4);
auto timeout = std::chrono::milliseconds(kShortStoreTimeoutMillis);
serverStore->setTimeout(timeout);
EXPECT_THROW(serverStore->get("key0"), c10::Error);
});
// Hammer on TCPStore
std::vector<std::thread> threads;
constexpr auto numIterations = 1000;
c10d::test::Semaphore sem1, sem2;
c10d::TCPStoreOptions opts{};
opts.port = serverTCPStore->getPort();
opts.numWorkers = numWorkers;
// Each thread will have a client store to send/recv data
std::vector<c10::intrusive_ptr<c10d::TCPStore>> clientTCPStores;
std::vector<c10::intrusive_ptr<c10d::PrefixStore>> clientStores;
for (const auto i : c10::irange(numThreads)) {
clientTCPStores.push_back(
c10::make_intrusive<c10d::TCPStore>("127.0.0.1", opts));
clientStores.push_back(
c10::make_intrusive<c10d::PrefixStore>(prefix, clientTCPStores[i]));
}
std::string expectedCounterRes =
std::to_string(numThreads * numIterations + 1);
for (const auto i : c10::irange(numThreads)) {
threads.emplace_back(
std::thread([=, &sem1, &sem2, &clientStores, &expectedCounterRes] {
for (C10_UNUSED const auto j : c10::irange(numIterations)) {
clientStores[i]->add("counter", 1);
}
// Let each thread set and get key on its client store
std::string key = "thread_" + std::to_string(i);
for (const auto j : c10::irange(numIterations)) {
std::string val = "thread_val_" + std::to_string(j);
c10d::test::set(*clientStores[i], key, val);
c10d::test::check(*clientStores[i], key, val);
}
sem1.post();
sem2.wait();
// Check the counter results
c10d::test::check(*clientStores[i], "counter", expectedCounterRes);
// Now check other threads' written data
for (const auto j : c10::irange(numThreads)) {
if (j == i) {
continue;
}
std::string key = "thread_" + std::to_string(i);
std::string val = "thread_val_" + std::to_string(numIterations - 1);
c10d::test::check(*clientStores[i], key, val);
}
}));
}
sem1.wait(numThreads);
sem2.post(numThreads);
for (auto& thread : threads) {
thread.join();
}
serverThread.join();
// Clear the store to test that client disconnect won't shutdown the store
clientStores.clear();
clientTCPStores.clear();
// Check that the counter has the expected value
c10d::test::check(*serverStore, "counter", expectedCounterRes);
// Check that each threads' written data from the main thread
for (const auto i : c10::irange(numThreads)) {
std::string key = "thread_" + std::to_string(i);
std::string val = "thread_val_" + std::to_string(numIterations - 1);
c10d::test::check(*serverStore, key, val);
}
}
TEST(TCPStoreTest, testHelper) {
testHelper(false);
}
TEST(TCPStoreTest, testHelperUV) {
testHelper(true);
}
TEST(TCPStoreTest, testHelperPrefix) {
testHelper(false, "testPrefix");
}
TEST(TCPStoreTest, testHelperPrefixUV) {
testHelper(true, "testPrefix");
}
TEST(TCPStoreTest, testCleanShutdown) {
int numWorkers = 2;
auto serverTCPStore = std::make_unique<c10d::TCPStore>(
"127.0.0.1",
0,
numWorkers,
true,
std::chrono::seconds(defaultTimeout),
/* wait */ false);
c10d::test::set(*serverTCPStore, "key", "val");
auto clientTCPStore = c10::make_intrusive<c10d::TCPStore>(
"127.0.0.1",
c10d::TCPStoreOptions{
/* port */ serverTCPStore->getPort(),
/* isServer */ false,
numWorkers,
/* waitWorkers */ false,
/* timeout */ std::chrono::seconds(defaultTimeout)});
clientTCPStore->get("key");
auto clientThread = std::thread([&clientTCPStore] {
EXPECT_THROW(clientTCPStore->get("invalid_key"), c10::DistNetworkError);
});
// start server shutdown during a client request
serverTCPStore = nullptr;
clientThread.join();
}
void testMultiTenantStores(bool libUV) {
c10d::TCPStoreOptions opts{};
opts.isServer = true;
opts.multiTenant = true;
opts.useLibUV = libUV;
// Construct two server stores on the same port.
auto store1 = c10::make_intrusive<c10d::TCPStore>("localhost", opts);
auto store2 = c10::make_intrusive<c10d::TCPStore>("localhost", opts);
// Assert that the two stores share the same server.
c10d::test::set(*store1, "key0", "value0");
c10d::test::check(*store2, "key0", "value0");
// Dispose the second instance and assert that the server is still alive.
store2.reset();
c10d::test::set(*store1, "key0", "value0");
c10d::test::check(*store1, "key0", "value0");
}
TEST(TCPStoreTest, testMultiTenantStores) {
testMultiTenantStores(false);
}
TEST(TCPStoreTest, testMultiTenantStoresUV) {
testMultiTenantStores(true);
}