| #include <torch/csrc/DataLoader.h> |
| |
| // Together with `torch/utils/data/_utils/signal_handling.py`, the following |
| // is an effort to do our best to provide some error message to users when a |
| // worker dies due to error / critical signals. |
| // |
| // See NOTE [ Signal handling in multiprocessing data loading ] for more |
| // details. |
| |
| // TODO: The following don't work on Windows. Specifically, sigaction, waitid |
| // calls, and SIGCHLD handler. Currently, dummy implementations are provided |
| // for Windows. |
| |
| #ifndef _WIN32 |
| |
| #include <torch/csrc/Exceptions.h> |
| #include <torch/csrc/utils/python_numbers.h> |
| |
| #include <c10/util/irange.h> |
| #include <fmt/format.h> |
| |
| #include <sys/wait.h> |
| #include <csignal> |
| #include <map> |
| #include <set> |
| #include <sstream> |
| |
| using namespace torch; |
| |
| // Critical signal handlers should be registered on worker processes before |
| // doing work. |
| // The handler will raise default handler so that the kill information will be |
| // retrieved from main process. |
| // Python handle is _set_worker_signal_handlers(). |
| #define SIGNAL_HANDLER(SIGNAL, HANDLER_NAME, ERROR_MSG) \ |
| static void HANDLER_NAME(int sig, siginfo_t* info, void* ctx) { \ |
| auto _w = \ |
| write(STDERR_FILENO, ERROR_MSG, sizeof(ERROR_MSG) / sizeof(char)); \ |
| (void)_w; \ |
| struct sigaction sa {}; \ |
| sa.sa_handler = SIG_DFL; \ |
| sa.sa_flags = 0; \ |
| if (sigemptyset(&sa.sa_mask) != 0 || \ |
| sigaction(SIGNAL, &sa, nullptr) != 0) { \ |
| _exit(EXIT_FAILURE); \ |
| } else { \ |
| raise(SIGNAL); \ |
| } \ |
| } |
| |
| // signal(2) is really not portable. So use sigaction. |
| // http://man7.org/linux/man-pages/man2/signal.2.html |
| static inline void setSignalHandler( |
| int signal, |
| void (*handler)(int, siginfo_t*, void*), |
| struct sigaction* old_sa_ptr) { |
| struct sigaction sa {}; |
| sa.sa_sigaction = handler; |
| sa.sa_flags = SA_RESTART | SA_SIGINFO | SA_NOCLDSTOP | SA_NODEFER; |
| if (sigemptyset(&sa.sa_mask) != 0 || |
| sigaction(signal, &sa, old_sa_ptr) != 0) { |
| std::ostringstream oss; |
| oss << "An error occurred while setting handler for " << strsignal(signal) |
| << "."; |
| throw std::runtime_error(oss.str()); |
| } |
| } |
| |
| SIGNAL_HANDLER( |
| SIGBUS, |
| handler_SIGBUS, |
| "ERROR: Unexpected bus error encountered in worker. " |
| "This might be caused by insufficient shared memory (shm).\n"); |
| SIGNAL_HANDLER( |
| SIGSEGV, |
| handler_SIGSEGV, |
| "ERROR: Unexpected segmentation fault encountered in worker.\n"); |
| SIGNAL_HANDLER( |
| SIGFPE, |
| handler_SIGFPE, |
| "ERROR: Unexpected floating-point exception encountered in worker.\n"); |
| |
| // When an error happened in DataLoader methods and Python starts to exit, the |
| // error trace will keep the loader alive, and Python may kill the children |
| // processes first before deleting the loader object. Then the cleaning up |
| // methods in DataLoader.__del__ are not yet called, and SIGCHILD will print an |
| // error saying a worker is killed by SIGTERM. So we suppress SIGTERM from main |
| // loader process here to avoid this by _exit(EXIT_SUCCESS). Note that if we |
| // exit with nonzero code, the loader SIGCHLD handler may report RuntimeError |
| // again, and then it defeats the whole purpose. |
| static void handler_SIGTERM(int sig, siginfo_t* info, void* ctx) { |
| if (info->si_pid == getppid()) { |
| _exit(EXIT_SUCCESS); |
| } |
| struct sigaction sa {}; |
| sa.sa_handler = SIG_DFL; |
| sa.sa_flags = 0; |
| if (sigemptyset(&sa.sa_mask) != 0 || sigaction(SIGTERM, &sa, nullptr) != 0) { |
| _exit(EXIT_FAILURE); |
| } else { |
| raise(SIGTERM); |
| } |
| } |
| |
| __attribute__((weak)) void setDataLoaderSignalHandlers() {} |
| |
| static PyObject* THPModule_setWorkerSignalHandlers( |
| PyObject* module, |
| PyObject* arg) { |
| HANDLE_TH_ERRORS |
| setSignalHandler(SIGBUS, &handler_SIGBUS, nullptr); |
| setSignalHandler(SIGSEGV, &handler_SIGSEGV, nullptr); |
| setSignalHandler(SIGTERM, &handler_SIGTERM, nullptr); |
| setSignalHandler(SIGFPE, &handler_SIGFPE, nullptr); |
| setDataLoaderSignalHandlers(); |
| Py_RETURN_NONE; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static std::map<int64_t, std::set<pid_t>> worker_pids = {}; |
| |
| static PyObject* THPModule_errorIfAnyWorkerFails( |
| PyObject* module, |
| PyObject* noargs) { |
| HANDLE_TH_ERRORS |
| |
| // Only check the pids we care about |
| for (auto& w : worker_pids) { |
| auto& pid_set = w.second; |
| for (auto worker_pid : pid_set) { |
| // Use waitid rather than waitpid so that we can set NOWAIT, and that |
| // Python and other handlers can get whatever info they want about the |
| // child. |
| siginfo_t infop{}; |
| infop.si_pid = 0; |
| auto error = |
| waitid(P_PID, worker_pid, &infop, WEXITED | WNOHANG | WNOWAIT); |
| // ignore errors and case with no waitable child |
| if (error < 0 || infop.si_pid == 0) |
| continue; |
| if (infop.si_code == CLD_EXITED && |
| infop.si_status != EXIT_SUCCESS) { // exit with error |
| std::ostringstream oss; |
| oss << "DataLoader worker (pid " << worker_pid << ") exited " |
| << "unexpectedly with exit code " << infop.si_status << ". " |
| << "Details are lost due to multiprocessing. Rerunning with " |
| << "num_workers=0 may give better error trace."; |
| // This is necessary. Otherwise, the runtime error will kill the other |
| // workers, and trigger this again. |
| pid_set.clear(); |
| throw std::runtime_error(oss.str()); |
| } else if ( |
| infop.si_code == CLD_KILLED || |
| infop.si_code == CLD_DUMPED) { // killed by signal |
| std::ostringstream oss; |
| oss << "DataLoader worker (pid " << worker_pid << ") is killed " |
| << "by signal: " << strsignal(infop.si_status) << ". "; |
| if (infop.si_status == SIGBUS) { |
| oss << "It is possible that dataloader's workers are out of shared memory. " |
| << "Please try to raise your shared memory limit."; |
| } |
| // This is necessary. Otherwise, the runtime error will kill the other |
| // workers, and trigger this again. |
| pid_set.clear(); |
| throw std::runtime_error(oss.str()); |
| } |
| } |
| } |
| Py_RETURN_NONE; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| // We don't want to exit on any SIGCHLD from any child. child_pids is a tuple |
| // of pids we are interested in. |
| static PyObject* THPModule_setWorkerPIDs(PyObject* module, PyObject* args) { |
| HANDLE_TH_ERRORS |
| TORCH_CHECK_TYPE( |
| PyTuple_GET_SIZE(args) == 2, |
| "_set_worker_pids expects exactly 2 arguments."); |
| int64_t key = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 0)); |
| TORCH_CHECK_VALUE( |
| worker_pids.find(key) == worker_pids.end(), |
| "_set_worker_pids should be called only once for each _BaseDataLoaderIter."); |
| PyObject* child_pids = PyTuple_GET_ITEM(args, 1); |
| TORCH_CHECK_TYPE( |
| PyTuple_Check(child_pids), |
| "_set_worker_pids expects a tuple for child_pids, but got ", |
| Py_TYPE(child_pids)->tp_name, |
| "."); |
| std::set<pid_t> pids_set = {}; |
| auto size = PyTuple_GET_SIZE(child_pids); |
| for (const auto idx : c10::irange(size)) { |
| PyObject* obj = PyTuple_GET_ITEM(child_pids, idx); |
| pids_set.insert(static_cast<pid_t>(THPUtils_unpackLong(obj))); |
| } |
| |
| worker_pids[key] = pids_set; |
| |
| Py_RETURN_NONE; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* THPModule_removeWorkerPIDs( |
| PyObject* module, |
| PyObject* loader_id) { |
| HANDLE_TH_ERRORS |
| |
| int64_t key = THPUtils_unpackLong(loader_id); |
| auto it = worker_pids.find(key); |
| TORCH_CHECK_VALUE( |
| it != worker_pids.end(), |
| "Cannot find worker information for _BaseDataLoaderIter with id ", |
| key); |
| worker_pids.erase(it); |
| |
| Py_RETURN_NONE; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| #undef SIGNAL_HANDLER |
| |
| #else |
| // dummy implementations for windows |
| |
| static PyObject* THPModule_setWorkerSignalHandlers( |
| PyObject* module, |
| PyObject* _ignored) { |
| Py_RETURN_NONE; |
| } |
| |
| static PyObject* THPModule_setWorkerPIDs(PyObject* module, PyObject* _ignored) { |
| Py_RETURN_NONE; |
| } |
| |
| static PyObject* THPModule_removeWorkerPIDs( |
| PyObject* module, |
| PyObject* _ignored) { |
| Py_RETURN_NONE; |
| } |
| |
| static PyObject* THPModule_errorIfAnyWorkerFails( |
| PyObject* module, |
| PyObject* _ignored) { |
| Py_RETURN_NONE; |
| } |
| |
| #endif |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) |
| PyMethodDef DataLoaderMethods[] = { |
| {"_set_worker_signal_handlers", |
| THPModule_setWorkerSignalHandlers, |
| METH_NOARGS, |
| nullptr}, |
| {"_set_worker_pids", THPModule_setWorkerPIDs, METH_VARARGS, nullptr}, |
| {"_remove_worker_pids", THPModule_removeWorkerPIDs, METH_O, nullptr}, |
| {"_error_if_any_worker_fails", |
| THPModule_errorIfAnyWorkerFails, |
| METH_NOARGS, |
| nullptr}, |
| {nullptr, nullptr, 0, nullptr}}; |