blob: d8ddd44a62188c16a89ee25ed198e2c1d25e93ca [file] [log] [blame]
Henri Chataing69626632023-05-12 15:28:29 +00001#!/usr/bin/env python3
2
3# Copyright 2023 Google LLC
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# https://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import argparse
18from dataclasses import dataclass, field
19import json
20from pathlib import Path
21import sys
22from textwrap import dedent
23from typing import List, Tuple, Union, Optional
24
25from pdl import ast, core
26from pdl.utils import indent, to_pascal_case
27
28
29def get_cxx_scalar_type(width: int) -> str:
30 """Return the cxx scalar type to be used to back a PDL type."""
31 for n in [8, 16, 32, 64]:
32 if width <= n:
33 return f'uint{n}_t'
34 # PDL type does not fit on non-extended scalar types.
35 assert False
36
37
38def generate_packet_parser_test(parser_test_suite: str, packet: ast.PacketDeclaration, tests: List[object]) -> str:
39 """Generate the implementation of unit tests for the selected packet."""
40
41 def parse_packet(packet: ast.PacketDeclaration) -> str:
42 parent = parse_packet(packet.parent) if packet.parent else "input"
43 return f"{packet.id}View::Create({parent})"
44
45 def input_bytes(input: str) -> List[str]:
46 input = bytes.fromhex(input)
47 input_bytes = []
48 for i in range(0, len(input), 16):
49 input_bytes.append(' '.join(f'0x{b:x},' for b in input[i:i + 16]))
50 return input_bytes
51
52 def get_field(decl: ast.Declaration, var: str, id: str) -> str:
53 if isinstance(decl, ast.StructDeclaration):
54 return f"{var}.{id}_"
55 else:
56 return f"{var}.Get{to_pascal_case(id)}()"
57
58 def check_members(decl: ast.Declaration, var: str, expected: object) -> List[str]:
59 checks = []
60 for (id, value) in expected.items():
61 field = core.get_packet_field(decl, id)
62 sanitized_var = var.replace('[', '_').replace(']', '')
63 field_var = f'{sanitized_var}_{id}'
64
Henri Chataing730f5d82023-10-31 01:50:29 +000065 if isinstance(field, ast.ScalarField) and field.cond:
66 value = f"std::make_optional({value})" if value is not None else "std::nullopt"
67 checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {value});")
68
69 elif isinstance(field, ast.ScalarField):
70 checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {value});")
71
72 elif (isinstance(field, ast.TypedefField) and
73 isinstance(field.type, ast.EnumDeclaration) and
74 field.cond):
75 value = f"std::make_optional({field.type_id}({value}))" if value is not None else "std::nullopt"
Henri Chataing69626632023-05-12 15:28:29 +000076 checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {value});")
77
78 elif (isinstance(field, ast.TypedefField) and
79 isinstance(field.type, (ast.EnumDeclaration, ast.CustomFieldDeclaration, ast.ChecksumDeclaration))):
80 checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, {field.type_id}({value}));")
81
Henri Chataing730f5d82023-10-31 01:50:29 +000082 elif isinstance(field, ast.TypedefField) and field.cond and value is None:
83 checks.append(f"ASSERT_TRUE(!{get_field(decl, var, id)}.has_value());")
84
85 elif isinstance(field, ast.TypedefField) and field.cond and value is not None:
86 checks.append(f"{field.type_id} const& {field_var} = {get_field(decl, var, id)}.value();")
87 checks.extend(check_members(field.type, field_var, value))
88
Henri Chataing69626632023-05-12 15:28:29 +000089 elif isinstance(field, ast.TypedefField):
90 checks.append(f"{field.type_id} const& {field_var} = {get_field(decl, var, id)};")
91 checks.extend(check_members(field.type, field_var, value))
92
93 elif isinstance(field, (ast.PayloadField, ast.BodyField)):
94 checks.append(f"std::vector<uint8_t> expected_{field_var} {{")
95 for i in range(0, len(value), 16):
96 checks.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + 16]]))
97 checks.append("};")
98 checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});")
99
Henri Chatainge6036682023-06-20 16:39:51 +0000100 elif isinstance(field, ast.ArrayField) and field.size and field.width:
101 checks.append(f"std::array<{get_cxx_scalar_type(field.width)}, {field.size}> expected_{field_var} {{")
102 step = int(16 * 8 / field.width)
103 for i in range(0, len(value), step):
104 checks.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]]))
105 checks.append("};")
106 checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});")
107
Henri Chataing69626632023-05-12 15:28:29 +0000108 elif isinstance(field, ast.ArrayField) and field.width:
109 checks.append(f"std::vector<{get_cxx_scalar_type(field.width)}> expected_{field_var} {{")
110 step = int(16 * 8 / field.width)
111 for i in range(0, len(value), step):
112 checks.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]]))
113 checks.append("};")
114 checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});")
115
Henri Chatainge6036682023-06-20 16:39:51 +0000116 elif (isinstance(field, ast.ArrayField) and field.size and isinstance(field.type, ast.EnumDeclaration)):
117 checks.append(f"std::array<{field.type_id}, {field.size}> expected_{field_var} {{")
118 for v in value:
119 checks.append(f" {field.type_id}({v}),")
120 checks.append("};")
121 checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});")
122
Henri Chataing69626632023-05-12 15:28:29 +0000123 elif (isinstance(field, ast.ArrayField) and isinstance(field.type, ast.EnumDeclaration)):
124 checks.append(f"std::vector<{field.type_id}> expected_{field_var} {{")
125 for v in value:
126 checks.append(f" {field.type_id}({v}),")
127 checks.append("};")
128 checks.append(f"ASSERT_EQ({get_field(decl, var, id)}, expected_{field_var});")
129
Henri Chatainge6036682023-06-20 16:39:51 +0000130 elif isinstance(field, ast.ArrayField) and field.size:
131 checks.append(f"std::array<{field.type_id}, {field.size}> {field_var} = {get_field(decl, var, id)};")
132 checks.append(f"ASSERT_EQ({field_var}.size(), {len(value)});")
133 for (n, value) in enumerate(value):
134 checks.extend(check_members(field.type, f"{field_var}[{n}]", value))
135
Henri Chataing69626632023-05-12 15:28:29 +0000136 elif isinstance(field, ast.ArrayField):
137 checks.append(f"std::vector<{field.type_id}> {field_var} = {get_field(decl, var, id)};")
138 checks.append(f"ASSERT_EQ({field_var}.size(), {len(value)});")
139 for (n, value) in enumerate(value):
140 checks.extend(check_members(field.type, f"{field_var}[{n}]", value))
141
142 else:
143 pass
144
145 return checks
146
147 generated_tests = []
148 for (test_nr, test) in enumerate(tests):
149 child_packet_id = test.get('packet', packet.id)
150 child_packet = packet.file.packet_scope[child_packet_id]
151
152 generated_tests.append(
153 dedent("""\
154
155 TEST_F({parser_test_suite}, {packet_id}_Case{test_nr}) {{
156 pdl::packet::slice input(std::shared_ptr<std::vector<uint8_t>>(new std::vector<uint8_t> {{
157 {input_bytes}
158 }}));
159 {child_packet_id}View packet = {parse_packet};
160 ASSERT_TRUE(packet.IsValid());
161 {checks}
162 }}
163 """).format(parser_test_suite=parser_test_suite,
164 packet_id=packet.id,
165 child_packet_id=child_packet_id,
166 test_nr=test_nr,
167 input_bytes=indent(input_bytes(test['packed']), 2),
168 parse_packet=parse_packet(child_packet),
169 checks=indent(check_members(packet, 'packet', test['unpacked']), 1)))
170
171 return ''.join(generated_tests)
172
173
174def generate_packet_serializer_test(serializer_test_suite: str, packet: ast.PacketDeclaration,
175 tests: List[object]) -> str:
176 """Generate the implementation of unit tests for the selected packet."""
177
178 def build_packet(decl: ast.Declaration, var: str, initializer: object) -> (str, List[str]):
179 fields = core.get_unconstrained_parent_fields(decl) + decl.fields
180 declarations = []
181 parameters = []
182 for field in fields:
183 sanitized_var = var.replace('[', '_').replace(']', '')
184 field_id = getattr(field, 'id', None)
185 field_var = f'{sanitized_var}_{field_id}'
186 value = initializer['payload'] if isinstance(field, (ast.PayloadField,
187 ast.BodyField)) else initializer.get(field_id, None)
188
Henri Chataing730f5d82023-10-31 01:50:29 +0000189 if field.cond_for:
190 pass
191
192 elif field.cond and value is None:
193 parameters.append("std::nullopt")
194
195 elif isinstance(field, ast.ScalarField) and field.cond:
196 parameters.append(f"std::make_optional({value})")
197
198 elif isinstance(field, ast.ScalarField):
Henri Chataing69626632023-05-12 15:28:29 +0000199 parameters.append(f"{value}")
200
Henri Chataing730f5d82023-10-31 01:50:29 +0000201 elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration) and field.cond:
202 parameters.append(f"std::make_optional({field.type_id}({value}))")
203
Henri Chataing69626632023-05-12 15:28:29 +0000204 elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration):
205 parameters.append(f"{field.type_id}({value})")
206
207 elif isinstance(field, ast.TypedefField):
208 (element, intermediate_declarations) = build_packet(field.type, field_var, value)
209 declarations.extend(intermediate_declarations)
210 parameters.append(element)
211
212 elif isinstance(field, (ast.PayloadField, ast.BodyField)):
213 declarations.append(f"std::vector<uint8_t> {field_var} {{")
214 for i in range(0, len(value), 16):
215 declarations.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + 16]]))
216 declarations.append("};")
217 parameters.append(f"std::move({field_var})")
218
Henri Chatainge6036682023-06-20 16:39:51 +0000219 elif isinstance(field, ast.ArrayField) and field.size and field.width:
220 declarations.append(f"std::array<{get_cxx_scalar_type(field.width)}, {field.size}> {field_var} {{")
221 step = int(16 * 8 / field.width)
222 for i in range(0, len(value), step):
223 declarations.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]]))
224 declarations.append("};")
225 parameters.append(f"std::move({field_var})")
226
Henri Chataing69626632023-05-12 15:28:29 +0000227 elif isinstance(field, ast.ArrayField) and field.width:
228 declarations.append(f"std::vector<{get_cxx_scalar_type(field.width)}> {field_var} {{")
229 step = int(16 * 8 / field.width)
230 for i in range(0, len(value), step):
231 declarations.append(' ' + ' '.join([f"0x{v:x}," for v in value[i:i + step]]))
232 declarations.append("};")
233 parameters.append(f"std::move({field_var})")
234
Henri Chatainge6036682023-06-20 16:39:51 +0000235 elif isinstance(field, ast.ArrayField) and field.size and isinstance(field.type, ast.EnumDeclaration):
236 declarations.append(f"std::array<{field.type_id}, {field.size}> {field_var} {{")
237 for v in value:
238 declarations.append(f" {field.type_id}({v}),")
239 declarations.append("};")
240 parameters.append(f"std::move({field_var})")
241
Henri Chataing69626632023-05-12 15:28:29 +0000242 elif isinstance(field, ast.ArrayField) and isinstance(field.type, ast.EnumDeclaration):
243 declarations.append(f"std::vector<{field.type_id}> {field_var} {{")
244 for v in value:
245 declarations.append(f" {field.type_id}({v}),")
246 declarations.append("};")
247 parameters.append(f"std::move({field_var})")
248
Henri Chatainge6036682023-06-20 16:39:51 +0000249 elif isinstance(field, ast.ArrayField) and field.size:
250 elements = []
251 for (n, value) in enumerate(value):
252 (element, intermediate_declarations) = build_packet(field.type, f'{field_var}_{n}', value)
253 elements.append(element)
254 declarations.extend(intermediate_declarations)
255 declarations.append(f"std::array<{field.type_id}, {field.size}> {field_var} {{")
256 for element in elements:
257 declarations.append(f" {element},")
258 declarations.append("};")
259 parameters.append(f"std::move({field_var})")
260
Henri Chataing69626632023-05-12 15:28:29 +0000261 elif isinstance(field, ast.ArrayField):
262 elements = []
263 for (n, value) in enumerate(value):
264 (element, intermediate_declarations) = build_packet(field.type, f'{field_var}_{n}', value)
265 elements.append(element)
266 declarations.extend(intermediate_declarations)
267 declarations.append(f"std::vector<{field.type_id}> {field_var} {{")
268 for element in elements:
269 declarations.append(f" {element},")
270 declarations.append("};")
271 parameters.append(f"std::move({field_var})")
272
273 else:
274 pass
275
276 constructor_name = f'{decl.id}Builder' if isinstance(decl, ast.PacketDeclaration) else decl.id
277 return (f"{constructor_name}({', '.join(parameters)})", declarations)
278
279 def output_bytes(output: str) -> List[str]:
280 output = bytes.fromhex(output)
281 output_bytes = []
282 for i in range(0, len(output), 16):
283 output_bytes.append(' '.join(f'0x{b:x},' for b in output[i:i + 16]))
284 return output_bytes
285
286 generated_tests = []
287 for (test_nr, test) in enumerate(tests):
288 child_packet_id = test.get('packet', packet.id)
289 child_packet = packet.file.packet_scope[child_packet_id]
290
291 (built_packet, intermediate_declarations) = build_packet(child_packet, 'packet', test['unpacked'])
292 generated_tests.append(
293 dedent("""\
294
295 TEST_F({serializer_test_suite}, {packet_id}_Case{test_nr}) {{
296 std::vector<uint8_t> expected_output {{
297 {output_bytes}
298 }};
299 {intermediate_declarations}
300 {child_packet_id}Builder packet = {built_packet};
Henri Chataingd43572c2023-06-27 13:56:19 +0000301 ASSERT_EQ(packet.SerializeToBytes(), expected_output);
Henri Chataing69626632023-05-12 15:28:29 +0000302 }}
303 """).format(serializer_test_suite=serializer_test_suite,
304 packet_id=packet.id,
305 child_packet_id=child_packet_id,
306 test_nr=test_nr,
307 output_bytes=indent(output_bytes(test['packed']), 2),
308 built_packet=built_packet,
309 intermediate_declarations=indent(intermediate_declarations, 1)))
310
311 return ''.join(generated_tests)
312
313
314def run(input: argparse.FileType, output: argparse.FileType, test_vectors: argparse.FileType, include_header: List[str],
315 using_namespace: List[str], namespace: str, parser_test_suite: str, serializer_test_suite: str):
316
317 file = ast.File.from_json(json.load(input))
318 tests = json.load(test_vectors)
319 core.desugar(file)
320
321 include_header = '\n'.join([f'#include <{header}>' for header in include_header])
322 using_namespace = '\n'.join([f'using namespace {namespace};' for namespace in using_namespace])
323
324 skipped_tests = [
325 'Packet_Checksum_Field_FromStart',
326 'Packet_Checksum_Field_FromEnd',
327 'Struct_Checksum_Field_FromStart',
328 'Struct_Checksum_Field_FromEnd',
329 'PartialParent5',
330 'PartialParent12',
David Duarte866fc0d2024-05-29 16:56:47 +0000331 'Packet_Array_Field_VariableElementSize_ConstantSize',
332 'Packet_Array_Field_VariableElementSize_VariableSize',
333 'Packet_Array_Field_VariableElementSize_VariableCount',
334 'Packet_Array_Field_VariableElementSize_UnknownSize',
Henri Chataing69626632023-05-12 15:28:29 +0000335 ]
336
337 output.write(
338 dedent("""\
339 // File generated from {input_name} and {test_vectors_name}, with the command:
340 // {input_command}
341 // /!\\ Do not edit by hand
342
343 #include <cstdint>
344 #include <string>
345 #include <gtest/gtest.h>
346 #include <packet_runtime.h>
347
348 {include_header}
349 {using_namespace}
350
351 namespace {namespace} {{
352
353 class {parser_test_suite} : public testing::Test {{}};
354 class {serializer_test_suite} : public testing::Test {{}};
355 """).format(parser_test_suite=parser_test_suite,
356 serializer_test_suite=serializer_test_suite,
357 input_name=input.name,
358 input_command=' '.join(sys.argv),
359 test_vectors_name=test_vectors.name,
360 include_header=include_header,
361 using_namespace=using_namespace,
362 namespace=namespace))
363
364 for decl in file.declarations:
365 if decl.id in skipped_tests:
366 continue
367
368 if isinstance(decl, ast.PacketDeclaration):
369 matching_tests = [test['tests'] for test in tests if test['packet'] == decl.id]
370 matching_tests = [test for test_list in matching_tests for test in test_list]
371 if matching_tests:
372 output.write(generate_packet_parser_test(parser_test_suite, decl, matching_tests))
373 output.write(generate_packet_serializer_test(serializer_test_suite, decl, matching_tests))
374
375 output.write(f"}} // namespace {namespace}\n")
376
377
378def main() -> int:
379 """Generate cxx PDL backend."""
380 parser = argparse.ArgumentParser(description=__doc__)
381 parser.add_argument('--input', type=argparse.FileType('r'), default=sys.stdin, help='Input PDL-JSON source')
382 parser.add_argument('--output', type=argparse.FileType('w'), default=sys.stdout, help='Output C++ file')
383 parser.add_argument('--test-vectors', type=argparse.FileType('r'), required=True, help='Input PDL test file')
384 parser.add_argument('--namespace', type=str, default='pdl', help='Namespace of the generated file')
385 parser.add_argument('--parser-test-suite', type=str, default='ParserTest', help='Name of the parser test suite')
386 parser.add_argument('--serializer-test-suite',
387 type=str,
388 default='SerializerTest',
389 help='Name of the serializer test suite')
390 parser.add_argument('--include-header', type=str, default=[], action='append', help='Added include directives')
391 parser.add_argument('--using-namespace',
392 type=str,
393 default=[],
394 action='append',
395 help='Added using namespace statements')
396 return run(**vars(parser.parse_args()))
397
398
399if __name__ == '__main__':
400 sys.exit(main())