| // |
| // Copyright (C) 2020 The Android Open Source Project |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // |
| |
| #include "host/frontend/webrtc/libdevice/server_connection.h" |
| |
| #include <android-base/logging.h> |
| #include <libwebsockets.h> |
| |
| #include "common/libs/fs/shared_fd.h" |
| #include "common/libs/fs/shared_select.h" |
| #include "common/libs/utils/files.h" |
| |
| namespace cuttlefish { |
| namespace webrtc_streaming { |
| |
| // ServerConnection over Unix socket |
| class UnixServerConnection : public ServerConnection { |
| public: |
| UnixServerConnection(const std::string& addr, |
| std::weak_ptr<ServerConnectionObserver> observer); |
| ~UnixServerConnection() override; |
| |
| bool Send(const Json::Value& msg) override; |
| |
| private: |
| void Connect() override; |
| void StopThread(); |
| void ReadLoop(); |
| |
| const std::string addr_; |
| SharedFD conn_; |
| std::mutex write_mtx_; |
| std::weak_ptr<ServerConnectionObserver> observer_; |
| // The event fd must be declared before the thread to ensure it's initialized |
| // before the thread starts and is safe to be accessed from it. |
| SharedFD thread_notifier_; |
| std::atomic_bool running_ = false; |
| std::thread thread_; |
| }; |
| |
| // ServerConnection using websockets |
| class WsConnectionContext; |
| |
| class WsConnection : public std::enable_shared_from_this<WsConnection> { |
| public: |
| struct CreateConnectionSul { |
| lws_sorted_usec_list_t sul = {}; |
| std::weak_ptr<WsConnection> weak_this; |
| }; |
| |
| WsConnection(int port, const std::string& addr, const std::string& path, |
| ServerConfig::Security secure, |
| std::weak_ptr<ServerConnectionObserver> observer, |
| std::shared_ptr<WsConnectionContext> context); |
| |
| ~WsConnection(); |
| |
| void Connect(); |
| bool Send(const Json::Value& msg); |
| |
| void ConnectInner(); |
| |
| void OnError(const std::string& error); |
| void OnReceive(const uint8_t* data, size_t len, bool is_binary); |
| void OnOpen(); |
| void OnClose(); |
| void OnWriteable(); |
| |
| private: |
| struct WsBuffer { |
| WsBuffer() = default; |
| WsBuffer(const uint8_t* data, size_t len, bool binary) |
| : buffer_(LWS_PRE + len), is_binary_(binary) { |
| memcpy(&buffer_[LWS_PRE], data, len); |
| } |
| |
| uint8_t* data() { return &buffer_[LWS_PRE]; } |
| bool is_binary() const { return is_binary_; } |
| size_t size() const { return buffer_.size() - LWS_PRE; } |
| |
| private: |
| std::vector<uint8_t> buffer_; |
| bool is_binary_; |
| }; |
| bool Send(const uint8_t* data, size_t len, bool binary = false); |
| |
| CreateConnectionSul extended_sul_; |
| struct lws* wsi_; |
| const int port_; |
| const std::string addr_; |
| const std::string path_; |
| const ServerConfig::Security security_; |
| |
| std::weak_ptr<ServerConnectionObserver> observer_; |
| |
| // each element contains the data to be sent and whether it's binary or not |
| std::deque<WsBuffer> write_queue_; |
| std::mutex write_queue_mutex_; |
| // The connection object should not outlive the context object. This reference |
| // guarantees it. |
| std::shared_ptr<WsConnectionContext> context_; |
| }; |
| |
| class WsConnectionContext |
| : public std::enable_shared_from_this<WsConnectionContext> { |
| public: |
| static std::shared_ptr<WsConnectionContext> Create(); |
| |
| WsConnectionContext(struct lws_context* lws_ctx); |
| ~WsConnectionContext(); |
| |
| std::unique_ptr<ServerConnection> CreateConnection( |
| int port, const std::string& addr, const std::string& path, |
| ServerConfig::Security secure, |
| std::weak_ptr<ServerConnectionObserver> observer); |
| |
| void RememberConnection(void*, std::weak_ptr<WsConnection>); |
| void ForgetConnection(void*); |
| std::shared_ptr<WsConnection> GetConnection(void*); |
| |
| struct lws_context* lws_context() { |
| return lws_context_; |
| } |
| |
| private: |
| void Start(); |
| |
| std::map<void*, std::weak_ptr<WsConnection>> weak_by_ptr_; |
| std::mutex map_mutex_; |
| struct lws_context* lws_context_; |
| std::thread message_loop_; |
| }; |
| |
| std::unique_ptr<ServerConnection> ServerConnection::Connect( |
| const ServerConfig& conf, |
| std::weak_ptr<ServerConnectionObserver> observer) { |
| std::unique_ptr<ServerConnection> ret; |
| // If the provided address points to an existing UNIX socket in the file |
| // system connect to it, otherwise assume it's a network address and connect |
| // using websockets |
| if (FileIsSocket(conf.addr)) { |
| ret.reset(new UnixServerConnection(conf.addr, observer)); |
| } else { |
| // This can be a local variable since the ws connection will keep a |
| // reference to it. |
| auto ws_context = WsConnectionContext::Create(); |
| CHECK(ws_context) << "Failed to create websocket context"; |
| ret = ws_context->CreateConnection(conf.port, conf.addr, conf.path, |
| conf.security, observer); |
| } |
| ret->Connect(); |
| return ret; |
| } |
| |
| void ServerConnection::Reconnect() { Connect(); } |
| |
| // UnixServerConnection implementation |
| |
| UnixServerConnection::UnixServerConnection( |
| const std::string& addr, std::weak_ptr<ServerConnectionObserver> observer) |
| : addr_(addr), observer_(observer) {} |
| |
| UnixServerConnection::~UnixServerConnection() { |
| StopThread(); |
| } |
| |
| bool UnixServerConnection::Send(const Json::Value& msg) { |
| Json::StreamWriterBuilder factory; |
| auto str = Json::writeString(factory, msg); |
| std::lock_guard<std::mutex> lock(write_mtx_); |
| auto res = |
| conn_->Send(reinterpret_cast<const uint8_t*>(str.c_str()), str.size(), 0); |
| if (res < 0) { |
| LOG(ERROR) << "Failed to send data to signaling server: " |
| << conn_->StrError(); |
| // Don't call OnError() here, the receiving thread probably did it already |
| // or is about to do it. |
| } |
| // A SOCK_SEQPACKET unix socket will send the entire message or fail, but it |
| // won't send a partial message. |
| return res == str.size(); |
| } |
| |
| void UnixServerConnection::Connect() { |
| // The thread could be running if this is a Reconnect |
| StopThread(); |
| |
| conn_ = SharedFD::SocketLocalClient(addr_, false, SOCK_SEQPACKET); |
| if (!conn_->IsOpen()) { |
| LOG(ERROR) << "Failed to connect to unix socket: " << conn_->StrError(); |
| if (auto o = observer_.lock(); o) { |
| o->OnError("Failed to connect to unix socket"); |
| } |
| return; |
| } |
| thread_notifier_ = SharedFD::Event(); |
| if (!thread_notifier_->IsOpen()) { |
| LOG(ERROR) << "Failed to create eventfd for background thread: " |
| << thread_notifier_->StrError(); |
| if (auto o = observer_.lock(); o) { |
| o->OnError("Failed to create eventfd for background thread"); |
| } |
| return; |
| } |
| if (auto o = observer_.lock(); o) { |
| o->OnOpen(); |
| } |
| // Start the thread |
| running_ = true; |
| thread_ = std::thread([this](){ReadLoop();}); |
| } |
| |
| void UnixServerConnection::StopThread() { |
| running_ = false; |
| if (!thread_notifier_->IsOpen()) { |
| // The thread won't be running if this isn't open |
| return; |
| } |
| if (thread_notifier_->EventfdWrite(1) < 0) { |
| LOG(ERROR) << "Failed to notify background thread, this thread may block"; |
| } |
| if (thread_.joinable()) { |
| thread_.join(); |
| } |
| } |
| |
| void UnixServerConnection::ReadLoop() { |
| if (!thread_notifier_->IsOpen()) { |
| LOG(ERROR) << "The UnixServerConnection's background thread is unable to " |
| "receive notifications so it can't run"; |
| return; |
| } |
| std::vector<uint8_t> buffer(4096, 0); |
| while (running_) { |
| SharedFDSet rset; |
| rset.Set(thread_notifier_); |
| rset.Set(conn_); |
| auto res = Select(&rset, nullptr, nullptr, nullptr); |
| if (res < 0) { |
| LOG(ERROR) << "Failed to select from background thread"; |
| break; |
| } |
| if (rset.IsSet(thread_notifier_)) { |
| eventfd_t val; |
| auto res = thread_notifier_->EventfdRead(&val); |
| if (res < 0) { |
| LOG(ERROR) << "Error reading from event fd: " |
| << thread_notifier_->StrError(); |
| break; |
| } |
| } |
| if (rset.IsSet(conn_)) { |
| auto size = conn_->Recv(buffer.data(), 0, MSG_TRUNC | MSG_PEEK); |
| if (size > buffer.size()) { |
| // Enlarge enough to accommodate size bytes and be a multiple of 4096 |
| auto new_size = (size + 4095) & ~4095; |
| buffer.resize(new_size); |
| } |
| auto res = conn_->Recv(buffer.data(), buffer.size(), MSG_TRUNC); |
| if (res < 0) { |
| LOG(ERROR) << "Failed to read from server: " << conn_->StrError(); |
| if (auto observer = observer_.lock(); observer) { |
| observer->OnError(conn_->StrError()); |
| } |
| return; |
| } |
| if (res == 0) { |
| auto observer = observer_.lock(); |
| if (observer) { |
| observer->OnClose(); |
| } |
| break; |
| } |
| auto observer = observer_.lock(); |
| if (observer) { |
| observer->OnReceive(buffer.data(), res, false); |
| } |
| } |
| } |
| } |
| |
| // WsConnection implementation |
| |
| int LwsCallback(struct lws* wsi, enum lws_callback_reasons reason, void* user, |
| void* in, size_t len); |
| void CreateConnectionCallback(lws_sorted_usec_list_t* sul); |
| |
| namespace { |
| |
| constexpr char kProtocolName[] = "cf-webrtc-device"; |
| constexpr int kBufferSize = 65536; |
| |
| const uint32_t backoff_ms[] = {1000, 2000, 3000, 4000, 5000}; |
| |
| const lws_retry_bo_t kRetry = { |
| .retry_ms_table = backoff_ms, |
| .retry_ms_table_count = LWS_ARRAY_SIZE(backoff_ms), |
| .conceal_count = LWS_ARRAY_SIZE(backoff_ms), |
| |
| .secs_since_valid_ping = 3, /* force PINGs after secs idle */ |
| .secs_since_valid_hangup = 10, /* hangup after secs idle */ |
| |
| .jitter_percent = 20, |
| }; |
| |
| const struct lws_protocols kProtocols[2] = { |
| {kProtocolName, LwsCallback, 0, kBufferSize, 0, NULL, 0}, |
| {NULL, NULL, 0, 0, 0, NULL, 0}}; |
| |
| } // namespace |
| |
| std::shared_ptr<WsConnectionContext> WsConnectionContext::Create() { |
| struct lws_context_creation_info context_info = {}; |
| context_info.port = CONTEXT_PORT_NO_LISTEN; |
| context_info.options = LWS_SERVER_OPTION_DO_SSL_GLOBAL_INIT; |
| context_info.protocols = kProtocols; |
| struct lws_context* lws_ctx = lws_create_context(&context_info); |
| if (!lws_ctx) { |
| return nullptr; |
| } |
| return std::shared_ptr<WsConnectionContext>(new WsConnectionContext(lws_ctx)); |
| } |
| |
| WsConnectionContext::WsConnectionContext(struct lws_context* lws_ctx) |
| : lws_context_(lws_ctx) { |
| Start(); |
| } |
| |
| WsConnectionContext::~WsConnectionContext() { |
| lws_context_destroy(lws_context_); |
| if (message_loop_.joinable()) { |
| message_loop_.join(); |
| } |
| } |
| |
| void WsConnectionContext::Start() { |
| message_loop_ = std::thread([this]() { |
| for (;;) { |
| if (lws_service(lws_context_, 0) < 0) { |
| break; |
| } |
| } |
| }); |
| } |
| |
| // This wrapper is needed because the ServerConnection objects are meant to be |
| // referenced by std::unique_ptr but WsConnection needs to be referenced by |
| // std::shared_ptr because it's also (weakly) referenced by the websocket |
| // thread. |
| class WsConnectionWrapper : public ServerConnection { |
| public: |
| WsConnectionWrapper(std::shared_ptr<WsConnection> conn) : conn_(conn) {} |
| |
| bool Send(const Json::Value& msg) override { return conn_->Send(msg); } |
| |
| private: |
| void Connect() override { return conn_->Connect(); } |
| std::shared_ptr<WsConnection> conn_; |
| }; |
| |
| std::unique_ptr<ServerConnection> WsConnectionContext::CreateConnection( |
| int port, const std::string& addr, const std::string& path, |
| ServerConfig::Security security, |
| std::weak_ptr<ServerConnectionObserver> observer) { |
| return std::unique_ptr<ServerConnection>( |
| new WsConnectionWrapper(std::make_shared<WsConnection>( |
| port, addr, path, security, observer, shared_from_this()))); |
| } |
| |
| std::shared_ptr<WsConnection> WsConnectionContext::GetConnection(void* raw) { |
| std::shared_ptr<WsConnection> connection; |
| { |
| std::lock_guard<std::mutex> lock(map_mutex_); |
| if (weak_by_ptr_.count(raw) == 0) { |
| return nullptr; |
| } |
| connection = weak_by_ptr_[raw].lock(); |
| if (!connection) { |
| weak_by_ptr_.erase(raw); |
| } |
| } |
| return connection; |
| } |
| |
| void WsConnectionContext::RememberConnection(void* raw, |
| std::weak_ptr<WsConnection> conn) { |
| std::lock_guard<std::mutex> lock(map_mutex_); |
| weak_by_ptr_.emplace( |
| std::pair<void*, std::weak_ptr<WsConnection>>(raw, conn)); |
| } |
| |
| void WsConnectionContext::ForgetConnection(void* raw) { |
| std::lock_guard<std::mutex> lock(map_mutex_); |
| weak_by_ptr_.erase(raw); |
| } |
| |
| WsConnection::WsConnection(int port, const std::string& addr, |
| const std::string& path, |
| ServerConfig::Security security, |
| std::weak_ptr<ServerConnectionObserver> observer, |
| std::shared_ptr<WsConnectionContext> context) |
| : port_(port), |
| addr_(addr), |
| path_(path), |
| security_(security), |
| observer_(observer), |
| context_(context) {} |
| |
| WsConnection::~WsConnection() { |
| context_->ForgetConnection(this); |
| // This will cause the callback to be called which will drop the connection |
| // after seeing the context doesn't remember this object |
| lws_callback_on_writable(wsi_); |
| } |
| |
| void WsConnection::Connect() { |
| memset(&extended_sul_.sul, 0, sizeof(extended_sul_.sul)); |
| extended_sul_.weak_this = weak_from_this(); |
| lws_sul_schedule(context_->lws_context(), 0, &extended_sul_.sul, |
| CreateConnectionCallback, 1); |
| } |
| |
| void WsConnection::OnError(const std::string& error) { |
| auto observer = observer_.lock(); |
| if (observer) { |
| observer->OnError(error); |
| } |
| } |
| void WsConnection::OnReceive(const uint8_t* data, size_t len, bool is_binary) { |
| auto observer = observer_.lock(); |
| if (observer) { |
| observer->OnReceive(data, len, is_binary); |
| } |
| } |
| void WsConnection::OnOpen() { |
| auto observer = observer_.lock(); |
| if (observer) { |
| observer->OnOpen(); |
| } |
| } |
| void WsConnection::OnClose() { |
| auto observer = observer_.lock(); |
| if (observer) { |
| observer->OnClose(); |
| } |
| } |
| |
| void WsConnection::OnWriteable() { |
| WsBuffer buffer; |
| { |
| std::lock_guard<std::mutex> lock(write_queue_mutex_); |
| if (write_queue_.size() == 0) { |
| return; |
| } |
| buffer = std::move(write_queue_.front()); |
| write_queue_.pop_front(); |
| } |
| auto flags = lws_write_ws_flags( |
| buffer.is_binary() ? LWS_WRITE_BINARY : LWS_WRITE_TEXT, true, true); |
| auto res = lws_write(wsi_, buffer.data(), buffer.size(), |
| (enum lws_write_protocol)flags); |
| if (res != buffer.size()) { |
| LOG(WARNING) << "Unable to send the entire message!"; |
| } |
| } |
| |
| bool WsConnection::Send(const Json::Value& msg) { |
| Json::StreamWriterBuilder factory; |
| auto str = Json::writeString(factory, msg); |
| return Send(reinterpret_cast<const uint8_t*>(str.c_str()), str.size()); |
| } |
| |
| bool WsConnection::Send(const uint8_t* data, size_t len, bool binary) { |
| if (!wsi_) { |
| LOG(WARNING) << "Send called on an uninitialized connection!!"; |
| return false; |
| } |
| WsBuffer buffer(data, len, binary); |
| { |
| std::lock_guard<std::mutex> lock(write_queue_mutex_); |
| write_queue_.emplace_back(std::move(buffer)); |
| } |
| |
| lws_callback_on_writable(wsi_); |
| return true; |
| } |
| |
| int LwsCallback(struct lws* wsi, enum lws_callback_reasons reason, void* user, |
| void* in, size_t len) { |
| constexpr int DROP = -1; |
| constexpr int OK = 0; |
| |
| // For some values of `reason`, `user` doesn't point to the value provided |
| // when the connection was created. This function object should be used with |
| // care. |
| auto with_connection = |
| [wsi, user](std::function<void(std::shared_ptr<WsConnection>)> cb) { |
| auto context = reinterpret_cast<WsConnectionContext*>(user); |
| auto connection = context->GetConnection(wsi); |
| if (!connection) { |
| return DROP; |
| } |
| cb(connection); |
| return OK; |
| }; |
| |
| switch (reason) { |
| case LWS_CALLBACK_CLIENT_CONNECTION_ERROR: |
| return with_connection([in](std::shared_ptr<WsConnection> connection) { |
| connection->OnError(in ? (char*)in : "(null)"); |
| }); |
| |
| case LWS_CALLBACK_CLIENT_RECEIVE: |
| return with_connection( |
| [in, len, wsi](std::shared_ptr<WsConnection> connection) { |
| connection->OnReceive((const uint8_t*)in, len, |
| lws_frame_is_binary(wsi)); |
| }); |
| |
| case LWS_CALLBACK_CLIENT_ESTABLISHED: |
| return with_connection([](std::shared_ptr<WsConnection> connection) { |
| connection->OnOpen(); |
| }); |
| |
| case LWS_CALLBACK_CLIENT_CLOSED: |
| return with_connection([](std::shared_ptr<WsConnection> connection) { |
| connection->OnClose(); |
| }); |
| |
| case LWS_CALLBACK_CLIENT_WRITEABLE: |
| return with_connection([](std::shared_ptr<WsConnection> connection) { |
| connection->OnWriteable(); |
| }); |
| |
| default: |
| LOG(VERBOSE) << "Unhandled value: " << reason; |
| return lws_callback_http_dummy(wsi, reason, user, in, len); |
| } |
| } |
| |
| void CreateConnectionCallback(lws_sorted_usec_list_t* sul) { |
| std::shared_ptr<WsConnection> connection = |
| reinterpret_cast<WsConnection::CreateConnectionSul*>(sul) |
| ->weak_this.lock(); |
| if (!connection) { |
| LOG(WARNING) << "The object was already destroyed by the time of the first " |
| << "connection attempt. That's unusual."; |
| return; |
| } |
| connection->ConnectInner(); |
| } |
| |
| void WsConnection::ConnectInner() { |
| struct lws_client_connect_info connect_info; |
| |
| memset(&connect_info, 0, sizeof(connect_info)); |
| |
| connect_info.context = context_->lws_context(); |
| connect_info.port = port_; |
| connect_info.address = addr_.c_str(); |
| connect_info.path = path_.c_str(); |
| connect_info.host = connect_info.address; |
| connect_info.origin = connect_info.address; |
| switch (security_) { |
| case ServerConfig::Security::kAllowSelfSigned: |
| connect_info.ssl_connection = LCCSCF_ALLOW_SELFSIGNED | |
| LCCSCF_SKIP_SERVER_CERT_HOSTNAME_CHECK | |
| LCCSCF_USE_SSL; |
| break; |
| case ServerConfig::Security::kStrict: |
| connect_info.ssl_connection = LCCSCF_USE_SSL; |
| break; |
| case ServerConfig::Security::kInsecure: |
| connect_info.ssl_connection = 0; |
| break; |
| } |
| connect_info.protocol = "webrtc-operator"; |
| connect_info.local_protocol_name = kProtocolName; |
| connect_info.pwsi = &wsi_; |
| connect_info.retry_and_idle_policy = &kRetry; |
| // There is no guarantee the connection object still exists when the callback |
| // is called. Put the context instead as the user data which is guaranteed to |
| // still exist and holds a weak ptr to the connection. |
| connect_info.userdata = context_.get(); |
| |
| if (lws_client_connect_via_info(&connect_info)) { |
| // wsi_ is not initialized until after the call to |
| // lws_client_connect_via_info(). Luckily, this is guaranteed to run before |
| // the protocol callback is called because it runs in the same loop. |
| context_->RememberConnection(wsi_, weak_from_this()); |
| } else { |
| LOG(ERROR) << "Connection failed!"; |
| } |
| } |
| |
| } // namespace webrtc_streaming |
| } // namespace cuttlefish |