| use crate::codec::UserError; |
| use crate::codec::UserError::*; |
| use crate::frame::{self, Frame, FrameSize}; |
| use crate::hpack; |
| |
| use bytes::{Buf, BufMut, BytesMut}; |
| use std::pin::Pin; |
| use std::task::{Context, Poll}; |
| use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
| use tokio_util::io::poll_write_buf; |
| |
| use std::io::{self, Cursor}; |
| |
| // A macro to get around a method needing to borrow &mut self |
| macro_rules! limited_write_buf { |
| ($self:expr) => {{ |
| let limit = $self.max_frame_size() + frame::HEADER_LEN; |
| $self.buf.get_mut().limit(limit) |
| }}; |
| } |
| |
| #[derive(Debug)] |
| pub struct FramedWrite<T, B> { |
| /// Upstream `AsyncWrite` |
| inner: T, |
| |
| encoder: Encoder<B>, |
| } |
| |
| #[derive(Debug)] |
| struct Encoder<B> { |
| /// HPACK encoder |
| hpack: hpack::Encoder, |
| |
| /// Write buffer |
| /// |
| /// TODO: Should this be a ring buffer? |
| buf: Cursor<BytesMut>, |
| |
| /// Next frame to encode |
| next: Option<Next<B>>, |
| |
| /// Last data frame |
| last_data_frame: Option<frame::Data<B>>, |
| |
| /// Max frame size, this is specified by the peer |
| max_frame_size: FrameSize, |
| |
| /// Chain payloads bigger than this. |
| chain_threshold: usize, |
| |
| /// Min buffer required to attempt to write a frame |
| min_buffer_capacity: usize, |
| } |
| |
| #[derive(Debug)] |
| enum Next<B> { |
| Data(frame::Data<B>), |
| Continuation(frame::Continuation), |
| } |
| |
| /// Initialize the connection with this amount of write buffer. |
| /// |
| /// The minimum MAX_FRAME_SIZE is 16kb, so always be able to send a HEADERS |
| /// frame that big. |
| const DEFAULT_BUFFER_CAPACITY: usize = 16 * 1_024; |
| |
| /// Chain payloads bigger than this when vectored I/O is enabled. The remote |
| /// will never advertise a max frame size less than this (well, the spec says |
| /// the max frame size can't be less than 16kb, so not even close). |
| const CHAIN_THRESHOLD: usize = 256; |
| |
| /// Chain payloads bigger than this when vectored I/O is **not** enabled. |
| /// A larger value in this scenario will reduce the number of small and |
| /// fragmented data being sent, and hereby improve the throughput. |
| const CHAIN_THRESHOLD_WITHOUT_VECTORED_IO: usize = 1024; |
| |
| // TODO: Make generic |
| impl<T, B> FramedWrite<T, B> |
| where |
| T: AsyncWrite + Unpin, |
| B: Buf, |
| { |
| pub fn new(inner: T) -> FramedWrite<T, B> { |
| let chain_threshold = if inner.is_write_vectored() { |
| CHAIN_THRESHOLD |
| } else { |
| CHAIN_THRESHOLD_WITHOUT_VECTORED_IO |
| }; |
| FramedWrite { |
| inner, |
| encoder: Encoder { |
| hpack: hpack::Encoder::default(), |
| buf: Cursor::new(BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY)), |
| next: None, |
| last_data_frame: None, |
| max_frame_size: frame::DEFAULT_MAX_FRAME_SIZE, |
| chain_threshold, |
| min_buffer_capacity: chain_threshold + frame::HEADER_LEN, |
| }, |
| } |
| } |
| |
| /// Returns `Ready` when `send` is able to accept a frame |
| /// |
| /// Calling this function may result in the current contents of the buffer |
| /// to be flushed to `T`. |
| pub fn poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> { |
| if !self.encoder.has_capacity() { |
| // Try flushing |
| ready!(self.flush(cx))?; |
| |
| if !self.encoder.has_capacity() { |
| return Poll::Pending; |
| } |
| } |
| |
| Poll::Ready(Ok(())) |
| } |
| |
| /// Buffer a frame. |
| /// |
| /// `poll_ready` must be called first to ensure that a frame may be |
| /// accepted. |
| pub fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> { |
| self.encoder.buffer(item) |
| } |
| |
| /// Flush buffered data to the wire |
| pub fn flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> { |
| let span = tracing::trace_span!("FramedWrite::flush"); |
| let _e = span.enter(); |
| |
| loop { |
| while !self.encoder.is_empty() { |
| match self.encoder.next { |
| Some(Next::Data(ref mut frame)) => { |
| tracing::trace!(queued_data_frame = true); |
| let mut buf = (&mut self.encoder.buf).chain(frame.payload_mut()); |
| ready!(poll_write_buf(Pin::new(&mut self.inner), cx, &mut buf))? |
| } |
| _ => { |
| tracing::trace!(queued_data_frame = false); |
| ready!(poll_write_buf( |
| Pin::new(&mut self.inner), |
| cx, |
| &mut self.encoder.buf |
| ))? |
| } |
| }; |
| } |
| |
| match self.encoder.unset_frame() { |
| ControlFlow::Continue => (), |
| ControlFlow::Break => break, |
| } |
| } |
| |
| tracing::trace!("flushing buffer"); |
| // Flush the upstream |
| ready!(Pin::new(&mut self.inner).poll_flush(cx))?; |
| |
| Poll::Ready(Ok(())) |
| } |
| |
| /// Close the codec |
| pub fn shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>> { |
| ready!(self.flush(cx))?; |
| Pin::new(&mut self.inner).poll_shutdown(cx) |
| } |
| } |
| |
| #[must_use] |
| enum ControlFlow { |
| Continue, |
| Break, |
| } |
| |
| impl<B> Encoder<B> |
| where |
| B: Buf, |
| { |
| fn unset_frame(&mut self) -> ControlFlow { |
| // Clear internal buffer |
| self.buf.set_position(0); |
| self.buf.get_mut().clear(); |
| |
| // The data frame has been written, so unset it |
| match self.next.take() { |
| Some(Next::Data(frame)) => { |
| self.last_data_frame = Some(frame); |
| debug_assert!(self.is_empty()); |
| ControlFlow::Break |
| } |
| Some(Next::Continuation(frame)) => { |
| // Buffer the continuation frame, then try to write again |
| let mut buf = limited_write_buf!(self); |
| if let Some(continuation) = frame.encode(&mut buf) { |
| self.next = Some(Next::Continuation(continuation)); |
| } |
| ControlFlow::Continue |
| } |
| None => ControlFlow::Break, |
| } |
| } |
| |
| fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> { |
| // Ensure that we have enough capacity to accept the write. |
| assert!(self.has_capacity()); |
| let span = tracing::trace_span!("FramedWrite::buffer", frame = ?item); |
| let _e = span.enter(); |
| |
| tracing::debug!(frame = ?item, "send"); |
| |
| match item { |
| Frame::Data(mut v) => { |
| // Ensure that the payload is not greater than the max frame. |
| let len = v.payload().remaining(); |
| |
| if len > self.max_frame_size() { |
| return Err(PayloadTooBig); |
| } |
| |
| if len >= self.chain_threshold { |
| let head = v.head(); |
| |
| // Encode the frame head to the buffer |
| head.encode(len, self.buf.get_mut()); |
| |
| if self.buf.get_ref().remaining() < self.chain_threshold { |
| let extra_bytes = self.chain_threshold - self.buf.remaining(); |
| self.buf.get_mut().put(v.payload_mut().take(extra_bytes)); |
| } |
| |
| // Save the data frame |
| self.next = Some(Next::Data(v)); |
| } else { |
| v.encode_chunk(self.buf.get_mut()); |
| |
| // The chunk has been fully encoded, so there is no need to |
| // keep it around |
| assert_eq!(v.payload().remaining(), 0, "chunk not fully encoded"); |
| |
| // Save off the last frame... |
| self.last_data_frame = Some(v); |
| } |
| } |
| Frame::Headers(v) => { |
| let mut buf = limited_write_buf!(self); |
| if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) { |
| self.next = Some(Next::Continuation(continuation)); |
| } |
| } |
| Frame::PushPromise(v) => { |
| let mut buf = limited_write_buf!(self); |
| if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) { |
| self.next = Some(Next::Continuation(continuation)); |
| } |
| } |
| Frame::Settings(v) => { |
| v.encode(self.buf.get_mut()); |
| tracing::trace!(rem = self.buf.remaining(), "encoded settings"); |
| } |
| Frame::GoAway(v) => { |
| v.encode(self.buf.get_mut()); |
| tracing::trace!(rem = self.buf.remaining(), "encoded go_away"); |
| } |
| Frame::Ping(v) => { |
| v.encode(self.buf.get_mut()); |
| tracing::trace!(rem = self.buf.remaining(), "encoded ping"); |
| } |
| Frame::WindowUpdate(v) => { |
| v.encode(self.buf.get_mut()); |
| tracing::trace!(rem = self.buf.remaining(), "encoded window_update"); |
| } |
| |
| Frame::Priority(_) => { |
| /* |
| v.encode(self.buf.get_mut()); |
| tracing::trace!("encoded priority; rem={:?}", self.buf.remaining()); |
| */ |
| unimplemented!(); |
| } |
| Frame::Reset(v) => { |
| v.encode(self.buf.get_mut()); |
| tracing::trace!(rem = self.buf.remaining(), "encoded reset"); |
| } |
| } |
| |
| Ok(()) |
| } |
| |
| fn has_capacity(&self) -> bool { |
| self.next.is_none() |
| && (self.buf.get_ref().capacity() - self.buf.get_ref().len() |
| >= self.min_buffer_capacity) |
| } |
| |
| fn is_empty(&self) -> bool { |
| match self.next { |
| Some(Next::Data(ref frame)) => !frame.payload().has_remaining(), |
| _ => !self.buf.has_remaining(), |
| } |
| } |
| } |
| |
| impl<B> Encoder<B> { |
| fn max_frame_size(&self) -> usize { |
| self.max_frame_size as usize |
| } |
| } |
| |
| impl<T, B> FramedWrite<T, B> { |
| /// Returns the max frame size that can be sent |
| pub fn max_frame_size(&self) -> usize { |
| self.encoder.max_frame_size() |
| } |
| |
| /// Set the peer's max frame size. |
| pub fn set_max_frame_size(&mut self, val: usize) { |
| assert!(val <= frame::MAX_MAX_FRAME_SIZE as usize); |
| self.encoder.max_frame_size = val as FrameSize; |
| } |
| |
| /// Set the peer's header table size. |
| pub fn set_header_table_size(&mut self, val: usize) { |
| self.encoder.hpack.update_max_size(val); |
| } |
| |
| /// Retrieve the last data frame that has been sent |
| pub fn take_last_data_frame(&mut self) -> Option<frame::Data<B>> { |
| self.encoder.last_data_frame.take() |
| } |
| |
| pub fn get_mut(&mut self) -> &mut T { |
| &mut self.inner |
| } |
| } |
| |
| impl<T: AsyncRead + Unpin, B> AsyncRead for FramedWrite<T, B> { |
| fn poll_read( |
| mut self: Pin<&mut Self>, |
| cx: &mut Context<'_>, |
| buf: &mut ReadBuf, |
| ) -> Poll<io::Result<()>> { |
| Pin::new(&mut self.inner).poll_read(cx, buf) |
| } |
| } |
| |
| // We never project the Pin to `B`. |
| impl<T: Unpin, B> Unpin for FramedWrite<T, B> {} |
| |
| #[cfg(feature = "unstable")] |
| mod unstable { |
| use super::*; |
| |
| impl<T, B> FramedWrite<T, B> { |
| pub fn get_ref(&self) -> &T { |
| &self.inner |
| } |
| } |
| } |