| #!/usr/bin/env python3 |
| |
| import argparse |
| import collections |
| import copy |
| import json |
| from pathlib import Path |
| import pprint |
| import traceback |
| from typing import Iterable, List, Optional, Union |
| import sys |
| |
| from pdl import ast, core |
| |
| MAX_ARRAY_SIZE = 256 |
| MAX_ARRAY_COUNT = 32 |
| DEFAULT_ARRAY_COUNT = 3 |
| DEFAULT_PAYLOAD_SIZE = 5 |
| |
| |
| class BitSerializer: |
| def __init__(self, big_endian: bool): |
| self.stream = [] |
| self.value = 0 |
| self.shift = 0 |
| self.byteorder = "big" if big_endian else "little" |
| |
| def append(self, value: int, width: int): |
| self.value = self.value | (value << self.shift) |
| self.shift += width |
| |
| if (self.shift % 8) == 0: |
| width = int(self.shift / 8) |
| self.stream.extend(self.value.to_bytes(width, byteorder=self.byteorder)) |
| self.shift = 0 |
| self.value = 0 |
| |
| |
| class Value: |
| def __init__(self, value: object, width: Optional[int] = None): |
| self.value = value |
| if width is not None: |
| self.width = width |
| elif isinstance(value, int) or callable(value): |
| raise Exception("Creating scalar value of unspecified width") |
| elif isinstance(value, list): |
| self.width = sum([v.width for v in value]) |
| elif isinstance(value, Packet): |
| self.width = value.width |
| else: |
| raise Exception(f"Malformed value {value}") |
| |
| def finalize(self, parent: "Packet"): |
| if callable(self.width): |
| self.width = self.width(parent) |
| |
| if callable(self.value): |
| self.value = self.value(parent) |
| elif isinstance(self.value, list): |
| for v in self.value: |
| v.finalize(parent) |
| elif isinstance(self.value, Packet): |
| self.value.finalize() |
| |
| def serialize_(self, serializer: BitSerializer): |
| if isinstance(self.value, int): |
| serializer.append(self.value, self.width) |
| elif isinstance(self.value, list): |
| for v in self.value: |
| v.serialize_(serializer) |
| elif isinstance(self.value, Packet): |
| self.value.serialize_(serializer) |
| elif self.value == None: |
| pass |
| else: |
| raise Exception(f"Malformed value {self.value}") |
| |
| def show(self, indent: int = 0): |
| space = " " * indent |
| if isinstance(self.value, int): |
| print(f"{space}{self.name}: {hex(self.value)}") |
| elif isinstance(self.value, list): |
| print(f"{space}{self.name}[{len(self.value)}]:") |
| for v in self.value: |
| v.show(indent + 2) |
| elif isinstance(self.value, Packet): |
| print(f"{space}{self.name}:") |
| self.value.show(indent + 2) |
| |
| def to_json(self) -> object: |
| if isinstance(self.value, int): |
| return self.value |
| elif isinstance(self.value, list): |
| return [v.to_json() for v in self.value] |
| elif isinstance(self.value, Packet): |
| return self.value.to_json() |
| |
| |
| class Field: |
| def __init__(self, value: Value, ref: ast.Field): |
| self.value = value |
| self.ref = ref |
| |
| def finalize(self, parent: "Packet"): |
| self.value.finalize(parent) |
| |
| def serialize_(self, serializer: BitSerializer): |
| self.value.serialize_(serializer) |
| |
| def clone(self): |
| return Field(copy.copy(self.value), self.ref) |
| |
| |
| class Packet: |
| def __init__(self, fields: List[Field], ref: ast.Declaration): |
| self.fields = fields |
| self.ref = ref |
| |
| def finalize(self, parent: Optional["Packet"] = None): |
| for f in self.fields: |
| f.finalize(self) |
| |
| def serialize_(self, serializer: BitSerializer): |
| for f in self.fields: |
| f.serialize_(serializer) |
| |
| def serialize(self, big_endian: bool) -> bytes: |
| serializer = BitSerializer(big_endian) |
| self.serialize_(serializer) |
| if serializer.shift != 0: |
| raise Exception("The packet size is not an integral number of octets") |
| return bytes(serializer.stream) |
| |
| def show(self, indent: int = 0): |
| for f in self.fields: |
| f.value.show(indent) |
| |
| def to_json(self) -> dict: |
| result = dict() |
| for f in self.fields: |
| if isinstance(f.ref, (ast.PayloadField, ast.BodyField)) and isinstance( |
| f.value.value, Packet |
| ): |
| result.update(f.value.to_json()) |
| elif isinstance(f.ref, (ast.PayloadField, ast.BodyField)): |
| result["payload"] = f.value.to_json() |
| elif hasattr(f.ref, "id"): |
| result[f.ref.id] = f.value.to_json() |
| return result |
| |
| @property |
| def width(self) -> int: |
| self.finalize() |
| return sum([f.value.width for f in self.fields]) |
| |
| |
| class BitGenerator: |
| def __init__(self): |
| self.value = 0 |
| self.shift = 0 |
| |
| def generate(self, width: int) -> Value: |
| """Generate an integer value of the selected width.""" |
| value = 0 |
| remains = width |
| while remains > 0: |
| w = min(8 - self.shift, remains) |
| mask = (1 << w) - 1 |
| value = (value << w) | ((self.value >> self.shift) & mask) |
| remains -= w |
| self.shift += w |
| if self.shift >= 8: |
| self.shift = 0 |
| self.value = (self.value + 1) % 0xFF |
| return Value(value, width) |
| |
| def generate_list(self, width: int, count: int) -> List[Value]: |
| return [self.generate(width) for n in range(count)] |
| |
| |
| generator = BitGenerator() |
| |
| |
| def generate_cond_field_values(field: ast.ScalarField) -> List[Value]: |
| cond_value_present = field.cond_for.cond.value |
| cond_value_absent = 0 if field.cond_for.cond.value != 0 else 1 |
| |
| def get_cond_value(parent: Packet, field: ast.Field) -> int: |
| for f in parent.fields: |
| if f.ref is field: |
| return cond_value_absent if f.value.value is None else cond_value_present |
| |
| return [Value(lambda p: get_cond_value(p, field.cond_for), field.width)] |
| |
| |
| def generate_size_field_values(field: ast.SizeField) -> List[Value]: |
| def get_field_size(parent: Packet, field_id: str) -> int: |
| for f in parent.fields: |
| if ( |
| (field_id == "_payload_" and isinstance(f.ref, ast.PayloadField)) |
| or (field_id == "_body_" and isinstance(f.ref, ast.BodyField)) |
| or (getattr(f.ref, "id", None) == field_id) |
| ): |
| assert f.value.width % 8 == 0 |
| size_modifier = int(getattr(f.ref, "size_modifier", None) or 0) |
| return int(f.value.width / 8) + size_modifier |
| raise Exception( |
| "Field {} not found in packet {}".format(field_id, parent.ref.id) |
| ) |
| |
| return [Value(lambda p: get_field_size(p, field.field_id), field.width)] |
| |
| |
| def generate_count_field_values(field: ast.CountField) -> List[Value]: |
| def get_array_count(parent: Packet, field_id: str) -> int: |
| for f in parent.fields: |
| if getattr(f.ref, "id", None) == field_id: |
| assert isinstance(f.value.value, list) |
| return len(f.value.value) |
| raise Exception( |
| "Field {} not found in packet {}".format(field_id, parent.ref.id) |
| ) |
| |
| return [Value(lambda p: get_array_count(p, field.field_id), field.width)] |
| |
| |
| def generate_checksum_field_values(field: ast.TypedefField) -> List[Value]: |
| field_width = core.get_field_size(field) |
| |
| def basic_checksum(input: bytes, width: int): |
| assert width == 8 |
| return sum(input) % 256 |
| |
| def compute_checksum(parent: Packet, field_id: str) -> int: |
| serializer = None |
| for f in parent.fields: |
| if isinstance(f.ref, ast.ChecksumField) and f.ref.field_id == field_id: |
| serializer = BitSerializer( |
| f.ref.parent.file.endianness.value == "big_endian" |
| ) |
| elif isinstance(f.ref, ast.TypedefField) and f.ref.id == field_id: |
| return basic_checksum(serializer.stream, field_width) |
| elif serializer: |
| f.value.serialize_(serializer) |
| raise Exception("malformed checksum") |
| |
| return [Value(lambda p: compute_checksum(p, field.id), field_width)] |
| |
| |
| def generate_padding_field_values(field: ast.PaddingField) -> List[Value]: |
| preceding_field_id = field.padded_field.id |
| |
| def get_padding(parent: Packet, field_id: str, width: int) -> List[Value]: |
| for f in parent.fields: |
| if ( |
| (field_id == "_payload_" and isinstance(f.ref, ast.PayloadField)) |
| or (field_id == "_body_" and isinstance(f.ref, ast.BodyField)) |
| or (getattr(f.ref, "id", None) == field_id) |
| ): |
| assert f.value.width % 8 == 0 |
| assert f.value.width <= width |
| return width - f.value.width |
| raise Exception( |
| "Field {} not found in packet {}".format(field_id, parent.ref.id) |
| ) |
| |
| return [Value(0, lambda p: get_padding(p, preceding_field_id, 8 * field.size))] |
| |
| |
| def generate_payload_field_values( |
| field: Union[ast.PayloadField, ast.BodyField] |
| ) -> List[Value]: |
| payload_size = core.get_payload_field_size(field) |
| size_modifier = int(getattr(field, "size_modifier", None) or 0) |
| |
| # If the paylaod has a size field, generate an empty payload and |
| # a payload of maximum size. If not generate a payload of the default size. |
| max_size = (1 << payload_size.width) - 1 if payload_size else DEFAULT_PAYLOAD_SIZE |
| max_size -= size_modifier |
| |
| assert max_size > 0 |
| return [Value([]), Value(generator.generate_list(8, max_size))] |
| |
| |
| def generate_scalar_array_field_values(field: ast.ArrayField) -> List[Value]: |
| if field.width % 8 != 0: |
| if element_width % 8 != 0: |
| raise Exception("Array element size is not a multiple of 8") |
| |
| array_size = core.get_array_field_size(field) |
| element_width = int(field.width / 8) |
| |
| # TODO |
| # The array might also be bounded if it is included in the sized payload |
| # of a packet. |
| |
| # Apply the size modifiers. |
| size_modifier = int(getattr(field, "size_modifier", None) or 0) |
| |
| # The element width is known, and the array element count is known |
| # statically. |
| if isinstance(array_size, int): |
| return [Value(generator.generate_list(field.width, array_size))] |
| |
| # The element width is known, and the array element count is known |
| # by count field. |
| elif isinstance(array_size, ast.CountField): |
| min_count = 0 |
| max_count = (1 << array_size.width) - 1 |
| return [Value([]), Value(generator.generate_list(field.width, max_count))] |
| |
| # The element width is known, and the array full size is known |
| # by size field. |
| elif isinstance(array_size, ast.SizeField): |
| min_count = 0 |
| max_size = (1 << array_size.width) - 1 - size_modifier |
| max_count = int(max_size / element_width) |
| return [Value([]), Value(generator.generate_list(field.width, max_count))] |
| |
| # The element width is known, but the array size is unknown. |
| # Generate two arrays: one empty and one including some possible element |
| # values. |
| else: |
| return [ |
| Value([]), |
| Value(generator.generate_list(field.width, DEFAULT_ARRAY_COUNT)), |
| ] |
| |
| |
| def generate_typedef_array_field_values(field: ast.ArrayField) -> List[Value]: |
| array_size = core.get_array_field_size(field) |
| element_width = core.get_array_element_size(field) |
| if element_width: |
| if element_width % 8 != 0: |
| raise Exception("Array element size is not a multiple of 8") |
| element_width = int(element_width / 8) |
| |
| # Generate element values to use for the generation. |
| type_decl = field.parent.file.typedef_scope[field.type_id] |
| |
| def generate_list(count: Optional[int]) -> List[Value]: |
| """Generate an array of specified length. |
| If the count is None all possible array items are returned.""" |
| element_values = generate_typedef_values(type_decl) |
| |
| # Requested a variable count, send everything in one chunk. |
| if count is None: |
| return [Value(element_values)] |
| # Have more items than the requested count. |
| # Slice the possible array values in multiple slices. |
| elif len(element_values) > count: |
| # Add more elements in case of wrap-over. |
| elements_count = len(element_values) |
| element_values.extend(generate_typedef_values(type_decl)) |
| chunk_count = int((len(elements) + count - 1) / count) |
| return [ |
| Value(element_values[n * count : (n + 1) * count]) |
| for n in range(chunk_count) |
| ] |
| # Have less items than the requested count. |
| # Generate additional items to fill the gap. |
| else: |
| chunk = element_values |
| while len(chunk) < count: |
| chunk.extend(generate_typedef_values(type_decl)) |
| return [Value(chunk[:count])] |
| |
| # TODO |
| # The array might also be bounded if it is included in the sized payload |
| # of a packet. |
| |
| # Apply the size modifier. |
| size_modifier = int(getattr(field, "size_modifier", None) or 0) |
| |
| min_size = 0 |
| max_size = MAX_ARRAY_SIZE |
| min_count = 0 |
| max_count = MAX_ARRAY_COUNT |
| |
| if field.padded_size: |
| max_size = field.padded_size |
| |
| if isinstance(array_size, ast.SizeField): |
| max_size = (1 << array_size.width) - 1 - size_modifier |
| min_size = size_modifier |
| elif isinstance(array_size, ast.CountField): |
| max_count = (1 << array_size.width) - 1 |
| elif isinstance(array_size, int): |
| min_count = array_size |
| max_count = array_size |
| |
| values = [] |
| chunk = [] |
| chunk_size = 0 |
| |
| while not values: |
| element_values = generate_typedef_values(type_decl) |
| for element_value in element_values: |
| element_size = int(element_value.width / 8) |
| |
| if len(chunk) >= max_count or chunk_size + element_size > max_size: |
| assert len(chunk) >= min_count |
| values.append(Value(chunk)) |
| chunk = [] |
| chunk_size = 0 |
| |
| chunk.append(element_value) |
| chunk_size += element_size |
| |
| if min_count == 0: |
| values.append(Value([])) |
| |
| return values |
| |
| # The element width is not known, but the array full octet size |
| # is known by size field. Generate two arrays: of minimal and maximum |
| # size. All unused element values are packed into arrays of varying size. |
| if element_width is None and isinstance(array_size, ast.SizeField): |
| element_values = generate_typedef_values(type_decl) |
| chunk = [] |
| chunk_size = 0 |
| values = [Value([])] |
| for element_value in element_values: |
| assert element_value.width % 8 == 0 |
| element_size = int(element_value.width / 8) |
| if chunk_size + element_size > max_size: |
| values.append(Value(chunk)) |
| chunk = [] |
| chunk.append(element_value) |
| chunk_size += element_size |
| if chunk: |
| values.append(Value(chunk)) |
| return values |
| |
| # The element width is not known, but the array element count |
| # is known statically or by count field. Generate two arrays: |
| # of minimal and maximum length. All unused element values are packed into |
| # arrays of varying count. |
| elif element_width is None and isinstance(array_size, ast.CountField): |
| return [Value([])] + generate_list(max_count) |
| |
| # The element width is not known, and the array element count is known |
| # statically. |
| elif element_width is None and isinstance(array_size, int): |
| return generate_list(array_size) |
| |
| # Neither the count not size is known, |
| # generate two arrays: one empty and one including all possible element |
| # values. |
| elif element_width is None: |
| return [Value([])] + generate_list(None) |
| |
| # The element width is known, and the array element count is known |
| # statically. |
| elif isinstance(array_size, int): |
| return generate_list(array_size) |
| |
| # The element width is known, and the array element count is known |
| # by count field. |
| elif isinstance(array_size, ast.CountField): |
| return [Value([])] + generate_list(max_count) |
| |
| # The element width is known, and the array full size is known |
| # by size field. |
| elif isinstance(array_size, ast.SizeField): |
| return [Value([])] + generate_list(max_count) |
| |
| # The element width is known, but the array size is unknown. |
| # Generate two arrays: one empty and one including all possible element |
| # values. |
| else: |
| return [Value([])] + generate_list(None) |
| |
| |
| def generate_array_field_values(field: ast.ArrayField) -> List[Value]: |
| if field.width is not None: |
| return generate_scalar_array_field_values(field) |
| else: |
| return generate_typedef_array_field_values(field) |
| |
| |
| def generate_typedef_field_values( |
| field: ast.TypedefField, constraints: List[ast.Constraint] |
| ) -> List[Value]: |
| type_decl = field.parent.file.typedef_scope[field.type_id] |
| |
| # Check for constraint on enum field. |
| if isinstance(type_decl, ast.EnumDeclaration): |
| for c in constraints: |
| if c.id == field.id: |
| for tag in type_decl.tags: |
| if tag.id == c.tag_id: |
| return [Value(tag.value, type_decl.width)] |
| raise Exception("undefined enum tag") |
| |
| # Checksum field needs to known the checksum range. |
| if isinstance(type_decl, ast.ChecksumDeclaration): |
| return generate_checksum_field_values(field) |
| |
| return generate_typedef_values(type_decl) |
| |
| |
| def generate_field_values( |
| field: ast.Field, constraints: List[ast.Constraint], payload: Optional[List[Packet]] |
| ) -> List[Value]: |
| if field.cond_for: |
| return generate_cond_field_values(field) |
| |
| elif isinstance(field, ast.ChecksumField): |
| # Checksum fields are just markers. |
| return [Value(0, 0)] |
| |
| elif isinstance(field, ast.PaddingField): |
| return generate_padding_field_values(field) |
| |
| elif isinstance(field, ast.SizeField): |
| return generate_size_field_values(field) |
| |
| elif isinstance(field, ast.CountField): |
| return generate_count_field_values(field) |
| |
| elif isinstance(field, (ast.BodyField, ast.PayloadField)) and payload: |
| return [Value(p) for p in payload] |
| |
| elif isinstance(field, (ast.BodyField, ast.PayloadField)): |
| return generate_payload_field_values(field) |
| |
| elif isinstance(field, ast.FixedField) and field.enum_id: |
| enum_decl = field.parent.file.typedef_scope[field.enum_id] |
| for tag in enum_decl.tags: |
| if tag.id == field.tag_id: |
| return [Value(tag.value, enum_decl.width)] |
| raise Exception("undefined enum tag") |
| |
| elif isinstance(field, ast.FixedField) and field.width: |
| return [Value(field.value, field.width)] |
| |
| elif isinstance(field, ast.ReservedField): |
| return [Value(0, field.width)] |
| |
| elif isinstance(field, ast.ArrayField): |
| return generate_array_field_values(field) |
| |
| elif isinstance(field, ast.ScalarField): |
| for c in constraints: |
| if c.id == field.id: |
| return [Value(c.value, field.width)] |
| mask = (1 << field.width) - 1 |
| return [ |
| Value(0, field.width), |
| Value(-1 & mask, field.width), |
| generator.generate(field.width), |
| ] |
| |
| elif isinstance(field, ast.TypedefField): |
| return generate_typedef_field_values(field, constraints) |
| |
| else: |
| raise Exception("unsupported field kind") |
| |
| |
| def generate_fields( |
| decl: ast.Declaration, |
| constraints: List[ast.Constraint], |
| payload: Optional[List[Packet]], |
| ) -> List[List[Field]]: |
| fields = [] |
| for f in decl.fields: |
| values = generate_field_values(f, constraints, payload) |
| optional_none = [] if not f.cond else [Field(Value(None, 0), f)] |
| fields.append(optional_none + [Field(v, f) for v in values]) |
| return fields |
| |
| |
| def generate_fields_recursive( |
| scope: dict, |
| decl: ast.Declaration, |
| constraints: List[ast.Constraint] = [], |
| payload: Optional[List[Packet]] = None, |
| ) -> List[List[Field]]: |
| fields = generate_fields(decl, constraints, payload) |
| |
| if not decl.parent_id: |
| return fields |
| |
| packets = [Packet(fields, decl) for fields in product(fields)] |
| parent_decl = scope[decl.parent_id] |
| return generate_fields_recursive( |
| scope, parent_decl, constraints + decl.constraints, payload=packets |
| ) |
| |
| |
| def generate_struct_values(decl: ast.StructDeclaration) -> List[Packet]: |
| fields = generate_fields_recursive(decl.file.typedef_scope, decl) |
| return [Packet(fields, decl) for fields in product(fields)] |
| |
| |
| def generate_packet_values(decl: ast.PacketDeclaration) -> List[Packet]: |
| fields = generate_fields_recursive(decl.file.packet_scope, decl) |
| return [Packet(fields, decl) for fields in product(fields)] |
| |
| |
| def generate_typedef_values(decl: ast.Declaration) -> List[Value]: |
| if isinstance(decl, ast.EnumDeclaration): |
| return [Value(t.value, decl.width) for t in decl.tags] |
| |
| elif isinstance(decl, ast.ChecksumDeclaration): |
| raise Exception("ChecksumDeclaration handled in typedef field") |
| |
| elif isinstance(decl, ast.CustomFieldDeclaration): |
| raise Exception("TODO custom field") |
| |
| elif isinstance(decl, ast.StructDeclaration): |
| return [Value(p) for p in generate_struct_values(decl)] |
| |
| else: |
| raise Exception("unsupported typedef declaration type") |
| |
| |
| def product(fields: List[List[Field]]) -> List[List[Field]]: |
| """Perform a cartesian product of generated options for packet field values.""" |
| |
| def aux(vec: List[List[Field]]) -> List[List[Field]]: |
| if len(vec) == 0: |
| return [[]] |
| return [[item.clone()] + items for item in vec[0] for items in aux(vec[1:])] |
| |
| count = 1 |
| max_len = 0 |
| for f in fields: |
| count *= len(f) |
| max_len = max(max_len, len(f)) |
| |
| # Limit products to 32 elements to prevent combinatorial |
| # explosion. |
| if count <= 32: |
| return aux(fields) |
| |
| # If too many products, select samples which test all fields value |
| # values at the minimum. |
| else: |
| return [[f[idx % len(f)] for f in fields] for idx in range(0, max_len + 1)] |
| |
| |
| def serialize_values(file: ast.File, values: List[Value]) -> List[dict]: |
| results = [] |
| for v in values: |
| v.finalize() |
| packed = v.serialize(file.endianness.value == "big_endian") |
| result = { |
| "packed": "".join([f"{b:02x}" for b in packed]), |
| "unpacked": v.to_json(), |
| } |
| if v.ref.parent_id: |
| result["packet"] = v.ref.id |
| results.append(result) |
| return results |
| |
| |
| def run(input: Path, packet: List[str]): |
| with open(input) as f: |
| file = ast.File.from_json(json.load(f)) |
| core.desugar(file) |
| |
| results = dict() |
| for decl in file.packet_scope.values(): |
| if core.get_derived_packets(decl) or (packet and decl.id not in packet): |
| continue |
| |
| try: |
| values = generate_packet_values(decl) |
| ancestor = core.get_packet_ancestor(decl) |
| results[ancestor.id] = results.get(ancestor.id, []) + serialize_values( |
| file, values |
| ) |
| except Exception as exn: |
| print( |
| f"Skipping packet {decl.id}; cannot generate values: {exn}", |
| file=sys.stderr, |
| ) |
| |
| results = [{"packet": k, "tests": v} for (k, v) in results.items()] |
| json.dump(results, sys.stdout, indent=2) |
| |
| |
| def main() -> int: |
| """Generate test vectors for top-level PDL packets.""" |
| parser = argparse.ArgumentParser(description=__doc__) |
| parser.add_argument( |
| "--input", type=Path, required=True, help="Input PDL-JSON source" |
| ) |
| parser.add_argument( |
| "--packet", |
| type=lambda x: x.split(","), |
| required=False, |
| action="extend", |
| default=[], |
| help="Select PDL packet to test", |
| ) |
| return run(**vars(parser.parse_args())) |
| |
| |
| if __name__ == "__main__": |
| sys.exit(main()) |