add benchmark tool and doc
diff --git a/apps/bench.py b/apps/bench.py
new file mode 100644
index 0000000..6a7aadd
--- /dev/null
+++ b/apps/bench.py
@@ -0,0 +1,1206 @@
+# Copyright 2021-2023 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# -----------------------------------------------------------------------------
+# Imports
+# -----------------------------------------------------------------------------
+import asyncio
+import enum
+import logging
+import os
+import struct
+import time
+
+import click
+
+from bumble.core import (
+ BT_BR_EDR_TRANSPORT,
+ BT_LE_TRANSPORT,
+ BT_L2CAP_PROTOCOL_ID,
+ BT_RFCOMM_PROTOCOL_ID,
+ UUID,
+ CommandTimeoutError,
+)
+from bumble.colors import color
+from bumble.device import Connection, ConnectionParametersPreferences, Device, Peer
+from bumble.gatt import Characteristic, CharacteristicValue, Service
+from bumble.hci import (
+ HCI_LE_1M_PHY,
+ HCI_LE_2M_PHY,
+ HCI_LE_CODED_PHY,
+ HCI_Constant,
+ HCI_Error,
+ HCI_StatusError,
+)
+from bumble.sdp import (
+ SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
+ SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
+ SDP_PUBLIC_BROWSE_ROOT,
+ SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
+ SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
+ DataElement,
+ ServiceAttribute,
+)
+from bumble.transport import open_transport_or_link
+import bumble.rfcomm
+import bumble.core
+
+
+# -----------------------------------------------------------------------------
+# Logging
+# -----------------------------------------------------------------------------
+logger = logging.getLogger(__name__)
+
+
+# -----------------------------------------------------------------------------
+# Constants
+# -----------------------------------------------------------------------------
+DEFAULT_CENTRAL_ADDRESS = 'F0:F0:F0:F0:F0:F0'
+DEFAULT_CENTRAL_NAME = 'Speed Central'
+DEFAULT_PERIPHERAL_ADDRESS = 'F1:F1:F1:F1:F1:F1'
+DEFAULT_PERIPHERAL_NAME = 'Speed Peripheral'
+
+SPEED_SERVICE_UUID = '50DB505C-8AC4-4738-8448-3B1D9CC09CC5'
+SPEED_TX_UUID = 'E789C754-41A1-45F4-A948-A0A1A90DBA53'
+SPEED_RX_UUID = '016A2CC7-E14B-4819-935F-1F56EAE4098D'
+
+DEFAULT_L2CAP_PSM = 1234
+DEFAULT_L2CAP_MAX_CREDITS = 128
+DEFAULT_L2CAP_MTU = 1022
+DEFAULT_L2CAP_MPS = 1024
+
+DEFAULT_LINGER_TIME = 1.0
+
+DEFAULT_RFCOMM_CHANNEL = 8
+
+# -----------------------------------------------------------------------------
+# Utils
+# -----------------------------------------------------------------------------
+def parse_packet(packet):
+ if len(packet) < 1:
+ print(
+ color(f'!!! Packet too short (got {len(packet)} bytes, need >= 1)', 'red')
+ )
+ raise ValueError('packet too short')
+
+ try:
+ packet_type = PacketType(packet[0])
+ except ValueError:
+ print(color(f'!!! Invalid packet type 0x{packet[0]:02X}', 'red'))
+ raise
+
+ return (packet_type, packet[1:])
+
+
+def parse_packet_sequence(packet_data):
+ if len(packet_data) < 5:
+ print(
+ color(
+ f'!!!Packet too short (got {len(packet_data)} bytes, need >= 5)',
+ 'red',
+ )
+ )
+ raise ValueError('packet too short')
+ return struct.unpack_from('>bI', packet_data, 0)
+
+
+def le_phy_name(phy_id):
+ return {HCI_LE_1M_PHY: '1M', HCI_LE_2M_PHY: '2M', HCI_LE_CODED_PHY: 'CODED'}.get(
+ phy_id, HCI_Constant.le_phy_name(phy_id)
+ )
+
+
+def print_connection(connection):
+ if connection.transport == BT_LE_TRANSPORT:
+ phy_state = (
+ 'PHY='
+ f'RX:{le_phy_name(connection.phy.rx_phy)}/'
+ f'TX:{le_phy_name(connection.phy.tx_phy)}'
+ )
+
+ data_length = f'DL={connection.data_length}'
+ connection_parameters = (
+ 'Parameters='
+ f'{connection.parameters.connection_interval * 1.25:.2f}/'
+ f'{connection.parameters.peripheral_latency}/'
+ f'{connection.parameters.supervision_timeout * 10} '
+ )
+
+ else:
+ phy_state = ''
+ data_length = ''
+ connection_parameters = ''
+
+ mtu = connection.att_mtu
+
+ print(
+ f'{color("@@@ Connection:", "yellow")} '
+ f'{connection_parameters} '
+ f'{data_length} '
+ f'{phy_state} '
+ f'MTU={mtu}'
+ )
+
+
+def make_sdp_records(channel):
+ return {
+ 0x00010001: [
+ ServiceAttribute(
+ SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
+ DataElement.unsigned_integer_32(0x00010001),
+ ),
+ ServiceAttribute(
+ SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
+ DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
+ ),
+ ServiceAttribute(
+ SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
+ DataElement.sequence(
+ [DataElement.uuid(UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'))]
+ ),
+ ),
+ ServiceAttribute(
+ SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
+ DataElement.sequence(
+ [
+ DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
+ DataElement.sequence(
+ [
+ DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
+ DataElement.unsigned_integer_8(channel),
+ ]
+ ),
+ ]
+ ),
+ ),
+ ]
+ }
+
+
+class PacketType(enum.IntEnum):
+ RESET = 0
+ SEQUENCE = 1
+ ACK = 2
+
+
+PACKET_FLAG_LAST = 1
+
+# -----------------------------------------------------------------------------
+# Sender
+# -----------------------------------------------------------------------------
+class Sender:
+ def __init__(self, packet_io, start_delay, packet_size, packet_count):
+ self.tx_start_delay = start_delay
+ self.tx_packet_size = packet_size
+ self.tx_packet_count = packet_count
+ self.packet_io = packet_io
+ self.packet_io.packet_listener = self
+ self.start_time = 0
+ self.bytes_sent = 0
+ self.done = asyncio.Event()
+
+ def reset(self):
+ pass
+
+ async def run(self):
+ print(color('--- Waiting for I/O to be ready...', 'blue'))
+ await self.packet_io.ready.wait()
+ print(color('--- Go!', 'blue'))
+
+ if self.tx_start_delay:
+ print(color(f'*** Startup delay: {self.tx_start_delay}', 'blue'))
+ await asyncio.sleep(self.tx_start_delay) # FIXME
+
+ print(color('=== Sending RESET', 'magenta'))
+ await self.packet_io.send_packet(bytes([PacketType.RESET]))
+ self.start_time = time.time()
+ for tx_i in range(self.tx_packet_count):
+ packet_flags = PACKET_FLAG_LAST if tx_i == self.tx_packet_count - 1 else 0
+ packet = struct.pack(
+ '>bbI',
+ PacketType.SEQUENCE,
+ packet_flags,
+ tx_i,
+ ) + bytes(self.tx_packet_size - 6)
+ print(color(f'Sending packet {tx_i}: {len(packet)} bytes', 'yellow'))
+ self.bytes_sent += len(packet)
+ await self.packet_io.send_packet(packet)
+
+ await self.done.wait()
+ print(color('=== Done!', 'magenta'))
+
+ def on_packet_received(self, packet):
+ try:
+ packet_type, _ = parse_packet(packet)
+ except ValueError:
+ return
+
+ if packet_type == PacketType.ACK:
+ elapsed = time.time() - self.start_time
+ average_tx_speed = self.bytes_sent / elapsed
+ print(
+ color(
+ f'@@@ Received ACK. Speed: average={average_tx_speed:.4f}'
+ f' ({self.bytes_sent} bytes in {elapsed:.2f} seconds)',
+ 'green',
+ )
+ )
+ self.done.set()
+
+
+# -----------------------------------------------------------------------------
+# Receiver
+# -----------------------------------------------------------------------------
+class Receiver:
+ def __init__(self, packet_io):
+ self.reset()
+ self.packet_io = packet_io
+ self.packet_io.packet_listener = self
+ self.done = asyncio.Event()
+
+ def reset(self):
+ self.expected_packet_index = 0
+ self.start_timestamp = 0.0
+ self.last_timestamp = 0.0
+ self.bytes_received = 0
+
+ def on_packet_received(self, packet):
+ try:
+ packet_type, packet_data = parse_packet(packet)
+ except ValueError:
+ return
+
+ now = time.time()
+
+ if packet_type == PacketType.RESET:
+ print(color('=== Received RESET', 'magenta'))
+ self.reset()
+ self.start_timestamp = now
+ return
+
+ try:
+ packet_flags, packet_index = parse_packet_sequence(packet_data)
+ except ValueError:
+ return
+ print(
+ f'<<< Received packet {packet_index}: '
+ f'flags=0x{packet_flags:02X}, {len(packet)} bytes'
+ )
+
+ if packet_index != self.expected_packet_index:
+ print(
+ color(
+ f'!!! Unexpected packet, expected {self.expected_packet_index} '
+ f'but received {packet_index}'
+ )
+ )
+
+ elapsed_since_start = now - self.start_timestamp
+ elapsed_since_last = now - self.last_timestamp
+ self.bytes_received += len(packet)
+ instant_rx_speed = len(packet) / elapsed_since_last
+ average_rx_speed = self.bytes_received / elapsed_since_start
+ print(
+ color(
+ f'Speed: instant={instant_rx_speed:.4f}, average={average_rx_speed:.4f}',
+ 'yellow',
+ )
+ )
+
+ self.last_timestamp = now
+ self.expected_packet_index = packet_index + 1
+
+ if packet_flags & PACKET_FLAG_LAST:
+ asyncio.create_task(
+ self.packet_io.send_packet(
+ struct.pack('>bbI', PacketType.ACK, packet_flags, packet_index)
+ )
+ )
+ print(color('@@@ Received last packet', 'green'))
+ self.done.set()
+
+ async def run(self):
+ await self.done.wait()
+ print(color('=== Done!', 'magenta'))
+
+
+# -----------------------------------------------------------------------------
+# Ping
+# -----------------------------------------------------------------------------
+class Ping:
+ def __init__(self, packet_io, start_delay, packet_size, packet_count):
+ self.tx_start_delay = start_delay
+ self.tx_packet_size = packet_size
+ self.tx_packet_count = packet_count
+ self.packet_io = packet_io
+ self.packet_io.packet_listener = self
+ self.done = asyncio.Event()
+ self.current_packet_index = 0
+ self.ping_sent_time = 0.0
+ self.latencies = []
+
+ def reset(self):
+ pass
+
+ async def run(self):
+ print(color('--- Waiting for I/O to be ready...', 'blue'))
+ await self.packet_io.ready.wait()
+ print(color('--- Go!', 'blue'))
+
+ if self.tx_start_delay:
+ print(color(f'*** Startup delay: {self.tx_start_delay}', 'blue'))
+ await asyncio.sleep(self.tx_start_delay) # FIXME
+
+ print(color('=== Sending RESET', 'magenta'))
+ await self.packet_io.send_packet(bytes([PacketType.RESET]))
+
+ await self.send_next_ping()
+
+ await self.done.wait()
+ average_latency = sum(self.latencies) / len(self.latencies)
+ print(color(f'@@@ Average latency: {average_latency:.2f}'))
+ print(color('=== Done!', 'magenta'))
+
+ async def send_next_ping(self):
+ packet = struct.pack(
+ '>bbI',
+ PacketType.SEQUENCE,
+ PACKET_FLAG_LAST
+ if self.current_packet_index == self.tx_packet_count - 1
+ else 0,
+ self.current_packet_index,
+ ) + bytes(self.tx_packet_size - 6)
+ print(color(f'Sending packet {self.current_packet_index}', 'yellow'))
+ self.ping_sent_time = time.time()
+ await self.packet_io.send_packet(packet)
+
+ def on_packet_received(self, packet):
+ elapsed = time.time() - self.ping_sent_time
+
+ try:
+ packet_type, packet_data = parse_packet(packet)
+ except ValueError:
+ return
+
+ try:
+ packet_flags, packet_index = parse_packet_sequence(packet_data)
+ except ValueError:
+ return
+
+ if packet_type == PacketType.ACK:
+ latency = elapsed * 1000
+ self.latencies.append(latency)
+ print(
+ color(
+ f'@@@ Received ACK [{packet_index}], latency={latency:.2f}ms',
+ 'green',
+ )
+ )
+
+ if packet_index == self.current_packet_index:
+ self.current_packet_index += 1
+ else:
+ print(
+ color(
+ f'!!! Unexpected packet, expected {self.current_packet_index} '
+ f'but received {packet_index}'
+ )
+ )
+
+ if packet_flags & PACKET_FLAG_LAST:
+ self.done.set()
+ return
+
+ asyncio.create_task(self.send_next_ping())
+
+
+# -----------------------------------------------------------------------------
+# Pong
+# -----------------------------------------------------------------------------
+class Pong:
+ def __init__(self, packet_io):
+ self.reset()
+ self.packet_io = packet_io
+ self.packet_io.packet_listener = self
+ self.done = asyncio.Event()
+
+ def reset(self):
+ self.expected_packet_index = 0
+
+ def on_packet_received(self, packet):
+ try:
+ packet_type, packet_data = parse_packet(packet)
+ except ValueError:
+ return
+
+ if packet_type == PacketType.RESET:
+ print(color('=== Received RESET', 'magenta'))
+ self.reset()
+ return
+
+ try:
+ packet_flags, packet_index = parse_packet_sequence(packet_data)
+ except ValueError:
+ return
+ print(
+ color(
+ f'<<< Received packet {packet_index}: '
+ f'flags=0x{packet_flags:02X}, {len(packet)} bytes',
+ 'green',
+ )
+ )
+
+ if packet_index != self.expected_packet_index:
+ print(
+ color(
+ f'!!! Unexpected packet, expected {self.expected_packet_index} '
+ f'but received {packet_index}'
+ )
+ )
+
+ self.expected_packet_index = packet_index + 1
+
+ asyncio.create_task(
+ self.packet_io.send_packet(
+ struct.pack('>bbI', PacketType.ACK, packet_flags, packet_index)
+ )
+ )
+
+ if packet_flags & PACKET_FLAG_LAST:
+ self.done.set()
+
+ async def run(self):
+ await self.done.wait()
+ print(color('=== Done!', 'magenta'))
+
+
+# -----------------------------------------------------------------------------
+# GattClient
+# -----------------------------------------------------------------------------
+class GattClient:
+ def __init__(self, _device, att_mtu=None):
+ self.att_mtu = att_mtu
+ self.speed_rx = None
+ self.speed_tx = None
+ self.packet_listener = None
+ self.ready = asyncio.Event()
+
+ async def on_connection(self, connection):
+ peer = Peer(connection)
+
+ if self.att_mtu:
+ print(color(f'*** Requesting MTU update: {self.att_mtu}', 'blue'))
+ await peer.request_mtu(self.att_mtu)
+
+ print(color('*** Discovering services...', 'blue'))
+ await peer.discover_services()
+
+ speed_services = peer.get_services_by_uuid(SPEED_SERVICE_UUID)
+ if not speed_services:
+ print(color('!!! Speed Service not found', 'red'))
+ return
+ speed_service = speed_services[0]
+ print(color('*** Discovering characteristics...', 'blue'))
+ await speed_service.discover_characteristics()
+
+ speed_txs = speed_service.get_characteristics_by_uuid(SPEED_TX_UUID)
+ if not speed_txs:
+ print(color('!!! Speed TX not found', 'red'))
+ return
+ self.speed_tx = speed_txs[0]
+
+ speed_rxs = speed_service.get_characteristics_by_uuid(SPEED_RX_UUID)
+ if not speed_rxs:
+ print(color('!!! Speed RX not found', 'red'))
+ return
+ self.speed_rx = speed_rxs[0]
+
+ print(color('*** Subscribing to RX', 'blue'))
+ await self.speed_rx.subscribe(self.on_packet_received)
+
+ print(color('*** Discovery complete', 'blue'))
+
+ connection.on('disconnection', self.on_disconnection)
+ self.ready.set()
+
+ def on_disconnection(self, _):
+ self.ready.clear()
+
+ def on_packet_received(self, packet):
+ if self.packet_listener:
+ self.packet_listener.on_packet_received(packet)
+
+ async def send_packet(self, packet):
+ await self.speed_tx.write_value(packet)
+
+
+# -----------------------------------------------------------------------------
+# GattServer
+# -----------------------------------------------------------------------------
+class GattServer:
+ def __init__(self, device):
+ self.device = device
+ self.packet_listener = None
+ self.ready = asyncio.Event()
+
+ # Setup the GATT service
+ self.speed_tx = Characteristic(
+ SPEED_TX_UUID,
+ Characteristic.WRITE,
+ Characteristic.WRITEABLE,
+ CharacteristicValue(write=self.on_tx_write),
+ )
+ self.speed_rx = Characteristic(SPEED_RX_UUID, Characteristic.NOTIFY, 0)
+
+ speed_service = Service(
+ SPEED_SERVICE_UUID,
+ [self.speed_tx, self.speed_rx],
+ )
+ device.add_services([speed_service])
+
+ self.speed_rx.on('subscription', self.on_rx_subscription)
+
+ async def on_connection(self, connection):
+ connection.on('disconnection', self.on_disconnection)
+
+ def on_disconnection(self, _):
+ self.ready.clear()
+
+ def on_rx_subscription(self, _connection, notify_enabled, _indicate_enabled):
+ if notify_enabled:
+ print(color('*** RX subscription', 'blue'))
+ self.ready.set()
+ else:
+ print(color('*** RX un-subscription', 'blue'))
+ self.ready.clear()
+
+ def on_tx_write(self, _, value):
+ if self.packet_listener:
+ self.packet_listener.on_packet_received(value)
+
+ async def send_packet(self, packet):
+ await self.device.notify_subscribers(self.speed_rx, packet)
+
+
+# -----------------------------------------------------------------------------
+# StreamedPacketIO
+# -----------------------------------------------------------------------------
+class StreamedPacketIO:
+ def __init__(self):
+ self.packet_listener = None
+ self.io_sink = None
+ self.rx_packet = b''
+ self.rx_packet_header = b''
+ self.rx_packet_need = 0
+
+ def on_packet(self, packet):
+ while packet:
+ if self.rx_packet_need:
+ chunk = packet[: self.rx_packet_need]
+ self.rx_packet += chunk
+ packet = packet[len(chunk) :]
+ self.rx_packet_need -= len(chunk)
+ if not self.rx_packet_need:
+ # Packet completed
+ if self.packet_listener:
+ self.packet_listener.on_packet_received(self.rx_packet)
+
+ self.rx_packet = b''
+ self.rx_packet_header = b''
+ else:
+ # Expect the next packet
+ header_bytes_needed = 2 - len(self.rx_packet_header)
+ header_bytes = packet[:header_bytes_needed]
+ self.rx_packet_header += header_bytes
+ if len(self.rx_packet_header) != 2:
+ return
+ packet = packet[len(header_bytes) :]
+ self.rx_packet_need = struct.unpack('>H', self.rx_packet_header)[0]
+
+ async def send_packet(self, packet):
+ if not self.io_sink:
+ print(color('!!! No sink, dropping packet', 'red'))
+ return
+
+ # pylint: disable-next=not-callable
+ self.io_sink(struct.pack('>H', len(packet)) + packet)
+
+
+# -----------------------------------------------------------------------------
+# L2capClient
+# -----------------------------------------------------------------------------
+class L2capClient(StreamedPacketIO):
+ def __init__(
+ self,
+ _device,
+ psm=DEFAULT_L2CAP_PSM,
+ max_credits=DEFAULT_L2CAP_MAX_CREDITS,
+ mtu=DEFAULT_L2CAP_MTU,
+ mps=DEFAULT_L2CAP_MPS,
+ ):
+ super().__init__()
+ self.psm = psm
+ self.max_credits = max_credits
+ self.mtu = mtu
+ self.mps = mps
+ self.ready = asyncio.Event()
+
+ async def on_connection(self, connection):
+ connection.on('disconnection', self.on_disconnection)
+
+ # Connect a new L2CAP channel
+ print(color(f'>>> Opening L2CAP channel on PSM = {self.psm}', 'yellow'))
+ try:
+ l2cap_channel = await connection.open_l2cap_channel(
+ psm=self.psm,
+ max_credits=self.max_credits,
+ mtu=self.mtu,
+ mps=self.mps,
+ )
+ print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
+ except Exception as error:
+ print(color(f'!!! Connection failed: {error}', 'red'))
+ return
+
+ l2cap_channel.sink = self.on_packet
+ l2cap_channel.on('close', self.on_l2cap_close)
+ self.io_sink = l2cap_channel.write
+
+ self.ready.set()
+
+ def on_disconnection(self, _):
+ pass
+
+ def on_l2cap_close(self):
+ print(color('*** L2CAP channel closed', 'red'))
+
+
+# -----------------------------------------------------------------------------
+# L2capServer
+# -----------------------------------------------------------------------------
+class L2capServer(StreamedPacketIO):
+ def __init__(
+ self,
+ device,
+ psm=DEFAULT_L2CAP_PSM,
+ max_credits=DEFAULT_L2CAP_MAX_CREDITS,
+ mtu=DEFAULT_L2CAP_MTU,
+ mps=DEFAULT_L2CAP_MPS,
+ ):
+ super().__init__()
+ self.l2cap_channel = None
+ self.ready = asyncio.Event()
+
+ # Listen for incoming L2CAP CoC connections
+ device.register_l2cap_channel_server(
+ psm=psm,
+ server=self.on_l2cap_channel,
+ max_credits=max_credits,
+ mtu=mtu,
+ mps=mps,
+ )
+ print(color(f'### Listening for CoC connection on PSM {psm}', 'yellow'))
+
+ async def on_connection(self, connection):
+ connection.on('disconnection', self.on_disconnection)
+
+ def on_disconnection(self, _):
+ pass
+
+ def on_l2cap_channel(self, l2cap_channel):
+ print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
+
+ self.io_sink = l2cap_channel.write
+ l2cap_channel.on('close', self.on_l2cap_close)
+ l2cap_channel.sink = self.on_packet
+
+ self.ready.set()
+
+ def on_l2cap_close(self):
+ print(color('*** L2CAP channel closed', 'red'))
+ self.l2cap_channel = None
+
+
+# -----------------------------------------------------------------------------
+# RfcommClient
+# -----------------------------------------------------------------------------
+class RfcommClient(StreamedPacketIO):
+ def __init__(self, device):
+ super().__init__()
+ self.device = device
+ self.ready = asyncio.Event()
+
+ async def on_connection(self, connection):
+ connection.on('disconnection', self.on_disconnection)
+
+ # Create a client and start it
+ print(color('*** Starting RFCOMM client...', 'blue'))
+ rfcomm_client = bumble.rfcomm.Client(self.device, connection)
+ rfcomm_mux = await rfcomm_client.start()
+ print(color('*** Started', 'blue'))
+
+ channel = DEFAULT_RFCOMM_CHANNEL
+ print(color(f'### Opening session for channel {channel}...', 'yellow'))
+ try:
+ rfcomm_session = await rfcomm_mux.open_dlc(channel)
+ print(color('### Session open', 'yellow'), rfcomm_session)
+ except bumble.core.ConnectionError as error:
+ print(color(f'!!! Session open failed: {error}', 'red'))
+ await rfcomm_mux.disconnect()
+ return
+
+ rfcomm_session.sink = self.on_packet
+ self.io_sink = rfcomm_session.write
+
+ self.ready.set()
+
+ def on_disconnection(self, _):
+ pass
+
+
+# -----------------------------------------------------------------------------
+# RfcommServer
+# -----------------------------------------------------------------------------
+class RfcommServer(StreamedPacketIO):
+ def __init__(self, device):
+ super().__init__()
+ self.ready = asyncio.Event()
+
+ # Create and register a server
+ rfcomm_server = bumble.rfcomm.Server(device)
+
+ # Listen for incoming DLC connections
+ channel_number = rfcomm_server.listen(self.on_dlc, DEFAULT_RFCOMM_CHANNEL)
+
+ # Setup the SDP to advertise this channel
+ device.sdp_service_records = make_sdp_records(channel_number)
+
+ print(
+ color(
+ f'### Listening for RFComm connection on channel {channel_number}',
+ 'yellow',
+ )
+ )
+
+ async def on_connection(self, connection):
+ connection.on('disconnection', self.on_disconnection)
+
+ def on_disconnection(self, _):
+ pass
+
+ def on_dlc(self, dlc):
+ print(color('*** DLC connected:', 'blue'), dlc)
+ dlc.sink = self.on_packet
+ self.io_sink = dlc.write
+
+
+# -----------------------------------------------------------------------------
+# Central
+# -----------------------------------------------------------------------------
+class Central(Connection.Listener):
+ def __init__(
+ self,
+ transport,
+ peripheral_address,
+ classic,
+ role_factory,
+ mode_factory,
+ connection_interval,
+ phy,
+ ):
+ super().__init__()
+ self.transport = transport
+ self.peripheral_address = peripheral_address
+ self.classic = classic
+ self.role_factory = role_factory
+ self.mode_factory = mode_factory
+ self.device = None
+ self.connection = None
+
+ if phy:
+ self.phy = {
+ '1m': HCI_LE_1M_PHY,
+ '2m': HCI_LE_2M_PHY,
+ 'coded': HCI_LE_CODED_PHY,
+ }[phy]
+ else:
+ self.phy = None
+
+ if connection_interval:
+ connection_parameter_preferences = ConnectionParametersPreferences()
+ connection_parameter_preferences.connection_interval_min = (
+ connection_interval
+ )
+ connection_parameter_preferences.connection_interval_max = (
+ connection_interval
+ )
+
+ # Preferences for the 1M PHY are always set.
+ self.connection_parameter_preferences = {
+ HCI_LE_1M_PHY: connection_parameter_preferences,
+ }
+
+ if self.phy not in (None, HCI_LE_1M_PHY):
+ # Add an connections parameters entry for this PHY.
+ self.connection_parameter_preferences[
+ self.phy
+ ] = connection_parameter_preferences
+ else:
+ self.connection_parameter_preferences = None
+
+ async def run(self):
+ print(color('>>> Connecting to HCI...', 'green'))
+ async with await open_transport_or_link(self.transport) as (
+ hci_source,
+ hci_sink,
+ ):
+ print(color('>>> Connected', 'green'))
+
+ central_address = DEFAULT_CENTRAL_ADDRESS
+ self.device = Device.with_hci(
+ DEFAULT_CENTRAL_NAME, central_address, hci_source, hci_sink
+ )
+ mode = self.mode_factory(self.device)
+ role = self.role_factory(mode)
+ self.device.classic_enabled = self.classic
+
+ await self.device.power_on()
+
+ print(color(f'### Connecting to {self.peripheral_address}...', 'cyan'))
+ try:
+ self.connection = await self.device.connect(
+ self.peripheral_address,
+ connection_parameters_preferences=self.connection_parameter_preferences,
+ transport=BT_BR_EDR_TRANSPORT if self.classic else BT_LE_TRANSPORT,
+ )
+ except CommandTimeoutError:
+ print(color('!!! Connection timed out', 'red'))
+ return
+ except bumble.core.ConnectionError as error:
+ print(color(f'!!! Connection error: {error}', 'red'))
+ return
+ except HCI_StatusError as error:
+ print(color(f'!!! Connection failed: {error.error_name}'))
+ return
+ print(color('### Connected', 'cyan'))
+ self.connection.listener = self
+ print_connection(self.connection)
+
+ await mode.on_connection(self.connection)
+
+ # Set the PHY if requested
+ if self.phy is not None:
+ try:
+ await self.connection.set_phy(
+ tx_phys=[self.phy], rx_phys=[self.phy]
+ )
+ except HCI_Error as error:
+ print(
+ color(
+ f'!!! Unable to set the PHY: {error.error_name}', 'yellow'
+ )
+ )
+
+ await role.run()
+ await asyncio.sleep(DEFAULT_LINGER_TIME)
+
+ def on_disconnection(self, reason):
+ print(color(f'!!! Disconnection: reason={reason}', 'red'))
+ self.connection = None
+
+ def on_connection_parameters_update(self):
+ print_connection(self.connection)
+
+ def on_connection_phy_update(self):
+ print_connection(self.connection)
+
+ def on_connection_att_mtu_update(self):
+ print_connection(self.connection)
+
+ def on_connection_data_length_change(self):
+ print_connection(self.connection)
+
+
+# -----------------------------------------------------------------------------
+# Peripheral
+# -----------------------------------------------------------------------------
+class Peripheral(Device.Listener, Connection.Listener):
+ def __init__(self, transport, classic, role_factory, mode_factory):
+ self.transport = transport
+ self.classic = classic
+ self.role_factory = role_factory
+ self.role = None
+ self.mode_factory = mode_factory
+ self.mode = None
+ self.device = None
+ self.connection = None
+ self.connected = asyncio.Event()
+
+ async def run(self):
+ print(color('>>> Connecting to HCI...', 'green'))
+ async with await open_transport_or_link(self.transport) as (
+ hci_source,
+ hci_sink,
+ ):
+ print(color('>>> Connected', 'green'))
+
+ peripheral_address = DEFAULT_PERIPHERAL_ADDRESS
+ self.device = Device.with_hci(
+ DEFAULT_PERIPHERAL_NAME, peripheral_address, hci_source, hci_sink
+ )
+ self.device.listener = self
+ self.mode = self.mode_factory(self.device)
+ self.role = self.role_factory(self.mode)
+ self.device.classic_enabled = self.classic
+
+ await self.device.power_on()
+
+ if self.classic:
+ await self.device.set_discoverable(True)
+ await self.device.set_connectable(True)
+ else:
+ await self.device.start_advertising(auto_restart=True)
+
+ if self.classic:
+ print(
+ color(
+ '### Waiting for connection on'
+ f' {self.device.public_address}...',
+ 'cyan',
+ )
+ )
+ else:
+ print(
+ color(
+ f'### Waiting for connection on {peripheral_address}...',
+ 'cyan',
+ )
+ )
+ await self.connected.wait()
+ print(color('### Connected', 'cyan'))
+
+ await self.mode.on_connection(self.connection)
+ await self.role.run()
+ await asyncio.sleep(DEFAULT_LINGER_TIME)
+
+ def on_connection(self, connection):
+ connection.listener = self
+ self.connection = connection
+ self.connected.set()
+
+ def on_disconnection(self, reason):
+ print(color(f'!!! Disconnection: reason={reason}', 'red'))
+ self.connection = None
+ self.role.reset()
+
+ def on_connection_parameters_update(self):
+ print_connection(self.connection)
+
+ def on_connection_phy_update(self):
+ print_connection(self.connection)
+
+ def on_connection_att_mtu_update(self):
+ print_connection(self.connection)
+
+ def on_connection_data_length_change(self):
+ print_connection(self.connection)
+
+
+# -----------------------------------------------------------------------------
+def create_mode_factory(ctx, default_mode):
+ mode = ctx.obj['mode']
+ if mode is None:
+ mode = default_mode
+
+ def create_mode(device):
+ if mode == 'gatt-client':
+ return GattClient(device, att_mtu=ctx.obj['att_mtu'])
+
+ if mode == 'gatt-server':
+ return GattServer(device)
+
+ if mode == 'l2cap-client':
+ return L2capClient(device)
+
+ if mode == 'l2cap-server':
+ return L2capServer(device)
+
+ if mode == 'rfcomm-client':
+ return RfcommClient(device)
+
+ if mode == 'rfcomm-server':
+ return RfcommServer(device)
+
+ raise ValueError('invalid mode')
+
+ return create_mode
+
+
+# -----------------------------------------------------------------------------
+def create_role_factory(ctx, default_role):
+ role = ctx.obj['role']
+ if role is None:
+ role = default_role
+
+ def create_role(packet_io):
+ if role == 'sender':
+ return Sender(
+ packet_io,
+ start_delay=ctx.obj['start_delay'],
+ packet_size=ctx.obj['packet_size'],
+ packet_count=ctx.obj['packet_count'],
+ )
+
+ if role == 'receiver':
+ return Receiver(packet_io)
+
+ if role == 'ping':
+ return Ping(
+ packet_io,
+ start_delay=ctx.obj['start_delay'],
+ packet_size=ctx.obj['packet_size'],
+ packet_count=ctx.obj['packet_count'],
+ )
+
+ if role == 'pong':
+ return Pong(packet_io)
+
+ raise ValueError('invalid role')
+
+ return create_role
+
+
+# -----------------------------------------------------------------------------
+# Main
+# -----------------------------------------------------------------------------
[email protected]()
[email protected]('--device-config', metavar='FILENAME', help='Device configuration file')
[email protected]('--role', type=click.Choice(['sender', 'receiver', 'ping', 'pong']))
[email protected](
+ '--mode',
+ type=click.Choice(
+ [
+ 'gatt-client',
+ 'gatt-server',
+ 'l2cap-client',
+ 'l2cap-server',
+ 'rfcomm-client',
+ 'rfcomm-server',
+ ]
+ ),
+)
[email protected](
+ '--att-mtu',
+ metavar='MTU',
+ type=click.IntRange(23, 517),
+ help='GATT MTU (gatt-client mode)',
+)
[email protected](
+ '--packet-size',
+ '-s',
+ metavar='SIZE',
+ type=click.IntRange(8, 4096),
+ default=500,
+ help='Packet size (server role)',
+)
[email protected](
+ '--packet-count',
+ '-c',
+ metavar='COUNT',
+ type=int,
+ default=10,
+ help='Packet count (server role)',
+)
[email protected](
+ '--start-delay',
+ '-sd',
+ metavar='SECONDS',
+ type=int,
+ default=1,
+ help='Start delay (server role)',
+)
[email protected]_context
+def bench(
+ ctx, device_config, role, mode, att_mtu, packet_size, packet_count, start_delay
+):
+ ctx.ensure_object(dict)
+ ctx.obj['device_config'] = device_config
+ ctx.obj['role'] = role
+ ctx.obj['mode'] = mode
+ ctx.obj['att_mtu'] = att_mtu
+ ctx.obj['packet_size'] = packet_size
+ ctx.obj['packet_count'] = packet_count
+ ctx.obj['start_delay'] = start_delay
+
+ ctx.obj['classic'] = mode in ('rfcomm-client', 'rfcomm-server')
+
+
[email protected]()
[email protected]('transport')
[email protected](
+ '--peripheral',
+ 'peripheral_address',
+ metavar='ADDRESS_OR_NAME',
+ default=DEFAULT_PERIPHERAL_ADDRESS,
+ help='Address or name to connect to',
+)
[email protected](
+ '--connection-interval',
+ '--ci',
+ metavar='CONNECTION_INTERVAL',
+ type=int,
+ help='Connection interval (in ms)',
+)
[email protected]('--phy', type=click.Choice(['1m', '2m', 'coded']), help='PHY to use')
[email protected]_context
+def central(ctx, transport, peripheral_address, connection_interval, phy):
+ """Run as a central (initiates the connection)"""
+ role_factory = create_role_factory(ctx, 'sender')
+ mode_factory = create_mode_factory(ctx, 'gatt-client')
+ classic = ctx.obj['classic']
+
+ asyncio.run(
+ Central(
+ transport,
+ peripheral_address,
+ classic,
+ role_factory,
+ mode_factory,
+ connection_interval,
+ phy,
+ ).run()
+ )
+
+
[email protected]()
[email protected]('transport')
[email protected]_context
+def peripheral(ctx, transport):
+ """Run as a peripheral (waits for a connection)"""
+ role_factory = create_role_factory(ctx, 'receiver')
+ mode_factory = create_mode_factory(ctx, 'gatt-server')
+
+ asyncio.run(
+ Peripheral(transport, ctx.obj['classic'], role_factory, mode_factory).run()
+ )
+
+
+def main():
+ logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
+ bench()
+
+
+# -----------------------------------------------------------------------------
+if __name__ == "__main__":
+ main() # pylint: disable=no-value-for-parameter