blob: c8461c691d56cb9448963193e76d3c280205f41c [file] [log] [blame]
Hyun Jae Moonbb8920e2023-03-29 16:48:48 +00001//! Copy-pasted from the internet
2/// Available encoding character sets
3#[derive(Clone, Copy, Debug)]
4enum _CharacterSet {
5 /// The standard character set (uses `+` and `/`)
6 _Standard,
7 /// The URL safe character set (uses `-` and `_`)
8 _UrlSafe,
9}
10
11static STANDARD_CHARS: &'static [u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ\
12 abcdefghijklmnopqrstuvwxyz\
13 0123456789+/";
14
15static _URLSAFE_CHARS: &'static [u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ\
16 abcdefghijklmnopqrstuvwxyz\
17 0123456789-_";
18
19pub fn encode(input: &[u8]) -> String {
20 let bytes = STANDARD_CHARS;
21
22 let len = input.len();
23
24 // Preallocate memory.
25 let prealloc_len = (len + 2) / 3 * 4;
26 let mut out_bytes = vec![b'='; prealloc_len];
27
28 // Deal with padding bytes
29 let mod_len = len % 3;
30
31 // Use iterators to reduce branching
32 {
33 let mut s_in = input[..len - mod_len].iter().map(|&x| x as u32);
34 let mut s_out = out_bytes.iter_mut();
35
36 // Convenient shorthand
37 let enc = |val| bytes[val as usize];
38 let mut write = |val| *s_out.next().unwrap() = val;
39
40 // Iterate though blocks of 4
41 while let (Some(first), Some(second), Some(third)) = (s_in.next(), s_in.next(), s_in.next())
42 {
43 let n = first << 16 | second << 8 | third;
44
45 // This 24-bit number gets separated into four 6-bit numbers.
46 write(enc((n >> 18) & 63));
47 write(enc((n >> 12) & 63));
48 write(enc((n >> 6) & 63));
49 write(enc((n >> 0) & 63));
50 }
51
52 // Heh, would be cool if we knew this was exhaustive
53 // (the dream of bounded integer types)
54 match mod_len {
55 0 => (),
56 1 => {
57 let n = (input[len - 1] as u32) << 16;
58 write(enc((n >> 18) & 63));
59 write(enc((n >> 12) & 63));
60 }
61 2 => {
62 let n = (input[len - 2] as u32) << 16 | (input[len - 1] as u32) << 8;
63 write(enc((n >> 18) & 63));
64 write(enc((n >> 12) & 63));
65 write(enc((n >> 6) & 63));
66 }
67 _ => panic!("Algebra is broken, please alert the math police"),
68 }
69 }
70
71 // `out_bytes` vec is prepopulated with `=` symbols and then only updated
72 // with base64 chars, so this unsafe is safe.
73 unsafe { String::from_utf8_unchecked(out_bytes) }
74}
75
76/// Errors that can occur when decoding a base64 encoded string
77#[derive(Clone, Copy, Debug, thiserror::Error)]
78pub enum FromBase64Error {
79 /// The input contained a character not part of the base64 format
80 #[error("Invalid base64 byte")]
81 InvalidBase64Byte(u8, usize),
82 /// The input had an invalid length
83 #[error("Invalid base64 length")]
84 InvalidBase64Length,
85}
86
87pub fn decode(input: &str) -> Result<Vec<u8>, FromBase64Error> {
88 let mut r = Vec::with_capacity(input.len());
89 let mut buf: u32 = 0;
90 let mut modulus = 0;
91
92 let mut it = input.as_bytes().iter();
93 for byte in it.by_ref() {
94 let code = DECODE_TABLE[*byte as usize];
95 if code >= SPECIAL_CODES_START {
96 match code {
97 NEWLINE_CODE => continue,
98 EQUALS_CODE => break,
99 INVALID_CODE => {
100 return Err(FromBase64Error::InvalidBase64Byte(
101 *byte,
102 (byte as *const _ as usize) - input.as_ptr() as usize,
103 ))
104 }
105 _ => unreachable!(),
106 }
107 }
108 buf = (buf | code as u32) << 6;
109 modulus += 1;
110 if modulus == 4 {
111 modulus = 0;
112 r.push((buf >> 22) as u8);
113 r.push((buf >> 14) as u8);
114 r.push((buf >> 6) as u8);
115 }
116 }
117
118 for byte in it {
119 match *byte {
120 b'=' | b'\r' | b'\n' => continue,
121 _ => {
122 return Err(FromBase64Error::InvalidBase64Byte(
123 *byte,
124 (byte as *const _ as usize) - input.as_ptr() as usize,
125 ))
126 }
127 }
128 }
129
130 match modulus {
131 2 => {
132 r.push((buf >> 10) as u8);
133 }
134 3 => {
135 r.push((buf >> 16) as u8);
136 r.push((buf >> 8) as u8);
137 }
138 0 => (),
139 _ => return Err(FromBase64Error::InvalidBase64Length),
140 }
141
142 Ok(r)
143}
144
145const DECODE_TABLE: [u8; 256] = [
146 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFD, 0xFF, 0xFF, 0xFD, 0xFF, 0xFF,
147 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
148 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x3E, 0xFF, 0x3E, 0xFF, 0x3F,
149 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, 0xFF, 0xFF, 0xFF, 0xFE, 0xFF, 0xFF,
150 0xFF, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E,
151 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0xFF, 0xFF, 0xFF, 0xFF, 0x3F,
152 0xFF, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28,
153 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30, 0x31, 0x32, 0x33, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
154 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
155 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
156 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
157 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
158 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
159 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
160 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
161 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
162];
163const INVALID_CODE: u8 = 0xFF;
164const EQUALS_CODE: u8 = 0xFE;
165const NEWLINE_CODE: u8 = 0xFD;
166const SPECIAL_CODES_START: u8 = NEWLINE_CODE;
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171
172 #[test]
173 fn test_encode_basic() {
174 assert_eq!(encode(b""), "");
175 assert_eq!(encode(b"f"), "Zg==");
176 assert_eq!(encode(b"fo"), "Zm8=");
177 assert_eq!(encode(b"foo"), "Zm9v");
178 assert_eq!(encode(b"foob"), "Zm9vYg==");
179 assert_eq!(encode(b"fooba"), "Zm9vYmE=");
180 assert_eq!(encode(b"foobar"), "Zm9vYmFy");
181 }
182
183 #[test]
184 fn test_encode_standard_safe() {
185 assert_eq!(encode(&[251, 255]), "+/8=");
186 }
187
188 #[test]
189 fn test_decode_basic() {
190 assert_eq!(decode("").unwrap(), b"");
191 assert_eq!(decode("Zg==").unwrap(), b"f");
192 assert_eq!(decode("Zm8=").unwrap(), b"fo");
193 assert_eq!(decode("Zm9v").unwrap(), b"foo");
194 assert_eq!(decode("Zm9vYg==").unwrap(), b"foob");
195 assert_eq!(decode("Zm9vYmE=").unwrap(), b"fooba");
196 assert_eq!(decode("Zm9vYmFy").unwrap(), b"foobar");
197 }
198
199 #[test]
200 fn test_decode() {
201 assert_eq!(decode("Zm9vYmFy").unwrap(), b"foobar");
202 }
203
204 #[test]
205 fn test_decode_newlines() {
206 assert_eq!(decode("Zm9v\r\nYmFy").unwrap(), b"foobar");
207 assert_eq!(decode("Zm9vYg==\r\n").unwrap(), b"foob");
208 assert_eq!(decode("Zm9v\nYmFy").unwrap(), b"foobar");
209 assert_eq!(decode("Zm9vYg==\n").unwrap(), b"foob");
210 }
211
212 #[test]
213 fn test_decode_urlsafe() {
214 assert_eq!(decode("-_8").unwrap(), decode("+/8=").unwrap());
215 }
216
217 #[test]
218 fn test_from_base64_invalid_char() {
219 assert!(decode("Zm$=").is_err());
220 assert!(decode("Zg==$").is_err());
221 }
222
223 #[test]
224 fn test_decode_invalid_padding() {
225 assert!(decode("Z===").is_err());
226 }
227}