| #include <fcntl.h> |
| #include <poll.h> |
| #include <sys/mman.h> |
| #include <unistd.h> |
| #include <algorithm> |
| #include <cerrno> |
| #include <memory> |
| #include <set> |
| #include <unordered_map> |
| #include <vector> |
| |
| #include <c10/util/tempfile.h> |
| |
| #include <libshm/err.h> |
| #include <libshm/socket.h> |
| |
| const int SHUTDOWN_TIMEOUT = 2000; // 2s |
| |
| #ifdef DEBUG_LOG |
| #define COLOR "\033[31;1m" |
| #define RESET "\033[0m" |
| #define __DEBUG(msg, ...) fprintf(stderr, COLOR msg "%c" RESET, __VA_ARGS__); |
| #define DEBUG(...) __DEBUG(__VA_ARGS__, '\n') |
| #else |
| #define DEBUG(...) (void)0 |
| #endif |
| |
| struct ClientSession { |
| ClientSession(ManagerSocket s) : socket(std::move(s)), pid(0) {} |
| |
| ManagerSocket socket; |
| pid_t pid; |
| }; |
| |
| std::vector<struct pollfd> pollfds; |
| std::unordered_map<int, ClientSession> client_sessions; |
| // TODO: check if objects have been freed from time to time |
| std::set<std::string> used_objects; |
| |
| void register_fd(int fd) { |
| struct pollfd pfd = {0}; |
| pfd.fd = fd; |
| pfd.events = POLLIN; |
| pollfds.push_back(pfd); |
| } |
| |
| void unregister_fd(int fd) { |
| pollfds.erase( |
| std::remove_if( |
| pollfds.begin(), |
| pollfds.end(), |
| [fd](const struct pollfd& pfd) { return pfd.fd == fd; }), |
| pollfds.end()); |
| client_sessions.erase(fd); |
| } |
| |
| void print_init_message(const char* message) { |
| write(1, message, strlen(message)); |
| write(1, "\n", 1); |
| } |
| |
| bool object_exists(const char* name) { |
| int fd = shm_open(name, O_RDONLY, 0); |
| if (fd >= 0) { |
| close(fd); |
| return true; |
| } else { |
| return false; |
| } |
| } |
| |
| void free_used_object(const std::string& name) { |
| if (!object_exists(name.c_str())) { |
| DEBUG("object %s appears to have been freed", name.c_str()); |
| used_objects.erase(name); |
| } else { |
| DEBUG("object %s still exists", name.c_str()); |
| } |
| } |
| |
| // NOLINTNEXTLINE(bugprone-exception-escape) |
| int main(int argc, char* argv[]) { |
| setsid(); // Daemonize the process |
| |
| std::unique_ptr<ManagerServerSocket> srv_socket; |
| c10::optional<c10::TempDir> tempdir; |
| try { |
| tempdir = c10::try_make_tempdir(/*name_prefix=*/"torch-shm-dir-"); |
| if (!tempdir.has_value()) { |
| throw std::runtime_error( |
| "could not generate a random directory for manager socket"); |
| } |
| |
| std::string tempfile = tempdir->name + "/manager.sock"; |
| |
| srv_socket = std::make_unique<ManagerServerSocket>(tempfile); |
| register_fd(srv_socket->socket_fd); |
| print_init_message(tempfile.c_str()); |
| DEBUG("opened socket %s", tempfile.c_str()); |
| } catch (const std::exception& e) { |
| std::string message("ERROR: "); |
| message += e.what(); |
| print_init_message(message.c_str()); |
| return 1; |
| } catch (...) { |
| print_init_message("ERROR: unhandled exception"); |
| return 1; |
| } |
| |
| int timeout = -1; |
| std::vector<int> to_add; |
| std::vector<int> to_remove; |
| for (;;) { |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| int nevents; |
| if (client_sessions.size() == 0) |
| timeout = SHUTDOWN_TIMEOUT; |
| SYSCHECK_ERR_RETURN_NEG1( |
| nevents = poll(pollfds.data(), pollfds.size(), timeout)); |
| timeout = -1; |
| if (nevents == 0 && client_sessions.size() == 0) |
| break; |
| |
| for (auto& pfd : pollfds) { |
| if (pfd.revents & (POLLERR | POLLHUP)) { |
| // some process died |
| DEBUG("detaching process"); |
| auto& session = client_sessions.at(pfd.fd); |
| (void)session; |
| DEBUG("%d has died", session.pid); |
| to_remove.push_back(pfd.fd); |
| } else if (pfd.revents & POLLIN) { |
| if (pfd.fd == srv_socket->socket_fd) { |
| // someone is joining |
| DEBUG("registered new client"); |
| auto client = srv_socket->accept(); |
| int fd = client.socket_fd; |
| to_add.push_back(fd); |
| client_sessions.emplace(fd, std::move(client)); |
| } else { |
| // someone wants to register a segment |
| DEBUG("got alloc info"); |
| auto& session = client_sessions.at(pfd.fd); |
| AllocInfo info = session.socket.receive(); |
| session.pid = info.pid; |
| DEBUG( |
| "got alloc info: %d %d %s", |
| (int)info.free, |
| info.pid, |
| info.filename); |
| if (info.free) { |
| free_used_object(info.filename); |
| } else { |
| used_objects.insert(info.filename); |
| DEBUG("registered object %s", info.filename); |
| session.socket.confirm(); |
| } |
| } |
| } |
| } |
| |
| for (int fd : to_add) |
| register_fd(fd); |
| to_add.clear(); |
| |
| for (int fd : to_remove) |
| unregister_fd(fd); |
| to_remove.clear(); |
| } |
| |
| for (auto& obj_name : used_objects) { |
| DEBUG("freeing %s", obj_name.c_str()); |
| shm_unlink(obj_name.c_str()); |
| } |
| |
| // Clean up file descriptors |
| for (auto& pfd : pollfds) { |
| unregister_fd(pfd.fd); |
| } |
| // Clean up manager.sock |
| srv_socket->remove(); |
| // Clean up directory automatically |
| |
| DEBUG("manager done"); |
| return 0; |
| } |