Henri Chataing | 6962663 | 2023-05-12 15:28:29 +0000 | [diff] [blame] | 1 | #!/usr/bin/env python3 |
| 2 | |
| 3 | # Copyright 2023 Google LLC |
| 4 | # |
| 5 | # Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | # you may not use this file except in compliance with the License. |
| 7 | # You may obtain a copy of the License at |
| 8 | # |
| 9 | # https://www.apache.org/licenses/LICENSE-2.0 |
| 10 | # |
| 11 | # Unless required by applicable law or agreed to in writing, software |
| 12 | # distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | # See the License for the specific language governing permissions and |
| 15 | # limitations under the License. |
| 16 | |
| 17 | import argparse |
| 18 | from dataclasses import dataclass, field |
| 19 | import json |
| 20 | from pathlib import Path |
| 21 | import sys |
| 22 | from textwrap import dedent |
| 23 | from typing import List, Tuple, Union, Optional |
| 24 | |
| 25 | from pdl import ast, core |
| 26 | from pdl.utils import indent, to_pascal_case |
| 27 | |
| 28 | |
| 29 | def get_cxx_scalar_type(width: int) -> str: |
| 30 | """Return the cxx scalar type to be used to back a PDL type.""" |
| 31 | for n in [8, 16, 32, 64]: |
| 32 | if width <= n: |
| 33 | return f'uint{n}_t' |
| 34 | # PDL type does not fit on non-extended scalar types. |
| 35 | assert False |
| 36 | |
| 37 | |
| 38 | def generate_packet_parser_test(parser_test_suite: str, packet: ast.PacketDeclaration, tests: List[object]) -> str: |
| 39 | """Generate the implementation of unit tests for the selected packet.""" |
| 40 | |
| 41 | def parse_packet(packet: ast.PacketDeclaration) -> str: |
| 42 | parent = parse_packet(packet.parent) if packet.parent else "input" |
| 43 | return f"{packet.id}View::Create({parent})" |
| 44 | |
| 45 | def input_bytes(input: str) -> List[str]: |
| 46 | input = bytes.fromhex(input) |
| 47 | input_bytes = [] |
| 48 | for i in range(0, len(input), 16): |
| 49 | input_bytes.append(' '.join(f'0x{b:x},' for b in input[i:i + 16])) |
| 50 | return input_bytes |
| 51 | |
| 52 | def get_field(decl: ast.Declaration, var: str, id: str) -> str: |
| 53 | if isinstance(decl, ast.StructDeclaration): |
| 54 | return f"{var}.{id}_" |
| 55 | else: |
| 56 | return f"{var}.Get{to_pascal_case(id)}()" |
| 57 | |
| 58 | def check_members(decl: ast.Declaration, var: str, expected: object) -> List[str]: |
| 59 | checks = [] |
| 60 | for (id, value) in expected.items(): |
| 61 | field = core.get_packet_field(decl, id) |
| 62 | sanitized_var = var.replace('[', '_').replace(']', '') |
| 63 | field_var = f'{sanitized_var}_{id}' |
| 64 | |
Henri Chataing | 730f5d8 | 2023-10-31 01:50:29 +0000 | [diff] [blame] | 65 | if isinstance(field, ast.ScalarField) and field.cond: |
| 66 | value = f"std::make_optional({value})" if value is not None else "std::nullopt" |
| 67 | checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {value});") |
| 68 | |
| 69 | elif isinstance(field, ast.ScalarField): |
| 70 | checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {value});") |
| 71 | |
| 72 | elif (isinstance(field, ast.TypedefField) and |
| 73 | isinstance(field.type, ast.EnumDeclaration) and |
| 74 | field.cond): |
| 75 | value = f"std::make_optional({field.type_id}({value}))" if value is not None else "std::nullopt" |
Henri Chataing | 6962663 | 2023-05-12 15:28:29 +0000 | [diff] [blame] | 76 | checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {value});") |
| 77 | |
| 78 | elif (isinstance(field, ast.TypedefField) and |
| 79 | isinstance(field.type, (ast.EnumDeclaration, ast.CustomFieldDeclaration, ast.ChecksumDeclaration))): |
| 80 | checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {field.type_id}({value}));") |
| 81 | |
Henri Chataing | 730f5d8 | 2023-10-31 01:50:29 +0000 | [diff] [blame] | 82 | elif isinstance(field, ast.TypedefField) and field.cond and value is None: |
| 83 | checks.append(f"ASSERT_TRUE(!{get_field(decl, var, id)}.has_value());") |
| 84 | |
| 85 | elif isinstance(field, ast.TypedefField) and field.cond and value is not None: |
| 86 | checks.append(f"{field.type_id} const& {field_var} = {get_field(decl, var, id)}.value();") |
| 87 | checks.extend(check_members(field.type, field_var, value)) |
| 88 | |
Henri Chataing | 6962663 | 2023-05-12 15:28:29 +0000 | [diff] [blame] | 89 | elif isinstance(field, ast.TypedefField): |
| 90 | checks.append(f"{field.type_id} const& {field_var} = {get_field(decl, var, id)};") |
| 91 | checks.extend(check_members(field.type, field_var, value)) |
| 92 | |
| 93 | elif isinstance(field, (ast.PayloadField, ast.BodyField)): |
| 94 | checks.append(f"std::vector<uint8_t> expected_{field_var} {{") |
| 95 | for i in range(0, len(value), 16): |
| 96 | checks.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + 16]])) |
| 97 | checks.append("};") |
| 98 | checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});") |
| 99 | |
Henri Chataing | e603668 | 2023-06-20 16:39:51 +0000 | [diff] [blame] | 100 | elif isinstance(field, ast.ArrayField) and field.size and field.width: |
| 101 | checks.append(f"std::array<{get_cxx_scalar_type(field.width)}, {field.size}> expected_{field_var} {{") |
| 102 | step = int(16 * 8 / field.width) |
| 103 | for i in range(0, len(value), step): |
| 104 | checks.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]])) |
| 105 | checks.append("};") |
| 106 | checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});") |
| 107 | |
Henri Chataing | 6962663 | 2023-05-12 15:28:29 +0000 | [diff] [blame] | 108 | elif isinstance(field, ast.ArrayField) and field.width: |
| 109 | checks.append(f"std::vector<{get_cxx_scalar_type(field.width)}> expected_{field_var} {{") |
| 110 | step = int(16 * 8 / field.width) |
| 111 | for i in range(0, len(value), step): |
| 112 | checks.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]])) |
| 113 | checks.append("};") |
| 114 | checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});") |
| 115 | |
Henri Chataing | e603668 | 2023-06-20 16:39:51 +0000 | [diff] [blame] | 116 | elif (isinstance(field, ast.ArrayField) and field.size and isinstance(field.type, ast.EnumDeclaration)): |
| 117 | checks.append(f"std::array<{field.type_id}, {field.size}> expected_{field_var} {{") |
| 118 | for v in value: |
| 119 | checks.append(f" {field.type_id}({v}),") |
| 120 | checks.append("};") |
| 121 | checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});") |
| 122 | |
Henri Chataing | 6962663 | 2023-05-12 15:28:29 +0000 | [diff] [blame] | 123 | elif (isinstance(field, ast.ArrayField) and isinstance(field.type, ast.EnumDeclaration)): |
| 124 | checks.append(f"std::vector<{field.type_id}> expected_{field_var} {{") |
| 125 | for v in value: |
| 126 | checks.append(f" {field.type_id}({v}),") |
| 127 | checks.append("};") |
| 128 | checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});") |
| 129 | |
Henri Chataing | e603668 | 2023-06-20 16:39:51 +0000 | [diff] [blame] | 130 | elif isinstance(field, ast.ArrayField) and field.size: |
| 131 | checks.append(f"std::array<{field.type_id}, {field.size}> {field_var} = {get_field(decl, var, id)};") |
| 132 | checks.append(f"ASSERT_EQ({field_var}.size(), {len(value)});") |
| 133 | for (n, value) in enumerate(value): |
| 134 | checks.extend(check_members(field.type, f"{field_var}[{n}]", value)) |
| 135 | |
Henri Chataing | 6962663 | 2023-05-12 15:28:29 +0000 | [diff] [blame] | 136 | elif isinstance(field, ast.ArrayField): |
| 137 | checks.append(f"std::vector<{field.type_id}> {field_var} = {get_field(decl, var, id)};") |
| 138 | checks.append(f"ASSERT_EQ({field_var}.size(), {len(value)});") |
| 139 | for (n, value) in enumerate(value): |
| 140 | checks.extend(check_members(field.type, f"{field_var}[{n}]", value)) |
| 141 | |
| 142 | else: |
| 143 | pass |
| 144 | |
| 145 | return checks |
| 146 | |
| 147 | generated_tests = [] |
| 148 | for (test_nr, test) in enumerate(tests): |
| 149 | child_packet_id = test.get('packet', packet.id) |
| 150 | child_packet = packet.file.packet_scope[child_packet_id] |
| 151 | |
| 152 | generated_tests.append( |
| 153 | dedent("""\ |
| 154 | |
| 155 | TEST_F({parser_test_suite}, {packet_id}_Case{test_nr}) {{ |
| 156 | pdl::packet::slice input(std::shared_ptr<std::vector<uint8_t>>(new std::vector<uint8_t> {{ |
| 157 | {input_bytes} |
| 158 | }})); |
| 159 | {child_packet_id}View packet = {parse_packet}; |
| 160 | ASSERT_TRUE(packet.IsValid()); |
| 161 | {checks} |
| 162 | }} |
| 163 | """).format(parser_test_suite=parser_test_suite, |
| 164 | packet_id=packet.id, |
| 165 | child_packet_id=child_packet_id, |
| 166 | test_nr=test_nr, |
| 167 | input_bytes=indent(input_bytes(test['packed']), 2), |
| 168 | parse_packet=parse_packet(child_packet), |
| 169 | checks=indent(check_members(packet, 'packet', test['unpacked']), 1))) |
| 170 | |
| 171 | return ''.join(generated_tests) |
| 172 | |
| 173 | |
| 174 | def generate_packet_serializer_test(serializer_test_suite: str, packet: ast.PacketDeclaration, |
| 175 | tests: List[object]) -> str: |
| 176 | """Generate the implementation of unit tests for the selected packet.""" |
| 177 | |
| 178 | def build_packet(decl: ast.Declaration, var: str, initializer: object) -> (str, List[str]): |
| 179 | fields = core.get_unconstrained_parent_fields(decl) + decl.fields |
| 180 | declarations = [] |
| 181 | parameters = [] |
| 182 | for field in fields: |
| 183 | sanitized_var = var.replace('[', '_').replace(']', '') |
| 184 | field_id = getattr(field, 'id', None) |
| 185 | field_var = f'{sanitized_var}_{field_id}' |
| 186 | value = initializer['payload'] if isinstance(field, (ast.PayloadField, |
| 187 | ast.BodyField)) else initializer.get(field_id, None) |
| 188 | |
Henri Chataing | 730f5d8 | 2023-10-31 01:50:29 +0000 | [diff] [blame] | 189 | if field.cond_for: |
| 190 | pass |
| 191 | |
| 192 | elif field.cond and value is None: |
| 193 | parameters.append("std::nullopt") |
| 194 | |
| 195 | elif isinstance(field, ast.ScalarField) and field.cond: |
| 196 | parameters.append(f"std::make_optional({value})") |
| 197 | |
| 198 | elif isinstance(field, ast.ScalarField): |
Henri Chataing | 6962663 | 2023-05-12 15:28:29 +0000 | [diff] [blame] | 199 | parameters.append(f"{value}") |
| 200 | |
Henri Chataing | 730f5d8 | 2023-10-31 01:50:29 +0000 | [diff] [blame] | 201 | elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration) and field.cond: |
| 202 | parameters.append(f"std::make_optional({field.type_id}({value}))") |
| 203 | |
Henri Chataing | 6962663 | 2023-05-12 15:28:29 +0000 | [diff] [blame] | 204 | elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration): |
| 205 | parameters.append(f"{field.type_id}({value})") |
| 206 | |
| 207 | elif isinstance(field, ast.TypedefField): |
| 208 | (element, intermediate_declarations) = build_packet(field.type, field_var, value) |
| 209 | declarations.extend(intermediate_declarations) |
| 210 | parameters.append(element) |
| 211 | |
| 212 | elif isinstance(field, (ast.PayloadField, ast.BodyField)): |
| 213 | declarations.append(f"std::vector<uint8_t> {field_var} {{") |
| 214 | for i in range(0, len(value), 16): |
| 215 | declarations.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + 16]])) |
| 216 | declarations.append("};") |
| 217 | parameters.append(f"std::move({field_var})") |
| 218 | |
Henri Chataing | e603668 | 2023-06-20 16:39:51 +0000 | [diff] [blame] | 219 | elif isinstance(field, ast.ArrayField) and field.size and field.width: |
| 220 | declarations.append(f"std::array<{get_cxx_scalar_type(field.width)}, {field.size}> {field_var} {{") |
| 221 | step = int(16 * 8 / field.width) |
| 222 | for i in range(0, len(value), step): |
| 223 | declarations.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]])) |
| 224 | declarations.append("};") |
| 225 | parameters.append(f"std::move({field_var})") |
| 226 | |
Henri Chataing | 6962663 | 2023-05-12 15:28:29 +0000 | [diff] [blame] | 227 | elif isinstance(field, ast.ArrayField) and field.width: |
| 228 | declarations.append(f"std::vector<{get_cxx_scalar_type(field.width)}> {field_var} {{") |
| 229 | step = int(16 * 8 / field.width) |
| 230 | for i in range(0, len(value), step): |
| 231 | declarations.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]])) |
| 232 | declarations.append("};") |
| 233 | parameters.append(f"std::move({field_var})") |
| 234 | |
Henri Chataing | e603668 | 2023-06-20 16:39:51 +0000 | [diff] [blame] | 235 | elif isinstance(field, ast.ArrayField) and field.size and isinstance(field.type, ast.EnumDeclaration): |
| 236 | declarations.append(f"std::array<{field.type_id}, {field.size}> {field_var} {{") |
| 237 | for v in value: |
| 238 | declarations.append(f" {field.type_id}({v}),") |
| 239 | declarations.append("};") |
| 240 | parameters.append(f"std::move({field_var})") |
| 241 | |
Henri Chataing | 6962663 | 2023-05-12 15:28:29 +0000 | [diff] [blame] | 242 | elif isinstance(field, ast.ArrayField) and isinstance(field.type, ast.EnumDeclaration): |
| 243 | declarations.append(f"std::vector<{field.type_id}> {field_var} {{") |
| 244 | for v in value: |
| 245 | declarations.append(f" {field.type_id}({v}),") |
| 246 | declarations.append("};") |
| 247 | parameters.append(f"std::move({field_var})") |
| 248 | |
Henri Chataing | e603668 | 2023-06-20 16:39:51 +0000 | [diff] [blame] | 249 | elif isinstance(field, ast.ArrayField) and field.size: |
| 250 | elements = [] |
| 251 | for (n, value) in enumerate(value): |
| 252 | (element, intermediate_declarations) = build_packet(field.type, f'{field_var}_{n}', value) |
| 253 | elements.append(element) |
| 254 | declarations.extend(intermediate_declarations) |
| 255 | declarations.append(f"std::array<{field.type_id}, {field.size}> {field_var} {{") |
| 256 | for element in elements: |
| 257 | declarations.append(f" {element},") |
| 258 | declarations.append("};") |
| 259 | parameters.append(f"std::move({field_var})") |
| 260 | |
Henri Chataing | 6962663 | 2023-05-12 15:28:29 +0000 | [diff] [blame] | 261 | elif isinstance(field, ast.ArrayField): |
| 262 | elements = [] |
| 263 | for (n, value) in enumerate(value): |
| 264 | (element, intermediate_declarations) = build_packet(field.type, f'{field_var}_{n}', value) |
| 265 | elements.append(element) |
| 266 | declarations.extend(intermediate_declarations) |
| 267 | declarations.append(f"std::vector<{field.type_id}> {field_var} {{") |
| 268 | for element in elements: |
| 269 | declarations.append(f" {element},") |
| 270 | declarations.append("};") |
| 271 | parameters.append(f"std::move({field_var})") |
| 272 | |
| 273 | else: |
| 274 | pass |
| 275 | |
| 276 | constructor_name = f'{decl.id}Builder' if isinstance(decl, ast.PacketDeclaration) else decl.id |
| 277 | return (f"{constructor_name}({', '.join(parameters)})", declarations) |
| 278 | |
| 279 | def output_bytes(output: str) -> List[str]: |
| 280 | output = bytes.fromhex(output) |
| 281 | output_bytes = [] |
| 282 | for i in range(0, len(output), 16): |
| 283 | output_bytes.append(' '.join(f'0x{b:x},' for b in output[i:i + 16])) |
| 284 | return output_bytes |
| 285 | |
| 286 | generated_tests = [] |
| 287 | for (test_nr, test) in enumerate(tests): |
| 288 | child_packet_id = test.get('packet', packet.id) |
| 289 | child_packet = packet.file.packet_scope[child_packet_id] |
| 290 | |
| 291 | (built_packet, intermediate_declarations) = build_packet(child_packet, 'packet', test['unpacked']) |
| 292 | generated_tests.append( |
| 293 | dedent("""\ |
| 294 | |
| 295 | TEST_F({serializer_test_suite}, {packet_id}_Case{test_nr}) {{ |
| 296 | std::vector<uint8_t> expected_output {{ |
| 297 | {output_bytes} |
| 298 | }}; |
| 299 | {intermediate_declarations} |
| 300 | {child_packet_id}Builder packet = {built_packet}; |
Henri Chataing | d43572c | 2023-06-27 13:56:19 +0000 | [diff] [blame] | 301 | ASSERT_EQ(packet.SerializeToBytes(), expected_output); |
Henri Chataing | 6962663 | 2023-05-12 15:28:29 +0000 | [diff] [blame] | 302 | }} |
| 303 | """).format(serializer_test_suite=serializer_test_suite, |
| 304 | packet_id=packet.id, |
| 305 | child_packet_id=child_packet_id, |
| 306 | test_nr=test_nr, |
| 307 | output_bytes=indent(output_bytes(test['packed']), 2), |
| 308 | built_packet=built_packet, |
| 309 | intermediate_declarations=indent(intermediate_declarations, 1))) |
| 310 | |
| 311 | return ''.join(generated_tests) |
| 312 | |
| 313 | |
| 314 | def run(input: argparse.FileType, output: argparse.FileType, test_vectors: argparse.FileType, include_header: List[str], |
| 315 | using_namespace: List[str], namespace: str, parser_test_suite: str, serializer_test_suite: str): |
| 316 | |
| 317 | file = ast.File.from_json(json.load(input)) |
| 318 | tests = json.load(test_vectors) |
| 319 | core.desugar(file) |
| 320 | |
| 321 | include_header = '\n'.join([f'#include <{header}>' for header in include_header]) |
| 322 | using_namespace = '\n'.join([f'using namespace {namespace};' for namespace in using_namespace]) |
| 323 | |
| 324 | skipped_tests = [ |
| 325 | 'Packet_Checksum_Field_FromStart', |
| 326 | 'Packet_Checksum_Field_FromEnd', |
| 327 | 'Struct_Checksum_Field_FromStart', |
| 328 | 'Struct_Checksum_Field_FromEnd', |
| 329 | 'PartialParent5', |
| 330 | 'PartialParent12', |
David Duarte | 866fc0d | 2024-05-29 16:56:47 +0000 | [diff] [blame] | 331 | 'Packet_Array_Field_VariableElementSize_ConstantSize', |
| 332 | 'Packet_Array_Field_VariableElementSize_VariableSize', |
| 333 | 'Packet_Array_Field_VariableElementSize_VariableCount', |
| 334 | 'Packet_Array_Field_VariableElementSize_UnknownSize', |
Henri Chataing | 6962663 | 2023-05-12 15:28:29 +0000 | [diff] [blame] | 335 | ] |
| 336 | |
| 337 | output.write( |
| 338 | dedent("""\ |
| 339 | // File generated from {input_name} and {test_vectors_name}, with the command: |
| 340 | // {input_command} |
| 341 | // /!\\ Do not edit by hand |
| 342 | |
| 343 | #include <cstdint> |
| 344 | #include <string> |
| 345 | #include <gtest/gtest.h> |
| 346 | #include <packet_runtime.h> |
| 347 | |
| 348 | {include_header} |
| 349 | {using_namespace} |
| 350 | |
| 351 | namespace {namespace} {{ |
| 352 | |
| 353 | class {parser_test_suite} : public testing::Test {{}}; |
| 354 | class {serializer_test_suite} : public testing::Test {{}}; |
| 355 | """).format(parser_test_suite=parser_test_suite, |
| 356 | serializer_test_suite=serializer_test_suite, |
| 357 | input_name=input.name, |
| 358 | input_command=' '.join(sys.argv), |
| 359 | test_vectors_name=test_vectors.name, |
| 360 | include_header=include_header, |
| 361 | using_namespace=using_namespace, |
| 362 | namespace=namespace)) |
| 363 | |
| 364 | for decl in file.declarations: |
| 365 | if decl.id in skipped_tests: |
| 366 | continue |
| 367 | |
| 368 | if isinstance(decl, ast.PacketDeclaration): |
| 369 | matching_tests = [test['tests'] for test in tests if test['packet'] == decl.id] |
| 370 | matching_tests = [test for test_list in matching_tests for test in test_list] |
| 371 | if matching_tests: |
| 372 | output.write(generate_packet_parser_test(parser_test_suite, decl, matching_tests)) |
| 373 | output.write(generate_packet_serializer_test(serializer_test_suite, decl, matching_tests)) |
| 374 | |
| 375 | output.write(f"}} // namespace {namespace}\n") |
| 376 | |
| 377 | |
| 378 | def main() -> int: |
| 379 | """Generate cxx PDL backend.""" |
| 380 | parser = argparse.ArgumentParser(description=__doc__) |
| 381 | parser.add_argument('--input', type=argparse.FileType('r'), default=sys.stdin, help='Input PDL-JSON source') |
| 382 | parser.add_argument('--output', type=argparse.FileType('w'), default=sys.stdout, help='Output C++ file') |
| 383 | parser.add_argument('--test-vectors', type=argparse.FileType('r'), required=True, help='Input PDL test file') |
| 384 | parser.add_argument('--namespace', type=str, default='pdl', help='Namespace of the generated file') |
| 385 | parser.add_argument('--parser-test-suite', type=str, default='ParserTest', help='Name of the parser test suite') |
| 386 | parser.add_argument('--serializer-test-suite', |
| 387 | type=str, |
| 388 | default='SerializerTest', |
| 389 | help='Name of the serializer test suite') |
| 390 | parser.add_argument('--include-header', type=str, default=[], action='append', help='Added include directives') |
| 391 | parser.add_argument('--using-namespace', |
| 392 | type=str, |
| 393 | default=[], |
| 394 | action='append', |
| 395 | help='Added using namespace statements') |
| 396 | return run(**vars(parser.parse_args())) |
| 397 | |
| 398 | |
| 399 | if __name__ == '__main__': |
| 400 | sys.exit(main()) |