classic: add BR/EDR accept connection logic
diff --git a/bumble/device.py b/bumble/device.py
index f68e330..13097a1 100644
--- a/bumble/device.py
+++ b/bumble/device.py
@@ -519,6 +519,7 @@
self.le_simultaneous_enabled = True
self.classic_sc_enabled = True
self.classic_ssp_enabled = True
+ self.classic_accept_any = True
self.connectable = True
self.discoverable = True
self.advertising_data = bytes(
@@ -539,6 +540,7 @@
self.le_simultaneous_enabled = config.get('le_simultaneous_enabled', self.le_simultaneous_enabled)
self.classic_sc_enabled = config.get('classic_sc_enabled', self.classic_sc_enabled)
self.classic_ssp_enabled = config.get('classic_ssp_enabled', self.classic_ssp_enabled)
+ self.classic_accept_any = config.get('classic_accept_any', self.classic_accept_any)
self.connectable = config.get('connectable', self.connectable)
self.discoverable = config.get('discoverable', self.discoverable)
@@ -630,6 +632,9 @@
def on_connection_failure(self, error):
pass
+ def on_connection_request(self, bd_addr, class_of_device, link_type):
+ pass
+
def on_characteristic_subscription(self, connection, characteristic, notify_enabled, indicate_enabled):
pass
@@ -680,6 +685,7 @@
self.classic_enabled = False
self.inquiry_response = None
self.address_resolver = None
+ self.classic_pending_accepts = { Address.ANY: [] } # Futures, by BD address OR [Futures] for Address.ANY
# Use the initial config or a default
self.public_address = Address('00:00:00:00:00:00')
@@ -700,6 +706,7 @@
self.classic_sc_enabled = config.classic_sc_enabled
self.discoverable = config.discoverable
self.connectable = config.connectable
+ self.classic_accept_any = config.classic_accept_any
# If a name is passed, override the name from the config
if name:
@@ -1300,6 +1307,89 @@
if transport == BT_LE_TRANSPORT:
self.le_connecting = False
+ async def accept(
+ self,
+ peer_address=Address.ANY,
+ role=BT_PERIPHERAL_ROLE,
+ timeout=DEVICE_DEFAULT_CONNECT_TIMEOUT
+ ):
+ '''
+ Wait and accept any incoming connection or a connection from `peer_address` when set.
+
+ Notes:
+ * A `connect` to the same peer will also complete this call.
+ * The `timeout` parameter is only handled while waiting for the connection request,
+ once received and accepeted, the controller shall issue a connection complete event.
+ '''
+
+ if type(peer_address) is str:
+ try:
+ peer_address = Address(peer_address)
+ except ValueError:
+ # If the address is not parsable, assume it is a name instead
+ logger.debug('looking for peer by name')
+ peer_address = await self.find_peer_by_name(peer_address, BT_BR_EDR_TRANSPORT) # TODO: timeout
+
+ if peer_address == Address.NIL:
+ raise ValueError('accept on nil address')
+
+ # Create a future so that we can wait for the request
+ pending_request = asyncio.get_running_loop().create_future()
+
+ if peer_address == Address.ANY:
+ self.classic_pending_accepts[Address.ANY].append(pending_request)
+ elif peer_address in self.classic_pending_accepts:
+ raise InvalidStateError('accept connection already pending')
+ else:
+ self.classic_pending_accepts[peer_address] = pending_request
+
+ try:
+ # Wait for a request or a completed connection
+ result = await (asyncio.wait_for(pending_request, timeout) if timeout else pending_request)
+
+ except:
+ # Remove future from device context
+ if peer_address == Address.ANY:
+ self.classic_pending_accepts[Address.ANY].remove(pending_request)
+ else:
+ self.classic_pending_accepts.pop(peer_address)
+ raise
+
+ # Result may already be a completed connection,
+ # see `on_connection` for details
+ if isinstance(result, Connection):
+ return result
+
+ # Otherwise, result came from `on_connection_request`
+ peer_address, class_of_device, link_type = result
+
+ def on_connection(connection):
+ if connection.transport == BT_BR_EDR_TRANSPORT and connection.peer_address == peer_address:
+ pending_connection.set_result(connection)
+
+ def on_connection_failure(error):
+ if error.transport == BT_BR_EDR_TRANSPORT and error.peer_address == peer_address:
+ pending_connection.set_exception(error)
+
+ # Create a future so that we can wait for the connection's result
+ pending_connection = asyncio.get_running_loop().create_future()
+ self.on('connection', on_connection)
+ self.on('connection_failure', on_connection_failure)
+
+ try:
+ # Accept connection request
+ await self.send_command(HCI_Accept_Connection_Request_Command(
+ bd_addr = peer_address,
+ role = role
+ ))
+
+ # Wait for connection complete
+ return await pending_connection
+
+ finally:
+ self.remove_listener('connection', on_connection)
+ self.remove_listener('connection_failure', on_connection_failure)
+
@asynccontextmanager
async def connect_as_gatt(self, peer_address):
async with AsyncExitStack() as stack:
@@ -1716,6 +1806,14 @@
)
self.connections[connection_handle] = connection
+ # We may have an accept ongoing waiting for a connection request for `peer_address`.
+ # Typicaly happen when using `connect` to the same `peer_address` we are waiting with
+ # an `accept` for.
+ # In this case, set the completed `connection` to the `accept` future result.
+ if peer_address in self.classic_pending_accepts:
+ future = self.classic_pending_accepts.pop(peer_address)
+ future.set_result(connection)
+
# Emit an event to notify listeners of the new connection
self.emit('connection', connection)
else:
@@ -1779,6 +1877,39 @@
)
self.emit('connection_failure', error)
+ # FIXME: Explore a delegate-model for BR/EDR wait connection #56.
+ @host_event_handler
+ def on_connection_request(self, bd_addr, class_of_device, link_type):
+ logger.debug(f'*** Connection request: {bd_addr}')
+
+ # match a pending future using `bd_addr`
+ if bd_addr in self.classic_pending_accepts:
+ future = self.classic_pending_accepts.pop(bd_addr)
+ future.set_result((bd_addr, class_of_device, link_type))
+
+ # match first pending future for ANY address
+ elif len(self.classic_pending_accepts[Address.ANY]) > 0:
+ future = self.classic_pending_accepts[Address.ANY].pop(0)
+ future.set_result((bd_addr, class_of_device, link_type))
+
+ # device configuration is set to accept any incoming connection
+ elif self.classic_accept_any:
+ self.host.send_command_sync(
+ HCI_Accept_Connection_Request_Command(
+ bd_addr = bd_addr,
+ role = 0x01 # Remain the peripheral
+ )
+ )
+
+ # reject incoming connection
+ else:
+ self.host.send_command_sync(
+ HCI_Reject_Connection_Request_Command(
+ bd_addr = bd_addr,
+ reason = HCI_CONNECTION_REJECTED_DUE_TO_LIMITED_RESOURCES_ERROR
+ )
+ )
+
@host_event_handler
@with_connection_from_handle
def on_disconnection(self, connection, reason):
diff --git a/bumble/hci.py b/bumble/hci.py
index af26374..d4cf7cc 100644
--- a/bumble/hci.py
+++ b/bumble/hci.py
@@ -1652,6 +1652,16 @@
ADDRESS_TYPE_SPEC = {'size': 1, 'mapper': lambda x: Address.address_type_name(x)}
+ @classmethod
+ @property
+ def ANY(cls):
+ return cls(b"\xff\xff\xff\xff\xff\xff", cls.PUBLIC_DEVICE_ADDRESS)
+
+ @classmethod
+ @property
+ def NIL(cls):
+ return cls(b"\x00\x00\x00\x00\x00\x00", cls.PUBLIC_DEVICE_ADDRESS)
+
@staticmethod
def address_type_name(address_type):
return name_or_number(Address.ADDRESS_TYPE_NAMES, address_type)
@@ -1937,6 +1947,17 @@
# -----------------------------------------------------------------------------
@HCI_Command.command([
+ ('bd_addr', Address.parse_address),
+ ('reason', {'size': 1, 'mapper': HCI_Constant.error_name})
+])
+class HCI_Reject_Connection_Request_Command(HCI_Command):
+ '''
+ See Bluetooth spec @ 7.1.9 Reject Connection Request Command
+ '''
+
+
+# -----------------------------------------------------------------------------
+@HCI_Command.command([
('bd_addr', Address.parse_address),
('link_key', 16)
])
diff --git a/bumble/host.py b/bumble/host.py
index 01c25a4..32b2194 100644
--- a/bumble/host.py
+++ b/bumble/host.py
@@ -347,13 +347,12 @@
# Classic only
def on_hci_connection_request_event(self, event):
- # For now, just accept everything
- # TODO: delegate the decision
- self.send_command_sync(
- HCI_Accept_Connection_Request_Command(
- bd_addr = event.bd_addr,
- role = 0x01 # Remain the peripheral
- )
+ # Notify the listeners
+ self.emit(
+ 'connection_request',
+ event.bd_addr,
+ event.class_of_device,
+ event.link_type,
)
def on_hci_le_connection_complete_event(self, event):
diff --git a/tests/device_test.py b/tests/device_test.py
index cd72c4c..acf4446 100644
--- a/tests/device_test.py
+++ b/tests/device_test.py
@@ -158,16 +158,23 @@
d1.host.set_packet_sink(Sink(d1_flow()))
d2.host.set_packet_sink(Sink(d2_flow()))
- [c1, c2] = await asyncio.gather(*[
+ [c01, c02, a10, a20, a01] = await asyncio.gather(*[
asyncio.create_task(d0.connect(d1.public_address, transport=BT_BR_EDR_TRANSPORT)),
asyncio.create_task(d0.connect(d2.public_address, transport=BT_BR_EDR_TRANSPORT)),
+ asyncio.create_task(d1.accept(peer_address=d0.public_address)),
+ asyncio.create_task(d2.accept()),
+ asyncio.create_task(d0.accept(peer_address=d1.public_address)),
])
- assert type(c1) == Connection
- assert type(c2) == Connection
+ assert type(c01) == Connection
+ assert type(c02) == Connection
+ assert type(a10) == Connection
+ assert type(a20) == Connection
+ assert type(a01) == Connection
- assert c1.handle == 0x100
- assert c2.handle == 0x101
+ assert c01.handle == a10.handle and c01.handle == 0x100
+ assert c02.handle == a20.handle and c02.handle == 0x101
+ assert a01 == c01
# -----------------------------------------------------------------------------