| use crate::{ |
| engine::{general_purpose::INVALID_VALUE, DecodeEstimate, DecodePaddingMode}, |
| DecodeError, PAD_BYTE, |
| }; |
| |
| // decode logic operates on chunks of 8 input bytes without padding |
| const INPUT_CHUNK_LEN: usize = 8; |
| const DECODED_CHUNK_LEN: usize = 6; |
| |
| // we read a u64 and write a u64, but a u64 of input only yields 6 bytes of output, so the last |
| // 2 bytes of any output u64 should not be counted as written to (but must be available in a |
| // slice). |
| const DECODED_CHUNK_SUFFIX: usize = 2; |
| |
| // how many u64's of input to handle at a time |
| const CHUNKS_PER_FAST_LOOP_BLOCK: usize = 4; |
| |
| const INPUT_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * INPUT_CHUNK_LEN; |
| |
| // includes the trailing 2 bytes for the final u64 write |
| const DECODED_BLOCK_LEN: usize = |
| CHUNKS_PER_FAST_LOOP_BLOCK * DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX; |
| |
| #[doc(hidden)] |
| pub struct GeneralPurposeEstimate { |
| /// Total number of decode chunks, including a possibly partial last chunk |
| num_chunks: usize, |
| decoded_len_estimate: usize, |
| } |
| |
| impl GeneralPurposeEstimate { |
| pub(crate) fn new(encoded_len: usize) -> Self { |
| Self { |
| num_chunks: encoded_len |
| .checked_add(INPUT_CHUNK_LEN - 1) |
| .expect("Overflow when calculating number of chunks in input") |
| / INPUT_CHUNK_LEN, |
| decoded_len_estimate: encoded_len |
| .checked_add(3) |
| .expect("Overflow when calculating decoded len estimate") |
| / 4 |
| * 3, |
| } |
| } |
| } |
| |
| impl DecodeEstimate for GeneralPurposeEstimate { |
| fn decoded_len_estimate(&self) -> usize { |
| self.decoded_len_estimate |
| } |
| } |
| |
| /// Helper to avoid duplicating num_chunks calculation, which is costly on short inputs. |
| /// Returns the number of bytes written, or an error. |
| // We're on the fragile edge of compiler heuristics here. If this is not inlined, slow. If this is |
| // inlined(always), a different slow. plain ol' inline makes the benchmarks happiest at the moment, |
| // but this is fragile and the best setting changes with only minor code modifications. |
| #[inline] |
| pub(crate) fn decode_helper( |
| input: &[u8], |
| estimate: GeneralPurposeEstimate, |
| output: &mut [u8], |
| decode_table: &[u8; 256], |
| decode_allow_trailing_bits: bool, |
| padding_mode: DecodePaddingMode, |
| ) -> Result<usize, DecodeError> { |
| let remainder_len = input.len() % INPUT_CHUNK_LEN; |
| |
| // Because the fast decode loop writes in groups of 8 bytes (unrolled to |
| // CHUNKS_PER_FAST_LOOP_BLOCK times 8 bytes, where possible) and outputs 8 bytes at a time (of |
| // which only 6 are valid data), we need to be sure that we stop using the fast decode loop |
| // soon enough that there will always be 2 more bytes of valid data written after that loop. |
| let trailing_bytes_to_skip = match remainder_len { |
| // if input is a multiple of the chunk size, ignore the last chunk as it may have padding, |
| // and the fast decode logic cannot handle padding |
| 0 => INPUT_CHUNK_LEN, |
| // 1 and 5 trailing bytes are illegal: can't decode 6 bits of input into a byte |
| 1 | 5 => { |
| // trailing whitespace is so common that it's worth it to check the last byte to |
| // possibly return a better error message |
| if let Some(b) = input.last() { |
| if *b != PAD_BYTE && decode_table[*b as usize] == INVALID_VALUE { |
| return Err(DecodeError::InvalidByte(input.len() - 1, *b)); |
| } |
| } |
| |
| return Err(DecodeError::InvalidLength); |
| } |
| // This will decode to one output byte, which isn't enough to overwrite the 2 extra bytes |
| // written by the fast decode loop. So, we have to ignore both these 2 bytes and the |
| // previous chunk. |
| 2 => INPUT_CHUNK_LEN + 2, |
| // If this is 3 un-padded chars, then it would actually decode to 2 bytes. However, if this |
| // is an erroneous 2 chars + 1 pad char that would decode to 1 byte, then it should fail |
| // with an error, not panic from going past the bounds of the output slice, so we let it |
| // use stage 3 + 4. |
| 3 => INPUT_CHUNK_LEN + 3, |
| // This can also decode to one output byte because it may be 2 input chars + 2 padding |
| // chars, which would decode to 1 byte. |
| 4 => INPUT_CHUNK_LEN + 4, |
| // Everything else is a legal decode len (given that we don't require padding), and will |
| // decode to at least 2 bytes of output. |
| _ => remainder_len, |
| }; |
| |
| // rounded up to include partial chunks |
| let mut remaining_chunks = estimate.num_chunks; |
| |
| let mut input_index = 0; |
| let mut output_index = 0; |
| |
| { |
| let length_of_fast_decode_chunks = input.len().saturating_sub(trailing_bytes_to_skip); |
| |
| // Fast loop, stage 1 |
| // manual unroll to CHUNKS_PER_FAST_LOOP_BLOCK of u64s to amortize slice bounds checks |
| if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_BLOCK_LEN) { |
| while input_index <= max_start_index { |
| let input_slice = &input[input_index..(input_index + INPUT_BLOCK_LEN)]; |
| let output_slice = &mut output[output_index..(output_index + DECODED_BLOCK_LEN)]; |
| |
| decode_chunk( |
| &input_slice[0..], |
| input_index, |
| decode_table, |
| &mut output_slice[0..], |
| )?; |
| decode_chunk( |
| &input_slice[8..], |
| input_index + 8, |
| decode_table, |
| &mut output_slice[6..], |
| )?; |
| decode_chunk( |
| &input_slice[16..], |
| input_index + 16, |
| decode_table, |
| &mut output_slice[12..], |
| )?; |
| decode_chunk( |
| &input_slice[24..], |
| input_index + 24, |
| decode_table, |
| &mut output_slice[18..], |
| )?; |
| |
| input_index += INPUT_BLOCK_LEN; |
| output_index += DECODED_BLOCK_LEN - DECODED_CHUNK_SUFFIX; |
| remaining_chunks -= CHUNKS_PER_FAST_LOOP_BLOCK; |
| } |
| } |
| |
| // Fast loop, stage 2 (aka still pretty fast loop) |
| // 8 bytes at a time for whatever we didn't do in stage 1. |
| if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_CHUNK_LEN) { |
| while input_index < max_start_index { |
| decode_chunk( |
| &input[input_index..(input_index + INPUT_CHUNK_LEN)], |
| input_index, |
| decode_table, |
| &mut output |
| [output_index..(output_index + DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX)], |
| )?; |
| |
| output_index += DECODED_CHUNK_LEN; |
| input_index += INPUT_CHUNK_LEN; |
| remaining_chunks -= 1; |
| } |
| } |
| } |
| |
| // Stage 3 |
| // If input length was such that a chunk had to be deferred until after the fast loop |
| // because decoding it would have produced 2 trailing bytes that wouldn't then be |
| // overwritten, we decode that chunk here. This way is slower but doesn't write the 2 |
| // trailing bytes. |
| // However, we still need to avoid the last chunk (partial or complete) because it could |
| // have padding, so we always do 1 fewer to avoid the last chunk. |
| for _ in 1..remaining_chunks { |
| decode_chunk_precise( |
| &input[input_index..], |
| input_index, |
| decode_table, |
| &mut output[output_index..(output_index + DECODED_CHUNK_LEN)], |
| )?; |
| |
| input_index += INPUT_CHUNK_LEN; |
| output_index += DECODED_CHUNK_LEN; |
| } |
| |
| // always have one more (possibly partial) block of 8 input |
| debug_assert!(input.len() - input_index > 1 || input.is_empty()); |
| debug_assert!(input.len() - input_index <= 8); |
| |
| super::decode_suffix::decode_suffix( |
| input, |
| input_index, |
| output, |
| output_index, |
| decode_table, |
| decode_allow_trailing_bits, |
| padding_mode, |
| ) |
| } |
| |
| /// Decode 8 bytes of input into 6 bytes of output. 8 bytes of output will be written, but only the |
| /// first 6 of those contain meaningful data. |
| /// |
| /// `input` is the bytes to decode, of which the first 8 bytes will be processed. |
| /// `index_at_start_of_input` is the offset in the overall input (used for reporting errors |
| /// accurately) |
| /// `decode_table` is the lookup table for the particular base64 alphabet. |
| /// `output` will have its first 8 bytes overwritten, of which only the first 6 are valid decoded |
| /// data. |
| // yes, really inline (worth 30-50% speedup) |
| #[inline(always)] |
| fn decode_chunk( |
| input: &[u8], |
| index_at_start_of_input: usize, |
| decode_table: &[u8; 256], |
| output: &mut [u8], |
| ) -> Result<(), DecodeError> { |
| let morsel = decode_table[input[0] as usize]; |
| if morsel == INVALID_VALUE { |
| return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0])); |
| } |
| let mut accum = (morsel as u64) << 58; |
| |
| let morsel = decode_table[input[1] as usize]; |
| if morsel == INVALID_VALUE { |
| return Err(DecodeError::InvalidByte( |
| index_at_start_of_input + 1, |
| input[1], |
| )); |
| } |
| accum |= (morsel as u64) << 52; |
| |
| let morsel = decode_table[input[2] as usize]; |
| if morsel == INVALID_VALUE { |
| return Err(DecodeError::InvalidByte( |
| index_at_start_of_input + 2, |
| input[2], |
| )); |
| } |
| accum |= (morsel as u64) << 46; |
| |
| let morsel = decode_table[input[3] as usize]; |
| if morsel == INVALID_VALUE { |
| return Err(DecodeError::InvalidByte( |
| index_at_start_of_input + 3, |
| input[3], |
| )); |
| } |
| accum |= (morsel as u64) << 40; |
| |
| let morsel = decode_table[input[4] as usize]; |
| if morsel == INVALID_VALUE { |
| return Err(DecodeError::InvalidByte( |
| index_at_start_of_input + 4, |
| input[4], |
| )); |
| } |
| accum |= (morsel as u64) << 34; |
| |
| let morsel = decode_table[input[5] as usize]; |
| if morsel == INVALID_VALUE { |
| return Err(DecodeError::InvalidByte( |
| index_at_start_of_input + 5, |
| input[5], |
| )); |
| } |
| accum |= (morsel as u64) << 28; |
| |
| let morsel = decode_table[input[6] as usize]; |
| if morsel == INVALID_VALUE { |
| return Err(DecodeError::InvalidByte( |
| index_at_start_of_input + 6, |
| input[6], |
| )); |
| } |
| accum |= (morsel as u64) << 22; |
| |
| let morsel = decode_table[input[7] as usize]; |
| if morsel == INVALID_VALUE { |
| return Err(DecodeError::InvalidByte( |
| index_at_start_of_input + 7, |
| input[7], |
| )); |
| } |
| accum |= (morsel as u64) << 16; |
| |
| write_u64(output, accum); |
| |
| Ok(()) |
| } |
| |
| /// Decode an 8-byte chunk, but only write the 6 bytes actually decoded instead of including 2 |
| /// trailing garbage bytes. |
| #[inline] |
| fn decode_chunk_precise( |
| input: &[u8], |
| index_at_start_of_input: usize, |
| decode_table: &[u8; 256], |
| output: &mut [u8], |
| ) -> Result<(), DecodeError> { |
| let mut tmp_buf = [0_u8; 8]; |
| |
| decode_chunk( |
| input, |
| index_at_start_of_input, |
| decode_table, |
| &mut tmp_buf[..], |
| )?; |
| |
| output[0..6].copy_from_slice(&tmp_buf[0..6]); |
| |
| Ok(()) |
| } |
| |
| #[inline] |
| fn write_u64(output: &mut [u8], value: u64) { |
| output[..8].copy_from_slice(&value.to_be_bytes()); |
| } |
| |
| #[cfg(test)] |
| mod tests { |
| use super::*; |
| |
| use crate::engine::general_purpose::STANDARD; |
| |
| #[test] |
| fn decode_chunk_precise_writes_only_6_bytes() { |
| let input = b"Zm9vYmFy"; // "foobar" |
| let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7]; |
| |
| decode_chunk_precise(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); |
| assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 6, 7], &output); |
| } |
| |
| #[test] |
| fn decode_chunk_writes_8_bytes() { |
| let input = b"Zm9vYmFy"; // "foobar" |
| let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7]; |
| |
| decode_chunk(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); |
| assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 0, 0], &output); |
| } |
| } |