Move retrieve_fd from util.h to AutoRemoteSyscalls.
diff --git a/src/AutoRemoteSyscalls.cc b/src/AutoRemoteSyscalls.cc
index b0a4d46..30b0dc2 100644
--- a/src/AutoRemoteSyscalls.cc
+++ b/src/AutoRemoteSyscalls.cc
@@ -8,6 +8,8 @@
#include "task.h"
#include "util.h"
+using namespace rr;
+
const uint8_t AutoRemoteSyscalls::syscall_insn[] = { 0xcd, 0x80 };
void AutoRestoreMem::init(const uint8_t* mem, ssize_t num_bytes) {
@@ -98,3 +100,165 @@
}
SupportedArch AutoRemoteSyscalls::arch() const { return t->arch(); }
+
+static void write_socketcall_args(Task* t,
+ remote_ptr<struct socketcall_args> remote_mem,
+ long arg1, long arg2, long arg3) {
+ struct socketcall_args sc_args = { { arg1, arg2, arg3 } };
+ t->write_mem(remote_mem, sc_args);
+}
+
+static size_t align_size(size_t size) {
+ static int align_amount = 8;
+ return (size + align_amount) & ~(align_amount - 1);
+}
+
+ScopedFd AutoRemoteSyscalls::retrieve_fd(int fd) {
+ struct sockaddr_un socket_addr;
+ struct msghdr msg;
+ // Unfortunately we need to send at least one byte of data in our
+ // message for it to work
+ struct iovec msgdata;
+ char received_data;
+ char cmsgbuf[CMSG_SPACE(sizeof(fd))];
+ int data_length =
+ align_size(sizeof(struct socketcall_args)) +
+ std::max(align_size(sizeof(socket_addr)),
+ align_size(sizeof(msg)) + align_size(sizeof(cmsgbuf)) +
+ align_size(sizeof(msgdata)));
+ AutoRestoreMem remote_socketcall_args_holder(*this, nullptr, data_length);
+ auto remote_socketcall_args = remote_socketcall_args_holder.get();
+ bool using_socketcall = has_socketcall_syscall(arch());
+
+ memset(&socket_addr, 0, sizeof(socket_addr));
+ socket_addr.sun_family = AF_UNIX;
+ snprintf(socket_addr.sun_path, sizeof(socket_addr.sun_path) - 1,
+ "/tmp/rr-tracee-fd-transfer-%d", t->tid);
+
+ int listen_sock = socket(AF_UNIX, SOCK_STREAM, 0);
+ if (listen_sock < 0) {
+ FATAL() << "Failed to create listen socket";
+ }
+ if (::bind(listen_sock, (struct sockaddr*)&socket_addr,
+ sizeof(socket_addr))) {
+ FATAL() << "Failed to bind listen socket";
+ }
+ if (listen(listen_sock, 1)) {
+ FATAL() << "Failed to mark listening for listen socket";
+ }
+
+ int child_sock;
+ if (using_socketcall) {
+ write_socketcall_args(t,
+ remote_socketcall_args.cast<struct socketcall_args>(),
+ AF_UNIX, SOCK_STREAM, 0);
+ child_sock = syscall(syscall_number_for_socketcall(arch()), SYS_SOCKET,
+ remote_socketcall_args);
+ } else {
+ child_sock =
+ syscall(syscall_number_for_socket(arch()), AF_UNIX, SOCK_STREAM, 0);
+ }
+ if (child_sock < 0) {
+ FATAL() << "Failed to create child socket";
+ }
+
+ auto remote_sockaddr =
+ (remote_socketcall_args + align_size(sizeof(struct socketcall_args)))
+ .cast<struct sockaddr_un>();
+ t->write_mem(remote_sockaddr, socket_addr);
+ Registers callregs = initial_regs;
+ int remote_syscall;
+ if (using_socketcall) {
+ write_socketcall_args(
+ t, remote_socketcall_args.cast<struct socketcall_args>(), child_sock,
+ remote_sockaddr.as_int(), sizeof(socket_addr));
+ callregs.set_arg1(SYS_CONNECT);
+ callregs.set_arg2(remote_socketcall_args);
+ remote_syscall = syscall_number_for_socketcall(arch());
+ } else {
+ callregs.set_arg1(child_sock);
+ callregs.set_arg2(remote_sockaddr);
+ callregs.set_arg3(sizeof(socket_addr));
+ remote_syscall = syscall_number_for_connect(arch());
+ }
+ syscall_helper(DONT_WAIT, remote_syscall, callregs);
+ // Now the child is waiting for us to accept it.
+
+ int sock = accept(listen_sock, NULL, NULL);
+ if (sock < 0) {
+ FATAL() << "Failed to create parent socket";
+ }
+ int child_ret = wait_syscall(remote_syscall);
+ if (child_ret) {
+ FATAL() << "Failed to connect() in tracee";
+ }
+ // Listening socket not needed anymore
+ close(listen_sock);
+ unlink(socket_addr.sun_path);
+
+ // Pull the puppet strings to have the child send its fd
+ // to us. Similarly to above, we DONT_WAIT on the
+ // call to finish, since it's likely not defined whether the
+ // sendmsg() may block on our recvmsg()ing what the tracee
+ // sent us (in which case we would deadlock with the tracee).
+ auto remote_msg =
+ remote_socketcall_args + align_size(sizeof(struct socketcall_args));
+ auto remote_msgdata = remote_msg + align_size(sizeof(msg));
+ auto remote_cmsgbuf = remote_msgdata + align_size(sizeof(msgdata));
+ // XXX should be using Arch::iovec
+ msgdata.iov_base =
+ (void*)remote_msg.as_int(); // doesn't matter much, we ignore the data
+ msgdata.iov_len = 1;
+ t->write_mem(remote_msgdata.cast<struct iovec>(), msgdata);
+ memset(&msg, 0, sizeof(msg));
+ msg.msg_control = cmsgbuf;
+ msg.msg_controllen = sizeof(cmsgbuf);
+ msg.msg_iov = reinterpret_cast<struct iovec*>(remote_msgdata.as_int());
+ msg.msg_iovlen = 1;
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ cmsg->cmsg_len = CMSG_LEN(sizeof(fd));
+ cmsg->cmsg_level = SOL_SOCKET;
+ cmsg->cmsg_type = SCM_RIGHTS;
+ *(int*)CMSG_DATA(cmsg) = fd;
+ t->write_bytes_helper(remote_cmsgbuf, sizeof(cmsgbuf), &cmsgbuf);
+ msg.msg_control = (void*)remote_cmsgbuf.as_int();
+ t->write_mem(remote_msg.cast<struct msghdr>(), msg);
+ callregs = initial_regs;
+ if (using_socketcall) {
+ write_socketcall_args(t,
+ remote_socketcall_args.cast<struct socketcall_args>(),
+ child_sock, remote_msg.as_int(), 0);
+ callregs.set_arg1(SYS_SENDMSG);
+ callregs.set_arg2(remote_socketcall_args);
+ remote_syscall = syscall_number_for_socketcall(arch());
+ } else {
+ callregs.set_arg1(child_sock);
+ callregs.set_arg2(remote_msg);
+ callregs.set_arg3(0);
+ remote_syscall = syscall_number_for_sendmsg(arch());
+ }
+ syscall_helper(DONT_WAIT, remote_syscall, callregs);
+ // Child may be waiting on our recvmsg().
+
+ // Our 'msg' struct is mostly already OK.
+ msg.msg_control = cmsgbuf;
+ msgdata.iov_base = &received_data;
+ msg.msg_iov = &msgdata;
+ if (0 > recvmsg(sock, &msg, 0)) {
+ FATAL() << "Failed to receive fd";
+ }
+ cmsg = CMSG_FIRSTHDR(&msg);
+ assert(cmsg && cmsg->cmsg_level == SOL_SOCKET &&
+ cmsg->cmsg_type == SCM_RIGHTS);
+ int our_fd = *(int*)CMSG_DATA(cmsg);
+ assert(our_fd >= 0);
+
+ if (0 >= wait_syscall(remote_syscall)) {
+ FATAL() << "Failed to sendmsg() in tracee";
+ }
+
+ syscall(syscall_number_for_close(arch()), child_sock);
+ close(sock);
+
+ return ScopedFd(our_fd);
+}