| // Copyright 2016 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "mojo/core/broker_host.h" |
| |
| #include <utility> |
| |
| #include "base/logging.h" |
| #include "base/memory/platform_shared_memory_region.h" |
| #include "base/memory/ref_counted.h" |
| #include "base/threading/thread_task_runner_handle.h" |
| #include "build/build_config.h" |
| #include "mojo/core/broker_messages.h" |
| #include "mojo/core/platform_handle_utils.h" |
| |
| #if defined(OS_WIN) |
| #include <windows.h> |
| #endif |
| |
| namespace mojo { |
| namespace core { |
| |
| BrokerHost::BrokerHost(base::ProcessHandle client_process, |
| ConnectionParams connection_params, |
| const ProcessErrorCallback& process_error_callback) |
| : process_error_callback_(process_error_callback) |
| #if defined(OS_WIN) |
| , |
| client_process_(ScopedProcessHandle::CloneFrom(client_process)) |
| #endif |
| { |
| CHECK(connection_params.endpoint().is_valid() || |
| connection_params.server_endpoint().is_valid()); |
| |
| base::MessageLoopCurrent::Get()->AddDestructionObserver(this); |
| |
| channel_ = Channel::Create(this, std::move(connection_params), |
| base::ThreadTaskRunnerHandle::Get()); |
| channel_->Start(); |
| } |
| |
| BrokerHost::~BrokerHost() { |
| // We're always destroyed on the creation thread, which is the IO thread. |
| base::MessageLoopCurrent::Get()->RemoveDestructionObserver(this); |
| |
| if (channel_) |
| channel_->ShutDown(); |
| } |
| |
| bool BrokerHost::PrepareHandlesForClient( |
| std::vector<PlatformHandleInTransit>* handles) { |
| #if defined(OS_WIN) |
| bool handles_ok = true; |
| for (auto& handle : *handles) { |
| if (!handle.TransferToProcess(client_process_.Clone())) |
| handles_ok = false; |
| } |
| return handles_ok; |
| #else |
| return true; |
| #endif |
| } |
| |
| bool BrokerHost::SendChannel(PlatformHandle handle) { |
| CHECK(handle.is_valid()); |
| CHECK(channel_); |
| |
| #if defined(OS_WIN) |
| InitData* data; |
| Channel::MessagePtr message = |
| CreateBrokerMessage(BrokerMessageType::INIT, 1, 0, &data); |
| data->pipe_name_length = 0; |
| #else |
| Channel::MessagePtr message = |
| CreateBrokerMessage(BrokerMessageType::INIT, 1, nullptr); |
| #endif |
| std::vector<PlatformHandleInTransit> handles(1); |
| handles[0] = PlatformHandleInTransit(std::move(handle)); |
| |
| // This may legitimately fail on Windows if the client process is in another |
| // session, e.g., is an elevated process. |
| if (!PrepareHandlesForClient(&handles)) |
| return false; |
| |
| message->SetHandles(std::move(handles)); |
| channel_->Write(std::move(message)); |
| return true; |
| } |
| |
| #if defined(OS_WIN) |
| |
| void BrokerHost::SendNamedChannel(const base::StringPiece16& pipe_name) { |
| InitData* data; |
| base::char16* name_data; |
| Channel::MessagePtr message = CreateBrokerMessage( |
| BrokerMessageType::INIT, 0, sizeof(*name_data) * pipe_name.length(), |
| &data, reinterpret_cast<void**>(&name_data)); |
| data->pipe_name_length = static_cast<uint32_t>(pipe_name.length()); |
| std::copy(pipe_name.begin(), pipe_name.end(), name_data); |
| channel_->Write(std::move(message)); |
| } |
| |
| #endif // defined(OS_WIN) |
| |
| void BrokerHost::OnBufferRequest(uint32_t num_bytes) { |
| base::subtle::PlatformSharedMemoryRegion region = |
| base::subtle::PlatformSharedMemoryRegion::CreateWritable(num_bytes); |
| |
| std::vector<PlatformHandleInTransit> handles(2); |
| if (region.IsValid()) { |
| PlatformHandle h[2]; |
| ExtractPlatformHandlesFromSharedMemoryRegionHandle( |
| region.PassPlatformHandle(), &h[0], &h[1]); |
| handles[0] = PlatformHandleInTransit(std::move(h[0])); |
| handles[1] = PlatformHandleInTransit(std::move(h[1])); |
| #if !defined(OS_POSIX) || defined(OS_ANDROID) || defined(OS_FUCHSIA) || \ |
| (defined(OS_MACOSX) && !defined(OS_IOS)) |
| // Non-POSIX systems, as well as Android, Fuchsia, and non-iOS Mac, only use |
| // a single handle to represent a writable region. |
| DCHECK(!handles[1].handle().is_valid()); |
| handles.resize(1); |
| #else |
| DCHECK(handles[1].handle().is_valid()); |
| #endif |
| } |
| |
| BufferResponseData* response; |
| Channel::MessagePtr message = CreateBrokerMessage( |
| BrokerMessageType::BUFFER_RESPONSE, handles.size(), 0, &response); |
| if (!handles.empty()) { |
| base::UnguessableToken guid = region.GetGUID(); |
| response->guid_high = guid.GetHighForSerialization(); |
| response->guid_low = guid.GetLowForSerialization(); |
| PrepareHandlesForClient(&handles); |
| message->SetHandles(std::move(handles)); |
| } |
| |
| channel_->Write(std::move(message)); |
| } |
| |
| void BrokerHost::OnChannelMessage(const void* payload, |
| size_t payload_size, |
| std::vector<PlatformHandle> handles) { |
| if (payload_size < sizeof(BrokerMessageHeader)) |
| return; |
| |
| const BrokerMessageHeader* header = |
| static_cast<const BrokerMessageHeader*>(payload); |
| switch (header->type) { |
| case BrokerMessageType::BUFFER_REQUEST: |
| if (payload_size == |
| sizeof(BrokerMessageHeader) + sizeof(BufferRequestData)) { |
| const BufferRequestData* request = |
| reinterpret_cast<const BufferRequestData*>(header + 1); |
| OnBufferRequest(request->size); |
| } |
| break; |
| |
| default: |
| DLOG(ERROR) << "Unexpected broker message type: " << header->type; |
| break; |
| } |
| } |
| |
| void BrokerHost::OnChannelError(Channel::Error error) { |
| if (process_error_callback_ && |
| error == Channel::Error::kReceivedMalformedData) { |
| process_error_callback_.Run("Broker host received malformed message"); |
| } |
| |
| delete this; |
| } |
| |
| void BrokerHost::WillDestroyCurrentMessageLoop() { |
| delete this; |
| } |
| |
| } // namespace core |
| } // namespace mojo |