| commit 3851430dec41c204d6219a84c47ca6885622a98e |
| Author: kolapapa <kola@kolapapas-MacBook-Pro.local> |
| Date: Sun Dec 20 14:02:52 2020 +0100 |
| |
| Add Socket::(bind_)device |
| |
| Co-authored-by: Thomas de Zeeuw <thomasdezeeuw@gmail.com> |
| |
| diff --git a/src/sys/unix.rs b/src/sys/unix.rs |
| index 72097ae..1a0f24e 100644 |
| --- a/src/sys/unix.rs |
| +++ b/src/sys/unix.rs |
| @@ -7,6 +7,8 @@ |
| // except according to those terms. |
| |
| use std::cmp::min; |
| +#[cfg(all(feature = "all", target_os = "linux"))] |
| +use std::ffi::{CStr, CString}; |
| #[cfg(not(target_os = "redox"))] |
| use std::io::{IoSlice, IoSliceMut}; |
| use std::mem::{self, size_of, MaybeUninit}; |
| @@ -19,6 +21,8 @@ |
| use std::os::unix::net::{UnixDatagram, UnixListener, UnixStream}; |
| #[cfg(feature = "all")] |
| use std::path::Path; |
| +#[cfg(all(feature = "all", target_os = "linux"))] |
| +use std::slice; |
| use std::time::Duration; |
| use std::{io, ptr}; |
| |
| @@ -867,6 +871,73 @@ pub fn set_mark(&self, mark: u32) -> io::Result<()> { |
| unsafe { setsockopt::<c_int>(self.inner, libc::SOL_SOCKET, libc::SO_MARK, mark as c_int) } |
| } |
| |
| + /// Gets the value for the `SO_BINDTODEVICE` option on this socket. |
| + /// |
| + /// This value gets the socket binded device's interface name. |
| + /// |
| + /// This function is only available on Linux. |
| + #[cfg(all(feature = "all", target_os = "linux"))] |
| + pub fn device(&self) -> io::Result<Option<CString>> { |
| + // TODO: replace with `MaybeUninit::uninit_array` once stable. |
| + let mut buf: [MaybeUninit<u8>; libc::IFNAMSIZ] = |
| + unsafe { MaybeUninit::<[MaybeUninit<u8>; libc::IFNAMSIZ]>::uninit().assume_init() }; |
| + let mut len = buf.len() as libc::socklen_t; |
| + unsafe { |
| + syscall!(getsockopt( |
| + self.inner, |
| + libc::SOL_SOCKET, |
| + libc::SO_BINDTODEVICE, |
| + buf.as_mut_ptr().cast(), |
| + &mut len, |
| + ))?; |
| + } |
| + if len == 0 { |
| + Ok(None) |
| + } else { |
| + // Allocate a buffer for `CString` with the length including the |
| + // null terminator. |
| + let len = len as usize; |
| + let mut name = Vec::with_capacity(len); |
| + |
| + // TODO: use `MaybeUninit::slice_assume_init_ref` once stable. |
| + // Safety: `len` bytes are writen by the OS, this includes a null |
| + // terminator. However we don't copy the null terminator because |
| + // `CString::from_vec_unchecked` adds its own null terminator. |
| + let buf = unsafe { slice::from_raw_parts(buf.as_ptr().cast(), len - 1) }; |
| + name.extend_from_slice(buf); |
| + |
| + // Safety: the OS initialised the string for us, which shouldn't |
| + // include any null bytes. |
| + Ok(Some(unsafe { CString::from_vec_unchecked(name) })) |
| + } |
| + } |
| + |
| + /// Sets the value for the `SO_BINDTODEVICE` option on this socket. |
| + /// |
| + /// If a socket is bound to an interface, only packets received from that |
| + /// particular interface are processed by the socket. Note that this only |
| + /// works for some socket types, particularly `AF_INET` sockets. |
| + /// |
| + /// If `interface` is `None` or an empty string it removes the binding. |
| + /// |
| + /// This function is only available on Linux. |
| + #[cfg(all(feature = "all", target_os = "linux"))] |
| + pub fn bind_device(&self, interface: Option<&CStr>) -> io::Result<()> { |
| + let (value, len) = if let Some(interface) = interface { |
| + (interface.as_ptr(), interface.to_bytes_with_nul().len()) |
| + } else { |
| + (ptr::null(), 0) |
| + }; |
| + syscall!(setsockopt( |
| + self.inner, |
| + libc::SOL_SOCKET, |
| + libc::SO_BINDTODEVICE, |
| + value.cast(), |
| + len as libc::socklen_t, |
| + )) |
| + .map(|_| ()) |
| + } |
| + |
| /// Get the value of the `SO_REUSEPORT` option on this socket. |
| /// |
| /// For more information about this option, see [`set_reuse_port`]. |
| diff --git a/tests/socket.rs b/tests/socket.rs |
| index 11a3d9c..75890af 100644 |
| --- a/tests/socket.rs |
| +++ b/tests/socket.rs |
| @@ -1,3 +1,5 @@ |
| +#[cfg(all(feature = "all", target_os = "linux"))] |
| +use std::ffi::CStr; |
| #[cfg(any(windows, target_vendor = "apple"))] |
| use std::io; |
| #[cfg(unix)] |
| @@ -271,3 +273,19 @@ fn keepalive() { |
| ))] |
| assert_eq!(socket.keepalive_retries().unwrap(), 10); |
| } |
| + |
| +#[cfg(all(feature = "all", target_os = "linux"))] |
| +#[test] |
| +fn device() { |
| + const INTERFACE: &str = "lo0\0"; |
| + let interface = CStr::from_bytes_with_nul(INTERFACE.as_bytes()).unwrap(); |
| + let socket = Socket::new(Domain::IPV4, Type::STREAM, None).unwrap(); |
| + |
| + assert_eq!(socket.device().unwrap(), None); |
| + |
| + socket.bind_device(Some(interface)).unwrap(); |
| + assert_eq!(socket.device().unwrap().as_deref(), Some(interface)); |
| + |
| + socket.bind_device(None).unwrap(); |
| + assert_eq!(socket.device().unwrap(), None); |
| +} |