| //! Networking as a capability |
| //! |
| //! All code that wants to hit the network should go through this module. |
| //! |
| //! Currently it produces the output all at once, but in the future it would |
| //! ideally provide hooks to give you streaming access to the download so |
| //! that you could do a streaming parse and reduce latency on network-bound |
| //! tasks. |
| |
| use std::{ |
| ffi::{OsStr, OsString}, |
| io::Write, |
| path::{Path, PathBuf}, |
| sync::Mutex, |
| time::Duration, |
| }; |
| |
| use base64_stream::FromBase64Writer; |
| use bytes::Bytes; |
| use reqwest::{Client, Url}; |
| use tokio::io::AsyncWriteExt; |
| |
| use crate::{ |
| errors::{DownloadError, SourceFile}, |
| PartialConfig, |
| }; |
| |
| /// Wrapper for the pair of a `reqwest::Response` and the `SemaphorePermit` used |
| /// to limit concurrent connections, with a test-only variant for mocking. |
| enum Response<'a> { |
| Real(reqwest::Response, tokio::sync::SemaphorePermit<'a>), |
| #[cfg(test)] |
| Mock(Option<Bytes>), |
| } |
| |
| impl Response<'_> { |
| fn has_header(&self, name: &str) -> bool { |
| match self { |
| Response::Real(response, _) => response.headers().contains_key(name), |
| #[cfg(test)] |
| Response::Mock(_) => false, |
| } |
| } |
| |
| /// Get the next chunk in the response stream, or `None` if at the end of |
| /// the stream. |
| async fn chunk(&mut self) -> Result<Option<Bytes>, DownloadError> { |
| match self { |
| Response::Real(res, _) => { |
| res.chunk() |
| .await |
| .map_err(|error| DownloadError::FailedToReadDownload { |
| url: Box::new(res.url().clone()), |
| error, |
| }) |
| } |
| #[cfg(test)] |
| Response::Mock(data) => Ok(data.take()), |
| } |
| } |
| } |
| |
| pub struct Network { |
| /// The HTTP client all requests go through |
| client: Client, |
| /// Semaphore preventing exceeding the maximum number of connections. |
| connection_semaphore: tokio::sync::Semaphore, |
| /// Cache of source files downloaded by Url |
| source_file_cache: Mutex<std::collections::HashMap<Url, SourceFile>>, |
| /// Test-only override for download requests. |
| #[cfg(test)] |
| mock_network: Option<std::collections::HashMap<Url, Bytes>>, |
| } |
| |
| const DEFAULT_TIMEOUT_SECS: u64 = 60; |
| const MAX_CONCURRENT_CONNECTIONS: usize = 40; |
| const USER_AGENT: &str = concat!( |
| env!("CARGO_PKG_NAME"), |
| "/", |
| env!("CARGO_PKG_VERSION"), |
| " (", |
| env!("CARGO_PKG_HOMEPAGE"), |
| ")" |
| ); |
| |
| /// The network payload encoding. |
| /// |
| /// This is only used in `download` (not `download_and_persist`) because (for now) it's only needed |
| /// in downloading imports (not packages, which is what `download_and_persist` is used for) to |
| /// workaround known server shortcomings. It could be added to `download_and_persist`, but due to |
| /// the use of `tokio::io::File` it gets messy and either this or that would need a refactor. |
| #[derive(Debug, Clone, Copy, PartialEq, Eq)] |
| pub enum PayloadEncoding { |
| Plaintext, |
| Base64, |
| } |
| |
| impl PayloadEncoding { |
| fn for_response(response: &Response) -> Self { |
| // gitiles always encodes content in base64 |
| if response.has_header("x-gitiles-object-type") { |
| Self::Base64 |
| } else { |
| Self::Plaintext |
| } |
| } |
| |
| pub fn to_plaintext<'a, W: Write + 'a>(&self, target: W) -> Box<dyn Write + 'a> { |
| match self { |
| Self::Plaintext => Box::new(target), |
| Self::Base64 => Box::new(FromBase64Writer::new(target)), |
| } |
| } |
| } |
| |
| impl std::fmt::Display for PayloadEncoding { |
| fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { |
| match self { |
| Self::Plaintext => write!(f, "plaintext"), |
| Self::Base64 => write!(f, "base64"), |
| } |
| } |
| } |
| |
| impl Network { |
| /// Acquire access to the network |
| /// |
| /// There should only ever be one Network instance instantiated. Do it early |
| /// and then pass it around by-ref. |
| pub fn acquire(cfg: &PartialConfig) -> Option<Self> { |
| if cfg.cli.frozen { |
| None |
| } else { |
| // TODO: make this configurable on the CLI or something |
| let timeout = Duration::from_secs(DEFAULT_TIMEOUT_SECS); |
| // TODO: make this configurable on the CLI or something |
| let client = Client::builder() |
| .user_agent(USER_AGENT) |
| .timeout(timeout) |
| .build() |
| .expect("Couldn't construct HTTP Client?"); |
| Some(Self { |
| client, |
| connection_semaphore: tokio::sync::Semaphore::new(MAX_CONCURRENT_CONNECTIONS), |
| source_file_cache: Default::default(), |
| #[cfg(test)] |
| mock_network: None, |
| }) |
| } |
| } |
| |
| /// Download a file and persist it to disk |
| pub async fn download_and_persist( |
| &self, |
| url: Url, |
| persist_to: &Path, |
| ) -> Result<(), DownloadError> { |
| let download_tmp_path = PathBuf::from(OsString::from_iter([ |
| persist_to.as_os_str(), |
| OsStr::new(".part"), |
| ])); |
| { |
| let mut res = self.fetch_core(url).await?; |
| |
| let mut download_tmp = |
| tokio::fs::File::create(&download_tmp_path) |
| .await |
| .map_err(|error| DownloadError::FailedToCreateDownload { |
| target: download_tmp_path.clone(), |
| error, |
| })?; |
| |
| while let Some(chunk) = res.chunk().await? { |
| download_tmp.write_all(&chunk[..]).await.map_err(|error| { |
| DownloadError::FailedToWriteDownload { |
| target: download_tmp_path.clone(), |
| error, |
| } |
| })?; |
| } |
| } |
| |
| // Rename the downloaded file into the final location. |
| match tokio::fs::rename(&download_tmp_path, &persist_to).await { |
| Ok(()) => {} |
| Err(err) => { |
| let _ = tokio::fs::remove_file(&download_tmp_path).await; |
| return Err(err).map_err(|error| DownloadError::FailedToFinalizeDownload { |
| target: persist_to.to_owned(), |
| error, |
| })?; |
| } |
| } |
| |
| Ok(()) |
| } |
| |
| /// Download a file into memory |
| pub async fn download(&self, url: Url) -> Result<Vec<u8>, DownloadError> { |
| let mut res = self.fetch_core(url).await?; |
| |
| let encoding = PayloadEncoding::for_response(&res); |
| |
| let mut output = vec![]; |
| { |
| let mut writer = encoding.to_plaintext(&mut output); |
| while let Some(chunk) = res.chunk().await? { |
| writer |
| .write_all(&chunk[..]) |
| .map_err(|error| DownloadError::InvalidEncoding { encoding, error })?; |
| } |
| writer |
| .flush() |
| .map_err(|error| DownloadError::InvalidEncoding { encoding, error })?; |
| } |
| Ok(output) |
| } |
| |
| /// Download a file into memory as a SourceFile, with in-memory caching |
| pub async fn download_source_file_cached(&self, url: Url) -> Result<SourceFile, DownloadError> { |
| if let Some(source_file) = self.source_file_cache.lock().unwrap().get(&url) { |
| return Ok(source_file.clone()); |
| } |
| |
| let bytes = self.download(url.clone()).await?; |
| match String::from_utf8(bytes) { |
| Ok(string) => { |
| let source_file = SourceFile::new(url.as_str(), string); |
| self.source_file_cache |
| .lock() |
| .unwrap() |
| .insert(url, source_file.clone()); |
| Ok(source_file) |
| } |
| Err(error) => Err(DownloadError::InvalidText { |
| url: Box::new(url), |
| error, |
| }), |
| } |
| } |
| |
| /// Internal core implementation of network fetching which is shared between |
| /// `download` and `download_and_persist`. |
| async fn fetch_core(&self, url: Url) -> Result<Response, DownloadError> { |
| #[cfg(test)] |
| if let Some(mock_network) = &self.mock_network { |
| let chunk = mock_network |
| .get(&url) |
| .cloned() |
| // The error is complete nonsense, but this is test-only. |
| .ok_or_else(|| { |
| tracing::warn!("Attempt to fetch unsupported URL from mock network: {url}"); |
| DownloadError::FailedToWriteDownload { |
| target: url.to_string().into(), |
| error: std::io::Error::new( |
| std::io::ErrorKind::Other, |
| format!("mock network does not support URL: {url}"), |
| ), |
| } |
| })?; |
| return Ok(Response::Mock(Some(chunk))); |
| } |
| |
| let permit = self |
| .connection_semaphore |
| .acquire() |
| .await |
| .expect("Semaphore dropped?!"); |
| |
| let res = self |
| .client |
| .get(url.clone()) |
| .send() |
| .await |
| .and_then(|res| res.error_for_status()) |
| .map_err(|error| DownloadError::FailedToStartDownload { |
| url: Box::new(url.clone()), |
| error, |
| })?; |
| |
| Ok(Response::Real(res, permit)) |
| } |
| } |
| |
| #[cfg(test)] |
| impl Network { |
| /// Create a new Network which is serving mocked out resources. |
| pub(crate) fn new_mock() -> Self { |
| let mut network = Network { |
| client: Client::new(), |
| connection_semaphore: tokio::sync::Semaphore::new(MAX_CONCURRENT_CONNECTIONS), |
| source_file_cache: Default::default(), |
| #[cfg(test)] |
| mock_network: Some(Default::default()), |
| }; |
| // Serve an empty registry by default. |
| network.mock_serve_toml( |
| crate::storage::REGISTRY_URL, |
| &crate::format::RegistryFile::default(), |
| ); |
| network |
| } |
| |
| /// Add a new resource to be served by a mocked-out network. |
| pub(crate) fn mock_serve(&mut self, url: impl AsRef<str>, data: impl AsRef<[u8]>) { |
| self.mock_network |
| .as_mut() |
| .expect("not a mock network") |
| .insert( |
| url.as_ref().parse().unwrap(), |
| Bytes::copy_from_slice(data.as_ref()), |
| ); |
| } |
| |
| /// Add a new toml resource to be served by a mocked-out network. |
| pub(crate) fn mock_serve_toml(&mut self, url: impl AsRef<str>, data: &impl serde::Serialize) { |
| self.mock_serve( |
| url, |
| crate::serialization::to_formatted_toml(data, None) |
| .unwrap() |
| .to_string(), |
| ); |
| } |
| |
| /// Add a new json resource to be served by a mocked-out network. |
| pub(crate) fn mock_serve_json(&mut self, url: impl AsRef<str>, data: &impl serde::Serialize) { |
| self.mock_serve(url, serde_json::to_string(data).unwrap()); |
| } |
| } |