| #!/usr/bin/env python3 |
| |
| # Copyright 2023 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # https://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| import argparse |
| from dataclasses import dataclass, field |
| import json |
| from pathlib import Path |
| import sys |
| from textwrap import dedent |
| from typing import List, Tuple, Union, Optional |
| |
| from pdl import ast, core |
| from pdl.utils import indent |
| |
| |
| def mask(width: int) -> str: |
| return hex((1 << width) - 1) |
| |
| |
| def generate_prelude() -> str: |
| return dedent("""\ |
| from dataclasses import dataclass, field, fields |
| from typing import Optional, List, Tuple, Union |
| import enum |
| import inspect |
| import math |
| |
| @dataclass |
| class Packet: |
| payload: Optional[bytes] = field(repr=False, default_factory=bytes, compare=False) |
| |
| @classmethod |
| def parse_all(cls, span: bytes) -> 'Packet': |
| packet, remain = getattr(cls, 'parse')(span) |
| if len(remain) > 0: |
| raise Exception('Unexpected parsing remainder') |
| return packet |
| |
| @property |
| def size(self) -> int: |
| pass |
| |
| def show(self, prefix: str = ''): |
| print(f'{self.__class__.__name__}') |
| |
| def print_val(p: str, pp: str, name: str, align: int, typ, val): |
| if name == 'payload': |
| pass |
| |
| # Scalar fields. |
| elif typ is int: |
| print(f'{p}{name:{align}} = {val} (0x{val:x})') |
| |
| # Byte fields. |
| elif typ is bytes: |
| print(f'{p}{name:{align}} = [', end='') |
| line = '' |
| n_pp = '' |
| for (idx, b) in enumerate(val): |
| if idx > 0 and idx % 8 == 0: |
| print(f'{n_pp}{line}') |
| line = '' |
| n_pp = pp + (' ' * (align + 4)) |
| line += f' {b:02x}' |
| print(f'{n_pp}{line} ]') |
| |
| # Enum fields. |
| elif inspect.isclass(typ) and issubclass(typ, enum.IntEnum): |
| print(f'{p}{name:{align}} = {typ.__name__}::{val.name} (0x{val:x})') |
| |
| # Struct fields. |
| elif inspect.isclass(typ) and issubclass(typ, globals().get('Packet')): |
| print(f'{p}{name:{align}} = ', end='') |
| val.show(prefix=pp) |
| |
| # Array fields. |
| elif getattr(typ, '__origin__', None) == list: |
| print(f'{p}{name:{align}}') |
| last = len(val) - 1 |
| align = 5 |
| for (idx, elt) in enumerate(val): |
| n_p = pp + ('├── ' if idx != last else '└── ') |
| n_pp = pp + ('│ ' if idx != last else ' ') |
| print_val(n_p, n_pp, f'[{idx}]', align, typ.__args__[0], val[idx]) |
| |
| # Custom fields. |
| elif inspect.isclass(typ): |
| print(f'{p}{name:{align}} = {repr(val)}') |
| |
| else: |
| print(f'{p}{name:{align}} = ##{typ}##') |
| |
| last = len(fields(self)) - 1 |
| align = max(len(f.name) for f in fields(self) if f.name != 'payload') |
| |
| for (idx, f) in enumerate(fields(self)): |
| p = prefix + ('├── ' if idx != last else '└── ') |
| pp = prefix + ('│ ' if idx != last else ' ') |
| val = getattr(self, f.name) |
| |
| print_val(p, pp, f.name, align, f.type, val) |
| """) |
| |
| |
| @dataclass |
| class FieldParser: |
| byteorder: str |
| offset: int = 0 |
| shift: int = 0 |
| chunk: List[Tuple[int, int, ast.Field]] = field(default_factory=lambda: []) |
| unchecked_code: List[str] = field(default_factory=lambda: []) |
| code: List[str] = field(default_factory=lambda: []) |
| |
| def unchecked_append_(self, line: str): |
| """Append unchecked field parsing code. |
| The function check_size_ must be called to generate a size guard |
| after parsing is completed.""" |
| self.unchecked_code.append(line) |
| |
| def append_(self, code: str): |
| """Append field parsing code. |
| There must be no unchecked code left before this function is called.""" |
| assert len(self.unchecked_code) == 0 |
| self.code.extend(code.split('\n')) |
| |
| def check_size_(self, size: str): |
| """Generate a check of the current span size.""" |
| self.append_(f"if len(span) < {size}:") |
| self.append_(f" raise Exception('Invalid packet size')") |
| |
| def check_code_(self): |
| """Generate a size check for pending field parsing.""" |
| if len(self.unchecked_code) > 0: |
| assert len(self.chunk) == 0 |
| unchecked_code = self.unchecked_code |
| self.unchecked_code = [] |
| self.check_size_(str(self.offset)) |
| self.code.extend(unchecked_code) |
| |
| def consume_span_(self, keep: int = 0) -> str: |
| """Skip consumed span bytes.""" |
| if self.offset > 0: |
| self.check_code_() |
| self.append_(f'span = span[{self.offset - keep}:]') |
| self.offset = 0 |
| |
| def parse_array_element_dynamic_(self, field: ast.ArrayField, span: str): |
| """Parse a single array field element of variable size.""" |
| if isinstance(field.type, ast.StructDeclaration): |
| self.append_(f" element, {span} = {field.type_id}.parse({span})") |
| self.append_(f" {field.id}.append(element)") |
| else: |
| raise Exception(f'Unexpected array element type {field.type_id} {field.width}') |
| |
| def parse_array_element_static_(self, field: ast.ArrayField, span: str): |
| """Parse a single array field element of constant size.""" |
| if field.width is not None: |
| element = f"int.from_bytes({span}, byteorder='{self.byteorder}')" |
| self.append_(f" {field.id}.append({element})") |
| elif isinstance(field.type, ast.EnumDeclaration): |
| element = f"int.from_bytes({span}, byteorder='{self.byteorder}')" |
| element = f"{field.type_id}({element})" |
| self.append_(f" {field.id}.append({element})") |
| else: |
| element = f"{field.type_id}.parse_all({span})" |
| self.append_(f" {field.id}.append({element})") |
| |
| def parse_byte_array_field_(self, field: ast.ArrayField): |
| """Parse the selected u8 array field.""" |
| array_size = core.get_array_field_size(field) |
| padded_size = field.padded_size |
| |
| # Shift the span to reset the offset to 0. |
| self.consume_span_() |
| |
| # Derive the array size. |
| if isinstance(array_size, int): |
| size = array_size |
| elif isinstance(array_size, ast.SizeField): |
| size = f'{field.id}_size - {field.size_modifier}' if field.size_modifier else f'{field.id}_size' |
| elif isinstance(array_size, ast.CountField): |
| size = f'{field.id}_count' |
| else: |
| size = None |
| |
| # Parse from the padded array if padding is present. |
| if padded_size and size is not None: |
| self.check_size_(padded_size) |
| self.append_(f"if {size} > {padded_size}:") |
| self.append_(" raise Exception('Array size is larger than the padding size')") |
| self.append_(f"fields['{field.id}'] = list(span[:{size}])") |
| self.append_(f"span = span[{padded_size}:]") |
| |
| elif size is not None: |
| self.check_size_(size) |
| self.append_(f"fields['{field.id}'] = list(span[:{size}])") |
| self.append_(f"span = span[{size}:]") |
| |
| else: |
| self.append_(f"fields['{field.id}'] = list(span)") |
| self.append_(f"span = bytes()") |
| |
| def parse_array_field_(self, field: ast.ArrayField): |
| """Parse the selected array field.""" |
| array_size = core.get_array_field_size(field) |
| element_width = core.get_array_element_size(field) |
| padded_size = field.padded_size |
| |
| if element_width: |
| if element_width % 8 != 0: |
| raise Exception('Array element size is not a multiple of 8') |
| element_width = int(element_width / 8) |
| |
| if isinstance(array_size, int): |
| size = None |
| count = array_size |
| elif isinstance(array_size, ast.SizeField): |
| size = f'{field.id}_size' |
| count = None |
| elif isinstance(array_size, ast.CountField): |
| size = None |
| count = f'{field.id}_count' |
| else: |
| size = None |
| count = None |
| |
| # Shift the span to reset the offset to 0. |
| self.consume_span_() |
| |
| # Apply the size modifier. |
| if field.size_modifier and size: |
| self.append_(f"{size} = {size} - {field.size_modifier}") |
| |
| # Parse from the padded array if padding is present. |
| if padded_size: |
| self.check_size_(padded_size) |
| self.append_(f"remaining_span = span[{padded_size}:]") |
| self.append_(f"span = span[:{padded_size}]") |
| |
| # The element width is not known, but the array full octet size |
| # is known by size field. Parse elements item by item as a vector. |
| if element_width is None and size is not None: |
| self.check_size_(size) |
| self.append_(f"array_span = span[:{size}]") |
| self.append_(f"{field.id} = []") |
| self.append_("while len(array_span) > 0:") |
| self.parse_array_element_dynamic_(field, 'array_span') |
| self.append_(f"fields['{field.id}'] = {field.id}") |
| self.append_(f"span = span[{size}:]") |
| |
| # The element width is not known, but the array element count |
| # is known statically or by count field. |
| # Parse elements item by item as a vector. |
| elif element_width is None and count is not None: |
| self.append_(f"{field.id} = []") |
| self.append_(f"for n in range({count}):") |
| self.parse_array_element_dynamic_(field, 'span') |
| self.append_(f"fields['{field.id}'] = {field.id}") |
| |
| # Neither the count not size is known, |
| # parse elements until the end of the span. |
| elif element_width is None: |
| self.append_(f"{field.id} = []") |
| self.append_("while len(span) > 0:") |
| self.parse_array_element_dynamic_(field, 'span') |
| self.append_(f"fields['{field.id}'] = {field.id}") |
| |
| # The element width is known, and the array element count is known |
| # statically, or by count field. |
| elif count is not None: |
| array_size = (f'{count}' if element_width == 1 else f'{count} * {element_width}') |
| self.check_size_(array_size) |
| self.append_(f"{field.id} = []") |
| self.append_(f"for n in range({count}):") |
| span = ('span[n:n + 1]' if element_width == 1 else f'span[n * {element_width}:(n + 1) * {element_width}]') |
| self.parse_array_element_static_(field, span) |
| self.append_(f"fields['{field.id}'] = {field.id}") |
| self.append_(f"span = span[{array_size}:]") |
| |
| # The element width is known, and the array full size is known |
| # by size field, or unknown (in which case it is the remaining span |
| # length). |
| else: |
| if size is not None: |
| self.check_size_(size) |
| array_size = size or 'len(span)' |
| if element_width != 1: |
| self.append_(f"if {array_size} % {element_width} != 0:") |
| self.append_(" raise Exception('Array size is not a multiple of the element size')") |
| self.append_(f"{field.id}_count = int({array_size} / {element_width})") |
| array_count = f'{field.id}_count' |
| else: |
| array_count = array_size |
| self.append_(f"{field.id} = []") |
| self.append_(f"for n in range({array_count}):") |
| span = ('span[n:n + 1]' if element_width == 1 else f'span[n * {element_width}:(n + 1) * {element_width}]') |
| self.parse_array_element_static_(field, span) |
| self.append_(f"fields['{field.id}'] = {field.id}") |
| if size is not None: |
| self.append_(f"span = span[{size}:]") |
| else: |
| self.append_(f"span = bytes()") |
| |
| # Drop the padding |
| if padded_size: |
| self.append_(f"span = remaining_span") |
| |
| def parse_optional_field_(self, field: ast.Field): |
| """Parse the selected optional field. |
| Optional fields must start and end on a byte boundary.""" |
| |
| if self.shift != 0: |
| raise Exception('Optional field does not start on an octet boundary') |
| if (isinstance(field, ast.TypedefField) and |
| isinstance(field.type, ast.StructDeclaration) and |
| field.type.parent_id is not None): |
| raise Exception('Derived struct used in optional typedef field') |
| |
| self.consume_span_() |
| |
| if isinstance(field, ast.ScalarField): |
| self.append_(dedent(""" |
| if {cond_id} == {cond_value}: |
| if len(span) < {size}: |
| raise Exception('Invalid packet size') |
| fields['{field_id}'] = int.from_bytes(span[:{size}], byteorder='{byteorder}') |
| span = span[{size}:] |
| """.format(size=int(field.width / 8), |
| field_id=field.id, |
| cond_id=field.cond.id, |
| cond_value=field.cond.value, |
| byteorder=self.byteorder))) |
| |
| elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration): |
| self.append_(dedent(""" |
| if {cond_id} == {cond_value}: |
| if len(span) < {size}: |
| raise Exception('Invalid packet size') |
| fields['{field_id}'] = {type_id}( |
| int.from_bytes(span[:{size}], byteorder='{byteorder}')) |
| span = span[{size}:] |
| """.format(size=int(field.type.width / 8), |
| field_id=field.id, |
| type_id=field.type_id, |
| cond_id=field.cond.id, |
| cond_value=field.cond.value, |
| byteorder=self.byteorder))) |
| |
| elif isinstance(field, ast.TypedefField): |
| self.append_(dedent(""" |
| if {cond_id} == {cond_value}: |
| {field_id}, span = {type_id}.parse(span) |
| fields['{field_id}'] = {field_id} |
| """.format(field_id=field.id, |
| type_id=field.type_id, |
| cond_id=field.cond.id, |
| cond_value=field.cond.value))) |
| |
| else: |
| raise Exception(f"unsupported field type {field.__class__.__name__}") |
| |
| def parse_bit_field_(self, field: ast.Field): |
| """Parse the selected field as a bit field. |
| The field is added to the current chunk. When a byte boundary |
| is reached all saved fields are extracted together.""" |
| |
| # Add to current chunk. |
| width = core.get_field_size(field) |
| self.chunk.append((self.shift, width, field)) |
| self.shift += width |
| |
| # Wait for more fields if not on a byte boundary. |
| if (self.shift % 8) != 0: |
| return |
| |
| # Parse the backing integer using the configured endiannes, |
| # extract field values. |
| size = int(self.shift / 8) |
| end_offset = self.offset + size |
| |
| if size == 1: |
| value = f"span[{self.offset}]" |
| else: |
| span = f"span[{self.offset}:{end_offset}]" |
| self.unchecked_append_(f"value_ = int.from_bytes({span}, byteorder='{self.byteorder}')") |
| value = "value_" |
| |
| for shift, width, field in self.chunk: |
| v = (value if len(self.chunk) == 1 and shift == 0 else f"({value} >> {shift}) & {mask(width)}") |
| |
| if field.cond_for: |
| self.unchecked_append_(f"{field.id} = {v}") |
| elif isinstance(field, ast.ScalarField): |
| self.unchecked_append_(f"fields['{field.id}'] = {v}") |
| elif isinstance(field, ast.FixedField) and field.enum_id: |
| self.unchecked_append_(f"if {v} != {field.enum_id}.{field.tag_id}:") |
| self.unchecked_append_(f" raise Exception('Unexpected fixed field value')") |
| elif isinstance(field, ast.FixedField): |
| self.unchecked_append_(f"if {v} != {hex(field.value)}:") |
| self.unchecked_append_(f" raise Exception('Unexpected fixed field value')") |
| elif isinstance(field, ast.TypedefField): |
| self.unchecked_append_(f"fields['{field.id}'] = {field.type_id}.from_int({v})") |
| elif isinstance(field, ast.SizeField): |
| self.unchecked_append_(f"{field.field_id}_size = {v}") |
| elif isinstance(field, ast.CountField): |
| self.unchecked_append_(f"{field.field_id}_count = {v}") |
| elif isinstance(field, ast.ReservedField): |
| pass |
| else: |
| raise Exception(f'Unsupported bit field type {field.kind}') |
| |
| # Reset state. |
| self.offset = end_offset |
| self.shift = 0 |
| self.chunk = [] |
| |
| def parse_typedef_field_(self, field: ast.TypedefField): |
| """Parse a typedef field, to the exclusion of Enum fields.""" |
| |
| if self.shift != 0: |
| raise Exception('Typedef field does not start on an octet boundary') |
| if (isinstance(field.type, ast.StructDeclaration) and field.type.parent_id is not None): |
| raise Exception('Derived struct used in typedef field') |
| |
| width = core.get_declaration_size(field.type) |
| if width is None: |
| self.consume_span_() |
| self.append_(f"{field.id}, span = {field.type_id}.parse(span)") |
| self.append_(f"fields['{field.id}'] = {field.id}") |
| else: |
| if width % 8 != 0: |
| raise Exception('Typedef field type size is not a multiple of 8') |
| width = int(width / 8) |
| end_offset = self.offset + width |
| # Checksum value field is generated alongside checksum start. |
| # Deal with this field as padding. |
| if not isinstance(field.type, ast.ChecksumDeclaration): |
| span = f'span[{self.offset}:{end_offset}]' |
| self.unchecked_append_(f"fields['{field.id}'] = {field.type_id}.parse_all({span})") |
| self.offset = end_offset |
| |
| def parse_payload_field_(self, field: Union[ast.BodyField, ast.PayloadField]): |
| """Parse body and payload fields.""" |
| |
| payload_size = core.get_payload_field_size(field) |
| offset_from_end = core.get_field_offset_from_end(field) |
| |
| # If the payload is not byte aligned, do parse the bit fields |
| # that can be extracted, but do not consume the input bytes as |
| # they will also be included in the payload span. |
| if self.shift != 0: |
| if payload_size: |
| raise Exception("Unexpected payload size for non byte aligned payload") |
| |
| rounded_size = int((self.shift + 7) / 8) |
| padding_bits = 8 * rounded_size - self.shift |
| self.parse_bit_field_(core.make_reserved_field(padding_bits)) |
| self.consume_span_(rounded_size) |
| else: |
| self.consume_span_() |
| |
| # The payload or body has a known size. |
| # Consume the payload and update the span in case |
| # fields are placed after the payload. |
| if payload_size: |
| if getattr(field, 'size_modifier', None): |
| self.append_(f"{field.id}_size -= {field.size_modifier}") |
| self.check_size_(f'{field.id}_size') |
| self.append_(f"payload = span[:{field.id}_size]") |
| self.append_(f"span = span[{field.id}_size:]") |
| # The payload or body is the last field of a packet, |
| # consume the remaining span. |
| elif offset_from_end == 0: |
| self.append_(f"payload = span") |
| self.append_(f"span = bytes([])") |
| # The payload or body is followed by fields of static size. |
| # Consume the span that is not reserved for the following fields. |
| elif offset_from_end is not None: |
| if (offset_from_end % 8) != 0: |
| raise Exception('Payload field offset from end of packet is not a multiple of 8') |
| offset_from_end = int(offset_from_end / 8) |
| self.check_size_(f'{offset_from_end}') |
| self.append_(f"payload = span[:-{offset_from_end}]") |
| self.append_(f"span = span[-{offset_from_end}:]") |
| self.append_(f"fields['payload'] = payload") |
| |
| def parse_checksum_field_(self, field: ast.ChecksumField): |
| """Generate a checksum check.""" |
| |
| # The checksum value field can be read starting from the current |
| # offset if the fields in between are of fixed size, or from the end |
| # of the span otherwise. |
| self.consume_span_() |
| value_field = core.get_packet_field(field.parent, field.field_id) |
| offset_from_start = 0 |
| offset_from_end = 0 |
| start_index = field.parent.fields.index(field) |
| value_index = field.parent.fields.index(value_field) |
| value_size = int(core.get_field_size(value_field) / 8) |
| |
| for f in field.parent.fields[start_index + 1:value_index]: |
| size = core.get_field_size(f) |
| if size is None: |
| offset_from_start = None |
| break |
| else: |
| offset_from_start += size |
| |
| trailing_fields = field.parent.fields[value_index:] |
| trailing_fields.reverse() |
| for f in trailing_fields: |
| size = core.get_field_size(f) |
| if size is None: |
| offset_from_end = None |
| break |
| else: |
| offset_from_end += size |
| |
| if offset_from_start is not None: |
| if offset_from_start % 8 != 0: |
| raise Exception('Checksum value field is not aligned to an octet boundary') |
| offset_from_start = int(offset_from_start / 8) |
| checksum_span = f'span[:{offset_from_start}]' |
| if value_size > 1: |
| start = offset_from_start |
| end = offset_from_start + value_size |
| value = f"int.from_bytes(span[{start}:{end}], byteorder='{self.byteorder}')" |
| else: |
| value = f'span[{offset_from_start}]' |
| self.check_size_(offset_from_start + value_size) |
| |
| elif offset_from_end is not None: |
| sign = '' |
| if offset_from_end % 8 != 0: |
| raise Exception('Checksum value field is not aligned to an octet boundary') |
| offset_from_end = int(offset_from_end / 8) |
| checksum_span = f'span[:-{offset_from_end}]' |
| if value_size > 1: |
| start = offset_from_end |
| end = offset_from_end - value_size |
| value = f"int.from_bytes(span[-{start}:-{end}], byteorder='{self.byteorder}')" |
| else: |
| value = f'span[-{offset_from_end}]' |
| self.check_size_(offset_from_end) |
| |
| else: |
| raise Exception('Checksum value field cannot be read at constant offset') |
| |
| self.append_(f"{value_field.id} = {value}") |
| self.append_(f"fields['{value_field.id}'] = {value_field.id}") |
| self.append_(f"computed_{value_field.id} = {value_field.type.function}({checksum_span})") |
| self.append_(f"if computed_{value_field.id} != {value_field.id}:") |
| self.append_(" raise Exception(f'Invalid checksum computation:" + |
| f" {{computed_{value_field.id}}} != {{{value_field.id}}}')") |
| |
| def parse(self, field: ast.Field): |
| if field.cond: |
| self.parse_optional_field_(field) |
| |
| # Field has bit granularity. |
| # Append the field to the current chunk, |
| # check if a byte boundary was reached. |
| elif core.is_bit_field(field): |
| self.parse_bit_field_(field) |
| |
| # Padding fields. |
| elif isinstance(field, ast.PaddingField): |
| pass |
| |
| # Array fields. |
| elif isinstance(field, ast.ArrayField) and field.width == 8: |
| self.parse_byte_array_field_(field) |
| |
| elif isinstance(field, ast.ArrayField): |
| self.parse_array_field_(field) |
| |
| # Other typedef fields. |
| elif isinstance(field, ast.TypedefField): |
| self.parse_typedef_field_(field) |
| |
| # Payload and body fields. |
| elif isinstance(field, (ast.PayloadField, ast.BodyField)): |
| self.parse_payload_field_(field) |
| |
| # Checksum fields. |
| elif isinstance(field, ast.ChecksumField): |
| self.parse_checksum_field_(field) |
| |
| else: |
| raise Exception(f'Unimplemented field type {field.kind}') |
| |
| def done(self): |
| self.consume_span_() |
| |
| |
| @dataclass |
| class FieldSerializer: |
| byteorder: str |
| shift: int = 0 |
| value: List[str] = field(default_factory=lambda: []) |
| code: List[str] = field(default_factory=lambda: []) |
| indent: int = 0 |
| |
| def indent_(self): |
| self.indent += 1 |
| |
| def unindent_(self): |
| self.indent -= 1 |
| |
| def append_(self, line: str): |
| """Append field serializing code.""" |
| lines = line.split('\n') |
| self.code.extend([' ' * self.indent + line for line in lines]) |
| |
| def extend_(self, value: str, length: int): |
| """Append data to the span being constructed.""" |
| if length == 1: |
| self.append_(f"_span.append({value})") |
| else: |
| self.append_(f"_span.extend(int.to_bytes({value}, length={length}, byteorder='{self.byteorder}'))") |
| |
| def serialize_array_element_(self, field: ast.ArrayField): |
| """Serialize a single array field element.""" |
| if field.width is not None: |
| length = int(field.width / 8) |
| self.extend_('_elt', length) |
| elif isinstance(field.type, ast.EnumDeclaration): |
| length = int(field.type.width / 8) |
| self.extend_('_elt', length) |
| else: |
| self.append_("_span.extend(_elt.serialize())") |
| |
| def serialize_array_field_(self, field: ast.ArrayField): |
| """Serialize the selected array field.""" |
| if field.padded_size: |
| self.append_(f"_{field.id}_start = len(_span)") |
| |
| if field.width == 8: |
| self.append_(f"_span.extend(self.{field.id})") |
| else: |
| self.append_(f"for _elt in self.{field.id}:") |
| self.indent_() |
| self.serialize_array_element_(field) |
| self.unindent_() |
| |
| if field.padded_size: |
| self.append_(f"_span.extend([0] * ({field.padded_size} - len(_span) + _{field.id}_start))") |
| |
| def serialize_optional_field_(self, field: ast.Field): |
| if isinstance(field, ast.ScalarField): |
| self.append_(dedent( |
| """ |
| if self.{field_id} is not None: |
| _span.extend(int.to_bytes(self.{field_id}, length={size}, byteorder='{byteorder}')) |
| """.format(field_id=field.id, |
| size=int(field.width / 8), |
| byteorder=self.byteorder))) |
| |
| elif isinstance(field, ast.TypedefField) and isinstance(field.type, ast.EnumDeclaration): |
| self.append_(dedent( |
| """ |
| if self.{field_id} is not None: |
| _span.extend(int.to_bytes(self.{field_id}, length={size}, byteorder='{byteorder}')) |
| """.format(field_id=field.id, |
| size=int(field.type.width / 8), |
| byteorder=self.byteorder))) |
| |
| elif isinstance(field, ast.TypedefField): |
| self.append_(dedent( |
| """ |
| if self.{field_id} is not None: |
| _span.extend(self.{field_id}.serialize()) |
| """.format(field_id=field.id))) |
| |
| else: |
| raise Exception(f"unsupported field type {field.__class__.__name__}") |
| |
| def serialize_bit_field_(self, field: ast.Field): |
| """Serialize the selected field as a bit field. |
| The field is added to the current chunk. When a byte boundary |
| is reached all saved fields are serialized together.""" |
| |
| # Add to current chunk. |
| width = core.get_field_size(field) |
| shift = self.shift |
| |
| if isinstance(field, str): |
| self.value.append(f"({field} << {shift})") |
| elif field.cond_for: |
| # Scalar field used as condition for an optional field. |
| # The width is always 1, the value is determined from |
| # the presence or absence of the optional field. |
| value_present = field.cond_for.cond.value |
| value_absent = 0 if field.cond_for.cond.value else 1 |
| self.value.append(f"(({value_absent} if self.{field.cond_for.id} is None else {value_present}) << {shift})") |
| elif isinstance(field, ast.ScalarField): |
| max_value = (1 << field.width) - 1 |
| self.append_(f"if self.{field.id} > {max_value}:") |
| self.append_(f" print(f\"Invalid value for field {field.parent.id}::{field.id}:" + |
| f" {{self.{field.id}}} > {max_value}; the value will be truncated\")") |
| self.append_(f" self.{field.id} &= {max_value}") |
| self.value.append(f"(self.{field.id} << {shift})") |
| elif isinstance(field, ast.FixedField) and field.enum_id: |
| self.value.append(f"({field.enum_id}.{field.tag_id} << {shift})") |
| elif isinstance(field, ast.FixedField): |
| self.value.append(f"({field.value} << {shift})") |
| elif isinstance(field, ast.TypedefField): |
| self.value.append(f"(self.{field.id} << {shift})") |
| |
| elif isinstance(field, ast.SizeField): |
| max_size = (1 << field.width) - 1 |
| value_field = core.get_packet_field(field.parent, field.field_id) |
| size_modifier = '' |
| |
| if getattr(value_field, 'size_modifier', None): |
| size_modifier = f' + {value_field.size_modifier}' |
| |
| if isinstance(value_field, (ast.PayloadField, ast.BodyField)): |
| self.append_(f"_payload_size = len(payload or self.payload or []){size_modifier}") |
| self.append_(f"if _payload_size > {max_size}:") |
| self.append_(f" print(f\"Invalid length for payload field:" + |
| f" {{_payload_size}} > {max_size}; the packet cannot be generated\")") |
| self.append_(f" raise Exception(\"Invalid payload length\")") |
| array_size = "_payload_size" |
| elif isinstance(value_field, ast.ArrayField) and value_field.width: |
| array_size = f"(len(self.{value_field.id}) * {int(value_field.width / 8)}{size_modifier})" |
| elif isinstance(value_field, ast.ArrayField) and isinstance(value_field.type, ast.EnumDeclaration): |
| array_size = f"(len(self.{value_field.id}) * {int(value_field.type.width / 8)}{size_modifier})" |
| elif isinstance(value_field, ast.ArrayField): |
| self.append_( |
| f"_{value_field.id}_size = sum([elt.size for elt in self.{value_field.id}]){size_modifier}") |
| array_size = f"_{value_field.id}_size" |
| else: |
| raise Exception("Unsupported field type") |
| self.value.append(f"({array_size} << {shift})") |
| |
| elif isinstance(field, ast.CountField): |
| max_count = (1 << field.width) - 1 |
| self.append_(f"if len(self.{field.field_id}) > {max_count}:") |
| self.append_(f" print(f\"Invalid length for field {field.parent.id}::{field.field_id}:" + |
| f" {{len(self.{field.field_id})}} > {max_count}; the array will be truncated\")") |
| self.append_(f" del self.{field.field_id}[{max_count}:]") |
| self.value.append(f"(len(self.{field.field_id}) << {shift})") |
| elif isinstance(field, ast.ReservedField): |
| pass |
| else: |
| raise Exception(f'Unsupported bit field type {field.kind}') |
| |
| # Check if a byte boundary is reached. |
| self.shift += width |
| if (self.shift % 8) == 0: |
| self.pack_bit_fields_() |
| |
| def pack_bit_fields_(self): |
| """Pack serialized bit fields.""" |
| |
| # Should have an integral number of bytes now. |
| assert (self.shift % 8) == 0 |
| |
| # Generate the backing integer, and serialize it |
| # using the configured endiannes, |
| size = int(self.shift / 8) |
| |
| if len(self.value) == 0: |
| self.append_(f"_span.extend([0] * {size})") |
| elif len(self.value) == 1: |
| self.extend_(self.value[0], size) |
| else: |
| self.append_(f"_value = (") |
| self.append_(" " + " |\n ".join(self.value)) |
| self.append_(")") |
| self.extend_('_value', size) |
| |
| # Reset state. |
| self.shift = 0 |
| self.value = [] |
| |
| def serialize_typedef_field_(self, field: ast.TypedefField): |
| """Serialize a typedef field, to the exclusion of Enum fields.""" |
| |
| if self.shift != 0: |
| raise Exception('Typedef field does not start on an octet boundary') |
| if (isinstance(field.type, ast.StructDeclaration) and field.type.parent_id is not None): |
| raise Exception('Derived struct used in typedef field') |
| |
| if isinstance(field.type, ast.ChecksumDeclaration): |
| size = int(field.type.width / 8) |
| self.append_(f"_checksum = {field.type.function}(_span[_checksum_start:])") |
| self.extend_('_checksum', size) |
| else: |
| self.append_(f"_span.extend(self.{field.id}.serialize())") |
| |
| def serialize_payload_field_(self, field: Union[ast.BodyField, ast.PayloadField]): |
| """Serialize body and payload fields.""" |
| |
| if self.shift != 0 and self.byteorder == 'big': |
| raise Exception('Payload field does not start on an octet boundary') |
| |
| if self.shift == 0: |
| self.append_(f"_span.extend(payload or self.payload or [])") |
| else: |
| # Supported case of packet inheritance; |
| # the incomplete fields are serialized into |
| # the payload, rather than separately. |
| # First extract the padding bits from the payload, |
| # then recombine them with the bit fields to be serialized. |
| rounded_size = int((self.shift + 7) / 8) |
| padding_bits = 8 * rounded_size - self.shift |
| self.append_(f"_payload = payload or self.payload or bytes()") |
| self.append_(f"if len(_payload) < {rounded_size}:") |
| self.append_(f" raise Exception(f\"Invalid length for payload field:" + |
| f" {{len(_payload)}} < {rounded_size}\")") |
| self.append_( |
| f"_padding = int.from_bytes(_payload[:{rounded_size}], byteorder='{self.byteorder}') >> {self.shift}") |
| self.value.append(f"(_padding << {self.shift})") |
| self.shift += padding_bits |
| self.pack_bit_fields_() |
| self.append_(f"_span.extend(_payload[{rounded_size}:])") |
| |
| def serialize_checksum_field_(self, field: ast.ChecksumField): |
| """Generate a checksum check.""" |
| |
| self.append_("_checksum_start = len(_span)") |
| |
| def serialize(self, field: ast.Field): |
| if field.cond: |
| self.serialize_optional_field_(field) |
| |
| # Field has bit granularity. |
| # Append the field to the current chunk, |
| # check if a byte boundary was reached. |
| elif core.is_bit_field(field): |
| self.serialize_bit_field_(field) |
| |
| # Padding fields. |
| elif isinstance(field, ast.PaddingField): |
| pass |
| |
| # Array fields. |
| elif isinstance(field, ast.ArrayField): |
| self.serialize_array_field_(field) |
| |
| # Other typedef fields. |
| elif isinstance(field, ast.TypedefField): |
| self.serialize_typedef_field_(field) |
| |
| # Payload and body fields. |
| elif isinstance(field, (ast.PayloadField, ast.BodyField)): |
| self.serialize_payload_field_(field) |
| |
| # Checksum fields. |
| elif isinstance(field, ast.ChecksumField): |
| self.serialize_checksum_field_(field) |
| |
| else: |
| raise Exception(f'Unimplemented field type {field.kind}') |
| |
| |
| def generate_toplevel_packet_serializer(packet: ast.Declaration) -> List[str]: |
| """Generate the serialize() function for a toplevel Packet or Struct |
| declaration.""" |
| |
| serializer = FieldSerializer(byteorder=packet.file.byteorder) |
| for f in packet.fields: |
| serializer.serialize(f) |
| return ['_span = bytearray()'] + serializer.code + ['return bytes(_span)'] |
| |
| |
| def generate_derived_packet_serializer(packet: ast.Declaration) -> List[str]: |
| """Generate the serialize() function for a derived Packet or Struct |
| declaration.""" |
| |
| packet_shift = core.get_packet_shift(packet) |
| if packet_shift and packet.file.byteorder == 'big': |
| raise Exception(f"Big-endian packet {packet.id} has an unsupported body shift") |
| |
| serializer = FieldSerializer(byteorder=packet.file.byteorder, shift=packet_shift) |
| for f in packet.fields: |
| serializer.serialize(f) |
| return ['_span = bytearray()' |
| ] + serializer.code + [f'return {packet.parent.id}.serialize(self, payload = bytes(_span))'] |
| |
| |
| def generate_packet_parser(packet: ast.Declaration) -> List[str]: |
| """Generate the parse() function for a toplevel Packet or Struct |
| declaration.""" |
| |
| packet_shift = core.get_packet_shift(packet) |
| if packet_shift and packet.file.byteorder == 'big': |
| raise Exception(f"Big-endian packet {packet.id} has an unsupported body shift") |
| |
| # Convert the packet constraints to a boolean expression. |
| validation = [] |
| constraints = core.get_all_packet_constraints(packet) |
| if constraints: |
| cond = [] |
| for c in constraints: |
| if c.value is not None: |
| cond.append(f"fields['{c.id}'] != {hex(c.value)}") |
| else: |
| field = core.get_packet_field(packet, c.id) |
| cond.append(f"fields['{c.id}'] != {field.type_id}.{c.tag_id}") |
| |
| validation = [f"if {' or '.join(cond)}:", " raise Exception(\"Invalid constraint field values\")"] |
| |
| # Parse fields iteratively. |
| parser = FieldParser(byteorder=packet.file.byteorder, shift=packet_shift) |
| for f in packet.fields: |
| parser.parse(f) |
| parser.done() |
| |
| # Specialize to child packets. |
| children = core.get_derived_packets(packet) |
| decl = [] if packet.parent_id else ['fields = {\'payload\': None}'] |
| specialization = [] |
| |
| if len(children) != 0: |
| # Try parsing every child packet successively until one is |
| # successfully parsed. Return a parsing error if none is valid. |
| # Return parent packet if no child packet matches. |
| # TODO: order child packets by decreasing size in case no constraint |
| # is given for specialization. |
| for _, child in children: |
| specialization.append("try:") |
| specialization.append(f" return {child.id}.parse(fields.copy(), payload)") |
| specialization.append("except Exception as exn:") |
| specialization.append(" pass") |
| |
| return decl + validation + parser.code + specialization + [f"return {packet.id}(**fields), span"] |
| |
| |
| def generate_packet_size_getter(packet: ast.Declaration) -> List[str]: |
| constant_width = 0 |
| variable_width = [] |
| for f in packet.fields: |
| field_size = core.get_field_size(f) |
| if f.cond: |
| if isinstance(f, ast.ScalarField): |
| return f"(0 if self.{f.id} is None else {f.width})" |
| elif isinstance(f, ast.TypedefField) and isinstance(f.type, ast.EnumDeclaration): |
| return f"(0 if self.{f.id} is None else {f.type.width})" |
| elif isinstance(f, ast.TypedefField): |
| return f"(0 if self.{f.id} is None else self.{f.id}.size)" |
| else: |
| raise Exception(f"unsupported field type {f.__class__.__name__}") |
| elif field_size is not None: |
| constant_width += field_size |
| elif isinstance(f, (ast.PayloadField, ast.BodyField)): |
| variable_width.append("len(self.payload)") |
| elif isinstance(f, ast.TypedefField): |
| variable_width.append(f"self.{f.id}.size") |
| elif isinstance(f, ast.ArrayField) and isinstance(f.type, (ast.StructDeclaration, ast.CustomFieldDeclaration)): |
| variable_width.append(f"sum([elt.size for elt in self.{f.id}])") |
| elif isinstance(f, ast.ArrayField) and isinstance(f.type, ast.EnumDeclaration): |
| variable_width.append(f"len(self.{f.id}) * {f.type.width}") |
| elif isinstance(f, ast.ArrayField): |
| variable_width.append(f"len(self.{f.id}) * {int(f.width / 8)}") |
| else: |
| raise Exception("Unsupported field type") |
| |
| constant_width = int(constant_width / 8) |
| if len(variable_width) == 0: |
| return [f"return {constant_width}"] |
| elif len(variable_width) == 1 and constant_width: |
| return [f"return {variable_width[0]} + {constant_width}"] |
| elif len(variable_width) == 1: |
| return [f"return {variable_width[0]}"] |
| elif len(variable_width) > 1 and constant_width: |
| return ([f"return {constant_width} + ("] + " +\n ".join(variable_width).split("\n") + [")"]) |
| elif len(variable_width) > 1: |
| return (["return ("] + " +\n ".join(variable_width).split("\n") + [")"]) |
| else: |
| assert False |
| |
| |
| def generate_packet_post_init(decl: ast.Declaration) -> List[str]: |
| """Generate __post_init__ function to set constraint field values.""" |
| |
| # Gather all constraints from parent packets. |
| constraints = core.get_all_packet_constraints(decl) |
| |
| if constraints: |
| code = [] |
| for c in constraints: |
| if c.value is not None: |
| code.append(f"self.{c.id} = {c.value}") |
| else: |
| field = core.get_packet_field(decl, c.id) |
| code.append(f"self.{c.id} = {field.type_id}.{c.tag_id}") |
| return code |
| |
| else: |
| return ["pass"] |
| |
| |
| def generate_enum_declaration(decl: ast.EnumDeclaration) -> str: |
| """Generate the implementation of an enum type.""" |
| |
| enum_name = decl.id |
| tag_decls = [] |
| for t in decl.tags: |
| # Enums in python are closed and ranges cannot be represented; |
| # instead the generated code uses Union[int, Enum] |
| # when ranges are used. |
| if t.value is not None: |
| tag_decls.append(f"{t.id} = {hex(t.value)}") |
| |
| if core.is_open_enum(decl): |
| unknown_handler = ["return v"] |
| else: |
| unknown_handler = [] |
| for t in decl.tags: |
| if t.range is not None: |
| unknown_handler.append(f"if v >= 0x{t.range[0]:x} and v <= 0x{t.range[1]:x}:") |
| unknown_handler.append(f" return v") |
| unknown_handler.append("raise exn") |
| |
| return dedent("""\ |
| |
| class {enum_name}(enum.IntEnum): |
| {tag_decls} |
| |
| @staticmethod |
| def from_int(v: int) -> Union[int, '{enum_name}']: |
| try: |
| return {enum_name}(v) |
| except ValueError as exn: |
| {unknown_handler} |
| |
| """).format(enum_name=enum_name, |
| tag_decls=indent(tag_decls, 1), |
| unknown_handler=indent(unknown_handler, 3)) |
| |
| |
| def generate_packet_declaration(packet: ast.Declaration) -> str: |
| """Generate the implementation a toplevel Packet or Struct |
| declaration.""" |
| |
| packet_name = packet.id |
| field_decls = [] |
| for f in packet.fields: |
| if f.cond: |
| if isinstance(f, ast.ScalarField): |
| field_decls.append(f"{f.id}: Optional[int] = field(kw_only=True, default=None)") |
| elif isinstance(f, ast.TypedefField): |
| field_decls.append(f"{f.id}: Optional[{f.type_id}] = field(kw_only=True, default=None)") |
| else: |
| pass |
| elif f.cond_for: |
| # The fields used as condition for optional fields are |
| # not generated since their value is tied to the value of the |
| # optional field. |
| pass |
| elif isinstance(f, ast.ScalarField): |
| field_decls.append(f"{f.id}: int = field(kw_only=True, default=0)") |
| elif isinstance(f, ast.TypedefField): |
| if isinstance(f.type, ast.EnumDeclaration) and f.type.tags[0].range: |
| field_decls.append( |
| f"{f.id}: {f.type_id} = field(kw_only=True, default={f.type.tags[0].range[0]})") |
| elif isinstance(f.type, ast.EnumDeclaration): |
| field_decls.append( |
| f"{f.id}: {f.type_id} = field(kw_only=True, default={f.type_id}.{f.type.tags[0].id})") |
| elif isinstance(f.type, ast.ChecksumDeclaration): |
| field_decls.append(f"{f.id}: int = field(kw_only=True, default=0)") |
| elif isinstance(f.type, (ast.StructDeclaration, ast.CustomFieldDeclaration)): |
| field_decls.append(f"{f.id}: {f.type_id} = field(kw_only=True, default_factory={f.type_id})") |
| else: |
| raise Exception("Unsupported typedef field type") |
| elif isinstance(f, ast.ArrayField) and f.width == 8: |
| field_decls.append(f"{f.id}: bytearray = field(kw_only=True, default_factory=bytearray)") |
| elif isinstance(f, ast.ArrayField) and f.width: |
| field_decls.append(f"{f.id}: List[int] = field(kw_only=True, default_factory=list)") |
| elif isinstance(f, ast.ArrayField) and f.type_id: |
| field_decls.append(f"{f.id}: List[{f.type_id}] = field(kw_only=True, default_factory=list)") |
| |
| if packet.parent_id: |
| parent_name = packet.parent_id |
| parent_fields = 'fields: dict, ' |
| serializer = generate_derived_packet_serializer(packet) |
| else: |
| parent_name = 'Packet' |
| parent_fields = '' |
| serializer = generate_toplevel_packet_serializer(packet) |
| |
| parser = generate_packet_parser(packet) |
| size = generate_packet_size_getter(packet) |
| post_init = generate_packet_post_init(packet) |
| |
| return dedent("""\ |
| |
| @dataclass |
| class {packet_name}({parent_name}): |
| {field_decls} |
| |
| def __post_init__(self): |
| {post_init} |
| |
| @staticmethod |
| def parse({parent_fields}span: bytes) -> Tuple['{packet_name}', bytes]: |
| {parser} |
| |
| def serialize(self, payload: bytes = None) -> bytes: |
| {serializer} |
| |
| @property |
| def size(self) -> int: |
| {size} |
| """).format(packet_name=packet_name, |
| parent_name=parent_name, |
| parent_fields=parent_fields, |
| field_decls=indent(field_decls, 1), |
| post_init=indent(post_init, 2), |
| parser=indent(parser, 2), |
| serializer=indent(serializer, 2), |
| size=indent(size, 2)) |
| |
| |
| def generate_custom_field_declaration_check(decl: ast.CustomFieldDeclaration) -> str: |
| """Generate the code to validate a user custom field implementation. |
| |
| This code is to be executed when the generated module is loaded to ensure |
| the user gets an immediate and clear error message when the provided |
| custom types do not fit the expected template. |
| """ |
| return dedent("""\ |
| |
| if (not callable(getattr({custom_field_name}, 'parse', None)) or |
| not callable(getattr({custom_field_name}, 'parse_all', None))): |
| raise Exception('The custom field type {custom_field_name} does not implement the parse method') |
| """).format(custom_field_name=decl.id) |
| |
| |
| def generate_checksum_declaration_check(decl: ast.ChecksumDeclaration) -> str: |
| """Generate the code to validate a user checksum field implementation. |
| |
| This code is to be executed when the generated module is loaded to ensure |
| the user gets an immediate and clear error message when the provided |
| checksum functions do not fit the expected template. |
| """ |
| return dedent("""\ |
| |
| if not callable({checksum_name}): |
| raise Exception('{checksum_name} is not callable') |
| """).format(checksum_name=decl.id) |
| |
| |
| def run(input: argparse.FileType, output: argparse.FileType, custom_type_location: Optional[str], exclude_declaration: List[str]): |
| file = ast.File.from_json(json.load(input)) |
| core.desugar(file) |
| |
| custom_types = [] |
| custom_type_checks = "" |
| for d in file.declarations: |
| if d.id in exclude_declaration: |
| continue |
| |
| if isinstance(d, ast.CustomFieldDeclaration): |
| custom_types.append(d.id) |
| custom_type_checks += generate_custom_field_declaration_check(d) |
| elif isinstance(d, ast.ChecksumDeclaration): |
| custom_types.append(d.id) |
| custom_type_checks += generate_checksum_declaration_check(d) |
| |
| output.write(f"# File generated from {input.name}, with the command:\n") |
| output.write(f"# {' '.join(sys.argv)}\n") |
| output.write("# /!\\ Do not edit by hand.\n") |
| if custom_types and custom_type_location: |
| output.write(f"\nfrom {custom_type_location} import {', '.join(custom_types)}\n") |
| output.write(generate_prelude()) |
| output.write(custom_type_checks) |
| |
| for d in file.declarations: |
| if d.id in exclude_declaration: |
| continue |
| |
| if isinstance(d, ast.EnumDeclaration): |
| output.write(generate_enum_declaration(d)) |
| elif isinstance(d, (ast.PacketDeclaration, ast.StructDeclaration)): |
| output.write(generate_packet_declaration(d)) |
| |
| |
| def main() -> int: |
| """Generate python PDL backend.""" |
| parser = argparse.ArgumentParser(description=__doc__) |
| parser.add_argument('--input', type=argparse.FileType('r'), default=sys.stdin, help='Input PDL-JSON source') |
| parser.add_argument('--output', type=argparse.FileType('w'), default=sys.stdout, help='Output Python file') |
| parser.add_argument('--custom-type-location', |
| type=str, |
| required=False, |
| help='Module of declaration of custom types') |
| parser.add_argument('--exclude-declaration', |
| type=str, |
| default=[], |
| action='append', |
| help='Exclude declaration from the generated output') |
| return run(**vars(parser.parse_args())) |
| |
| |
| if __name__ == '__main__': |
| sys.exit(main()) |