host: spawn each asynchronous task with the right aliveness
diff --git a/bumble/device.py b/bumble/device.py
index 5a16392..79b305a 100644
--- a/bumble/device.py
+++ b/bumble/device.py
@@ -535,7 +535,7 @@
self.on('disconnection_failure', abort.set_exception)
try:
- await asyncio.wait_for(abort, timeout)
+ await asyncio.wait_for(self.device.abort_on('flush', abort), timeout)
except asyncio.TimeoutError:
pass
@@ -1592,7 +1592,7 @@
if transport == BT_LE_TRANSPORT:
self.le_connecting = True
if timeout is None:
- return await pending_connection
+ return await self.abort_on('flush', pending_connection)
else:
try:
return await asyncio.wait_for(
@@ -1609,7 +1609,7 @@
)
try:
- return await pending_connection
+ return await self.abort_on('flush', pending_connection)
except ConnectionError:
raise TimeoutError()
finally:
@@ -1661,6 +1661,7 @@
try:
# Wait for a request or a completed connection
+ pending_request = self.abort_on('flush', pending_request)
result = await (
asyncio.wait_for(pending_request, timeout)
if timeout
@@ -1682,6 +1683,9 @@
# Otherwise, result came from `on_connection_request`
peer_address, class_of_device, link_type = result
+ # Create a future so that we can wait for the connection's result
+ pending_connection = asyncio.get_running_loop().create_future()
+
def on_connection(connection):
if (
connection.transport == BT_BR_EDR_TRANSPORT
@@ -1696,8 +1700,6 @@
):
pending_connection.set_exception(error)
- # Create a future so that we can wait for the connection's result
- pending_connection = asyncio.get_running_loop().create_future()
self.on('connection', on_connection)
self.on('connection_failure', on_connection_failure)
@@ -1713,7 +1715,7 @@
)
# Wait for connection complete
- return await pending_connection
+ return await self.abort_on('flush', pending_connection)
finally:
self.remove_listener('connection', on_connection)
@@ -1782,7 +1784,7 @@
# Wait for the disconnection process to complete
self.disconnecting = True
- return await pending_disconnection
+ return await self.abort_on('flush', pending_disconnection)
finally:
connection.remove_listener(
'disconnection', pending_disconnection.set_result
@@ -1910,7 +1912,7 @@
else:
return None
- return await peer_address
+ return await self.abort_on('flush', peer_address)
finally:
if handler is not None:
self.remove_listener(event_name, handler)
@@ -1994,7 +1996,7 @@
connection.authenticating = True
# Wait for the authentication to complete
- await pending_authentication
+ await connection.abort_on('disconnection', pending_authentication)
finally:
connection.authenticating = False
connection.remove_listener('connection_authentication', on_authentication)
@@ -2068,7 +2070,7 @@
raise HCI_StatusError(result)
# Wait for the result
- await pending_encryption
+ await connection.abort_on('disconnection', pending_encryption)
finally:
connection.remove_listener(
'connection_encryption_change', on_encryption_change
@@ -2116,11 +2118,18 @@
raise HCI_StatusError(result)
# Wait for the result
- return await pending_name
+ return await self.abort_on('flush', pending_name)
finally:
self.remove_listener('remote_name', handler)
self.remove_listener('remote_name_failure', failure_handler)
+ @host_event_handler
+ def on_flush(self):
+ self.emit('flush')
+ for _, connection in self.connections.items():
+ connection.emit('disconnection', 0)
+ self.connections = {}
+
# [Classic only]
@host_event_handler
def on_link_key(self, bd_addr, link_key, key_type):
@@ -2135,7 +2144,7 @@
except Exception as error:
logger.warn(f'!!! error while storing keys: {error}')
- asyncio.create_task(store_keys())
+ self.abort_on('flush', store_keys())
if connection := self.find_connection_by_bd_addr(
bd_addr, transport=BT_BR_EDR_TRANSPORT
@@ -2227,10 +2236,10 @@
async def new_connection():
# Figure out which PHY we're connected with
if self.host.supports_command(HCI_LE_READ_PHY_COMMAND):
- result = await self.send_command(
+ result = await asyncio.shield(self.send_command(
HCI_LE_Read_PHY_Command(connection_handle=connection_handle),
check_result=True,
- )
+ ))
phy = ConnectionPHY(
result.return_parameters.tx_phy, result.return_parameters.rx_phy
)
@@ -2261,7 +2270,7 @@
# Emit an event to notify listeners of the new connection
self.emit('connection', connection)
- asyncio.create_task(new_connection())
+ self.abort_on('flush', new_connection())
@host_event_handler
def on_connection_failure(self, transport, peer_address, error_code):
@@ -2338,7 +2347,7 @@
# Restart advertising if auto-restart is enabled
if self.auto_restart_advertising:
logger.debug('restarting advertising')
- asyncio.create_task(
+ self.abort_on('flush',
self.start_advertising(
advertising_type=self.advertising_type, auto_restart=True
)
@@ -2460,17 +2469,19 @@
if can_compare:
async def compare_numbers():
- numbers_match = await pairing_config.delegate.compare_numbers(
- code, digits=6
+ numbers_match = await connection.abort_on('disconnection',
+ pairing_config.delegate.compare_numbers(
+ code, digits=6
+ )
)
if numbers_match:
- self.host.send_command_sync(
+ await self.host.send_command(
HCI_User_Confirmation_Request_Reply_Command(
bd_addr=connection.peer_address
)
)
else:
- self.host.send_command_sync(
+ await self.host.send_command(
HCI_User_Confirmation_Request_Negative_Reply_Command(
bd_addr=connection.peer_address
)
@@ -2480,15 +2491,16 @@
else:
async def confirm():
- confirm = await pairing_config.delegate.confirm()
+ confirm = await connection.abort_on('disconnection',
+ pairing_config.delegate.confirm())
if confirm:
- self.host.send_command_sync(
+ await self.host.send_command(
HCI_User_Confirmation_Request_Reply_Command(
bd_addr=connection.peer_address
)
)
else:
- self.host.send_command_sync(
+ await self.host.send_command(
HCI_User_Confirmation_Request_Negative_Reply_Command(
bd_addr=connection.peer_address
)
@@ -2512,15 +2524,16 @@
if can_input:
async def get_number():
- number = await pairing_config.delegate.get_number()
+ number = await connection.abort_on('disconnection',
+ pairing_config.delegate.get_number())
if number is not None:
- self.host.send_command_sync(
+ await self.host.send_command(
HCI_User_Passkey_Request_Reply_Command(
bd_addr=connection.peer_address, numeric_value=number
)
)
else:
- self.host.send_command_sync(
+ await self.host.send_command(
HCI_User_Passkey_Request_Negative_Reply_Command(
bd_addr=connection.peer_address
)
@@ -2541,7 +2554,7 @@
# Ask what the pairing config should be for this connection
pairing_config = self.pairing_config_factory(connection)
- asyncio.create_task(pairing_config.delegate.display_number(passkey))
+ connection.abort_on('disconnection', pairing_config.delegate.display_number(passkey))
# [Classic only]
@host_event_handler
diff --git a/bumble/host.py b/bumble/host.py
index 354d5fb..d768379 100644
--- a/bumble/host.py
+++ b/bumble/host.py
@@ -17,7 +17,6 @@
# -----------------------------------------------------------------------------
import asyncio
import logging
-from pyee import EventEmitter
from colors import color
from .hci import *
@@ -26,6 +25,7 @@
from .gatt import *
from .smp import *
from .core import ConnectionParameters
+from .utils import AbortableEventEmitter
# -----------------------------------------------------------------------------
# Logging
@@ -65,7 +65,7 @@
# -----------------------------------------------------------------------------
-class Host(EventEmitter):
+class Host(AbortableEventEmitter):
def __init__(self, controller_source=None, controller_sink=None):
super().__init__()
@@ -96,7 +96,19 @@
if controller_sink:
self.set_packet_sink(controller_sink)
+ async def flush(self):
+ # Make sure no command is pending
+ await self.command_semaphore.acquire()
+
+ # Flush current host state, then release command semaphore
+ self.emit('flush')
+ self.command_semaphore.release()
+
async def reset(self):
+ if self.ready:
+ self.ready = False
+ await self.flush()
+
await self.send_command(HCI_Reset_Command(), check_result=True)
self.ready = True
@@ -604,9 +616,9 @@
logger.debug('no long term key provider')
long_term_key = None
else:
- long_term_key = await self.long_term_key_provider(
+ long_term_key = await self.abort_on('flush', self.long_term_key_provider(
connection.handle, event.random_number, event.encryption_diversifier
- )
+ ))
if long_term_key:
response = HCI_LE_Long_Term_Key_Request_Reply_Command(
connection_handle=event.connection_handle,
@@ -719,7 +731,7 @@
logger.debug('no link key provider')
link_key = None
else:
- link_key = await self.link_key_provider(event.bd_addr)
+ link_key = await self.abort_on('flush', self.link_key_provider(event.bd_addr))
if link_key:
response = HCI_Link_Key_Request_Reply_Command(
bd_addr=event.bd_addr, link_key=link_key
diff --git a/bumble/smp.py b/bumble/smp.py
index e9d6fe3..be3eea3 100644
--- a/bumble/smp.py
+++ b/bumble/smp.py
@@ -766,7 +766,7 @@
self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR)
- asyncio.create_task(prompt())
+ self.connection.abort_on('disconnection', prompt())
def prompt_user_for_numeric_comparison(self, code, next_steps):
async def prompt():
@@ -783,7 +783,7 @@
self.send_pairing_failed(SMP_CONFIRM_VALUE_FAILED_ERROR)
- asyncio.create_task(prompt())
+ self.connection.abort_on('disconnection', prompt())
def prompt_user_for_number(self, next_steps):
async def prompt():
@@ -796,7 +796,7 @@
logger.warn(f'exception while prompting: {error}')
self.send_pairing_failed(SMP_PASSKEY_ENTRY_FAILED_ERROR)
- asyncio.create_task(prompt())
+ self.connection.abort_on('disconnection', prompt())
def display_passkey(self):
# Generate random Passkey/PIN code
@@ -808,7 +808,7 @@
self.tk = self.passkey.to_bytes(16, byteorder='little')
logger.debug(f'TK from passkey = {self.tk.hex()}')
- asyncio.create_task(
+ self.connection.abort_on('disconnection',
self.pairing_config.delegate.display_number(self.passkey, digits=6)
)
@@ -921,14 +921,12 @@
def start_encryption(self, key):
# We can now encrypt the connection with the short term key, so that we can
# distribute the long term and/or other keys over an encrypted connection
- asyncio.create_task(
- self.manager.device.host.send_command(
- HCI_LE_Enable_Encryption_Command(
- connection_handle=self.connection.handle,
- random_number=bytes(8),
- encrypted_diversifier=0,
- long_term_key=key,
- )
+ self.manager.device.host.send_command_sync(
+ HCI_LE_Enable_Encryption_Command(
+ connection_handle=self.connection.handle,
+ random_number=bytes(8),
+ encrypted_diversifier=0,
+ long_term_key=key
)
)
@@ -950,7 +948,7 @@
self.connection.transport == BT_BR_EDR_TRANSPORT
and self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
):
- self.ctkd_task = asyncio.create_task(self.derive_ltk())
+ self.ctkd_task = self.connection.abort_on('disconnection', self.derive_ltk())
elif not self.sc:
# Distribute the LTK, EDIV and RAND
if self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG:
@@ -997,7 +995,7 @@
self.connection.transport == BT_BR_EDR_TRANSPORT
and self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
):
- self.ctkd_task = asyncio.create_task(self.derive_ltk())
+ self.ctkd_task = self.connection.abort_on('disconnection', self.derive_ltk())
# Distribute the LTK, EDIV and RAND
elif not self.sc:
if self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG:
@@ -1094,7 +1092,7 @@
self.send_pairing_request_command()
# Wait for the pairing process to finish
- await self.pairing_result
+ await self.connection.abort_on('disconnection', self.pairing_result)
def on_disconnection(self, reason):
self.connection.remove_listener('disconnection', self.on_disconnection)
@@ -1112,7 +1110,7 @@
if self.is_initiator:
self.distribute_keys()
- asyncio.create_task(self.on_pairing())
+ self.connection.abort_on('disconnection', self.on_pairing())
def on_connection_encryption_change(self):
if self.connection.is_encrypted:
@@ -1219,7 +1217,7 @@
logger.error(color('SMP command not handled???', 'red'))
def on_smp_pairing_request_command(self, command):
- asyncio.create_task(self.on_smp_pairing_request_command_async(command))
+ self.connection.abort_on('disconnection', self.on_smp_pairing_request_command_async(command))
async def on_smp_pairing_request_command_async(self, command):
# Check if the request should proceed
@@ -1572,7 +1570,7 @@
self.wait_before_continuing = None
self.send_pairing_dhkey_check_command()
- asyncio.create_task(next_steps())
+ self.connection.abort_on('disconnection', next_steps())
else:
self.send_pairing_dhkey_check_command()
else:
@@ -1688,7 +1686,7 @@
except Exception as error:
logger.warn(f'!!! error while storing keys: {error}')
- asyncio.create_task(store_keys())
+ self.device.abort_on('flush', store_keys())
# Notify the device
self.device.on_pairing(session.connection.handle, keys, session.sc)
diff --git a/bumble/utils.py b/bumble/utils.py
index 92cef63..3345612 100644
--- a/bumble/utils.py
+++ b/bumble/utils.py
@@ -19,6 +19,8 @@
import logging
import traceback
import collections
+import sys
+from typing import Awaitable
from functools import wraps
from colors import color
from pyee import EventEmitter
@@ -62,7 +64,37 @@
# -----------------------------------------------------------------------------
-class CompositeEventEmitter(EventEmitter):
+class AbortableEventEmitter(EventEmitter):
+
+ def abort_on(self, event: str, awaitable: Awaitable):
+ """
+ Set a coroutine or future to abort when an event occur.
+ """
+ future = asyncio.ensure_future(awaitable)
+ if future.done():
+ return future
+
+ def on_event(*_):
+ msg = f'abort: {event} event occurred.'
+ if isinstance(future, asyncio.Task):
+ # python prior to 3.9 does not support passing a message on `Task.cancel`
+ if sys.version_info < (3, 9, 0):
+ future.cancel()
+ else:
+ future.cancel(msg)
+ else:
+ future.set_exception(asyncio.CancelledError(msg))
+
+ def on_done(_):
+ self.remove_listener(event, on_event)
+
+ self.on(event, on_event)
+ future.add_done_callback(on_done)
+ return future
+
+
+# -----------------------------------------------------------------------------
+class CompositeEventEmitter(AbortableEventEmitter):
def __init__(self):
super().__init__()
self._listener = None
diff --git a/tests/device_test.py b/tests/device_test.py
index 123df29..07aecdd 100644
--- a/tests/device_test.py
+++ b/tests/device_test.py
@@ -223,8 +223,16 @@
# -----------------------------------------------------------------------------
-async def run_test_device():
- await test_device_connect_parallel()
[email protected]
+async def test_flush():
+ d0 = Device(host=Host(None, None))
+ task = d0.abort_on('flush', asyncio.sleep(10000))
+ await d0.host.flush()
+ try:
+ await task
+ assert False
+ except asyncio.CancelledError:
+ pass
# -----------------------------------------------------------------------------
@@ -249,6 +257,14 @@
# -----------------------------------------------------------------------------
+async def run_test_device():
+ await test_device_connect_parallel()
+ await test_flush()
+ await test_gatt_services_with_gas()
+ await test_gatt_services_without_gas()
+
+
+# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run_test_device())