blob: 3be979cb4779c48e52387f8d24a23ff3683742b2 [file] [log] [blame]
#include <fcntl.h>
#include <poll.h>
#include <sys/mman.h>
#include <unistd.h>
#include <algorithm>
#include <cerrno>
#include <memory>
#include <set>
#include <unordered_map>
#include <vector>
#include <c10/util/tempfile.h>
#include <libshm/err.h>
#include <libshm/socket.h>
const int SHUTDOWN_TIMEOUT = 2000; // 2s
#ifdef DEBUG_LOG
#define COLOR "\033[31;1m"
#define RESET "\033[0m"
#define __DEBUG(msg, ...) fprintf(stderr, COLOR msg "%c" RESET, __VA_ARGS__);
#define DEBUG(...) __DEBUG(__VA_ARGS__, '\n')
#else
#define DEBUG(...) (void)0
#endif
struct ClientSession {
ClientSession(ManagerSocket s) : socket(std::move(s)), pid(0) {}
ManagerSocket socket;
pid_t pid;
};
std::vector<struct pollfd> pollfds;
std::unordered_map<int, ClientSession> client_sessions;
// TODO: check if objects have been freed from time to time
std::set<std::string> used_objects;
void register_fd(int fd) {
struct pollfd pfd = {0};
pfd.fd = fd;
pfd.events = POLLIN;
pollfds.push_back(pfd);
}
void unregister_fd(int fd) {
pollfds.erase(
std::remove_if(
pollfds.begin(),
pollfds.end(),
[fd](const struct pollfd& pfd) { return pfd.fd == fd; }),
pollfds.end());
client_sessions.erase(fd);
}
void print_init_message(const char* message) {
write(1, message, strlen(message));
write(1, "\n", 1);
}
bool object_exists(const char* name) {
int fd = shm_open(name, O_RDONLY, 0);
if (fd >= 0) {
close(fd);
return true;
} else {
return false;
}
}
void free_used_object(const std::string& name) {
if (!object_exists(name.c_str())) {
DEBUG("object %s appears to have been freed", name.c_str());
used_objects.erase(name);
} else {
DEBUG("object %s still exists", name.c_str());
}
}
// NOLINTNEXTLINE(bugprone-exception-escape)
int main(int argc, char* argv[]) {
setsid(); // Daemonize the process
std::unique_ptr<ManagerServerSocket> srv_socket;
c10::optional<c10::TempDir> tempdir;
try {
tempdir = c10::try_make_tempdir(/*name_prefix=*/"torch-shm-dir-");
if (!tempdir.has_value()) {
throw std::runtime_error(
"could not generate a random directory for manager socket");
}
std::string tempfile = tempdir->name + "/manager.sock";
srv_socket = std::make_unique<ManagerServerSocket>(tempfile);
register_fd(srv_socket->socket_fd);
print_init_message(tempfile.c_str());
DEBUG("opened socket %s", tempfile.c_str());
} catch (const std::exception& e) {
std::string message("ERROR: ");
message += e.what();
print_init_message(message.c_str());
return 1;
} catch (...) {
print_init_message("ERROR: unhandled exception");
return 1;
}
int timeout = -1;
std::vector<int> to_add;
std::vector<int> to_remove;
for (;;) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int nevents;
if (client_sessions.size() == 0)
timeout = SHUTDOWN_TIMEOUT;
SYSCHECK_ERR_RETURN_NEG1(
nevents = poll(pollfds.data(), pollfds.size(), timeout));
timeout = -1;
if (nevents == 0 && client_sessions.size() == 0)
break;
for (auto& pfd : pollfds) {
if (pfd.revents & (POLLERR | POLLHUP)) {
// some process died
DEBUG("detaching process");
auto& session = client_sessions.at(pfd.fd);
(void)session;
DEBUG("%d has died", session.pid);
to_remove.push_back(pfd.fd);
} else if (pfd.revents & POLLIN) {
if (pfd.fd == srv_socket->socket_fd) {
// someone is joining
DEBUG("registered new client");
auto client = srv_socket->accept();
int fd = client.socket_fd;
to_add.push_back(fd);
client_sessions.emplace(fd, std::move(client));
} else {
// someone wants to register a segment
DEBUG("got alloc info");
auto& session = client_sessions.at(pfd.fd);
AllocInfo info = session.socket.receive();
session.pid = info.pid;
DEBUG(
"got alloc info: %d %d %s",
(int)info.free,
info.pid,
info.filename);
if (info.free) {
free_used_object(info.filename);
} else {
used_objects.insert(info.filename);
DEBUG("registered object %s", info.filename);
session.socket.confirm();
}
}
}
}
for (int fd : to_add)
register_fd(fd);
to_add.clear();
for (int fd : to_remove)
unregister_fd(fd);
to_remove.clear();
}
for (auto& obj_name : used_objects) {
DEBUG("freeing %s", obj_name.c_str());
shm_unlink(obj_name.c_str());
}
// Clean up file descriptors
for (auto& pfd : pollfds) {
unregister_fd(pfd.fd);
}
// Clean up manager.sock
srv_socket->remove();
// Clean up directory automatically
DEBUG("manager done");
return 0;
}