blob: 61aadc93d9fbb0ece06a1bfbfac0a09b2d8c9044 [file] [log] [blame]
use super::{Connected, Server};
use crate::transport::service::ServerIo;
use hyper::server::{
accept::Accept,
conn::{AddrIncoming, AddrStream},
};
use std::{
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use tokio::{
io::{AsyncRead, AsyncWrite},
net::TcpListener,
};
use tokio_stream::{Stream, StreamExt};
#[cfg(not(feature = "tls"))]
pub(crate) fn tcp_incoming<IO, IE, L>(
incoming: impl Stream<Item = Result<IO, IE>>,
_server: Server<L>,
) -> impl Stream<Item = Result<ServerIo<IO>, crate::Error>>
where
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
IE: Into<crate::Error>,
{
async_stream::try_stream! {
tokio::pin!(incoming);
while let Some(item) = incoming.next().await {
yield item.map(ServerIo::new_io)?
}
}
}
#[cfg(feature = "tls")]
pub(crate) fn tcp_incoming<IO, IE, L>(
incoming: impl Stream<Item = Result<IO, IE>>,
server: Server<L>,
) -> impl Stream<Item = Result<ServerIo<IO>, crate::Error>>
where
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
IE: Into<crate::Error>,
{
async_stream::try_stream! {
tokio::pin!(incoming);
let mut tasks = tokio::task::JoinSet::new();
loop {
match select(&mut incoming, &mut tasks).await {
SelectOutput::Incoming(stream) => {
if let Some(tls) = &server.tls {
let tls = tls.clone();
tasks.spawn(async move {
let io = tls.accept(stream).await?;
Ok(ServerIo::new_tls_io(io))
});
} else {
yield ServerIo::new_io(stream);
}
}
SelectOutput::Io(io) => {
yield io;
}
SelectOutput::Err(e) => {
tracing::debug!(message = "Accept loop error.", error = %e);
}
SelectOutput::Done => {
break;
}
}
}
}
}
#[cfg(feature = "tls")]
async fn select<IO: 'static, IE>(
incoming: &mut (impl Stream<Item = Result<IO, IE>> + Unpin),
tasks: &mut tokio::task::JoinSet<Result<ServerIo<IO>, crate::Error>>,
) -> SelectOutput<IO>
where
IE: Into<crate::Error>,
{
if tasks.is_empty() {
return match incoming.try_next().await {
Ok(Some(stream)) => SelectOutput::Incoming(stream),
Ok(None) => SelectOutput::Done,
Err(e) => SelectOutput::Err(e.into()),
};
}
tokio::select! {
stream = incoming.try_next() => {
match stream {
Ok(Some(stream)) => SelectOutput::Incoming(stream),
Ok(None) => SelectOutput::Done,
Err(e) => SelectOutput::Err(e.into()),
}
}
accept = tasks.join_next() => {
match accept.expect("JoinSet should never end") {
Ok(Ok(io)) => SelectOutput::Io(io),
Ok(Err(e)) => SelectOutput::Err(e),
Err(e) => SelectOutput::Err(e.into()),
}
}
}
}
#[cfg(feature = "tls")]
enum SelectOutput<A> {
Incoming(A),
Io(ServerIo<A>),
Err(crate::Error),
Done,
}
/// Binds a socket address for a [Router](super::Router)
///
/// An incoming stream, usable with [Router::serve_with_incoming](super::Router::serve_with_incoming),
/// of `AsyncRead + AsyncWrite` that communicate with clients that connect to a socket address.
#[derive(Debug)]
pub struct TcpIncoming {
inner: AddrIncoming,
}
impl TcpIncoming {
/// Creates an instance by binding (opening) the specified socket address
/// to which the specified TCP 'nodelay' and 'keepalive' parameters are applied.
/// Returns a TcpIncoming if the socket address was successfully bound.
///
/// # Examples
/// ```no_run
/// # use tower_service::Service;
/// # use http::{request::Request, response::Response};
/// # use tonic::{body::BoxBody, server::NamedService, transport::{Body, Server, server::TcpIncoming}};
/// # use core::convert::Infallible;
/// # use std::error::Error;
/// # fn main() { } // Cannot have type parameters, hence instead define:
/// # fn run<S>(some_service: S) -> Result<(), Box<dyn Error + Send + Sync>>
/// # where
/// # S: Service<Request<Body>, Response = Response<BoxBody>, Error = Infallible> + NamedService + Clone + Send + 'static,
/// # S::Future: Send + 'static,
/// # {
/// // Find a free port
/// let mut port = 1322;
/// let tinc = loop {
/// let addr = format!("127.0.0.1:{}", port).parse().unwrap();
/// match TcpIncoming::new(addr, true, None) {
/// Ok(t) => break t,
/// Err(_) => port += 1
/// }
/// };
/// Server::builder()
/// .add_service(some_service)
/// .serve_with_incoming(tinc);
/// # Ok(())
/// # }
pub fn new(
addr: SocketAddr,
nodelay: bool,
keepalive: Option<Duration>,
) -> Result<Self, crate::Error> {
let mut inner = AddrIncoming::bind(&addr)?;
inner.set_nodelay(nodelay);
inner.set_keepalive(keepalive);
Ok(TcpIncoming { inner })
}
/// Creates a new `TcpIncoming` from an existing `tokio::net::TcpListener`.
pub fn from_listener(
listener: TcpListener,
nodelay: bool,
keepalive: Option<Duration>,
) -> Result<Self, crate::Error> {
let mut inner = AddrIncoming::from_listener(listener)?;
inner.set_nodelay(nodelay);
inner.set_keepalive(keepalive);
Ok(TcpIncoming { inner })
}
}
impl Stream for TcpIncoming {
type Item = Result<AddrStream, std::io::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inner).poll_accept(cx)
}
}
#[cfg(test)]
mod tests {
use crate::transport::server::TcpIncoming;
#[tokio::test]
async fn one_tcpincoming_at_a_time() {
let addr = "127.0.0.1:1322".parse().unwrap();
{
let _t1 = TcpIncoming::new(addr, true, None).unwrap();
let _t2 = TcpIncoming::new(addr, true, None).unwrap_err();
}
let _t3 = TcpIncoming::new(addr, true, None).unwrap();
}
}