add support for field arrays in hci packet definitions
diff --git a/bumble/hci.py b/bumble/hci.py
index cba207d..0dbb127 100644
--- a/bumble/hci.py
+++ b/bumble/hci.py
@@ -1445,8 +1445,14 @@
@staticmethod
def init_from_fields(hci_object, fields, values):
if isinstance(values, dict):
- for field_name, _ in fields:
- setattr(hci_object, field_name, values[field_name])
+ for field in fields:
+ if isinstance(field, list):
+ # The field is an array, up-level the array field names
+ for sub_field_name, _ in field:
+ setattr(hci_object, sub_field_name, values[sub_field_name])
+ else:
+ field_name = field[0]
+ setattr(hci_object, field_name, values[field_name])
else:
for field_name, field_value in zip(fields, values):
setattr(hci_object, field_name, field_value)
@@ -1457,132 +1463,160 @@
HCI_Object.init_from_fields(hci_object, parsed.keys(), parsed.values())
@staticmethod
+ def parse_field(data, offset, field_type):
+ # The field_type may be a dictionary with a mapper, parser, and/or size
+ if isinstance(field_type, dict):
+ if 'size' in field_type:
+ field_type = field_type['size']
+ elif 'parser' in field_type:
+ field_type = field_type['parser']
+
+ # Parse the field
+ if field_type == '*':
+ # The rest of the bytes
+ field_value = data[offset:]
+ return (field_value, len(field_value))
+ if field_type == 1:
+ # 8-bit unsigned
+ return (data[offset], 1)
+ if field_type == -1:
+ # 8-bit signed
+ return (struct.unpack_from('b', data, offset)[0], 1)
+ if field_type == 2:
+ # 16-bit unsigned
+ return (struct.unpack_from('<H', data, offset)[0], 2)
+ if field_type == '>2':
+ # 16-bit unsigned big-endian
+ return (struct.unpack_from('>H', data, offset)[0], 2)
+ if field_type == -2:
+ # 16-bit signed
+ return (struct.unpack_from('<h', data, offset)[0], 2)
+ if field_type == 3:
+ # 24-bit unsigned
+ padded = data[offset : offset + 3] + bytes([0])
+ return (struct.unpack('<I', padded)[0], 3)
+ if field_type == 4:
+ # 32-bit unsigned
+ return (struct.unpack_from('<I', data, offset)[0], 4)
+ if field_type == '>4':
+ # 32-bit unsigned big-endian
+ return (struct.unpack_from('>I', data, offset)[0], 4)
+ if isinstance(field_type, int) and 4 < field_type <= 256:
+ # Byte array (from 5 up to 256 bytes)
+ return (data[offset : offset + field_type], field_type)
+ if callable(field_type):
+ new_offset, field_value = field_type(data, offset)
+ return (field_value, new_offset - offset)
+
+ raise ValueError(f'unknown field type {field_type}')
+
+ @staticmethod
def dict_from_bytes(data, offset, fields):
result = collections.OrderedDict()
- for (field_name, field_type) in fields:
- # The field_type may be a dictionary with a mapper, parser, and/or size
- if isinstance(field_type, dict):
- if 'size' in field_type:
- field_type = field_type['size']
- elif 'parser' in field_type:
- field_type = field_type['parser']
-
- # Parse the field
- if field_type == '*':
- # The rest of the bytes
- field_value = data[offset:]
- offset += len(field_value)
- elif field_type == 1:
- # 8-bit unsigned
- field_value = data[offset]
+ for field in fields:
+ if isinstance(field, list):
+ # This is an array field, starting with a 1-byte item count.
+ item_count = data[offset]
offset += 1
- elif field_type == -1:
- # 8-bit signed
- field_value = struct.unpack_from('b', data, offset)[0]
- offset += 1
- elif field_type == 2:
- # 16-bit unsigned
- field_value = struct.unpack_from('<H', data, offset)[0]
- offset += 2
- elif field_type == '>2':
- # 16-bit unsigned big-endian
- field_value = struct.unpack_from('>H', data, offset)[0]
- offset += 2
- elif field_type == -2:
- # 16-bit signed
- field_value = struct.unpack_from('<h', data, offset)[0]
- offset += 2
- elif field_type == 3:
- # 24-bit unsigned
- padded = data[offset : offset + 3] + bytes([0])
- field_value = struct.unpack('<I', padded)[0]
- offset += 3
- elif field_type == 4:
- # 32-bit unsigned
- field_value = struct.unpack_from('<I', data, offset)[0]
- offset += 4
- elif field_type == '>4':
- # 32-bit unsigned big-endian
- field_value = struct.unpack_from('>I', data, offset)[0]
- offset += 4
- elif isinstance(field_type, int) and 4 < field_type <= 256:
- # Byte array (from 5 up to 256 bytes)
- field_value = data[offset : offset + field_type]
- offset += field_type
- elif callable(field_type):
- offset, field_value = field_type(data, offset)
- else:
- raise ValueError(f'unknown field type {field_type}')
+ for _ in range(item_count):
+ for sub_field_name, sub_field_type in field:
+ value, size = HCI_Object.parse_field(
+ data, offset, sub_field_type
+ )
+ result.setdefault(sub_field_name, []).append(value)
+ offset += size
+ continue
+ field_name, field_type = field
+ field_value, field_size = HCI_Object.parse_field(data, offset, field_type)
result[field_name] = field_value
+ offset += field_size
return result
@staticmethod
+ def serialize_field(field_value, field_type):
+ # The field_type may be a dictionary with a mapper, parser, serializer,
+ # and/or size
+ serializer = None
+ if isinstance(field_type, dict):
+ if 'serializer' in field_type:
+ serializer = field_type['serializer']
+ if 'size' in field_type:
+ field_type = field_type['size']
+
+ # Serialize the field
+ if serializer:
+ field_bytes = serializer(field_value)
+ elif field_type == 1:
+ # 8-bit unsigned
+ field_bytes = bytes([field_value])
+ elif field_type == -1:
+ # 8-bit signed
+ field_bytes = struct.pack('b', field_value)
+ elif field_type == 2:
+ # 16-bit unsigned
+ field_bytes = struct.pack('<H', field_value)
+ elif field_type == '>2':
+ # 16-bit unsigned big-endian
+ field_bytes = struct.pack('>H', field_value)
+ elif field_type == -2:
+ # 16-bit signed
+ field_bytes = struct.pack('<h', field_value)
+ elif field_type == 3:
+ # 24-bit unsigned
+ field_bytes = struct.pack('<I', field_value)[0:3]
+ elif field_type == 4:
+ # 32-bit unsigned
+ field_bytes = struct.pack('<I', field_value)
+ elif field_type == '>4':
+ # 32-bit unsigned big-endian
+ field_bytes = struct.pack('>I', field_value)
+ elif field_type == '*':
+ if isinstance(field_value, int):
+ if 0 <= field_value <= 255:
+ field_bytes = bytes([field_value])
+ else:
+ raise ValueError('value too large for *-typed field')
+ else:
+ field_bytes = bytes(field_value)
+ elif isinstance(field_value, (bytes, bytearray)) or hasattr(
+ field_value, 'to_bytes'
+ ):
+ field_bytes = bytes(field_value)
+ if isinstance(field_type, int) and 4 < field_type <= 256:
+ # Truncate or pad with zeros if the field is too long or too short
+ if len(field_bytes) < field_type:
+ field_bytes += bytes(field_type - len(field_bytes))
+ elif len(field_bytes) > field_type:
+ field_bytes = field_bytes[:field_type]
+ else:
+ raise ValueError(f"don't know how to serialize type {type(field_value)}")
+
+ return field_bytes
+
+ @staticmethod
def dict_to_bytes(hci_object, fields):
result = bytearray()
- for (field_name, field_type) in fields:
- # The field_type may be a dictionary with a mapper, parser, serializer,
- # and/or size
- serializer = None
- if isinstance(field_type, dict):
- if 'serializer' in field_type:
- serializer = field_type['serializer']
- if 'size' in field_type:
- field_type = field_type['size']
-
- # Serialize the field
- field_value = hci_object[field_name]
- if serializer:
- field_bytes = serializer(field_value)
- elif field_type == 1:
- # 8-bit unsigned
- field_bytes = bytes([field_value])
- elif field_type == -1:
- # 8-bit signed
- field_bytes = struct.pack('b', field_value)
- elif field_type == 2:
- # 16-bit unsigned
- field_bytes = struct.pack('<H', field_value)
- elif field_type == '>2':
- # 16-bit unsigned big-endian
- field_bytes = struct.pack('>H', field_value)
- elif field_type == -2:
- # 16-bit signed
- field_bytes = struct.pack('<h', field_value)
- elif field_type == 3:
- # 24-bit unsigned
- field_bytes = struct.pack('<I', field_value)[0:3]
- elif field_type == 4:
- # 32-bit unsigned
- field_bytes = struct.pack('<I', field_value)
- elif field_type == '>4':
- # 32-bit unsigned big-endian
- field_bytes = struct.pack('>I', field_value)
- elif field_type == '*':
- if isinstance(field_value, int):
- if 0 <= field_value <= 255:
- field_bytes = bytes([field_value])
- else:
- raise ValueError('value too large for *-typed field')
- else:
- field_bytes = bytes(field_value)
- elif isinstance(field_value, (bytes, bytearray)) or hasattr(
- field_value, 'to_bytes'
- ):
- field_bytes = bytes(field_value)
- if isinstance(field_type, int) and 4 < field_type <= 256:
- # Truncate or Pad with zeros if the field is too long or too short
- if len(field_bytes) < field_type:
- field_bytes += bytes(field_type - len(field_bytes))
- elif len(field_bytes) > field_type:
- field_bytes = field_bytes[:field_type]
- else:
- raise ValueError(
- f"don't know how to serialize type {type(field_value)}"
+ for field in fields:
+ if isinstance(field, list):
+ # The field is an array. The serialized form starts with a 1-byte
+ # item count. We use the length of the first array field as the
+ # array count, since all array fields have the same number of items.
+ item_count = len(hci_object[field[0][0]])
+ result += bytes([item_count]) + b''.join(
+ b''.join(
+ HCI_Object.serialize_field(
+ hci_object[sub_field_name][i], sub_field_type
+ )
+ for sub_field_name, sub_field_type in field
+ )
+ for i in range(item_count)
)
+ continue
- result += field_bytes
+ (field_name, field_type) = field
+ result += HCI_Object.serialize_field(hci_object[field_name], field_type)
return bytes(result)
@@ -1617,48 +1651,73 @@
return str(value)
@staticmethod
- def format_fields(hci_object, keys, indentation='', value_mappers=None):
- if not keys:
- return ''
+ def stringify_field(
+ field_name, field_type, field_value, indentation, value_mappers
+ ):
+ value_mapper = None
+ if isinstance(field_type, dict):
+ # Get the value mapper from the specifier
+ value_mapper = field_type.get('mapper')
- # Measure the widest field name
- max_field_name_length = max(
- (len(key[0] if isinstance(key, tuple) else key) for key in keys)
+ # Check if there's a matching mapper passed
+ if value_mappers:
+ value_mapper = value_mappers.get(field_name, value_mapper)
+
+ # Map the value if we have a mapper
+ if value_mapper is not None:
+ field_value = value_mapper(field_value)
+
+ # Get the string representation of the value
+ return HCI_Object.format_field_value(
+ field_value, indentation=indentation + ' '
)
+ @staticmethod
+ def format_fields(hci_object, fields, indentation='', value_mappers=None):
+ if not fields:
+ return ''
+
# Build array of formatted key:value pairs
- fields = []
- for key in keys:
- value_mapper = None
- if isinstance(key, tuple):
- # The key has an associated specifier
- key, specifier = key
+ field_strings = []
+ for field in fields:
+ if isinstance(field, list):
+ for sub_field in field:
+ sub_field_name, sub_field_type = sub_field
+ item_count = len(hci_object[sub_field_name])
+ for i in range(item_count):
+ field_strings.append(
+ (
+ f'{sub_field_name}[{i}]',
+ HCI_Object.stringify_field(
+ sub_field_name,
+ sub_field_type,
+ hci_object[sub_field_name][i],
+ indentation,
+ value_mappers,
+ ),
+ ),
+ )
+ continue
- # Get the value mapper from the specifier
- if isinstance(specifier, dict):
- value_mapper = specifier.get('mapper')
-
- # Get the value for the field
- value = hci_object[key]
-
- # Check if there's a matching mapper passed
- if value_mappers:
- value_mapper = value_mappers.get(key, value_mapper)
-
- # Map the value if we have a mapper
- if value_mapper is not None:
- value = value_mapper(value)
-
- # Get the string representation of the value
- value_str = HCI_Object.format_field_value(
- value, indentation=indentation + ' '
+ field_name, field_type = field
+ field_value = hci_object[field_name]
+ field_strings.append(
+ (
+ field_name,
+ HCI_Object.stringify_field(
+ field_name, field_type, field_value, indentation, value_mappers
+ ),
+ ),
)
- # Add the field to the formatted result
- key_str = color(f'{key + ":":{1 + max_field_name_length}}', 'cyan')
- fields.append(f'{indentation}{key_str} {value_str}')
-
- return '\n'.join(fields)
+ # Measure the widest field name
+ max_field_name_length = max(len(s[0]) for s in field_strings)
+ sep = ':'
+ return '\n'.join(
+ f'{indentation}'
+ f'{color(f"{field_name + sep:{1 + max_field_name_length}}", "cyan")} {field_value}'
+ for field_name, field_value in field_strings
+ )
def __bytes__(self):
return self.to_bytes()
@@ -3769,9 +3828,7 @@
'advertising_data',
{
'parser': HCI_Object.parse_length_prefixed_bytes,
- 'serializer': functools.partial(
- HCI_Object.serialize_length_prefixed_bytes
- ),
+ 'serializer': HCI_Object.serialize_length_prefixed_bytes,
},
),
]
@@ -3819,9 +3876,7 @@
'scan_response_data',
{
'parser': HCI_Object.parse_length_prefixed_bytes,
- 'serializer': functools.partial(
- HCI_Object.serialize_length_prefixed_bytes
- ),
+ 'serializer': HCI_Object.serialize_length_prefixed_bytes,
},
),
]
@@ -3849,73 +3904,21 @@
# -----------------------------------------------------------------------------
-@HCI_Command.command(fields=None)
+@HCI_Command.command(
+ [
+ ('enable', 1),
+ [
+ ('advertising_handles', 1),
+ ('durations', 2),
+ ('max_extended_advertising_events', 1),
+ ],
+ ]
+)
class HCI_LE_Set_Extended_Advertising_Enable_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.56 LE Set Extended Advertising Enable Command
'''
- @classmethod
- def from_parameters(cls, parameters):
- enable = parameters[0]
- num_sets = parameters[1]
- advertising_handles = []
- durations = []
- max_extended_advertising_events = []
- offset = 2
- for _ in range(num_sets):
- advertising_handles.append(parameters[offset])
- durations.append(struct.unpack_from('<H', parameters, offset + 1)[0])
- max_extended_advertising_events.append(parameters[offset + 3])
- offset += 4
-
- return cls(
- enable, advertising_handles, durations, max_extended_advertising_events
- )
-
- def __init__(
- self, enable, advertising_handles, durations, max_extended_advertising_events
- ):
- super().__init__(HCI_LE_SET_EXTENDED_ADVERTISING_ENABLE_COMMAND)
- self.enable = enable
- self.advertising_handles = advertising_handles
- self.durations = durations
- self.max_extended_advertising_events = max_extended_advertising_events
-
- self.parameters = bytes([enable, len(advertising_handles)]) + b''.join(
- [
- struct.pack(
- '<BHB',
- advertising_handles[i],
- durations[i],
- max_extended_advertising_events[i],
- )
- for i in range(len(advertising_handles))
- ]
- )
-
- def __str__(self):
- fields = [('enable:', self.enable)]
- for i, advertising_handle in enumerate(self.advertising_handles):
- fields.append(
- (f'advertising_handle[{i}]: ', advertising_handle)
- )
- fields.append((f'duration[{i}]: ', self.durations[i]))
- fields.append(
- (
- f'max_extended_advertising_events[{i}]:',
- self.max_extended_advertising_events[i],
- )
- )
-
- return (
- color(self.name, 'green')
- + ':\n'
- + '\n'.join(
- [color(field[0], 'cyan') + ' ' + str(field[1]) for field in fields]
- )
- )
-
# -----------------------------------------------------------------------------
@HCI_Command.command(
@@ -4066,7 +4069,10 @@
color(self.name, 'green')
+ ':\n'
+ '\n'.join(
- [color(field[0], 'cyan') + ' ' + str(field[1]) for field in fields]
+ [
+ color(' ' + field[0], 'cyan') + ' ' + str(field[1])
+ for field in fields
+ ]
)
)
@@ -4242,7 +4248,10 @@
color(self.name, 'green')
+ ':\n'
+ '\n'.join(
- [color(field[0], 'cyan') + ' ' + str(field[1]) for field in fields]
+ [
+ color(' ' + field[0], 'cyan') + ' ' + str(field[1])
+ for field in fields
+ ]
)
)
@@ -5205,7 +5214,7 @@
def __str__(self):
lines = [
color(self.name, 'magenta') + ':',
- color(' number_of_handles: ', 'cyan')
+ color(' number_of_handles: ', 'cyan')
+ f'{len(self.connection_handles)}',
]
for i, connection_handle in enumerate(self.connection_handles):
diff --git a/tests/hci_test.py b/tests/hci_test.py
index af68e86..c648592 100644
--- a/tests/hci_test.py
+++ b/tests/hci_test.py
@@ -46,6 +46,7 @@
HCI_LE_Set_Advertising_Parameters_Command,
HCI_LE_Set_Default_PHY_Command,
HCI_LE_Set_Event_Mask_Command,
+ HCI_LE_Set_Extended_Advertising_Enable_Command,
HCI_LE_Set_Extended_Scan_Parameters_Command,
HCI_LE_Set_Random_Address_Command,
HCI_LE_Set_Scan_Enable_Command,
@@ -423,6 +424,25 @@
# -----------------------------------------------------------------------------
+def test_HCI_LE_Set_Extended_Advertising_Enable_Command():
+ command = HCI_Packet.from_bytes(
+ bytes.fromhex('0139200e010301050008020600090307000a')
+ )
+ assert command.enable == 1
+ assert command.advertising_handles == [1, 2, 3]
+ assert command.durations == [5, 6, 7]
+ assert command.max_extended_advertising_events == [8, 9, 10]
+
+ command = HCI_LE_Set_Extended_Advertising_Enable_Command(
+ enable=1,
+ advertising_handles=[1, 2, 3],
+ durations=[5, 6, 7],
+ max_extended_advertising_events=[8, 9, 10],
+ )
+ basic_check(command)
+
+
+# -----------------------------------------------------------------------------
def test_address():
a = Address('C4:F2:17:1A:1D:BB')
assert not a.is_public
@@ -478,6 +498,7 @@
test_HCI_LE_Read_Remote_Features_Command()
test_HCI_LE_Set_Default_PHY_Command()
test_HCI_LE_Set_Extended_Scan_Parameters_Command()
+ test_HCI_LE_Set_Extended_Advertising_Enable_Command()
# -----------------------------------------------------------------------------