blob: d044edfa1a7e4458d5cc718a2e9ffb568d6a2c34 [file] [log] [blame]
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Generate Rust unit tests for canonical test vectors.
use quote::{format_ident, quote};
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Deserialize)]
struct Packet {
#[serde(rename = "packet")]
name: String,
tests: Vec<TestVector>,
}
#[derive(Debug, Deserialize)]
struct TestVector {
packed: String,
unpacked: Value,
packet: Option<String>,
}
/// Convert a string of hexadecimal characters into a Rust vector of
/// bytes.
///
/// The string `"80038302"` becomes `vec![0x80, 0x03, 0x83, 0x02]`.
fn hexadecimal_to_vec(hex: &str) -> proc_macro2::TokenStream {
assert!(hex.len() % 2 == 0, "Expects an even number of hex digits");
let bytes = hex.as_bytes().chunks_exact(2).map(|chunk| {
let number = format!("0x{}", std::str::from_utf8(chunk).unwrap());
syn::parse_str::<syn::LitInt>(&number).unwrap()
});
quote! {
vec![#(#bytes),*]
}
}
/// Convert `value` to a JSON string literal.
///
/// The string literal is a raw literal to avoid escaping
/// double-quotes.
fn to_json<T: Serialize>(value: &T) -> syn::LitStr {
let json = serde_json::to_string(value).unwrap();
assert!(!json.contains("\"#"), "Please increase number of # for {json:?}");
syn::parse_str::<syn::LitStr>(&format!("r#\" {json} \"#")).unwrap()
}
fn generate_unit_tests(input: &str, packet_names: &[&str], module_name: &str) {
eprintln!("Reading test vectors from {input}, will use {} packets", packet_names.len());
let data = std::fs::read_to_string(input)
.unwrap_or_else(|err| panic!("Could not read {input}: {err}"));
let packets: Vec<Packet> = serde_json::from_str(&data).expect("Could not parse JSON");
let module = syn::parse_str::<syn::Path>(module_name).unwrap();
let mut tests = Vec::new();
for packet in &packets {
for (i, test_vector) in packet.tests.iter().enumerate() {
let test_packet = test_vector.packet.as_deref().unwrap_or(packet.name.as_str());
if !packet_names.contains(&test_packet) {
eprintln!("Skipping packet {}", test_packet);
continue;
}
eprintln!("Generating tests for packet {}", test_packet);
let parse_test_name = format_ident!(
"test_parse_{}_vector_{}_0x{}",
test_packet,
i + 1,
&test_vector.packed
);
let serialize_test_name = format_ident!(
"test_serialize_{}_vector_{}_0x{}",
test_packet,
i + 1,
&test_vector.packed
);
let packed = hexadecimal_to_vec(&test_vector.packed);
let packet_name = format_ident!("{}", test_packet);
let builder_name = format_ident!("{}Builder", test_packet);
let object = test_vector.unpacked.as_object().unwrap_or_else(|| {
panic!("Expected test vector object, found: {}", test_vector.unpacked)
});
let assertions = object.iter().map(|(key, value)| {
let getter = format_ident!("get_{key}");
let expected = format_ident!("expected_{key}");
let json = to_json(&value);
quote! {
let #expected: serde_json::Value = serde_json::from_str(#json)
.expect("Could not create expected value from canonical JSON data");
assert_eq!(json!(actual.#getter()), #expected);
}
});
let json = to_json(&object);
tests.push(quote! {
#[test]
fn #parse_test_name() {
let packed = #packed;
let actual = #module::#packet_name::parse(&packed).unwrap();
#(#assertions)*
}
#[test]
fn #serialize_test_name() {
let builder: #module::#builder_name = serde_json::from_str(#json)
.expect("Could not create builder from canonical JSON data");
let packet = builder.build();
let packed: Vec<u8> = #packed;
assert_eq!(packet.to_vec(), packed);
}
});
}
}
// TODO(mgeisler): make the generated code clean from warnings.
let code = quote! {
#![allow(warnings, missing_docs)]
use pdl_runtime::Packet;
use serde_json::json;
#(#tests)*
};
let syntax_tree = syn::parse2::<syn::File>(code).expect("Could not parse {code:#?}");
println!("{}", prettyplease::unparse(&syntax_tree));
}
fn main() {
let input_path = std::env::args().nth(1).expect("Need path to JSON file with test vectors");
let module_name = std::env::args().nth(2).expect("Need name for the generated module");
// TODO(mgeisler): remove the `packet_names` argument when we
// support all canonical packets.
generate_unit_tests(
&input_path,
&[
"EnumChild_A",
"EnumChild_B",
"Packet_Array_Field_ByteElement_ConstantSize",
"Packet_Array_Field_ByteElement_UnknownSize",
"Packet_Array_Field_ByteElement_VariableCount",
"Packet_Array_Field_ByteElement_VariableSize",
"Packet_Array_Field_EnumElement",
"Packet_Array_Field_EnumElement_ConstantSize",
"Packet_Array_Field_EnumElement_UnknownSize",
"Packet_Array_Field_EnumElement_VariableCount",
"Packet_Array_Field_EnumElement_VariableCount",
"Packet_Array_Field_ScalarElement",
"Packet_Array_Field_ScalarElement_ConstantSize",
"Packet_Array_Field_ScalarElement_UnknownSize",
"Packet_Array_Field_ScalarElement_VariableCount",
"Packet_Array_Field_ScalarElement_VariableSize",
"Packet_Array_Field_SizedElement_ConstantSize",
"Packet_Array_Field_SizedElement_UnknownSize",
"Packet_Array_Field_SizedElement_VariableCount",
"Packet_Array_Field_SizedElement_VariableSize",
"Packet_Array_Field_UnsizedElement_ConstantSize",
"Packet_Array_Field_UnsizedElement_UnknownSize",
"Packet_Array_Field_UnsizedElement_VariableCount",
"Packet_Array_Field_UnsizedElement_VariableSize",
"Packet_Array_Field_SizedElement_VariableSize_Padded",
"Packet_Array_Field_UnsizedElement_VariableCount_Padded",
"Packet_Optional_Scalar_Field",
"Packet_Optional_Enum_Field",
"Packet_Optional_Struct_Field",
"Packet_Body_Field_UnknownSize",
"Packet_Body_Field_UnknownSize_Terminal",
"Packet_Body_Field_VariableSize",
"Packet_Count_Field",
"Packet_Enum8_Field",
"Packet_Enum_Field",
"Packet_FixedEnum_Field",
"Packet_FixedScalar_Field",
"Packet_Payload_Field_UnknownSize",
"Packet_Payload_Field_UnknownSize_Terminal",
"Packet_Payload_Field_VariableSize",
"Packet_Reserved_Field",
"Packet_Scalar_Field",
"Packet_Size_Field",
"Packet_Struct_Field",
"ScalarChild_A",
"ScalarChild_B",
"Struct_Count_Field",
"Struct_Array_Field_ByteElement_ConstantSize",
"Struct_Array_Field_ByteElement_UnknownSize",
"Struct_Array_Field_ByteElement_UnknownSize",
"Struct_Array_Field_ByteElement_VariableCount",
"Struct_Array_Field_ByteElement_VariableCount",
"Struct_Array_Field_ByteElement_VariableSize",
"Struct_Array_Field_ByteElement_VariableSize",
"Struct_Array_Field_EnumElement_ConstantSize",
"Struct_Array_Field_EnumElement_UnknownSize",
"Struct_Array_Field_EnumElement_UnknownSize",
"Struct_Array_Field_EnumElement_VariableCount",
"Struct_Array_Field_EnumElement_VariableCount",
"Struct_Array_Field_EnumElement_VariableSize",
"Struct_Array_Field_EnumElement_VariableSize",
"Struct_Array_Field_ScalarElement_ConstantSize",
"Struct_Array_Field_ScalarElement_UnknownSize",
"Struct_Array_Field_ScalarElement_UnknownSize",
"Struct_Array_Field_ScalarElement_VariableCount",
"Struct_Array_Field_ScalarElement_VariableCount",
"Struct_Array_Field_ScalarElement_VariableSize",
"Struct_Array_Field_ScalarElement_VariableSize",
"Struct_Array_Field_SizedElement_ConstantSize",
"Struct_Array_Field_SizedElement_UnknownSize",
"Struct_Array_Field_SizedElement_UnknownSize",
"Struct_Array_Field_SizedElement_VariableCount",
"Struct_Array_Field_SizedElement_VariableCount",
"Struct_Array_Field_SizedElement_VariableSize",
"Struct_Array_Field_SizedElement_VariableSize",
"Struct_Array_Field_UnsizedElement_ConstantSize",
"Struct_Array_Field_UnsizedElement_UnknownSize",
"Struct_Array_Field_UnsizedElement_UnknownSize",
"Struct_Array_Field_UnsizedElement_VariableCount",
"Struct_Array_Field_UnsizedElement_VariableCount",
"Struct_Array_Field_UnsizedElement_VariableSize",
"Struct_Array_Field_UnsizedElement_VariableSize",
"Struct_Array_Field_SizedElement_VariableSize_Padded",
"Struct_Array_Field_UnsizedElement_VariableCount_Padded",
"Struct_Optional_Scalar_Field",
"Struct_Optional_Enum_Field",
"Struct_Optional_Struct_Field",
"Struct_Enum_Field",
"Struct_FixedEnum_Field",
"Struct_FixedScalar_Field",
"Struct_Size_Field",
"Struct_Struct_Field",
],
&module_name,
);
}