| #!/usr/bin/env python3 |
| |
| # Copyright 2023 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # https://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| import argparse |
| from dataclasses import dataclass, field |
| import json |
| from pathlib import Path |
| import sys |
| from textwrap import dedent |
| from typing import List, Tuple, Union, Optional |
| |
| from pdl import ast, core |
| from pdl.utils import indent, to_pascal_case |
| |
| |
| def get_cxx_scalar_type(width: int) -> str: |
| """Return the cxx scalar type to be used to back a PDL type.""" |
| for n in [8, 16, 32, 64]: |
| if width <= n: |
| return f'uint{n}_t' |
| # PDL type does not fit on non-extended scalar types. |
| assert False |
| |
| |
| def generate_packet_parser_test(parser_test_suite: str, packet: ast.PacketDeclaration, tests: List[object]) -> str: |
| """Generate the implementation of unit tests for the selected packet.""" |
| |
| def parse_packet(packet: ast.PacketDeclaration) -> str: |
| parent = parse_packet(packet.parent) if packet.parent else "input" |
| return f"{packet.id}View::Create({parent})" |
| |
| def input_bytes(input: str) -> List[str]: |
| input = bytes.fromhex(input) |
| input_bytes = [] |
| for i in range(0, len(input), 16): |
| input_bytes.append(' '.join(f'0x{b:x},' for b in input[i:i + 16])) |
| return input_bytes |
| |
| def get_field(decl: ast.Declaration, var: str, id: str) -> str: |
| if isinstance(decl, ast.StructDeclaration): |
| return f"{var}.{id}_" |
| else: |
| return f"{var}.Get{to_pascal_case(id)}()" |
| |
| def check_members(decl: ast.Declaration, var: str, expected: object) -> List[str]: |
| checks = [] |
| for (id, value) in expected.items(): |
| field = core.get_packet_field(decl, id) |
| sanitized_var = var.replace('[', '_').replace(']', '') |
| field_var = f'{sanitized_var}_{id}' |
| |
| if isinstance(field, ast.ScalarField) and field.cond: |
| value = f"std::make_optional({value})" if value is not None else "std::nullopt" |
| checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {value});") |
| |
| elif isinstance(field, ast.ScalarField): |
| checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {value});") |
| |
| elif (isinstance(field, ast.TypedefField) and |
| isinstance(field.type, ast.EnumDeclaration) and |
| field.cond): |
| value = f"std::make_optional({field.type_id}({value}))" if value is not None else "std::nullopt" |
| checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {value});") |
| |
| elif (isinstance(field, ast.TypedefField) and |
| isinstance(field.type, (ast.EnumDeclaration, ast.CustomFieldDeclaration, ast.ChecksumDeclaration))): |
| checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {field.type_id}({value}));") |
| |
| elif isinstance(field, ast.TypedefField) and field.cond and value is None: |
| checks.append(f"ASSERT_TRUE(!{get_field(decl, var, id)}.has_value());") |
| |
| elif isinstance(field, ast.TypedefField) and field.cond and value is not None: |
| checks.append(f"{field.type_id} const& {field_var} = {get_field(decl, var, id)}.value();") |
| checks.extend(check_members(field.type, field_var, value)) |
| |
| elif isinstance(field, ast.TypedefField): |
| checks.append(f"{field.type_id} const& {field_var} = {get_field(decl, var, id)};") |
| checks.extend(check_members(field.type, field_var, value)) |
| |
| elif isinstance(field, (ast.PayloadField, ast.BodyField)): |
| checks.append(f"std::vector<uint8_t> expected_{field_var} {{") |
| for i in range(0, len(value), 16): |
| checks.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + 16]])) |
| checks.append("};") |
| checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});") |
| |
| elif isinstance(field, ast.ArrayField) and field.size and field.width: |
| checks.append(f"std::array<{get_cxx_scalar_type(field.width)}, {field.size}> expected_{field_var} {{") |
| step = int(16 * 8 / field.width) |
| for i in range(0, len(value), step): |
| checks.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]])) |
| checks.append("};") |
| checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});") |
| |
| elif isinstance(field, ast.ArrayField) and field.width: |
| checks.append(f"std::vector<{get_cxx_scalar_type(field.width)}> expected_{field_var} {{") |
| step = int(16 * 8 / field.width) |
| for i in range(0, len(value), step): |
| checks.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]])) |
| checks.append("};") |
| checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});") |
| |
| elif (isinstance(field, ast.ArrayField) and field.size and isinstance(field.type, ast.EnumDeclaration)): |
| checks.append(f"std::array<{field.type_id}, {field.size}> expected_{field_var} {{") |
| for v in value: |
| checks.append(f" {field.type_id}({v}),") |
| checks.append("};") |
| checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});") |
| |
| elif (isinstance(field, ast.ArrayField) and isinstance(field.type, ast.EnumDeclaration)): |
| checks.append(f"std::vector<{field.type_id}> expected_{field_var} {{") |
| for v in value: |
| checks.append(f" {field.type_id}({v}),") |
| checks.append("};") |
| checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});") |
| |
| elif isinstance(field, ast.ArrayField) and field.size: |
| checks.append(f"std::array<{field.type_id}, {field.size}> {field_var} = {get_field(decl, var, id)};") |
| checks.append(f"ASSERT_EQ({field_var}.size(), {len(value)});") |
| for (n, value) in enumerate(value): |
| checks.extend(check_members(field.type, f"{field_var}[{n}]", value)) |
| |
| elif isinstance(field, ast.ArrayField): |
| checks.append(f"std::vector<{field.type_id}> {field_var} = {get_field(decl, var, id)};") |
| checks.append(f"ASSERT_EQ({field_var}.size(), {len(value)});") |
| for (n, value) in enumerate(value): |
| checks.extend(check_members(field.type, f"{field_var}[{n}]", value)) |
| |
| else: |
| pass |
| |
| return checks |
| |
| generated_tests = [] |
| for (test_nr, test) in enumerate(tests): |
| child_packet_id = test.get('packet', packet.id) |
| child_packet = packet.file.packet_scope[child_packet_id] |
| |
| generated_tests.append( |
| dedent("""\ |
| |
| TEST_F({parser_test_suite}, {packet_id}_Case{test_nr}) {{ |
| pdl::packet::slice input(std::shared_ptr<std::vector<uint8_t>>(new std::vector<uint8_t> {{ |
| {input_bytes} |
| }})); |
| {child_packet_id}View packet = {parse_packet}; |
| ASSERT_TRUE(packet.IsValid()); |
| {checks} |
| }} |
| """).format(parser_test_suite=parser_test_suite, |
| packet_id=packet.id, |
| child_packet_id=child_packet_id, |
| test_nr=test_nr, |
| input_bytes=indent(input_bytes(test['packed']), 2), |
| parse_packet=parse_packet(child_packet), |
| checks=indent(check_members(packet, 'packet', test['unpacked']), 1))) |
| |
| return ''.join(generated_tests) |
| |
| |
| def generate_packet_serializer_test(serializer_test_suite: str, packet: ast.PacketDeclaration, |
| tests: List[object]) -> str: |
| """Generate the implementation of unit tests for the selected packet.""" |
| |
| def build_packet(decl: ast.Declaration, var: str, initializer: object) -> (str, List[str]): |
| fields = core.get_unconstrained_parent_fields(decl) + decl.fields |
| declarations = [] |
| parameters = [] |
| for field in fields: |
| sanitized_var = var.replace('[', '_').replace(']', '') |
| field_id = getattr(field, 'id', None) |
| field_var = f'{sanitized_var}_{field_id}' |
| value = initializer['payload'] if isinstance(field, (ast.PayloadField, |
| ast.BodyField)) else initializer.get(field_id, None) |
| |
| if field.cond_for: |
| pass |
| |
| elif field.cond and value is None: |
| parameters.append("std::nullopt") |
| |
| elif isinstance(field, ast.ScalarField) and field.cond: |
| parameters.append(f"std::make_optional({value})") |
| |
| elif isinstance(field, ast.ScalarField): |
| parameters.append(f"{value}") |
| |
| elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration) and field.cond: |
| parameters.append(f"std::make_optional({field.type_id}({value}))") |
| |
| elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration): |
| parameters.append(f"{field.type_id}({value})") |
| |
| elif isinstance(field, ast.TypedefField): |
| (element, intermediate_declarations) = build_packet(field.type, field_var, value) |
| declarations.extend(intermediate_declarations) |
| parameters.append(element) |
| |
| elif isinstance(field, (ast.PayloadField, ast.BodyField)): |
| declarations.append(f"std::vector<uint8_t> {field_var} {{") |
| for i in range(0, len(value), 16): |
| declarations.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + 16]])) |
| declarations.append("};") |
| parameters.append(f"std::move({field_var})") |
| |
| elif isinstance(field, ast.ArrayField) and field.size and field.width: |
| declarations.append(f"std::array<{get_cxx_scalar_type(field.width)}, {field.size}> {field_var} {{") |
| step = int(16 * 8 / field.width) |
| for i in range(0, len(value), step): |
| declarations.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]])) |
| declarations.append("};") |
| parameters.append(f"std::move({field_var})") |
| |
| elif isinstance(field, ast.ArrayField) and field.width: |
| declarations.append(f"std::vector<{get_cxx_scalar_type(field.width)}> {field_var} {{") |
| step = int(16 * 8 / field.width) |
| for i in range(0, len(value), step): |
| declarations.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]])) |
| declarations.append("};") |
| parameters.append(f"std::move({field_var})") |
| |
| elif isinstance(field, ast.ArrayField) and field.size and isinstance(field.type, ast.EnumDeclaration): |
| declarations.append(f"std::array<{field.type_id}, {field.size}> {field_var} {{") |
| for v in value: |
| declarations.append(f" {field.type_id}({v}),") |
| declarations.append("};") |
| parameters.append(f"std::move({field_var})") |
| |
| elif isinstance(field, ast.ArrayField) and isinstance(field.type, ast.EnumDeclaration): |
| declarations.append(f"std::vector<{field.type_id}> {field_var} {{") |
| for v in value: |
| declarations.append(f" {field.type_id}({v}),") |
| declarations.append("};") |
| parameters.append(f"std::move({field_var})") |
| |
| elif isinstance(field, ast.ArrayField) and field.size: |
| elements = [] |
| for (n, value) in enumerate(value): |
| (element, intermediate_declarations) = build_packet(field.type, f'{field_var}_{n}', value) |
| elements.append(element) |
| declarations.extend(intermediate_declarations) |
| declarations.append(f"std::array<{field.type_id}, {field.size}> {field_var} {{") |
| for element in elements: |
| declarations.append(f" {element},") |
| declarations.append("};") |
| parameters.append(f"std::move({field_var})") |
| |
| elif isinstance(field, ast.ArrayField): |
| elements = [] |
| for (n, value) in enumerate(value): |
| (element, intermediate_declarations) = build_packet(field.type, f'{field_var}_{n}', value) |
| elements.append(element) |
| declarations.extend(intermediate_declarations) |
| declarations.append(f"std::vector<{field.type_id}> {field_var} {{") |
| for element in elements: |
| declarations.append(f" {element},") |
| declarations.append("};") |
| parameters.append(f"std::move({field_var})") |
| |
| else: |
| pass |
| |
| constructor_name = f'{decl.id}Builder' if isinstance(decl, ast.PacketDeclaration) else decl.id |
| return (f"{constructor_name}({', '.join(parameters)})", declarations) |
| |
| def output_bytes(output: str) -> List[str]: |
| output = bytes.fromhex(output) |
| output_bytes = [] |
| for i in range(0, len(output), 16): |
| output_bytes.append(' '.join(f'0x{b:x},' for b in output[i:i + 16])) |
| return output_bytes |
| |
| generated_tests = [] |
| for (test_nr, test) in enumerate(tests): |
| child_packet_id = test.get('packet', packet.id) |
| child_packet = packet.file.packet_scope[child_packet_id] |
| |
| (built_packet, intermediate_declarations) = build_packet(child_packet, 'packet', test['unpacked']) |
| generated_tests.append( |
| dedent("""\ |
| |
| TEST_F({serializer_test_suite}, {packet_id}_Case{test_nr}) {{ |
| std::vector<uint8_t> expected_output {{ |
| {output_bytes} |
| }}; |
| {intermediate_declarations} |
| {child_packet_id}Builder packet = {built_packet}; |
| ASSERT_EQ(packet.SerializeToBytes(), expected_output); |
| }} |
| """).format(serializer_test_suite=serializer_test_suite, |
| packet_id=packet.id, |
| child_packet_id=child_packet_id, |
| test_nr=test_nr, |
| output_bytes=indent(output_bytes(test['packed']), 2), |
| built_packet=built_packet, |
| intermediate_declarations=indent(intermediate_declarations, 1))) |
| |
| return ''.join(generated_tests) |
| |
| |
| def run(input: argparse.FileType, output: argparse.FileType, test_vectors: argparse.FileType, include_header: List[str], |
| using_namespace: List[str], namespace: str, parser_test_suite: str, serializer_test_suite: str): |
| |
| file = ast.File.from_json(json.load(input)) |
| tests = json.load(test_vectors) |
| core.desugar(file) |
| |
| include_header = '\n'.join([f'#include <{header}>' for header in include_header]) |
| using_namespace = '\n'.join([f'using namespace {namespace};' for namespace in using_namespace]) |
| |
| skipped_tests = [ |
| 'Packet_Checksum_Field_FromStart', |
| 'Packet_Checksum_Field_FromEnd', |
| 'Struct_Checksum_Field_FromStart', |
| 'Struct_Checksum_Field_FromEnd', |
| 'PartialParent5', |
| 'PartialParent12', |
| 'Packet_Array_Field_VariableElementSize_ConstantSize', |
| 'Packet_Array_Field_VariableElementSize_VariableSize', |
| 'Packet_Array_Field_VariableElementSize_VariableCount', |
| 'Packet_Array_Field_VariableElementSize_UnknownSize', |
| ] |
| |
| output.write( |
| dedent("""\ |
| // File generated from {input_name} and {test_vectors_name}, with the command: |
| // {input_command} |
| // /!\\ Do not edit by hand |
| |
| #include <cstdint> |
| #include <string> |
| #include <gtest/gtest.h> |
| #include <packet_runtime.h> |
| |
| {include_header} |
| {using_namespace} |
| |
| namespace {namespace} {{ |
| |
| class {parser_test_suite} : public testing::Test {{}}; |
| class {serializer_test_suite} : public testing::Test {{}}; |
| """).format(parser_test_suite=parser_test_suite, |
| serializer_test_suite=serializer_test_suite, |
| input_name=input.name, |
| input_command=' '.join(sys.argv), |
| test_vectors_name=test_vectors.name, |
| include_header=include_header, |
| using_namespace=using_namespace, |
| namespace=namespace)) |
| |
| for decl in file.declarations: |
| if decl.id in skipped_tests: |
| continue |
| |
| if isinstance(decl, ast.PacketDeclaration): |
| matching_tests = [test['tests'] for test in tests if test['packet'] == decl.id] |
| matching_tests = [test for test_list in matching_tests for test in test_list] |
| if matching_tests: |
| output.write(generate_packet_parser_test(parser_test_suite, decl, matching_tests)) |
| output.write(generate_packet_serializer_test(serializer_test_suite, decl, matching_tests)) |
| |
| output.write(f"}} // namespace {namespace}\n") |
| |
| |
| def main() -> int: |
| """Generate cxx PDL backend.""" |
| parser = argparse.ArgumentParser(description=__doc__) |
| parser.add_argument('--input', type=argparse.FileType('r'), default=sys.stdin, help='Input PDL-JSON source') |
| parser.add_argument('--output', type=argparse.FileType('w'), default=sys.stdout, help='Output C++ file') |
| parser.add_argument('--test-vectors', type=argparse.FileType('r'), required=True, help='Input PDL test file') |
| parser.add_argument('--namespace', type=str, default='pdl', help='Namespace of the generated file') |
| parser.add_argument('--parser-test-suite', type=str, default='ParserTest', help='Name of the parser test suite') |
| parser.add_argument('--serializer-test-suite', |
| type=str, |
| default='SerializerTest', |
| help='Name of the serializer test suite') |
| parser.add_argument('--include-header', type=str, default=[], action='append', help='Added include directives') |
| parser.add_argument('--using-namespace', |
| type=str, |
| default=[], |
| action='append', |
| help='Added using namespace statements') |
| return run(**vars(parser.parse_args())) |
| |
| |
| if __name__ == '__main__': |
| sys.exit(main()) |