| use crate::{tables, Config, PAD_BYTE}; |
| |
| #[cfg(any(feature = "alloc", feature = "std", test))] |
| use crate::STANDARD; |
| #[cfg(any(feature = "alloc", feature = "std", test))] |
| use alloc::vec::Vec; |
| use core::fmt; |
| #[cfg(any(feature = "std", test))] |
| use std::error; |
| |
| // decode logic operates on chunks of 8 input bytes without padding |
| const INPUT_CHUNK_LEN: usize = 8; |
| const DECODED_CHUNK_LEN: usize = 6; |
| // we read a u64 and write a u64, but a u64 of input only yields 6 bytes of output, so the last |
| // 2 bytes of any output u64 should not be counted as written to (but must be available in a |
| // slice). |
| const DECODED_CHUNK_SUFFIX: usize = 2; |
| |
| // how many u64's of input to handle at a time |
| const CHUNKS_PER_FAST_LOOP_BLOCK: usize = 4; |
| const INPUT_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * INPUT_CHUNK_LEN; |
| // includes the trailing 2 bytes for the final u64 write |
| const DECODED_BLOCK_LEN: usize = |
| CHUNKS_PER_FAST_LOOP_BLOCK * DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX; |
| |
| /// Errors that can occur while decoding. |
| #[derive(Clone, Debug, PartialEq, Eq)] |
| pub enum DecodeError { |
| /// An invalid byte was found in the input. The offset and offending byte are provided. |
| InvalidByte(usize, u8), |
| /// The length of the input is invalid. |
| /// A typical cause of this is stray trailing whitespace or other separator bytes. |
| /// In the case where excess trailing bytes have produced an invalid length *and* the last byte |
| /// is also an invalid base64 symbol (as would be the case for whitespace, etc), `InvalidByte` |
| /// will be emitted instead of `InvalidLength` to make the issue easier to debug. |
| InvalidLength, |
| /// The last non-padding input symbol's encoded 6 bits have nonzero bits that will be discarded. |
| /// This is indicative of corrupted or truncated Base64. |
| /// Unlike InvalidByte, which reports symbols that aren't in the alphabet, this error is for |
| /// symbols that are in the alphabet but represent nonsensical encodings. |
| InvalidLastSymbol(usize, u8), |
| } |
| |
| impl fmt::Display for DecodeError { |
| fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
| match *self { |
| DecodeError::InvalidByte(index, byte) => { |
| write!(f, "Invalid byte {}, offset {}.", byte, index) |
| } |
| DecodeError::InvalidLength => write!(f, "Encoded text cannot have a 6-bit remainder."), |
| DecodeError::InvalidLastSymbol(index, byte) => { |
| write!(f, "Invalid last symbol {}, offset {}.", byte, index) |
| } |
| } |
| } |
| } |
| |
| #[cfg(any(feature = "std", test))] |
| impl error::Error for DecodeError { |
| fn description(&self) -> &str { |
| match *self { |
| DecodeError::InvalidByte(_, _) => "invalid byte", |
| DecodeError::InvalidLength => "invalid length", |
| DecodeError::InvalidLastSymbol(_, _) => "invalid last symbol", |
| } |
| } |
| |
| fn cause(&self) -> Option<&dyn error::Error> { |
| None |
| } |
| } |
| |
| ///Decode from string reference as octets. |
| ///Returns a Result containing a Vec<u8>. |
| ///Convenience `decode_config(input, base64::STANDARD);`. |
| /// |
| ///# Example |
| /// |
| ///```rust |
| ///extern crate base64; |
| /// |
| ///fn main() { |
| /// let bytes = base64::decode("aGVsbG8gd29ybGQ=").unwrap(); |
| /// println!("{:?}", bytes); |
| ///} |
| ///``` |
| #[cfg(any(feature = "alloc", feature = "std", test))] |
| pub fn decode<T: AsRef<[u8]>>(input: T) -> Result<Vec<u8>, DecodeError> { |
| decode_config(input, STANDARD) |
| } |
| |
| ///Decode from string reference as octets. |
| ///Returns a Result containing a Vec<u8>. |
| /// |
| ///# Example |
| /// |
| ///```rust |
| ///extern crate base64; |
| /// |
| ///fn main() { |
| /// let bytes = base64::decode_config("aGVsbG8gd29ybGR+Cg==", base64::STANDARD).unwrap(); |
| /// println!("{:?}", bytes); |
| /// |
| /// let bytes_url = base64::decode_config("aGVsbG8gaW50ZXJuZXR-Cg==", base64::URL_SAFE).unwrap(); |
| /// println!("{:?}", bytes_url); |
| ///} |
| ///``` |
| #[cfg(any(feature = "alloc", feature = "std", test))] |
| pub fn decode_config<T: AsRef<[u8]>>(input: T, config: Config) -> Result<Vec<u8>, DecodeError> { |
| let mut buffer = Vec::<u8>::with_capacity(input.as_ref().len() * 4 / 3); |
| |
| decode_config_buf(input, config, &mut buffer).map(|_| buffer) |
| } |
| |
| ///Decode from string reference as octets. |
| ///Writes into the supplied buffer to avoid allocation. |
| ///Returns a Result containing an empty tuple, aka (). |
| /// |
| ///# Example |
| /// |
| ///```rust |
| ///extern crate base64; |
| /// |
| ///fn main() { |
| /// let mut buffer = Vec::<u8>::new(); |
| /// base64::decode_config_buf("aGVsbG8gd29ybGR+Cg==", base64::STANDARD, &mut buffer).unwrap(); |
| /// println!("{:?}", buffer); |
| /// |
| /// buffer.clear(); |
| /// |
| /// base64::decode_config_buf("aGVsbG8gaW50ZXJuZXR-Cg==", base64::URL_SAFE, &mut buffer) |
| /// .unwrap(); |
| /// println!("{:?}", buffer); |
| ///} |
| ///``` |
| #[cfg(any(feature = "alloc", feature = "std", test))] |
| pub fn decode_config_buf<T: AsRef<[u8]>>( |
| input: T, |
| config: Config, |
| buffer: &mut Vec<u8>, |
| ) -> Result<(), DecodeError> { |
| let input_bytes = input.as_ref(); |
| |
| let starting_output_len = buffer.len(); |
| |
| let num_chunks = num_chunks(input_bytes); |
| let decoded_len_estimate = num_chunks |
| .checked_mul(DECODED_CHUNK_LEN) |
| .and_then(|p| p.checked_add(starting_output_len)) |
| .expect("Overflow when calculating output buffer length"); |
| buffer.resize(decoded_len_estimate, 0); |
| |
| let bytes_written; |
| { |
| let buffer_slice = &mut buffer.as_mut_slice()[starting_output_len..]; |
| bytes_written = decode_helper(input_bytes, num_chunks, config, buffer_slice)?; |
| } |
| |
| buffer.truncate(starting_output_len + bytes_written); |
| |
| Ok(()) |
| } |
| |
| /// Decode the input into the provided output slice. |
| /// |
| /// This will not write any bytes past exactly what is decoded (no stray garbage bytes at the end). |
| /// |
| /// If you don't know ahead of time what the decoded length should be, size your buffer with a |
| /// conservative estimate for the decoded length of an input: 3 bytes of output for every 4 bytes of |
| /// input, rounded up, or in other words `(input_len + 3) / 4 * 3`. |
| /// |
| /// If the slice is not large enough, this will panic. |
| pub fn decode_config_slice<T: AsRef<[u8]>>( |
| input: T, |
| config: Config, |
| output: &mut [u8], |
| ) -> Result<usize, DecodeError> { |
| let input_bytes = input.as_ref(); |
| |
| decode_helper(input_bytes, num_chunks(input_bytes), config, output) |
| } |
| |
| /// Return the number of input chunks (including a possibly partial final chunk) in the input |
| fn num_chunks(input: &[u8]) -> usize { |
| input |
| .len() |
| .checked_add(INPUT_CHUNK_LEN - 1) |
| .expect("Overflow when calculating number of chunks in input") |
| / INPUT_CHUNK_LEN |
| } |
| |
| /// Helper to avoid duplicating num_chunks calculation, which is costly on short inputs. |
| /// Returns the number of bytes written, or an error. |
| // We're on the fragile edge of compiler heuristics here. If this is not inlined, slow. If this is |
| // inlined(always), a different slow. plain ol' inline makes the benchmarks happiest at the moment, |
| // but this is fragile and the best setting changes with only minor code modifications. |
| #[inline] |
| fn decode_helper( |
| input: &[u8], |
| num_chunks: usize, |
| config: Config, |
| output: &mut [u8], |
| ) -> Result<usize, DecodeError> { |
| let char_set = config.char_set; |
| let decode_table = char_set.decode_table(); |
| |
| let remainder_len = input.len() % INPUT_CHUNK_LEN; |
| |
| // Because the fast decode loop writes in groups of 8 bytes (unrolled to |
| // CHUNKS_PER_FAST_LOOP_BLOCK times 8 bytes, where possible) and outputs 8 bytes at a time (of |
| // which only 6 are valid data), we need to be sure that we stop using the fast decode loop |
| // soon enough that there will always be 2 more bytes of valid data written after that loop. |
| let trailing_bytes_to_skip = match remainder_len { |
| // if input is a multiple of the chunk size, ignore the last chunk as it may have padding, |
| // and the fast decode logic cannot handle padding |
| 0 => INPUT_CHUNK_LEN, |
| // 1 and 5 trailing bytes are illegal: can't decode 6 bits of input into a byte |
| 1 | 5 => { |
| // trailing whitespace is so common that it's worth it to check the last byte to |
| // possibly return a better error message |
| if let Some(b) = input.last() { |
| if *b != PAD_BYTE && decode_table[*b as usize] == tables::INVALID_VALUE { |
| return Err(DecodeError::InvalidByte(input.len() - 1, *b)); |
| } |
| } |
| |
| return Err(DecodeError::InvalidLength); |
| } |
| // This will decode to one output byte, which isn't enough to overwrite the 2 extra bytes |
| // written by the fast decode loop. So, we have to ignore both these 2 bytes and the |
| // previous chunk. |
| 2 => INPUT_CHUNK_LEN + 2, |
| // If this is 3 unpadded chars, then it would actually decode to 2 bytes. However, if this |
| // is an erroneous 2 chars + 1 pad char that would decode to 1 byte, then it should fail |
| // with an error, not panic from going past the bounds of the output slice, so we let it |
| // use stage 3 + 4. |
| 3 => INPUT_CHUNK_LEN + 3, |
| // This can also decode to one output byte because it may be 2 input chars + 2 padding |
| // chars, which would decode to 1 byte. |
| 4 => INPUT_CHUNK_LEN + 4, |
| // Everything else is a legal decode len (given that we don't require padding), and will |
| // decode to at least 2 bytes of output. |
| _ => remainder_len, |
| }; |
| |
| // rounded up to include partial chunks |
| let mut remaining_chunks = num_chunks; |
| |
| let mut input_index = 0; |
| let mut output_index = 0; |
| |
| { |
| let length_of_fast_decode_chunks = input.len().saturating_sub(trailing_bytes_to_skip); |
| |
| // Fast loop, stage 1 |
| // manual unroll to CHUNKS_PER_FAST_LOOP_BLOCK of u64s to amortize slice bounds checks |
| if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_BLOCK_LEN) { |
| while input_index <= max_start_index { |
| let input_slice = &input[input_index..(input_index + INPUT_BLOCK_LEN)]; |
| let output_slice = &mut output[output_index..(output_index + DECODED_BLOCK_LEN)]; |
| |
| decode_chunk( |
| &input_slice[0..], |
| input_index, |
| decode_table, |
| &mut output_slice[0..], |
| )?; |
| decode_chunk( |
| &input_slice[8..], |
| input_index + 8, |
| decode_table, |
| &mut output_slice[6..], |
| )?; |
| decode_chunk( |
| &input_slice[16..], |
| input_index + 16, |
| decode_table, |
| &mut output_slice[12..], |
| )?; |
| decode_chunk( |
| &input_slice[24..], |
| input_index + 24, |
| decode_table, |
| &mut output_slice[18..], |
| )?; |
| |
| input_index += INPUT_BLOCK_LEN; |
| output_index += DECODED_BLOCK_LEN - DECODED_CHUNK_SUFFIX; |
| remaining_chunks -= CHUNKS_PER_FAST_LOOP_BLOCK; |
| } |
| } |
| |
| // Fast loop, stage 2 (aka still pretty fast loop) |
| // 8 bytes at a time for whatever we didn't do in stage 1. |
| if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_CHUNK_LEN) { |
| while input_index < max_start_index { |
| decode_chunk( |
| &input[input_index..(input_index + INPUT_CHUNK_LEN)], |
| input_index, |
| decode_table, |
| &mut output |
| [output_index..(output_index + DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX)], |
| )?; |
| |
| output_index += DECODED_CHUNK_LEN; |
| input_index += INPUT_CHUNK_LEN; |
| remaining_chunks -= 1; |
| } |
| } |
| } |
| |
| // Stage 3 |
| // If input length was such that a chunk had to be deferred until after the fast loop |
| // because decoding it would have produced 2 trailing bytes that wouldn't then be |
| // overwritten, we decode that chunk here. This way is slower but doesn't write the 2 |
| // trailing bytes. |
| // However, we still need to avoid the last chunk (partial or complete) because it could |
| // have padding, so we always do 1 fewer to avoid the last chunk. |
| for _ in 1..remaining_chunks { |
| decode_chunk_precise( |
| &input[input_index..], |
| input_index, |
| decode_table, |
| &mut output[output_index..(output_index + DECODED_CHUNK_LEN)], |
| )?; |
| |
| input_index += INPUT_CHUNK_LEN; |
| output_index += DECODED_CHUNK_LEN; |
| } |
| |
| // always have one more (possibly partial) block of 8 input |
| debug_assert!(input.len() - input_index > 1 || input.is_empty()); |
| debug_assert!(input.len() - input_index <= 8); |
| |
| // Stage 4 |
| // Finally, decode any leftovers that aren't a complete input block of 8 bytes. |
| // Use a u64 as a stack-resident 8 byte buffer. |
| let mut leftover_bits: u64 = 0; |
| let mut morsels_in_leftover = 0; |
| let mut padding_bytes = 0; |
| let mut first_padding_index: usize = 0; |
| let mut last_symbol = 0_u8; |
| let start_of_leftovers = input_index; |
| for (i, b) in input[start_of_leftovers..].iter().enumerate() { |
| // '=' padding |
| if *b == PAD_BYTE { |
| // There can be bad padding in a few ways: |
| // 1 - Padding with non-padding characters after it |
| // 2 - Padding after zero or one non-padding characters before it |
| // in the current quad. |
| // 3 - More than two characters of padding. If 3 or 4 padding chars |
| // are in the same quad, that implies it will be caught by #2. |
| // If it spreads from one quad to another, it will be caught by |
| // #2 in the second quad. |
| |
| if i % 4 < 2 { |
| // Check for case #2. |
| let bad_padding_index = start_of_leftovers |
| + if padding_bytes > 0 { |
| // If we've already seen padding, report the first padding index. |
| // This is to be consistent with the faster logic above: it will report an |
| // error on the first padding character (since it doesn't expect to see |
| // anything but actual encoded data). |
| first_padding_index |
| } else { |
| // haven't seen padding before, just use where we are now |
| i |
| }; |
| return Err(DecodeError::InvalidByte(bad_padding_index, *b)); |
| } |
| |
| if padding_bytes == 0 { |
| first_padding_index = i; |
| } |
| |
| padding_bytes += 1; |
| continue; |
| } |
| |
| // Check for case #1. |
| // To make '=' handling consistent with the main loop, don't allow |
| // non-suffix '=' in trailing chunk either. Report error as first |
| // erroneous padding. |
| if padding_bytes > 0 { |
| return Err(DecodeError::InvalidByte( |
| start_of_leftovers + first_padding_index, |
| PAD_BYTE, |
| )); |
| } |
| last_symbol = *b; |
| |
| // can use up to 8 * 6 = 48 bits of the u64, if last chunk has no padding. |
| // To minimize shifts, pack the leftovers from left to right. |
| let shift = 64 - (morsels_in_leftover + 1) * 6; |
| // tables are all 256 elements, lookup with a u8 index always succeeds |
| let morsel = decode_table[*b as usize]; |
| if morsel == tables::INVALID_VALUE { |
| return Err(DecodeError::InvalidByte(start_of_leftovers + i, *b)); |
| } |
| |
| leftover_bits |= (morsel as u64) << shift; |
| morsels_in_leftover += 1; |
| } |
| |
| let leftover_bits_ready_to_append = match morsels_in_leftover { |
| 0 => 0, |
| 2 => 8, |
| 3 => 16, |
| 4 => 24, |
| 6 => 32, |
| 7 => 40, |
| 8 => 48, |
| _ => unreachable!( |
| "Impossible: must only have 0 to 8 input bytes in last chunk, with no invalid lengths" |
| ), |
| }; |
| |
| // if there are bits set outside the bits we care about, last symbol encodes trailing bits that |
| // will not be included in the output |
| let mask = !0 >> leftover_bits_ready_to_append; |
| if !config.decode_allow_trailing_bits && (leftover_bits & mask) != 0 { |
| // last morsel is at `morsels_in_leftover` - 1 |
| return Err(DecodeError::InvalidLastSymbol( |
| start_of_leftovers + morsels_in_leftover - 1, |
| last_symbol, |
| )); |
| } |
| |
| let mut leftover_bits_appended_to_buf = 0; |
| while leftover_bits_appended_to_buf < leftover_bits_ready_to_append { |
| // `as` simply truncates the higher bits, which is what we want here |
| let selected_bits = (leftover_bits >> (56 - leftover_bits_appended_to_buf)) as u8; |
| output[output_index] = selected_bits; |
| output_index += 1; |
| |
| leftover_bits_appended_to_buf += 8; |
| } |
| |
| Ok(output_index) |
| } |
| |
| #[inline] |
| fn write_u64(output: &mut [u8], value: u64) { |
| output[..8].copy_from_slice(&value.to_be_bytes()); |
| } |
| |
| /// Decode 8 bytes of input into 6 bytes of output. 8 bytes of output will be written, but only the |
| /// first 6 of those contain meaningful data. |
| /// |
| /// `input` is the bytes to decode, of which the first 8 bytes will be processed. |
| /// `index_at_start_of_input` is the offset in the overall input (used for reporting errors |
| /// accurately) |
| /// `decode_table` is the lookup table for the particular base64 alphabet. |
| /// `output` will have its first 8 bytes overwritten, of which only the first 6 are valid decoded |
| /// data. |
| // yes, really inline (worth 30-50% speedup) |
| #[inline(always)] |
| fn decode_chunk( |
| input: &[u8], |
| index_at_start_of_input: usize, |
| decode_table: &[u8; 256], |
| output: &mut [u8], |
| ) -> Result<(), DecodeError> { |
| let mut accum: u64; |
| |
| let morsel = decode_table[input[0] as usize]; |
| if morsel == tables::INVALID_VALUE { |
| return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0])); |
| } |
| accum = (morsel as u64) << 58; |
| |
| let morsel = decode_table[input[1] as usize]; |
| if morsel == tables::INVALID_VALUE { |
| return Err(DecodeError::InvalidByte( |
| index_at_start_of_input + 1, |
| input[1], |
| )); |
| } |
| accum |= (morsel as u64) << 52; |
| |
| let morsel = decode_table[input[2] as usize]; |
| if morsel == tables::INVALID_VALUE { |
| return Err(DecodeError::InvalidByte( |
| index_at_start_of_input + 2, |
| input[2], |
| )); |
| } |
| accum |= (morsel as u64) << 46; |
| |
| let morsel = decode_table[input[3] as usize]; |
| if morsel == tables::INVALID_VALUE { |
| return Err(DecodeError::InvalidByte( |
| index_at_start_of_input + 3, |
| input[3], |
| )); |
| } |
| accum |= (morsel as u64) << 40; |
| |
| let morsel = decode_table[input[4] as usize]; |
| if morsel == tables::INVALID_VALUE { |
| return Err(DecodeError::InvalidByte( |
| index_at_start_of_input + 4, |
| input[4], |
| )); |
| } |
| accum |= (morsel as u64) << 34; |
| |
| let morsel = decode_table[input[5] as usize]; |
| if morsel == tables::INVALID_VALUE { |
| return Err(DecodeError::InvalidByte( |
| index_at_start_of_input + 5, |
| input[5], |
| )); |
| } |
| accum |= (morsel as u64) << 28; |
| |
| let morsel = decode_table[input[6] as usize]; |
| if morsel == tables::INVALID_VALUE { |
| return Err(DecodeError::InvalidByte( |
| index_at_start_of_input + 6, |
| input[6], |
| )); |
| } |
| accum |= (morsel as u64) << 22; |
| |
| let morsel = decode_table[input[7] as usize]; |
| if morsel == tables::INVALID_VALUE { |
| return Err(DecodeError::InvalidByte( |
| index_at_start_of_input + 7, |
| input[7], |
| )); |
| } |
| accum |= (morsel as u64) << 16; |
| |
| write_u64(output, accum); |
| |
| Ok(()) |
| } |
| |
| /// Decode an 8-byte chunk, but only write the 6 bytes actually decoded instead of including 2 |
| /// trailing garbage bytes. |
| #[inline] |
| fn decode_chunk_precise( |
| input: &[u8], |
| index_at_start_of_input: usize, |
| decode_table: &[u8; 256], |
| output: &mut [u8], |
| ) -> Result<(), DecodeError> { |
| let mut tmp_buf = [0_u8; 8]; |
| |
| decode_chunk( |
| input, |
| index_at_start_of_input, |
| decode_table, |
| &mut tmp_buf[..], |
| )?; |
| |
| output[0..6].copy_from_slice(&tmp_buf[0..6]); |
| |
| Ok(()) |
| } |
| |
| #[cfg(test)] |
| mod tests { |
| use super::*; |
| use crate::{ |
| encode::encode_config_buf, |
| encode::encode_config_slice, |
| tests::{assert_encode_sanity, random_config}, |
| }; |
| |
| use rand::{ |
| distributions::{Distribution, Uniform}, |
| FromEntropy, Rng, |
| }; |
| |
| #[test] |
| fn decode_chunk_precise_writes_only_6_bytes() { |
| let input = b"Zm9vYmFy"; // "foobar" |
| let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7]; |
| decode_chunk_precise(&input[..], 0, tables::STANDARD_DECODE, &mut output).unwrap(); |
| assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 6, 7], &output); |
| } |
| |
| #[test] |
| fn decode_chunk_writes_8_bytes() { |
| let input = b"Zm9vYmFy"; // "foobar" |
| let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7]; |
| decode_chunk(&input[..], 0, tables::STANDARD_DECODE, &mut output).unwrap(); |
| assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 0, 0], &output); |
| } |
| |
| #[test] |
| fn decode_into_nonempty_vec_doesnt_clobber_existing_prefix() { |
| let mut orig_data = Vec::new(); |
| let mut encoded_data = String::new(); |
| let mut decoded_with_prefix = Vec::new(); |
| let mut decoded_without_prefix = Vec::new(); |
| let mut prefix = Vec::new(); |
| |
| let prefix_len_range = Uniform::new(0, 1000); |
| let input_len_range = Uniform::new(0, 1000); |
| |
| let mut rng = rand::rngs::SmallRng::from_entropy(); |
| |
| for _ in 0..10_000 { |
| orig_data.clear(); |
| encoded_data.clear(); |
| decoded_with_prefix.clear(); |
| decoded_without_prefix.clear(); |
| prefix.clear(); |
| |
| let input_len = input_len_range.sample(&mut rng); |
| |
| for _ in 0..input_len { |
| orig_data.push(rng.gen()); |
| } |
| |
| let config = random_config(&mut rng); |
| encode_config_buf(&orig_data, config, &mut encoded_data); |
| assert_encode_sanity(&encoded_data, config, input_len); |
| |
| let prefix_len = prefix_len_range.sample(&mut rng); |
| |
| // fill the buf with a prefix |
| for _ in 0..prefix_len { |
| prefix.push(rng.gen()); |
| } |
| |
| decoded_with_prefix.resize(prefix_len, 0); |
| decoded_with_prefix.copy_from_slice(&prefix); |
| |
| // decode into the non-empty buf |
| decode_config_buf(&encoded_data, config, &mut decoded_with_prefix).unwrap(); |
| // also decode into the empty buf |
| decode_config_buf(&encoded_data, config, &mut decoded_without_prefix).unwrap(); |
| |
| assert_eq!( |
| prefix_len + decoded_without_prefix.len(), |
| decoded_with_prefix.len() |
| ); |
| assert_eq!(orig_data, decoded_without_prefix); |
| |
| // append plain decode onto prefix |
| prefix.append(&mut decoded_without_prefix); |
| |
| assert_eq!(prefix, decoded_with_prefix); |
| } |
| } |
| |
| #[test] |
| fn decode_into_slice_doesnt_clobber_existing_prefix_or_suffix() { |
| let mut orig_data = Vec::new(); |
| let mut encoded_data = String::new(); |
| let mut decode_buf = Vec::new(); |
| let mut decode_buf_copy: Vec<u8> = Vec::new(); |
| |
| let input_len_range = Uniform::new(0, 1000); |
| |
| let mut rng = rand::rngs::SmallRng::from_entropy(); |
| |
| for _ in 0..10_000 { |
| orig_data.clear(); |
| encoded_data.clear(); |
| decode_buf.clear(); |
| decode_buf_copy.clear(); |
| |
| let input_len = input_len_range.sample(&mut rng); |
| |
| for _ in 0..input_len { |
| orig_data.push(rng.gen()); |
| } |
| |
| let config = random_config(&mut rng); |
| encode_config_buf(&orig_data, config, &mut encoded_data); |
| assert_encode_sanity(&encoded_data, config, input_len); |
| |
| // fill the buffer with random garbage, long enough to have some room before and after |
| for _ in 0..5000 { |
| decode_buf.push(rng.gen()); |
| } |
| |
| // keep a copy for later comparison |
| decode_buf_copy.extend(decode_buf.iter()); |
| |
| let offset = 1000; |
| |
| // decode into the non-empty buf |
| let decode_bytes_written = |
| decode_config_slice(&encoded_data, config, &mut decode_buf[offset..]).unwrap(); |
| |
| assert_eq!(orig_data.len(), decode_bytes_written); |
| assert_eq!( |
| orig_data, |
| &decode_buf[offset..(offset + decode_bytes_written)] |
| ); |
| assert_eq!(&decode_buf_copy[0..offset], &decode_buf[0..offset]); |
| assert_eq!( |
| &decode_buf_copy[offset + decode_bytes_written..], |
| &decode_buf[offset + decode_bytes_written..] |
| ); |
| } |
| } |
| |
| #[test] |
| fn decode_into_slice_fits_in_precisely_sized_slice() { |
| let mut orig_data = Vec::new(); |
| let mut encoded_data = String::new(); |
| let mut decode_buf = Vec::new(); |
| |
| let input_len_range = Uniform::new(0, 1000); |
| |
| let mut rng = rand::rngs::SmallRng::from_entropy(); |
| |
| for _ in 0..10_000 { |
| orig_data.clear(); |
| encoded_data.clear(); |
| decode_buf.clear(); |
| |
| let input_len = input_len_range.sample(&mut rng); |
| |
| for _ in 0..input_len { |
| orig_data.push(rng.gen()); |
| } |
| |
| let config = random_config(&mut rng); |
| encode_config_buf(&orig_data, config, &mut encoded_data); |
| assert_encode_sanity(&encoded_data, config, input_len); |
| |
| decode_buf.resize(input_len, 0); |
| |
| // decode into the non-empty buf |
| let decode_bytes_written = |
| decode_config_slice(&encoded_data, config, &mut decode_buf[..]).unwrap(); |
| |
| assert_eq!(orig_data.len(), decode_bytes_written); |
| assert_eq!(orig_data, decode_buf); |
| } |
| } |
| |
| #[test] |
| fn detect_invalid_last_symbol_two_bytes() { |
| let decode = |
| |input, forgiving| decode_config(input, STANDARD.decode_allow_trailing_bits(forgiving)); |
| |
| // example from https://github.com/marshallpierce/rust-base64/issues/75 |
| assert!(decode("iYU=", false).is_ok()); |
| // trailing 01 |
| assert_eq!( |
| Err(DecodeError::InvalidLastSymbol(2, b'V')), |
| decode("iYV=", false) |
| ); |
| assert_eq!(Ok(vec![137, 133]), decode("iYV=", true)); |
| // trailing 10 |
| assert_eq!( |
| Err(DecodeError::InvalidLastSymbol(2, b'W')), |
| decode("iYW=", false) |
| ); |
| assert_eq!(Ok(vec![137, 133]), decode("iYV=", true)); |
| // trailing 11 |
| assert_eq!( |
| Err(DecodeError::InvalidLastSymbol(2, b'X')), |
| decode("iYX=", false) |
| ); |
| assert_eq!(Ok(vec![137, 133]), decode("iYV=", true)); |
| |
| // also works when there are 2 quads in the last block |
| assert_eq!( |
| Err(DecodeError::InvalidLastSymbol(6, b'X')), |
| decode("AAAAiYX=", false) |
| ); |
| assert_eq!(Ok(vec![0, 0, 0, 137, 133]), decode("AAAAiYX=", true)); |
| } |
| |
| #[test] |
| fn detect_invalid_last_symbol_one_byte() { |
| // 0xFF -> "/w==", so all letters > w, 0-9, and '+', '/' should get InvalidLastSymbol |
| |
| assert!(decode("/w==").is_ok()); |
| // trailing 01 |
| assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'x')), decode("/x==")); |
| assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'z')), decode("/z==")); |
| assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'0')), decode("/0==")); |
| assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'9')), decode("/9==")); |
| assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'+')), decode("/+==")); |
| assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'/')), decode("//==")); |
| |
| // also works when there are 2 quads in the last block |
| assert_eq!( |
| Err(DecodeError::InvalidLastSymbol(5, b'x')), |
| decode("AAAA/x==") |
| ); |
| } |
| |
| #[test] |
| fn detect_invalid_last_symbol_every_possible_three_symbols() { |
| let mut base64_to_bytes = ::std::collections::HashMap::new(); |
| |
| let mut bytes = [0_u8; 2]; |
| for b1 in 0_u16..256 { |
| bytes[0] = b1 as u8; |
| for b2 in 0_u16..256 { |
| bytes[1] = b2 as u8; |
| let mut b64 = vec![0_u8; 4]; |
| assert_eq!(4, encode_config_slice(&bytes, STANDARD, &mut b64[..])); |
| let mut v = ::std::vec::Vec::with_capacity(2); |
| v.extend_from_slice(&bytes[..]); |
| |
| assert!(base64_to_bytes.insert(b64, v).is_none()); |
| } |
| } |
| |
| // every possible combination of symbols must either decode to 2 bytes or get InvalidLastSymbol |
| |
| let mut symbols = [0_u8; 4]; |
| for &s1 in STANDARD.char_set.encode_table().iter() { |
| symbols[0] = s1; |
| for &s2 in STANDARD.char_set.encode_table().iter() { |
| symbols[1] = s2; |
| for &s3 in STANDARD.char_set.encode_table().iter() { |
| symbols[2] = s3; |
| symbols[3] = PAD_BYTE; |
| |
| match base64_to_bytes.get(&symbols[..]) { |
| Some(bytes) => { |
| assert_eq!(Ok(bytes.to_vec()), decode_config(&symbols, STANDARD)) |
| } |
| None => assert_eq!( |
| Err(DecodeError::InvalidLastSymbol(2, s3)), |
| decode_config(&symbols[..], STANDARD) |
| ), |
| } |
| } |
| } |
| } |
| } |
| |
| #[test] |
| fn detect_invalid_last_symbol_every_possible_two_symbols() { |
| let mut base64_to_bytes = ::std::collections::HashMap::new(); |
| |
| for b in 0_u16..256 { |
| let mut b64 = vec![0_u8; 4]; |
| assert_eq!(4, encode_config_slice(&[b as u8], STANDARD, &mut b64[..])); |
| let mut v = ::std::vec::Vec::with_capacity(1); |
| v.push(b as u8); |
| |
| assert!(base64_to_bytes.insert(b64, v).is_none()); |
| } |
| |
| // every possible combination of symbols must either decode to 1 byte or get InvalidLastSymbol |
| |
| let mut symbols = [0_u8; 4]; |
| for &s1 in STANDARD.char_set.encode_table().iter() { |
| symbols[0] = s1; |
| for &s2 in STANDARD.char_set.encode_table().iter() { |
| symbols[1] = s2; |
| symbols[2] = PAD_BYTE; |
| symbols[3] = PAD_BYTE; |
| |
| match base64_to_bytes.get(&symbols[..]) { |
| Some(bytes) => { |
| assert_eq!(Ok(bytes.to_vec()), decode_config(&symbols, STANDARD)) |
| } |
| None => assert_eq!( |
| Err(DecodeError::InvalidLastSymbol(1, s2)), |
| decode_config(&symbols[..], STANDARD) |
| ), |
| } |
| } |
| } |
| } |
| } |