| // Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. |
| // SPDX-License-Identifier: Apache-2.0 |
| |
| //! Structs for Unix Domain Socket listener and endpoint. |
| |
| #![allow(dead_code)] |
| |
| use std::fs::File; |
| use std::io::ErrorKind; |
| use std::marker::PhantomData; |
| use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; |
| use std::os::unix::net::{UnixListener, UnixStream}; |
| use std::path::{Path, PathBuf}; |
| use std::{mem, slice}; |
| |
| use libc::{c_void, iovec}; |
| use vm_memory::ByteValued; |
| use vmm_sys_util::sock_ctrl_msg::ScmSocket; |
| |
| use super::message::*; |
| use super::{Error, Result}; |
| |
| /// Unix domain socket listener for accepting incoming connections. |
| pub struct Listener { |
| fd: UnixListener, |
| path: Option<PathBuf>, |
| } |
| |
| impl Listener { |
| /// Create a unix domain socket listener. |
| /// |
| /// # Return: |
| /// * - the new Listener object on success. |
| /// * - SocketError: failed to create listener socket. |
| pub fn new<P: AsRef<Path>>(path: P, unlink: bool) -> Result<Self> { |
| if unlink { |
| let _ = std::fs::remove_file(&path); |
| } |
| let fd = UnixListener::bind(&path).map_err(Error::SocketError)?; |
| Ok(Listener { |
| fd, |
| path: Some(path.as_ref().to_owned()), |
| }) |
| } |
| |
| /// Accept an incoming connection. |
| /// |
| /// # Return: |
| /// * - Some(UnixStream): new UnixStream object if new incoming connection is available. |
| /// * - None: no incoming connection available. |
| /// * - SocketError: errors from accept(). |
| pub fn accept(&self) -> Result<Option<UnixStream>> { |
| loop { |
| match self.fd.accept() { |
| Ok((socket, _addr)) => return Ok(Some(socket)), |
| Err(e) => { |
| match e.kind() { |
| // No incoming connection available. |
| ErrorKind::WouldBlock => return Ok(None), |
| // New connection closed by peer. |
| ErrorKind::ConnectionAborted => return Ok(None), |
| // Interrupted by signals, retry |
| ErrorKind::Interrupted => continue, |
| _ => return Err(Error::SocketError(e)), |
| } |
| } |
| } |
| } |
| } |
| |
| /// Change blocking status on the listener. |
| /// |
| /// # Return: |
| /// * - () on success. |
| /// * - SocketError: failure from set_nonblocking(). |
| pub fn set_nonblocking(&self, block: bool) -> Result<()> { |
| self.fd.set_nonblocking(block).map_err(Error::SocketError) |
| } |
| } |
| |
| impl AsRawFd for Listener { |
| fn as_raw_fd(&self) -> RawFd { |
| self.fd.as_raw_fd() |
| } |
| } |
| |
| impl FromRawFd for Listener { |
| unsafe fn from_raw_fd(fd: RawFd) -> Self { |
| Listener { |
| fd: UnixListener::from_raw_fd(fd), |
| path: None, |
| } |
| } |
| } |
| |
| impl Drop for Listener { |
| fn drop(&mut self) { |
| if let Some(path) = &self.path { |
| let _ = std::fs::remove_file(path); |
| } |
| } |
| } |
| |
| /// Unix domain socket endpoint for vhost-user connection. |
| pub(super) struct Endpoint<R: Req> { |
| sock: UnixStream, |
| _r: PhantomData<R>, |
| } |
| |
| impl<R: Req> Endpoint<R> { |
| /// Create a new stream by connecting to server at `str`. |
| /// |
| /// # Return: |
| /// * - the new Endpoint object on success. |
| /// * - SocketConnect: failed to connect to peer. |
| pub fn connect<P: AsRef<Path>>(path: P) -> Result<Self> { |
| let sock = UnixStream::connect(path).map_err(Error::SocketConnect)?; |
| Ok(Self::from_stream(sock)) |
| } |
| |
| /// Create an endpoint from a stream object. |
| pub fn from_stream(sock: UnixStream) -> Self { |
| Endpoint { |
| sock, |
| _r: PhantomData, |
| } |
| } |
| |
| /// Sends bytes from scatter-gather vectors over the socket with optional attached file |
| /// descriptors. |
| /// |
| /// # Return: |
| /// * - number of bytes sent on success |
| /// * - SocketRetry: temporary error caused by signals or short of resources. |
| /// * - SocketBroken: the underline socket is broken. |
| /// * - SocketError: other socket related errors. |
| pub fn send_iovec(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize> { |
| let rfds = match fds { |
| Some(rfds) => rfds, |
| _ => &[], |
| }; |
| self.sock.send_with_fds(iovs, rfds).map_err(Into::into) |
| } |
| |
| /// Sends all bytes from scatter-gather vectors over the socket with optional attached file |
| /// descriptors. Will loop until all data has been transfered. |
| /// |
| /// # Return: |
| /// * - number of bytes sent on success |
| /// * - SocketBroken: the underline socket is broken. |
| /// * - SocketError: other socket related errors. |
| pub fn send_iovec_all(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize> { |
| let mut data_sent = 0; |
| let mut data_total = 0; |
| let iov_lens: Vec<usize> = iovs.iter().map(|iov| iov.len()).collect(); |
| for len in &iov_lens { |
| data_total += len; |
| } |
| |
| while (data_total - data_sent) > 0 { |
| let (nr_skip, offset) = get_sub_iovs_offset(&iov_lens, data_sent); |
| let iov = &iovs[nr_skip][offset..]; |
| |
| let data = &[&[iov], &iovs[(nr_skip + 1)..]].concat(); |
| let sfds = if data_sent == 0 { fds } else { None }; |
| |
| let sent = self.send_iovec(data, sfds); |
| match sent { |
| Ok(0) => return Ok(data_sent), |
| Ok(n) => data_sent += n, |
| Err(e) => match e { |
| Error::SocketRetry(_) => {} |
| _ => return Err(e), |
| }, |
| } |
| } |
| Ok(data_sent) |
| } |
| |
| /// Sends bytes from a slice over the socket with optional attached file descriptors. |
| /// |
| /// # Return: |
| /// * - number of bytes sent on success |
| /// * - SocketRetry: temporary error caused by signals or short of resources. |
| /// * - SocketBroken: the underline socket is broken. |
| /// * - SocketError: other socket related errors. |
| pub fn send_slice(&mut self, data: &[u8], fds: Option<&[RawFd]>) -> Result<usize> { |
| self.send_iovec(&[data], fds) |
| } |
| |
| /// Sends a header-only message with optional attached file descriptors. |
| /// |
| /// # Return: |
| /// * - number of bytes sent on success |
| /// * - SocketRetry: temporary error caused by signals or short of resources. |
| /// * - SocketBroken: the underline socket is broken. |
| /// * - SocketError: other socket related errors. |
| /// * - PartialMessage: received a partial message. |
| pub fn send_header( |
| &mut self, |
| hdr: &VhostUserMsgHeader<R>, |
| fds: Option<&[RawFd]>, |
| ) -> Result<()> { |
| // SAFETY: Safe because there can't be other mutable referance to hdr. |
| let iovs = unsafe { |
| [slice::from_raw_parts( |
| hdr as *const VhostUserMsgHeader<R> as *const u8, |
| mem::size_of::<VhostUserMsgHeader<R>>(), |
| )] |
| }; |
| let bytes = self.send_iovec_all(&iovs[..], fds)?; |
| if bytes != mem::size_of::<VhostUserMsgHeader<R>>() { |
| return Err(Error::PartialMessage); |
| } |
| Ok(()) |
| } |
| |
| /// Send a message with header and body. Optional file descriptors may be attached to |
| /// the message. |
| /// |
| /// # Return: |
| /// * - number of bytes sent on success |
| /// * - SocketRetry: temporary error caused by signals or short of resources. |
| /// * - SocketBroken: the underline socket is broken. |
| /// * - SocketError: other socket related errors. |
| /// * - PartialMessage: received a partial message. |
| pub fn send_message<T: ByteValued>( |
| &mut self, |
| hdr: &VhostUserMsgHeader<R>, |
| body: &T, |
| fds: Option<&[RawFd]>, |
| ) -> Result<()> { |
| if mem::size_of::<T>() > MAX_MSG_SIZE { |
| return Err(Error::OversizedMsg); |
| } |
| let bytes = self.send_iovec_all(&[hdr.as_slice(), body.as_slice()], fds)?; |
| if bytes != mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() { |
| return Err(Error::PartialMessage); |
| } |
| Ok(()) |
| } |
| |
| /// Send a message with header, body and payload. Optional file descriptors |
| /// may also be attached to the message. |
| /// |
| /// # Return: |
| /// * - number of bytes sent on success |
| /// * - SocketRetry: temporary error caused by signals or short of resources. |
| /// * - SocketBroken: the underline socket is broken. |
| /// * - SocketError: other socket related errors. |
| /// * - OversizedMsg: message size is too big. |
| /// * - PartialMessage: received a partial message. |
| /// * - IncorrectFds: wrong number of attached fds. |
| pub fn send_message_with_payload<T: ByteValued>( |
| &mut self, |
| hdr: &VhostUserMsgHeader<R>, |
| body: &T, |
| payload: &[u8], |
| fds: Option<&[RawFd]>, |
| ) -> Result<()> { |
| let len = payload.len(); |
| if mem::size_of::<T>() > MAX_MSG_SIZE { |
| return Err(Error::OversizedMsg); |
| } |
| if len > MAX_MSG_SIZE - mem::size_of::<T>() { |
| return Err(Error::OversizedMsg); |
| } |
| if let Some(fd_arr) = fds { |
| if fd_arr.len() > MAX_ATTACHED_FD_ENTRIES { |
| return Err(Error::IncorrectFds); |
| } |
| } |
| |
| let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() + len; |
| let len = self.send_iovec_all(&[hdr.as_slice(), body.as_slice(), payload], fds)?; |
| if len != total { |
| return Err(Error::PartialMessage); |
| } |
| Ok(()) |
| } |
| |
| /// Reads bytes from the socket into the given scatter/gather vectors. |
| /// |
| /// # Return: |
| /// * - (number of bytes received, buf) on success |
| /// * - SocketRetry: temporary error caused by signals or short of resources. |
| /// * - SocketBroken: the underline socket is broken. |
| /// * - SocketError: other socket related errors. |
| pub fn recv_data(&mut self, len: usize) -> Result<(usize, Vec<u8>)> { |
| let mut rbuf = vec![0u8; len]; |
| let mut iovs = [iovec { |
| iov_base: rbuf.as_mut_ptr() as *mut c_void, |
| iov_len: len, |
| }]; |
| // SAFETY: Safe because we own rbuf and it's safe to fill a byte array with arbitrary data. |
| let (bytes, _) = unsafe { self.sock.recv_with_fds(&mut iovs, &mut [])? }; |
| Ok((bytes, rbuf)) |
| } |
| |
| /// Reads bytes from the socket into the given scatter/gather vectors with optional attached |
| /// file. |
| /// |
| /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little |
| /// tricky to pass file descriptors through such a communication channel. Let's assume that a |
| /// sender sending a message with some file descriptors attached. To successfully receive those |
| /// attached file descriptors, the receiver must obey following rules: |
| /// 1) file descriptors are attached to a message. |
| /// 2) message(packet) boundaries must be respected on the receive side. |
| /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the |
| /// attached file descriptors will get lost. |
| /// Note that this function wraps received file descriptors as `File`. |
| /// |
| /// # Return: |
| /// * - (number of bytes received, [received files]) on success |
| /// * - SocketRetry: temporary error caused by signals or short of resources. |
| /// * - SocketBroken: the underline socket is broken. |
| /// * - SocketError: other socket related errors. |
| /// |
| /// # Safety |
| /// |
| /// It is the callers responsibility to ensure it is safe for arbitrary data to be |
| /// written to the iovec pointers. |
| pub unsafe fn recv_into_iovec( |
| &mut self, |
| iovs: &mut [iovec], |
| ) -> Result<(usize, Option<Vec<File>>)> { |
| let mut fd_array = vec![0; MAX_ATTACHED_FD_ENTRIES]; |
| let (bytes, fds) = self.sock.recv_with_fds(iovs, &mut fd_array)?; |
| |
| let files = match fds { |
| 0 => None, |
| n => { |
| let files = fd_array |
| .iter() |
| .take(n) |
| .map(|fd| { |
| // Safe because we have the ownership of `fd`. |
| File::from_raw_fd(*fd) |
| }) |
| .collect(); |
| Some(files) |
| } |
| }; |
| |
| Ok((bytes, files)) |
| } |
| |
| /// Reads all bytes from the socket into the given scatter/gather vectors with optional |
| /// attached files. Will loop until all data has been transferred. |
| /// |
| /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little |
| /// tricky to pass file descriptors through such a communication channel. Let's assume that a |
| /// sender sending a message with some file descriptors attached. To successfully receive those |
| /// attached file descriptors, the receiver must obey following rules: |
| /// 1) file descriptors are attached to a message. |
| /// 2) message(packet) boundaries must be respected on the receive side. |
| /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the |
| /// attached file descriptors will get lost. |
| /// Note that this function wraps received file descriptors as `File`. |
| /// |
| /// # Return: |
| /// * - (number of bytes received, [received fds]) on success |
| /// * - SocketBroken: the underline socket is broken. |
| /// * - SocketError: other socket related errors. |
| /// |
| /// # Safety |
| /// |
| /// It is the callers responsibility to ensure it is safe for arbitrary data to be |
| /// written to the iovec pointers. |
| pub unsafe fn recv_into_iovec_all( |
| &mut self, |
| iovs: &mut [iovec], |
| ) -> Result<(usize, Option<Vec<File>>)> { |
| let mut data_read = 0; |
| let mut data_total = 0; |
| let mut rfds = None; |
| let iov_lens: Vec<usize> = iovs.iter().map(|iov| iov.iov_len).collect(); |
| for len in &iov_lens { |
| data_total += len; |
| } |
| |
| while (data_total - data_read) > 0 { |
| let (nr_skip, offset) = get_sub_iovs_offset(&iov_lens, data_read); |
| let iov = &mut iovs[nr_skip]; |
| |
| let mut data = [ |
| &[iovec { |
| iov_base: (iov.iov_base as usize + offset) as *mut c_void, |
| iov_len: iov.iov_len - offset, |
| }], |
| &iovs[(nr_skip + 1)..], |
| ] |
| .concat(); |
| |
| let res = self.recv_into_iovec(&mut data); |
| match res { |
| Ok((0, _)) => return Ok((data_read, rfds)), |
| Ok((n, fds)) => { |
| if data_read == 0 { |
| rfds = fds; |
| } |
| data_read += n; |
| } |
| Err(e) => match e { |
| Error::SocketRetry(_) => {} |
| _ => return Err(e), |
| }, |
| } |
| } |
| Ok((data_read, rfds)) |
| } |
| |
| /// Reads bytes from the socket into a new buffer with optional attached |
| /// files. Received file descriptors are set close-on-exec and converted to `File`. |
| /// |
| /// # Return: |
| /// * - (number of bytes received, buf, [received files]) on success. |
| /// * - SocketRetry: temporary error caused by signals or short of resources. |
| /// * - SocketBroken: the underline socket is broken. |
| /// * - SocketError: other socket related errors. |
| pub fn recv_into_buf( |
| &mut self, |
| buf_size: usize, |
| ) -> Result<(usize, Vec<u8>, Option<Vec<File>>)> { |
| let mut buf = vec![0u8; buf_size]; |
| let (bytes, files) = { |
| let mut iovs = [iovec { |
| iov_base: buf.as_mut_ptr() as *mut c_void, |
| iov_len: buf_size, |
| }]; |
| // SAFETY: Safe because we own buf and it's safe to fill a byte array with arbitrary data. |
| unsafe { self.recv_into_iovec(&mut iovs)? } |
| }; |
| Ok((bytes, buf, files)) |
| } |
| |
| /// Receive a header-only message with optional attached files. |
| /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be |
| /// accepted and all other file descriptor will be discard silently. |
| /// |
| /// # Return: |
| /// * - (message header, [received files]) on success. |
| /// * - SocketRetry: temporary error caused by signals or short of resources. |
| /// * - SocketBroken: the underline socket is broken. |
| /// * - SocketError: other socket related errors. |
| /// * - PartialMessage: received a partial message. |
| /// * - InvalidMessage: received a invalid message. |
| pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<File>>)> { |
| let mut hdr = VhostUserMsgHeader::default(); |
| let mut iovs = [iovec { |
| iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void, |
| iov_len: mem::size_of::<VhostUserMsgHeader<R>>(), |
| }]; |
| // SAFETY: Safe because we own hdr and it's ByteValued. |
| let (bytes, files) = unsafe { self.recv_into_iovec_all(&mut iovs[..])? }; |
| |
| if bytes == 0 { |
| return Err(Error::Disconnected); |
| } else if bytes != mem::size_of::<VhostUserMsgHeader<R>>() { |
| return Err(Error::PartialMessage); |
| } else if !hdr.is_valid() { |
| return Err(Error::InvalidMessage); |
| } |
| |
| Ok((hdr, files)) |
| } |
| |
| /// Receive a message with optional attached file descriptors. |
| /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be |
| /// accepted and all other file descriptor will be discard silently. |
| /// |
| /// # Return: |
| /// * - (message header, message body, [received files]) on success. |
| /// * - SocketRetry: temporary error caused by signals or short of resources. |
| /// * - SocketBroken: the underline socket is broken. |
| /// * - SocketError: other socket related errors. |
| /// * - PartialMessage: received a partial message. |
| /// * - InvalidMessage: received a invalid message. |
| pub fn recv_body<T: ByteValued + Sized + VhostUserMsgValidator>( |
| &mut self, |
| ) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<File>>)> { |
| let mut hdr = VhostUserMsgHeader::default(); |
| let mut body: T = Default::default(); |
| let mut iovs = [ |
| iovec { |
| iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void, |
| iov_len: mem::size_of::<VhostUserMsgHeader<R>>(), |
| }, |
| iovec { |
| iov_base: (&mut body as *mut T) as *mut c_void, |
| iov_len: mem::size_of::<T>(), |
| }, |
| ]; |
| // SAFETY: Safe because we own hdr and body and they're ByteValued. |
| let (bytes, files) = unsafe { self.recv_into_iovec_all(&mut iovs[..])? }; |
| |
| let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>(); |
| if bytes != total { |
| return Err(Error::PartialMessage); |
| } else if !hdr.is_valid() || !body.is_valid() { |
| return Err(Error::InvalidMessage); |
| } |
| |
| Ok((hdr, body, files)) |
| } |
| |
| /// Receive a message with header and optional content. Callers need to |
| /// pre-allocate a big enough buffer to receive the message body and |
| /// optional payload. If there are attached file descriptor associated |
| /// with the message, the first MAX_ATTACHED_FD_ENTRIES file descriptors |
| /// will be accepted and all other file descriptor will be discard |
| /// silently. |
| /// |
| /// # Return: |
| /// * - (message header, message size, [received files]) on success. |
| /// * - SocketRetry: temporary error caused by signals or short of resources. |
| /// * - SocketBroken: the underline socket is broken. |
| /// * - SocketError: other socket related errors. |
| /// * - PartialMessage: received a partial message. |
| /// * - InvalidMessage: received a invalid message. |
| pub fn recv_body_into_buf( |
| &mut self, |
| buf: &mut [u8], |
| ) -> Result<(VhostUserMsgHeader<R>, usize, Option<Vec<File>>)> { |
| let mut hdr = VhostUserMsgHeader::default(); |
| let mut iovs = [ |
| iovec { |
| iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void, |
| iov_len: mem::size_of::<VhostUserMsgHeader<R>>(), |
| }, |
| iovec { |
| iov_base: buf.as_mut_ptr() as *mut c_void, |
| iov_len: buf.len(), |
| }, |
| ]; |
| // SAFETY: Safe because we own hdr and have a mutable borrow of buf, and hdr is ByteValued |
| // and it's safe to fill a byte slice with arbitrary data. |
| let (bytes, files) = unsafe { self.recv_into_iovec_all(&mut iovs[..])? }; |
| |
| if bytes < mem::size_of::<VhostUserMsgHeader<R>>() { |
| return Err(Error::PartialMessage); |
| } else if !hdr.is_valid() { |
| return Err(Error::InvalidMessage); |
| } |
| |
| Ok((hdr, bytes - mem::size_of::<VhostUserMsgHeader<R>>(), files)) |
| } |
| |
| /// Receive a message with optional payload and attached file descriptors. |
| /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be |
| /// accepted and all other file descriptor will be discard silently. |
| /// |
| /// # Return: |
| /// * - (message header, message body, size of payload, [received files]) on success. |
| /// * - SocketRetry: temporary error caused by signals or short of resources. |
| /// * - SocketBroken: the underline socket is broken. |
| /// * - SocketError: other socket related errors. |
| /// * - PartialMessage: received a partial message. |
| /// * - InvalidMessage: received a invalid message. |
| #[cfg_attr(feature = "cargo-clippy", allow(clippy::type_complexity))] |
| pub fn recv_payload_into_buf<T: ByteValued + Sized + VhostUserMsgValidator>( |
| &mut self, |
| buf: &mut [u8], |
| ) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<File>>)> { |
| let mut hdr = VhostUserMsgHeader::default(); |
| let mut body: T = Default::default(); |
| let mut iovs = [ |
| iovec { |
| iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void, |
| iov_len: mem::size_of::<VhostUserMsgHeader<R>>(), |
| }, |
| iovec { |
| iov_base: (&mut body as *mut T) as *mut c_void, |
| iov_len: mem::size_of::<T>(), |
| }, |
| iovec { |
| iov_base: buf.as_mut_ptr() as *mut c_void, |
| iov_len: buf.len(), |
| }, |
| ]; |
| // SAFETY: Safe because we own hdr and body and have a mutable borrow of buf, and |
| // hdr and body are ByteValued, and it's safe to fill a byte slice with |
| // arbitrary data. |
| let (bytes, files) = unsafe { self.recv_into_iovec_all(&mut iovs[..])? }; |
| |
| let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>(); |
| if bytes < total { |
| return Err(Error::PartialMessage); |
| } else if !hdr.is_valid() || !body.is_valid() { |
| return Err(Error::InvalidMessage); |
| } |
| |
| Ok((hdr, body, bytes - total, files)) |
| } |
| } |
| |
| impl<T: Req> AsRawFd for Endpoint<T> { |
| fn as_raw_fd(&self) -> RawFd { |
| self.sock.as_raw_fd() |
| } |
| } |
| |
| // Given a slice of sizes and the `skip_size`, return the offset of `skip_size` in the slice. |
| // For example: |
| // let iov_lens = vec![4, 4, 5]; |
| // let size = 6; |
| // assert_eq!(get_sub_iovs_offset(&iov_len, size), (1, 2)); |
| fn get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize) { |
| let mut size = skip_size; |
| let mut nr_skip = 0; |
| |
| for len in iov_lens { |
| if size >= *len { |
| size -= *len; |
| nr_skip += 1; |
| } else { |
| break; |
| } |
| } |
| (nr_skip, size) |
| } |
| |
| #[cfg(test)] |
| mod tests { |
| use super::*; |
| use std::io::{Read, Seek, SeekFrom, Write}; |
| use vmm_sys_util::rand::rand_alphanumerics; |
| use vmm_sys_util::tempfile::TempFile; |
| |
| fn temp_path() -> PathBuf { |
| PathBuf::from(format!( |
| "/tmp/vhost_test_{}", |
| rand_alphanumerics(8).to_str().unwrap() |
| )) |
| } |
| |
| #[test] |
| fn create_listener() { |
| let path = temp_path(); |
| let listener = Listener::new(path, true).unwrap(); |
| |
| assert!(listener.as_raw_fd() > 0); |
| } |
| |
| #[test] |
| fn create_listener_from_raw_fd() { |
| let path = temp_path(); |
| let file = File::create(path).unwrap(); |
| |
| // SAFETY: Safe because `file` contains a valid fd to a file just created. |
| let listener = unsafe { Listener::from_raw_fd(file.as_raw_fd()) }; |
| |
| assert!(listener.as_raw_fd() > 0); |
| } |
| |
| #[test] |
| fn accept_connection() { |
| let path = temp_path(); |
| let listener = Listener::new(path, true).unwrap(); |
| listener.set_nonblocking(true).unwrap(); |
| |
| // accept on a fd without incoming connection |
| let conn = listener.accept().unwrap(); |
| assert!(conn.is_none()); |
| } |
| |
| #[test] |
| fn send_data() { |
| let path = temp_path(); |
| let listener = Listener::new(&path, true).unwrap(); |
| listener.set_nonblocking(true).unwrap(); |
| let mut master = Endpoint::<MasterReq>::connect(&path).unwrap(); |
| let sock = listener.accept().unwrap().unwrap(); |
| let mut slave = Endpoint::<MasterReq>::from_stream(sock); |
| |
| let buf1 = vec![0x1, 0x2, 0x3, 0x4]; |
| let mut len = master.send_slice(&buf1[..], None).unwrap(); |
| assert_eq!(len, 4); |
| let (bytes, buf2, _) = slave.recv_into_buf(0x1000).unwrap(); |
| assert_eq!(bytes, 4); |
| assert_eq!(&buf1[..], &buf2[..bytes]); |
| |
| len = master.send_slice(&buf1[..], None).unwrap(); |
| assert_eq!(len, 4); |
| let (bytes, buf2, _) = slave.recv_into_buf(0x2).unwrap(); |
| assert_eq!(bytes, 2); |
| assert_eq!(&buf1[..2], &buf2[..]); |
| let (bytes, buf2, _) = slave.recv_into_buf(0x2).unwrap(); |
| assert_eq!(bytes, 2); |
| assert_eq!(&buf1[2..], &buf2[..]); |
| } |
| |
| #[test] |
| fn send_fd() { |
| let path = temp_path(); |
| let listener = Listener::new(&path, true).unwrap(); |
| listener.set_nonblocking(true).unwrap(); |
| let mut master = Endpoint::<MasterReq>::connect(&path).unwrap(); |
| let sock = listener.accept().unwrap().unwrap(); |
| let mut slave = Endpoint::<MasterReq>::from_stream(sock); |
| |
| let mut fd = TempFile::new().unwrap().into_file(); |
| write!(fd, "test").unwrap(); |
| |
| // Normal case for sending/receiving file descriptors |
| let buf1 = vec![0x1, 0x2, 0x3, 0x4]; |
| let len = master |
| .send_slice(&buf1[..], Some(&[fd.as_raw_fd()])) |
| .unwrap(); |
| assert_eq!(len, 4); |
| |
| let (bytes, buf2, files) = slave.recv_into_buf(4).unwrap(); |
| assert_eq!(bytes, 4); |
| assert_eq!(&buf1[..], &buf2[..]); |
| assert!(files.is_some()); |
| let files = files.unwrap(); |
| { |
| assert_eq!(files.len(), 1); |
| let mut file = &files[0]; |
| let mut content = String::new(); |
| file.seek(SeekFrom::Start(0)).unwrap(); |
| file.read_to_string(&mut content).unwrap(); |
| assert_eq!(content, "test"); |
| } |
| |
| // Following communication pattern should work: |
| // Sending side: data(header, body) with fds |
| // Receiving side: data(header) with fds, data(body) |
| let len = master |
| .send_slice( |
| &buf1[..], |
| Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]), |
| ) |
| .unwrap(); |
| assert_eq!(len, 4); |
| |
| let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap(); |
| assert_eq!(bytes, 2); |
| assert_eq!(&buf1[..2], &buf2[..]); |
| assert!(files.is_some()); |
| let files = files.unwrap(); |
| { |
| assert_eq!(files.len(), 3); |
| let mut file = &files[1]; |
| let mut content = String::new(); |
| file.seek(SeekFrom::Start(0)).unwrap(); |
| file.read_to_string(&mut content).unwrap(); |
| assert_eq!(content, "test"); |
| } |
| let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap(); |
| assert_eq!(bytes, 2); |
| assert_eq!(&buf1[2..], &buf2[..]); |
| assert!(files.is_none()); |
| |
| // Following communication pattern should not work: |
| // Sending side: data(header, body) with fds |
| // Receiving side: data(header), data(body) with fds |
| let len = master |
| .send_slice( |
| &buf1[..], |
| Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]), |
| ) |
| .unwrap(); |
| assert_eq!(len, 4); |
| |
| let (bytes, buf4) = slave.recv_data(2).unwrap(); |
| assert_eq!(bytes, 2); |
| assert_eq!(&buf1[..2], &buf4[..]); |
| let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap(); |
| assert_eq!(bytes, 2); |
| assert_eq!(&buf1[2..], &buf2[..]); |
| assert!(files.is_none()); |
| |
| // Following communication pattern should work: |
| // Sending side: data, data with fds |
| // Receiving side: data, data with fds |
| let len = master.send_slice(&buf1[..], None).unwrap(); |
| assert_eq!(len, 4); |
| let len = master |
| .send_slice( |
| &buf1[..], |
| Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]), |
| ) |
| .unwrap(); |
| assert_eq!(len, 4); |
| |
| let (bytes, buf2, files) = slave.recv_into_buf(0x4).unwrap(); |
| assert_eq!(bytes, 4); |
| assert_eq!(&buf1[..], &buf2[..]); |
| assert!(files.is_none()); |
| |
| let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap(); |
| assert_eq!(bytes, 2); |
| assert_eq!(&buf1[..2], &buf2[..]); |
| assert!(files.is_some()); |
| let files = files.unwrap(); |
| { |
| assert_eq!(files.len(), 3); |
| let mut file = &files[1]; |
| let mut content = String::new(); |
| file.seek(SeekFrom::Start(0)).unwrap(); |
| file.read_to_string(&mut content).unwrap(); |
| assert_eq!(content, "test"); |
| } |
| let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap(); |
| assert_eq!(bytes, 2); |
| assert_eq!(&buf1[2..], &buf2[..]); |
| assert!(files.is_none()); |
| |
| // Following communication pattern should not work: |
| // Sending side: data1, data2 with fds |
| // Receiving side: data + partial of data2, left of data2 with fds |
| let len = master.send_slice(&buf1[..], None).unwrap(); |
| assert_eq!(len, 4); |
| let len = master |
| .send_slice( |
| &buf1[..], |
| Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]), |
| ) |
| .unwrap(); |
| assert_eq!(len, 4); |
| |
| let (bytes, _) = slave.recv_data(5).unwrap(); |
| assert_eq!(bytes, 5); |
| |
| let (bytes, _, files) = slave.recv_into_buf(0x4).unwrap(); |
| assert_eq!(bytes, 3); |
| assert!(files.is_none()); |
| |
| // If the target fd array is too small, extra file descriptors will get lost. |
| let len = master |
| .send_slice( |
| &buf1[..], |
| Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]), |
| ) |
| .unwrap(); |
| assert_eq!(len, 4); |
| |
| let (bytes, _, files) = slave.recv_into_buf(0x4).unwrap(); |
| assert_eq!(bytes, 4); |
| assert!(files.is_some()); |
| } |
| |
| #[test] |
| fn send_recv() { |
| let path = temp_path(); |
| let listener = Listener::new(&path, true).unwrap(); |
| listener.set_nonblocking(true).unwrap(); |
| let mut master = Endpoint::<MasterReq>::connect(&path).unwrap(); |
| let sock = listener.accept().unwrap().unwrap(); |
| let mut slave = Endpoint::<MasterReq>::from_stream(sock); |
| |
| let mut hdr1 = |
| VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0, mem::size_of::<u64>() as u32); |
| hdr1.set_need_reply(true); |
| let features1 = 0x1u64; |
| master.send_message(&hdr1, &features1, None).unwrap(); |
| |
| let mut features2 = 0u64; |
| |
| // SAFETY: Safe because features2 is valid and it's an `u64`. |
| let slice = unsafe { |
| slice::from_raw_parts_mut( |
| (&mut features2 as *mut u64) as *mut u8, |
| mem::size_of::<u64>(), |
| ) |
| }; |
| let (hdr2, bytes, files) = slave.recv_body_into_buf(slice).unwrap(); |
| assert_eq!(hdr1, hdr2); |
| assert_eq!(bytes, 8); |
| assert_eq!(features1, features2); |
| assert!(files.is_none()); |
| |
| master.send_header(&hdr1, None).unwrap(); |
| let (hdr2, files) = slave.recv_header().unwrap(); |
| assert_eq!(hdr1, hdr2); |
| assert!(files.is_none()); |
| } |
| |
| #[test] |
| fn partial_message() { |
| let path = temp_path(); |
| let listener = Listener::new(&path, true).unwrap(); |
| let mut master = UnixStream::connect(&path).unwrap(); |
| let sock = listener.accept().unwrap().unwrap(); |
| let mut slave = Endpoint::<MasterReq>::from_stream(sock); |
| |
| write!(master, "a").unwrap(); |
| drop(master); |
| assert!(matches!(slave.recv_header(), Err(Error::PartialMessage))); |
| } |
| |
| #[test] |
| fn disconnected() { |
| let path = temp_path(); |
| let listener = Listener::new(&path, true).unwrap(); |
| let _ = UnixStream::connect(&path).unwrap(); |
| let sock = listener.accept().unwrap().unwrap(); |
| let mut slave = Endpoint::<MasterReq>::from_stream(sock); |
| |
| assert!(matches!(slave.recv_header(), Err(Error::Disconnected))); |
| } |
| } |