| # Copyright 2021-2022 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # https://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| # ----------------------------------------------------------------------------- |
| # Imports |
| # ----------------------------------------------------------------------------- |
| import asyncio |
| import logging |
| from pyee import EventEmitter |
| from colors import color |
| |
| from .hci import * |
| from .l2cap import * |
| from .att import * |
| from .gatt import * |
| from .smp import * |
| from .core import ConnectionParameters |
| |
| # ----------------------------------------------------------------------------- |
| # Logging |
| # ----------------------------------------------------------------------------- |
| logger = logging.getLogger(__name__) |
| |
| |
| # ----------------------------------------------------------------------------- |
| # Constants |
| # ----------------------------------------------------------------------------- |
| HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH = 27 |
| HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS = 1 |
| HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH = 27 |
| HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS = 1 |
| |
| |
| # ----------------------------------------------------------------------------- |
| class Connection: |
| def __init__(self, host, handle, role, peer_address): |
| self.host = host |
| self.handle = handle |
| self.role = role |
| self.peer_address = peer_address |
| self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu) |
| |
| def on_hci_acl_data_packet(self, packet): |
| self.assembler.feed_packet(packet) |
| |
| def on_acl_pdu(self, pdu): |
| l2cap_pdu = L2CAP_PDU.from_bytes(pdu) |
| |
| if l2cap_pdu.cid == ATT_CID: |
| self.host.on_gatt_pdu(self, l2cap_pdu.payload) |
| elif l2cap_pdu.cid == SMP_CID: |
| self.host.on_smp_pdu(self, l2cap_pdu.payload) |
| else: |
| self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload) |
| |
| |
| # ----------------------------------------------------------------------------- |
| class Host(EventEmitter): |
| def __init__(self, controller_source = None, controller_sink = None): |
| super().__init__() |
| |
| self.hci_sink = None |
| self.ready = False # True when we can accept incoming packets |
| self.connections = {} # Connections, by connection handle |
| self.pending_command = None |
| self.pending_response = None |
| self.hc_le_acl_data_packet_length = HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH |
| self.hc_total_num_le_acl_data_packets = HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS |
| self.hc_acl_data_packet_length = HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH |
| self.hc_total_num_acl_data_packets = HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS |
| self.acl_packet_queue = collections.deque() |
| self.acl_packets_in_flight = 0 |
| self.local_supported_commands = bytes(64) |
| self.command_semaphore = asyncio.Semaphore(1) |
| self.long_term_key_provider = None |
| self.link_key_provider = None |
| self.pairing_io_capability_provider = None # Classic only |
| |
| # Connect to the source and sink if specified |
| if controller_source: |
| controller_source.set_packet_sink(self) |
| if controller_sink: |
| self.set_packet_sink(controller_sink) |
| |
| async def reset(self): |
| await self.send_command(HCI_Reset_Command()) |
| self.ready = True |
| |
| response = await self.send_command(HCI_Read_Local_Supported_Commands_Command()) |
| if response.return_parameters.status != HCI_SUCCESS: |
| raise ProtocolError(response.return_parameters.status, 'hci') |
| self.local_supported_commands = response.return_parameters.supported_commands |
| |
| await self.send_command(HCI_Set_Event_Mask_Command(event_mask = bytes.fromhex('FFFFFFFFFFFFFFFF'))) |
| await self.send_command(HCI_LE_Set_Event_Mask_Command(le_event_mask = bytes.fromhex('FFFFF00000000000'))) |
| await self.send_command(HCI_Read_Local_Version_Information_Command()) |
| await self.send_command(HCI_Write_LE_Host_Support_Command(le_supported_host = 1, simultaneous_le_host = 0)) |
| |
| response = await self.send_command(HCI_LE_Read_Buffer_Size_Command()) |
| if response.return_parameters.status == HCI_SUCCESS: |
| self.hc_le_acl_data_packet_length = response.return_parameters.hc_le_acl_data_packet_length |
| self.hc_total_num_le_acl_data_packets = response.return_parameters.hc_total_num_le_acl_data_packets |
| logger.debug(f'HCI LE ACL flow control: hc_le_acl_data_packet_length={response.return_parameters.hc_le_acl_data_packet_length}, hc_total_num_le_acl_data_packets={response.return_parameters.hc_total_num_le_acl_data_packets}') |
| else: |
| logger.warn(f'HCI_LE_Read_Buffer_Size_Command failed: {response.return_parameters.status}') |
| if response.return_parameters.hc_le_acl_data_packet_length == 0 or response.return_parameters.hc_total_num_le_acl_data_packets == 0: |
| # Read the non-LE-specific values |
| response = await self.send_command(HCI_Read_Buffer_Size_Command()) |
| if response.return_parameters.status == HCI_SUCCESS: |
| self.hc_acl_data_packet_length = response.return_parameters.hc_le_acl_data_packet_length |
| self.hc_le_acl_data_packet_length = self.hc_le_acl_data_packet_length or self.hc_acl_data_packet_length |
| self.hc_total_num_acl_data_packets = response.return_parameters.hc_total_num_le_acl_data_packets |
| self.hc_total_num_le_acl_data_packets = self.hc_total_num_le_acl_data_packets or self.hc_total_num_acl_data_packets |
| logger.debug(f'HCI LE ACL flow control: hc_le_acl_data_packet_length={self.hc_le_acl_data_packet_length}, hc_total_num_le_acl_data_packets={self.hc_total_num_le_acl_data_packets}') |
| else: |
| logger.warn(f'HCI_Read_Buffer_Size_Command failed: {response.return_parameters.status}') |
| |
| self.reset_done = True |
| |
| @property |
| def controller(self): |
| return self.hci_sink |
| |
| @controller.setter |
| def controller(self, controller): |
| self.set_packet_sink(controller) |
| if controller: |
| controller.set_packet_sink(self) |
| |
| def set_packet_sink(self, sink): |
| self.hci_sink = sink |
| |
| def send_hci_packet(self, packet): |
| self.hci_sink.on_packet(packet.to_bytes()) |
| |
| async def send_command(self, command): |
| logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: {command}') |
| |
| # Wait until we can send (only one pending command at a time) |
| async with self.command_semaphore: |
| assert(self.pending_command is None) |
| assert(self.pending_response is None) |
| |
| # Create a future value to hold the eventual response |
| self.pending_response = asyncio.get_running_loop().create_future() |
| self.pending_command = command |
| |
| try: |
| self.send_hci_packet(command) |
| response = await self.pending_response |
| # TODO: check error values |
| return response |
| except Exception as error: |
| logger.warning(f'{color("!!! Exception while sending HCI packet:", "red")} {error}') |
| # raise error |
| finally: |
| self.pending_command = None |
| self.pending_response = None |
| |
| # Use this method to send a command from a task |
| def send_command_sync(self, command): |
| async def send_command(command): |
| await self.send_command(command) |
| |
| asyncio.create_task(send_command(command)) |
| |
| def send_l2cap_pdu(self, connection_handle, cid, pdu): |
| l2cap_pdu = L2CAP_PDU(cid, pdu).to_bytes() |
| |
| # Send the data to the controller via ACL packets |
| bytes_remaining = len(l2cap_pdu) |
| offset = 0 |
| pb_flag = 0 |
| while bytes_remaining: |
| data_total_length = min(bytes_remaining, self.hc_le_acl_data_packet_length) |
| acl_packet = HCI_AclDataPacket( |
| connection_handle = connection_handle, |
| pb_flag = pb_flag, |
| bc_flag = 0, |
| data_total_length = data_total_length, |
| data = l2cap_pdu[offset:offset + data_total_length] |
| ) |
| logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: (CID={cid}) {acl_packet}') |
| self.queue_acl_packet(acl_packet) |
| pb_flag = 1 |
| offset += data_total_length |
| bytes_remaining -= data_total_length |
| |
| def queue_acl_packet(self, acl_packet): |
| self.acl_packet_queue.appendleft(acl_packet) |
| self.check_acl_packet_queue() |
| |
| if len(self.acl_packet_queue): |
| logger.debug(f'{self.acl_packets_in_flight} ACL packets in flight, {len(self.acl_packet_queue)} in queue') |
| |
| def check_acl_packet_queue(self): |
| # Send all we can |
| while len(self.acl_packet_queue) > 0 and self.acl_packets_in_flight < self.hc_total_num_le_acl_data_packets: |
| packet = self.acl_packet_queue.pop() |
| self.send_hci_packet(packet) |
| self.acl_packets_in_flight += 1 |
| |
| # Packet Sink protocol (packets coming from the controller via HCI) |
| def on_packet(self, packet): |
| hci_packet = HCI_Packet.from_bytes(packet) |
| if self.ready or ( |
| hci_packet.hci_packet_type == HCI_EVENT_PACKET and |
| hci_packet.event_code == HCI_COMMAND_COMPLETE_EVENT and |
| hci_packet.command_opcode == HCI_RESET_COMMAND |
| ): |
| self.on_hci_packet(hci_packet) |
| else: |
| logger.debug('reset not done, ignoring packet from controller') |
| |
| def on_hci_packet(self, packet): |
| logger.debug(f'{color("### CONTROLLER -> HOST", "green")}: {packet}') |
| |
| # If the packet is a command, invoke the handler for this packet |
| if packet.hci_packet_type == HCI_COMMAND_PACKET: |
| self.on_hci_command_packet(packet) |
| elif packet.hci_packet_type == HCI_EVENT_PACKET: |
| self.on_hci_event_packet(packet) |
| elif packet.hci_packet_type == HCI_ACL_DATA_PACKET: |
| self.on_hci_acl_data_packet(packet) |
| else: |
| logger.warning(f'!!! unknown packet type {packet.hci_packet_type}') |
| |
| def on_hci_command_packet(self, command): |
| logger.warning(f'!!! unexpected command packet: {command}') |
| |
| def on_hci_event_packet(self, event): |
| handler_name = f'on_{event.name.lower()}' |
| handler = getattr(self, handler_name, self.on_hci_event) |
| handler(event) |
| |
| def on_hci_acl_data_packet(self, packet): |
| # Look for the connection to which this data belongs |
| if connection := self.connections.get(packet.connection_handle): |
| connection.on_hci_acl_data_packet(packet) |
| |
| def on_gatt_pdu(self, connection, pdu): |
| self.emit('gatt_pdu', connection.handle, pdu) |
| |
| def on_smp_pdu(self, connection, pdu): |
| self.emit('smp_pdu', connection.handle, pdu) |
| |
| def on_l2cap_pdu(self, connection, cid, pdu): |
| self.emit('l2cap_pdu', connection.handle, cid, pdu) |
| |
| def on_command_processed(self, event): |
| if self.pending_response: |
| # Check that it is what we were expecting |
| if self.pending_command.op_code != event.command_opcode: |
| logger.warning(f'!!! command result mismatch, expected 0x{self.pending_command.op_code:X} but got 0x{event.command_opcode:X}') |
| |
| self.pending_response.set_result(event) |
| else: |
| logger.warning('!!! no pending response future to set') |
| |
| ############################################################ |
| # HCI handlers |
| ############################################################ |
| def on_hci_event(self, event): |
| logger.warning(f'{color(f"--- Ignoring event {event}", "red")}') |
| |
| def on_hci_command_complete_event(self, event): |
| if event.command_opcode == 0: |
| # This is used just for the Num_HCI_Command_Packets field, not related to an actual command |
| logger.debug('no-command event') |
| else: |
| return self.on_command_processed(event) |
| |
| def on_hci_command_status_event(self, event): |
| return self.on_command_processed(event) |
| |
| def on_hci_number_of_completed_packets_event(self, event): |
| total_packets = sum(event.num_completed_packets) |
| if total_packets <= self.acl_packets_in_flight: |
| self.acl_packets_in_flight -= total_packets |
| self.check_acl_packet_queue() |
| else: |
| logger.warning(color(f'!!! {total_packets} completed but only {self.acl_packets_in_flight} in flight')) |
| self.acl_packets_in_flight = 0 |
| |
| # Classic only |
| def on_hci_connection_request_event(self, event): |
| # For now, just accept everything |
| # TODO: delegate the decision |
| self.send_command_sync( |
| HCI_Accept_Connection_Request_Command( |
| bd_addr = event.bd_addr, |
| role = 0x01 # Remain the peripheral |
| ) |
| ) |
| |
| def on_hci_le_connection_complete_event(self, event): |
| # Check if this is a cancellation |
| if event.status == HCI_SUCCESS: |
| # Create/update the connection |
| logger.debug(f'### CONNECTION: [0x{event.connection_handle:04X}] {event.peer_address} as {HCI_Constant.role_name(event.role)}') |
| |
| connection = self.connections.get(event.connection_handle) |
| if connection is None: |
| connection = Connection(self, event.connection_handle, event.role, event.peer_address) |
| self.connections[event.connection_handle] = connection |
| |
| # Notify the client |
| connection_parameters = ConnectionParameters( |
| event.conn_interval, |
| event.conn_latency, |
| event.supervision_timeout |
| ) |
| self.emit( |
| 'connection', |
| event.connection_handle, |
| BT_LE_TRANSPORT, |
| event.peer_address, |
| None, |
| event.role, |
| connection_parameters |
| ) |
| else: |
| logger.debug(f'### CONNECTION FAILED: {event.status}') |
| |
| # Notify the listeners |
| self.emit('connection_failure', event.status) |
| |
| def on_hci_le_enhanced_connection_complete_event(self, event): |
| # Just use the same implementation as for the non-enhanced event for now |
| self.on_hci_le_connection_complete_event(event) |
| |
| def on_hci_connection_complete_event(self, event): |
| if event.status == HCI_SUCCESS: |
| # Create/update the connection |
| logger.debug(f'### BR/EDR CONNECTION: [0x{event.connection_handle:04X}] {event.bd_addr}') |
| |
| connection = self.connections.get(event.connection_handle) |
| if connection is None: |
| connection = Connection(self, event.connection_handle, BT_CENTRAL_ROLE, event.bd_addr) |
| self.connections[event.connection_handle] = connection |
| |
| # Notify the client |
| self.emit( |
| 'connection', |
| event.connection_handle, |
| BT_BR_EDR_TRANSPORT, |
| event.bd_addr, |
| None, |
| BT_CENTRAL_ROLE, |
| None |
| ) |
| else: |
| logger.debug(f'### BR/EDR CONNECTION FAILED: {event.status}') |
| |
| # Notify the client |
| self.emit('connection_failure', event.connection_handle, event.status) |
| |
| def on_hci_disconnection_complete_event(self, event): |
| # Find the connection |
| if (connection := self.connections.get(event.connection_handle)) is None: |
| logger.warning('!!! DISCONNECTION COMPLETE: unknown handle') |
| return |
| |
| if event.status == HCI_SUCCESS: |
| logger.debug(f'### DISCONNECTION: [0x{event.connection_handle:04X}] {connection.peer_address} as {HCI_Constant.role_name(connection.role)}, reason={event.reason}') |
| del self.connections[event.connection_handle] |
| |
| # Notify the listeners |
| self.emit('disconnection', event.connection_handle, event.reason) |
| else: |
| logger.debug(f'### DISCONNECTION FAILED: {event.status}') |
| |
| # Notify the listeners |
| self.emit('disconnection_failure', event.status) |
| |
| def on_hci_le_connection_update_complete_event(self, event): |
| if (connection := self.connections.get(event.connection_handle)) is None: |
| logger.warning('!!! CONNECTION PARAMETERS UPDATE COMPLETE: unknown handle') |
| return |
| |
| # Notify the client |
| if event.status == HCI_SUCCESS: |
| connection_parameters = ConnectionParameters( |
| event.conn_interval, |
| event.conn_latency, |
| event.supervision_timeout |
| ) |
| self.emit('connection_parameters_update', connection.handle, connection_parameters) |
| else: |
| self.emit('connection_parameters_update_failure', connection.handle, event.status) |
| |
| def on_hci_le_phy_update_complete_event(self, event): |
| if (connection := self.connections.get(event.connection_handle)) is None: |
| logger.warning('!!! CONNECTION PHY UPDATE COMPLETE: unknown handle') |
| return |
| |
| # Notify the client |
| if event.status == HCI_SUCCESS: |
| connection_phy = ConnectionPHY(event.tx_phy, event.rx_phy) |
| self.emit('connection_phy_update', connection.handle, connection_phy) |
| else: |
| self.emit('connection_phy_update_failure', connection.handle, event.status) |
| |
| def on_hci_le_advertising_report_event(self, event): |
| for report in event.reports: |
| self.emit( |
| 'advertising_report', |
| report.address, |
| report.data, |
| report.rssi, |
| report.event_type |
| ) |
| |
| def on_hci_le_remote_connection_parameter_request_event(self, event): |
| if event.connection_handle not in self.connections: |
| logger.warning('!!! REMOTE CONNECTION PARAMETER REQUEST: unknown handle') |
| return |
| |
| # For now, just accept everything |
| # TODO: delegate the decision |
| self.send_command_sync( |
| HCI_LE_Remote_Connection_Parameter_Request_Reply_Command( |
| connection_handle = event.connection_handle, |
| interval_min = event.interval_min, |
| interval_max = event.interval_max, |
| latency = event.latency, |
| timeout = event.timeout, |
| minimum_ce_length = 0, |
| maximum_ce_length = 0 |
| ) |
| ) |
| |
| def on_hci_le_long_term_key_request_event(self, event): |
| if (connection := self.connections.get(event.connection_handle)) is None: |
| logger.warning('!!! LE LONG TERM KEY REQUEST: unknown handle') |
| return |
| |
| async def send_long_term_key(): |
| if self.long_term_key_provider is None: |
| logger.debug('no long term key provider') |
| long_term_key = None |
| else: |
| long_term_key = await self.long_term_key_provider( |
| connection.handle, |
| event.random_number, |
| event.encryption_diversifier |
| ) |
| if long_term_key: |
| response = HCI_LE_Long_Term_Key_Request_Reply_Command( |
| connection_handle = event.connection_handle, |
| long_term_key = long_term_key |
| ) |
| else: |
| response = HCI_LE_Long_Term_Key_Request_Negative_Reply_Command( |
| connection_handle = event.connection_handle |
| ) |
| |
| await self.send_command(response) |
| |
| asyncio.create_task(send_long_term_key()) |
| |
| def on_hci_synchronous_connection_complete_event(self, event): |
| pass |
| |
| def on_hci_synchronous_connection_changed_event(self, event): |
| pass |
| |
| def on_hci_role_change_event(self, event): |
| if event.status == HCI_SUCCESS: |
| logger.debug(f'role change for {event.bd_addr}: {HCI_Constant.role_name(event.new_role)}') |
| # TODO: lookup the connection and update the role |
| else: |
| logger.debug(f'role change for {event.bd_addr} failed: {HCI_Constant.error_name(event.status)}') |
| |
| def on_hci_le_data_length_change_event(self, event): |
| self.emit( |
| 'connection_data_length_change', |
| event.connection_handle, |
| event.max_tx_octets, |
| event.max_tx_time, |
| event.max_rx_octets, |
| event.max_rx_time |
| ) |
| |
| def on_hci_authentication_complete_event(self, event): |
| # Notify the client |
| if event.status == HCI_SUCCESS: |
| self.emit('connection_authentication', event.connection_handle) |
| else: |
| self.emit('connection_authentication_failure', event.connection_handle, event.status) |
| |
| def on_hci_encryption_change_event(self, event): |
| # Notify the client |
| if event.status == HCI_SUCCESS: |
| self.emit('connection_encryption_change', event.connection_handle, event.encryption_enabled) |
| else: |
| self.emit('connection_encryption_failure', event.connection_handle, event.status) |
| |
| def on_hci_encryption_key_refresh_complete_event(self, event): |
| # Notify the client |
| if event.status == HCI_SUCCESS: |
| self.emit('connection_encryption_key_refresh', event.connection_handle) |
| else: |
| self.emit('connection_encryption_key_refresh_failure', event.connection_handle, event.status) |
| |
| def on_hci_link_supervision_timeout_changed_event(self, event): |
| pass |
| |
| def on_hci_max_slots_change_event(self, event): |
| pass |
| |
| def on_hci_page_scan_repetition_mode_change_event(self, event): |
| pass |
| |
| def on_hci_link_key_notification_event(self, event): |
| logger.debug(f'link key for {event.bd_addr}: {event.link_key.hex()}, type={HCI_Constant.link_key_type_name(event.key_type)}') |
| self.emit('link_key', event.bd_addr, event.link_key, event.key_type) |
| |
| def on_hci_simple_pairing_complete_event(self, event): |
| logger.debug(f'simple pairing complete for {event.bd_addr}: status={HCI_Constant.status_name(event.status)}') |
| |
| def on_hci_pin_code_request_event(self, event): |
| # For now, just refuse all requests |
| # TODO: delegate the decision |
| self.send_command_sync( |
| HCI_PIN_Code_Request_Negative_Reply_Command( |
| bd_addr = event.bd_addr |
| ) |
| ) |
| |
| def on_hci_link_key_request_event(self, event): |
| async def send_link_key(): |
| if self.link_key_provider is None: |
| logger.debug('no link key provider') |
| link_key = None |
| else: |
| link_key = await self.link_key_provider(event.bd_addr) |
| if link_key: |
| response = HCI_Link_Key_Request_Reply_Command( |
| bd_addr = event.bd_addr, |
| link_key = link_key |
| ) |
| else: |
| response = HCI_Link_Key_Request_Negative_Reply_Command( |
| bd_addr = event.bd_addr |
| ) |
| |
| await self.send_command(response) |
| |
| asyncio.create_task(send_link_key()) |
| |
| def on_hci_io_capability_request_event(self, event): |
| self.emit('authentication_io_capability_request', event.bd_addr) |
| |
| def on_hci_io_capability_response_event(self, event): |
| pass |
| |
| def on_hci_user_confirmation_request_event(self, event): |
| self.emit('authentication_user_confirmation_request', event.bd_addr, event.numeric_value) |
| |
| def on_hci_user_passkey_request_event(self, event): |
| self.emit('authentication_user_passkey_request', event.bd_addr) |
| |
| def on_hci_inquiry_complete_event(self, event): |
| self.emit('inquiry_complete') |
| |
| def on_hci_inquiry_result_with_rssi_event(self, event): |
| for response in event.responses: |
| self.emit( |
| 'inquiry_result', |
| response.bd_addr, |
| response.class_of_device, |
| b'', |
| response.rssi |
| ) |
| |
| def on_hci_extended_inquiry_result_event(self, event): |
| self.emit( |
| 'inquiry_result', |
| event.bd_addr, |
| event.class_of_device, |
| event.extended_inquiry_response, |
| event.rssi |
| ) |
| |
| def on_hci_remote_name_request_complete_event(self, event): |
| if event.status != HCI_SUCCESS: |
| self.emit('remote_name_failure', event.bd_addr, event.status) |
| else: |
| self.emit('remote_name', event.bd_addr, event.remote_name) |