| # Copyright 2021-2022 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # https://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| # ---------------------------------------------------------------------------- |
| # Imports |
| # ---------------------------------------------------------------------------- |
| import sys |
| import websockets |
| import logging |
| import json |
| import asyncio |
| import argparse |
| import uuid |
| import os |
| from urllib.parse import urlparse |
| from colors import color |
| |
| # ----------------------------------------------------------------------------- |
| # Logging |
| # ----------------------------------------------------------------------------- |
| logger = logging.getLogger(__name__) |
| |
| |
| # ---------------------------------------------------------------------------- |
| # Constants |
| # ---------------------------------------------------------------------------- |
| DEFAULT_RELAY_PORT = 10723 |
| |
| |
| # ---------------------------------------------------------------------------- |
| # Utils |
| # ---------------------------------------------------------------------------- |
| def error_to_json(error): |
| return json.dumps({'error': error}) |
| |
| |
| def error_to_result(error): |
| return f'result:{error_to_json(error)}' |
| |
| |
| async def broadcast_message(message, connections): |
| # Send to all the connections |
| tasks = [connection.send_message(message) for connection in connections] |
| if tasks: |
| await asyncio.gather(*tasks) |
| |
| |
| # ---------------------------------------------------------------------------- |
| # Connection class |
| # ---------------------------------------------------------------------------- |
| class Connection: |
| """ |
| A Connection represents a client connected to the relay over a websocket |
| """ |
| |
| def __init__(self, room, websocket): |
| self.room = room |
| self.websocket = websocket |
| self.address = str(uuid.uuid4()) |
| |
| async def send_message(self, message): |
| try: |
| logger.debug(color(f'->{self.address}: {message}', 'yellow')) |
| return await self.websocket.send(message) |
| except websockets.exceptions.WebSocketException as error: |
| logger.info(f'! client "{self}" disconnected: {error}') |
| await self.cleanup() |
| |
| async def send_error(self, error): |
| return await self.send_message(f'result:{error_to_json(error)}') |
| |
| async def receive_message(self): |
| try: |
| message = await self.websocket.recv() |
| logger.debug(color(f'<-{self.address}: {message}', 'blue')) |
| return message |
| except websockets.exceptions.WebSocketException as error: |
| logger.info(color(f'! client "{self}" disconnected: {error}', 'red')) |
| await self.cleanup() |
| |
| async def cleanup(self): |
| if self.room: |
| await self.room.remove_connection(self) |
| |
| def set_address(self, address): |
| logger.info(f'Connection address changed: {self.address} -> {address}') |
| self.address = address |
| |
| def __str__(self): |
| return f'Connection(address="{self.address}", client={self.websocket.remote_address[0]}:{self.websocket.remote_address[1]})' |
| |
| |
| # ---------------------------------------------------------------------------- |
| # Room class |
| # ---------------------------------------------------------------------------- |
| class Room: |
| """ |
| A Room is a collection of bridged connections |
| """ |
| |
| def __init__(self, relay, name): |
| self.relay = relay |
| self.name = name |
| self.observers = [] |
| self.connections = [] |
| |
| async def add_connection(self, connection): |
| logger.info(f'New participant in {self.name}: {connection}') |
| self.connections.append(connection) |
| await self.broadcast_message(connection, f'joined:{connection.address}') |
| |
| async def remove_connection(self, connection): |
| if connection in self.connections: |
| self.connections.remove(connection) |
| await self.broadcast_message(connection, f'left:{connection.address}') |
| |
| def find_connections_by_address(self, address): |
| return [c for c in self.connections if c.address == address] |
| |
| async def bridge_connection(self, connection): |
| while True: |
| # Wait for a message |
| message = await connection.receive_message() |
| |
| # Skip empty messages |
| if message is None: |
| return |
| |
| # Parse the message to decide how to handle it |
| if message.startswith('@'): |
| # This is a targetted message |
| await self.on_targetted_message(connection, message) |
| elif message.startswith('/'): |
| # This is an RPC request |
| await self.on_rpc_request(connection, message) |
| else: |
| await connection.send_message(f'result:{error_to_json("error: invalid message")}') |
| |
| async def broadcast_message(self, sender, message): |
| ''' |
| Send to all connections in the room except back to the sender |
| ''' |
| await broadcast_message(message, [c for c in self.connections if c != sender]) |
| |
| async def on_rpc_request(self, connection, message): |
| command, *params = message.split(' ', 1) |
| if handler := getattr(self, f'on_{command[1:].lower().replace("-","_")}_command', None): |
| try: |
| result = await handler(connection, params) |
| except Exception as error: |
| result = error_to_result(error) |
| else: |
| result = error_to_result('unknown command') |
| |
| await connection.send_message(result or 'result:{}') |
| |
| async def on_targetted_message(self, connection, message): |
| target, *payload = message.split(' ', 1) |
| if not payload: |
| return error_to_json('missing arguments') |
| payload = payload[0] |
| target = target[1:] |
| |
| # Determine what targets to send to |
| if target == '*': |
| # Send to all connections in the room except the connection from which the message was received |
| connections = [c for c in self.connections if c != connection] |
| else: |
| connections = self.find_connections_by_address(target) |
| if not connections: |
| # Unicast with no recipient, let the sender know |
| await connection.send_message(f'unreachable:{target}') |
| |
| # Send to targets |
| await broadcast_message(f'message:{connection.address}/{payload}', connections) |
| |
| async def on_set_address_command(self, connection, params): |
| if not params: |
| return error_to_result('missing address') |
| |
| current_address = connection.address |
| new_address = params[0] |
| connection.set_address(new_address) |
| await self.broadcast_message(connection, f'address-changed:from={current_address},to={new_address}') |
| |
| |
| # ---------------------------------------------------------------------------- |
| class Relay: |
| """ |
| A relay accepts connections with the following url: ws://<hostname>/<room>. |
| Participants in a room can communicate with each other |
| """ |
| |
| def __init__(self, port): |
| self.port = port |
| self.rooms = {} |
| self.observers = [] |
| |
| def start(self): |
| logger.info(f'Starting Relay on port {self.port}') |
| |
| return websockets.serve(self.serve, '0.0.0.0', self.port, ping_interval=None) |
| |
| async def serve_as_controller(connection): |
| pass |
| |
| async def serve(self, websocket, path): |
| logger.debug(f'New connection with path {path}') |
| |
| # Parse the path |
| parsed = urlparse(path) |
| |
| # Check if this is a controller client |
| if parsed.path == '/': |
| return await self.serve_as_controller(Connection('', websocket)) |
| |
| # Find or create a room for this connection |
| room_name = parsed.path[1:].split('/')[0] |
| if room_name not in self.rooms: |
| self.rooms[room_name] = Room(self, room_name) |
| room = self.rooms[room_name] |
| |
| # Add the connection to the room |
| connection = Connection(room, websocket) |
| await room.add_connection(connection) |
| |
| # Bridge until the connection is closed |
| await room.bridge_connection(connection) |
| |
| |
| # ---------------------------------------------------------------------------- |
| def main(): |
| # Check the Python version |
| if sys.version_info < (3, 6, 1): |
| print('ERROR: Python 3.6.1 or higher is required') |
| sys.exit(1) |
| |
| logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) |
| |
| # Parse arguments |
| arg_parser = argparse.ArgumentParser(description='Bumble Link Relay') |
| arg_parser.add_argument('--log-level', default='INFO', help='logger level') |
| arg_parser.add_argument('--log-config', help='logger config file (YAML)') |
| arg_parser.add_argument('--port', |
| type = int, |
| default = DEFAULT_RELAY_PORT, |
| help = 'Port to listen on') |
| args = arg_parser.parse_args() |
| |
| # Setup logger |
| if args.log_config: |
| from logging import config |
| config.fileConfig(args.log_config) |
| else: |
| logging.basicConfig(level = getattr(logging, args.log_level.upper())) |
| |
| # Start a relay |
| relay = Relay(args.port) |
| asyncio.get_event_loop().run_until_complete(relay.start()) |
| asyncio.get_event_loop().run_forever() |
| |
| |
| # ---------------------------------------------------------------------------- |
| if __name__ == '__main__': |
| main() |