Import platform/external/rust/crates/pdl-compiler
Import version 0.1.1 of crate pdl-compiler
Cargo.lock is removed to make sure the emulator
build is working with corrosion.
Bug: 283153347
Test: m
Change-Id: I6d0d099c1548762b608e933b148f5b2f5261f5e2
diff --git a/scripts/generate_cxx_backend_tests.py b/scripts/generate_cxx_backend_tests.py
new file mode 100755
index 0000000..1f90600
--- /dev/null
+++ b/scripts/generate_cxx_backend_tests.py
@@ -0,0 +1,319 @@
+#!/usr/bin/env python3
+
+# Copyright 2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+from dataclasses import dataclass, field
+import json
+from pathlib import Path
+import sys
+from textwrap import dedent
+from typing import List, Tuple, Union, Optional
+
+from pdl import ast, core
+from pdl.utils import indent, to_pascal_case
+
+
+def get_cxx_scalar_type(width: int) -> str:
+ """Return the cxx scalar type to be used to back a PDL type."""
+ for n in [8, 16, 32, 64]:
+ if width <= n:
+ return f'uint{n}_t'
+ # PDL type does not fit on non-extended scalar types.
+ assert False
+
+
+def generate_packet_parser_test(parser_test_suite: str, packet: ast.PacketDeclaration, tests: List[object]) -> str:
+ """Generate the implementation of unit tests for the selected packet."""
+
+ def parse_packet(packet: ast.PacketDeclaration) -> str:
+ parent = parse_packet(packet.parent) if packet.parent else "input"
+ return f"{packet.id}View::Create({parent})"
+
+ def input_bytes(input: str) -> List[str]:
+ input = bytes.fromhex(input)
+ input_bytes = []
+ for i in range(0, len(input), 16):
+ input_bytes.append(' '.join(f'0x{b:x},' for b in input[i:i + 16]))
+ return input_bytes
+
+ def get_field(decl: ast.Declaration, var: str, id: str) -> str:
+ if isinstance(decl, ast.StructDeclaration):
+ return f"{var}.{id}_"
+ else:
+ return f"{var}.Get{to_pascal_case(id)}()"
+
+ def check_members(decl: ast.Declaration, var: str, expected: object) -> List[str]:
+ checks = []
+ for (id, value) in expected.items():
+ field = core.get_packet_field(decl, id)
+ sanitized_var = var.replace('[', '_').replace(']', '')
+ field_var = f'{sanitized_var}_{id}'
+
+ if isinstance(field, ast.ScalarField):
+ checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {value});")
+
+ elif (isinstance(field, ast.TypedefField) and
+ isinstance(field.type, (ast.EnumDeclaration, ast.CustomFieldDeclaration, ast.ChecksumDeclaration))):
+ checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {field.type_id}({value}));")
+
+ elif isinstance(field, ast.TypedefField):
+ checks.append(f"{field.type_id} const& {field_var} = {get_field(decl, var, id)};")
+ checks.extend(check_members(field.type, field_var, value))
+
+ elif isinstance(field, (ast.PayloadField, ast.BodyField)):
+ checks.append(f"std::vector<uint8_t> expected_{field_var} {{")
+ for i in range(0, len(value), 16):
+ checks.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + 16]]))
+ checks.append("};")
+ checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});")
+
+ elif isinstance(field, ast.ArrayField) and field.width:
+ checks.append(f"std::vector<{get_cxx_scalar_type(field.width)}> expected_{field_var} {{")
+ step = int(16 * 8 / field.width)
+ for i in range(0, len(value), step):
+ checks.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]]))
+ checks.append("};")
+ checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});")
+
+ elif (isinstance(field, ast.ArrayField) and isinstance(field.type, ast.EnumDeclaration)):
+ checks.append(f"std::vector<{field.type_id}> expected_{field_var} {{")
+ for v in value:
+ checks.append(f" {field.type_id}({v}),")
+ checks.append("};")
+ checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});")
+
+ elif isinstance(field, ast.ArrayField):
+ checks.append(f"std::vector<{field.type_id}> {field_var} = {get_field(decl, var, id)};")
+ checks.append(f"ASSERT_EQ({field_var}.size(), {len(value)});")
+ for (n, value) in enumerate(value):
+ checks.extend(check_members(field.type, f"{field_var}[{n}]", value))
+
+ else:
+ pass
+
+ return checks
+
+ generated_tests = []
+ for (test_nr, test) in enumerate(tests):
+ child_packet_id = test.get('packet', packet.id)
+ child_packet = packet.file.packet_scope[child_packet_id]
+
+ generated_tests.append(
+ dedent("""\
+
+ TEST_F({parser_test_suite}, {packet_id}_Case{test_nr}) {{
+ pdl::packet::slice input(std::shared_ptr<std::vector<uint8_t>>(new std::vector<uint8_t> {{
+ {input_bytes}
+ }}));
+ {child_packet_id}View packet = {parse_packet};
+ ASSERT_TRUE(packet.IsValid());
+ {checks}
+ }}
+ """).format(parser_test_suite=parser_test_suite,
+ packet_id=packet.id,
+ child_packet_id=child_packet_id,
+ test_nr=test_nr,
+ input_bytes=indent(input_bytes(test['packed']), 2),
+ parse_packet=parse_packet(child_packet),
+ checks=indent(check_members(packet, 'packet', test['unpacked']), 1)))
+
+ return ''.join(generated_tests)
+
+
+def generate_packet_serializer_test(serializer_test_suite: str, packet: ast.PacketDeclaration,
+ tests: List[object]) -> str:
+ """Generate the implementation of unit tests for the selected packet."""
+
+ def build_packet(decl: ast.Declaration, var: str, initializer: object) -> (str, List[str]):
+ fields = core.get_unconstrained_parent_fields(decl) + decl.fields
+ declarations = []
+ parameters = []
+ for field in fields:
+ sanitized_var = var.replace('[', '_').replace(']', '')
+ field_id = getattr(field, 'id', None)
+ field_var = f'{sanitized_var}_{field_id}'
+ value = initializer['payload'] if isinstance(field, (ast.PayloadField,
+ ast.BodyField)) else initializer.get(field_id, None)
+
+ if isinstance(field, ast.ScalarField):
+ parameters.append(f"{value}")
+
+ elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration):
+ parameters.append(f"{field.type_id}({value})")
+
+ elif isinstance(field, ast.TypedefField):
+ (element, intermediate_declarations) = build_packet(field.type, field_var, value)
+ declarations.extend(intermediate_declarations)
+ parameters.append(element)
+
+ elif isinstance(field, (ast.PayloadField, ast.BodyField)):
+ declarations.append(f"std::vector<uint8_t> {field_var} {{")
+ for i in range(0, len(value), 16):
+ declarations.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + 16]]))
+ declarations.append("};")
+ parameters.append(f"std::move({field_var})")
+
+ elif isinstance(field, ast.ArrayField) and field.width:
+ declarations.append(f"std::vector<{get_cxx_scalar_type(field.width)}> {field_var} {{")
+ step = int(16 * 8 / field.width)
+ for i in range(0, len(value), step):
+ declarations.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]]))
+ declarations.append("};")
+ parameters.append(f"std::move({field_var})")
+
+ elif isinstance(field, ast.ArrayField) and isinstance(field.type, ast.EnumDeclaration):
+ declarations.append(f"std::vector<{field.type_id}> {field_var} {{")
+ for v in value:
+ declarations.append(f" {field.type_id}({v}),")
+ declarations.append("};")
+ parameters.append(f"std::move({field_var})")
+
+ elif isinstance(field, ast.ArrayField):
+ elements = []
+ for (n, value) in enumerate(value):
+ (element, intermediate_declarations) = build_packet(field.type, f'{field_var}_{n}', value)
+ elements.append(element)
+ declarations.extend(intermediate_declarations)
+ declarations.append(f"std::vector<{field.type_id}> {field_var} {{")
+ for element in elements:
+ declarations.append(f" {element},")
+ declarations.append("};")
+ parameters.append(f"std::move({field_var})")
+
+ else:
+ pass
+
+ constructor_name = f'{decl.id}Builder' if isinstance(decl, ast.PacketDeclaration) else decl.id
+ return (f"{constructor_name}({', '.join(parameters)})", declarations)
+
+ def output_bytes(output: str) -> List[str]:
+ output = bytes.fromhex(output)
+ output_bytes = []
+ for i in range(0, len(output), 16):
+ output_bytes.append(' '.join(f'0x{b:x},' for b in output[i:i + 16]))
+ return output_bytes
+
+ generated_tests = []
+ for (test_nr, test) in enumerate(tests):
+ child_packet_id = test.get('packet', packet.id)
+ child_packet = packet.file.packet_scope[child_packet_id]
+
+ (built_packet, intermediate_declarations) = build_packet(child_packet, 'packet', test['unpacked'])
+ generated_tests.append(
+ dedent("""\
+
+ TEST_F({serializer_test_suite}, {packet_id}_Case{test_nr}) {{
+ std::vector<uint8_t> expected_output {{
+ {output_bytes}
+ }};
+ {intermediate_declarations}
+ {child_packet_id}Builder packet = {built_packet};
+ ASSERT_EQ(packet.pdl::packet::Builder::Serialize(), expected_output);
+ }}
+ """).format(serializer_test_suite=serializer_test_suite,
+ packet_id=packet.id,
+ child_packet_id=child_packet_id,
+ test_nr=test_nr,
+ output_bytes=indent(output_bytes(test['packed']), 2),
+ built_packet=built_packet,
+ intermediate_declarations=indent(intermediate_declarations, 1)))
+
+ return ''.join(generated_tests)
+
+
+def run(input: argparse.FileType, output: argparse.FileType, test_vectors: argparse.FileType, include_header: List[str],
+ using_namespace: List[str], namespace: str, parser_test_suite: str, serializer_test_suite: str):
+
+ file = ast.File.from_json(json.load(input))
+ tests = json.load(test_vectors)
+ core.desugar(file)
+
+ include_header = '\n'.join([f'#include <{header}>' for header in include_header])
+ using_namespace = '\n'.join([f'using namespace {namespace};' for namespace in using_namespace])
+
+ skipped_tests = [
+ 'Packet_Checksum_Field_FromStart',
+ 'Packet_Checksum_Field_FromEnd',
+ 'Struct_Checksum_Field_FromStart',
+ 'Struct_Checksum_Field_FromEnd',
+ 'PartialParent5',
+ 'PartialParent12',
+ ]
+
+ output.write(
+ dedent("""\
+ // File generated from {input_name} and {test_vectors_name}, with the command:
+ // {input_command}
+ // /!\\ Do not edit by hand
+
+ #include <cstdint>
+ #include <string>
+ #include <gtest/gtest.h>
+ #include <packet_runtime.h>
+
+ {include_header}
+ {using_namespace}
+
+ namespace {namespace} {{
+
+ class {parser_test_suite} : public testing::Test {{}};
+ class {serializer_test_suite} : public testing::Test {{}};
+ """).format(parser_test_suite=parser_test_suite,
+ serializer_test_suite=serializer_test_suite,
+ input_name=input.name,
+ input_command=' '.join(sys.argv),
+ test_vectors_name=test_vectors.name,
+ include_header=include_header,
+ using_namespace=using_namespace,
+ namespace=namespace))
+
+ for decl in file.declarations:
+ if decl.id in skipped_tests:
+ continue
+
+ if isinstance(decl, ast.PacketDeclaration):
+ matching_tests = [test['tests'] for test in tests if test['packet'] == decl.id]
+ matching_tests = [test for test_list in matching_tests for test in test_list]
+ if matching_tests:
+ output.write(generate_packet_parser_test(parser_test_suite, decl, matching_tests))
+ output.write(generate_packet_serializer_test(serializer_test_suite, decl, matching_tests))
+
+ output.write(f"}} // namespace {namespace}\n")
+
+
+def main() -> int:
+ """Generate cxx PDL backend."""
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument('--input', type=argparse.FileType('r'), default=sys.stdin, help='Input PDL-JSON source')
+ parser.add_argument('--output', type=argparse.FileType('w'), default=sys.stdout, help='Output C++ file')
+ parser.add_argument('--test-vectors', type=argparse.FileType('r'), required=True, help='Input PDL test file')
+ parser.add_argument('--namespace', type=str, default='pdl', help='Namespace of the generated file')
+ parser.add_argument('--parser-test-suite', type=str, default='ParserTest', help='Name of the parser test suite')
+ parser.add_argument('--serializer-test-suite',
+ type=str,
+ default='SerializerTest',
+ help='Name of the serializer test suite')
+ parser.add_argument('--include-header', type=str, default=[], action='append', help='Added include directives')
+ parser.add_argument('--using-namespace',
+ type=str,
+ default=[],
+ action='append',
+ help='Added using namespace statements')
+ return run(**vars(parser.parse_args()))
+
+
+if __name__ == '__main__':
+ sys.exit(main())