blob: d8cf11f2a7c7ca63f2efba7694526434e98b9037 [file] [log] [blame]
// Copyright (c) 2016 Sandstorm Development Group, Inc. and contributors
// Licensed under the MIT License:
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#if _WIN32
// For Unix implementation, see async-io-unix.c++.
// Request Vista-level APIs.
#include "win32-api-version.h"
#include "async-io.h"
#include "async-io-internal.h"
#include "async-win32.h"
#include "debug.h"
#include "thread.h"
#include "io.h"
#include "vector.h"
#include <set>
#include <winsock2.h>
#include <ws2ipdef.h>
#include <ws2tcpip.h>
#include <mswsock.h>
#include <stdlib.h>
#ifndef IPV6_V6ONLY
// MinGW's headers are missing this.
#define IPV6_V6ONLY 27
#endif
namespace kj {
namespace _ { // private
struct WinsockInitializer {
WinsockInitializer() {
WSADATA dontcare;
int result = WSAStartup(MAKEWORD(2, 2), &dontcare);
if (result != 0) {
KJ_FAIL_WIN32("WSAStartup()", result);
}
}
};
void initWinsockOnce() {
static WinsockInitializer initializer;
}
int win32Socketpair(SOCKET socks[2]) {
// This function from: https://github.com/ncm/selectable-socketpair/blob/master/socketpair.c
//
// Copyright notice:
//
// Copyright 2007, 2010 by Nathan C. Myers <ncm@cantrip.org>
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
// Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// The name of the author must not be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Note: This function is called from some Cap'n Proto unit tests, despite not having a public
// header declaration.
// TODO(cleanup): Consider putting this somewhere public? Note that since it depends on Winsock,
// it needs to be in the kj-async library.
initWinsockOnce();
union {
struct sockaddr_in inaddr;
struct sockaddr addr;
} a;
SOCKET listener;
int e;
socklen_t addrlen = sizeof(a.inaddr);
int reuse = 1;
if (socks == 0) {
WSASetLastError(WSAEINVAL);
return SOCKET_ERROR;
}
socks[0] = socks[1] = -1;
listener = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
if (listener == -1)
return SOCKET_ERROR;
memset(&a, 0, sizeof(a));
a.inaddr.sin_family = AF_INET;
a.inaddr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
a.inaddr.sin_port = 0;
for (;;) {
if (setsockopt(listener, SOL_SOCKET, SO_REUSEADDR,
(char*) &reuse, (socklen_t) sizeof(reuse)) == -1)
break;
if (bind(listener, &a.addr, sizeof(a.inaddr)) == SOCKET_ERROR)
break;
memset(&a, 0, sizeof(a));
if (getsockname(listener, &a.addr, &addrlen) == SOCKET_ERROR)
break;
// win32 getsockname may only set the port number, p=0.0005.
// ( http://msdn.microsoft.com/library/ms738543.aspx ):
a.inaddr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
a.inaddr.sin_family = AF_INET;
if (listen(listener, 1) == SOCKET_ERROR)
break;
socks[0] = WSASocket(AF_INET, SOCK_STREAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED);
if (socks[0] == -1)
break;
if (connect(socks[0], &a.addr, sizeof(a.inaddr)) == SOCKET_ERROR)
break;
retryAccept:
socks[1] = accept(listener, NULL, NULL);
if (socks[1] == -1)
break;
// Verify that the client is actually us and not someone else who raced to connect first.
// (This check added by Kenton for security.)
union {
struct sockaddr_in inaddr;
struct sockaddr addr;
} b, c;
socklen_t bAddrlen = sizeof(b.inaddr);
socklen_t cAddrlen = sizeof(b.inaddr);
if (getpeername(socks[1], &b.addr, &bAddrlen) == SOCKET_ERROR)
break;
if (getsockname(socks[0], &c.addr, &cAddrlen) == SOCKET_ERROR)
break;
if (bAddrlen != cAddrlen || memcmp(&b.addr, &c.addr, bAddrlen) != 0) {
// Someone raced to connect first. Ignore.
closesocket(socks[1]);
goto retryAccept;
}
closesocket(listener);
return 0;
}
e = WSAGetLastError();
closesocket(listener);
closesocket(socks[0]);
closesocket(socks[1]);
WSASetLastError(e);
socks[0] = socks[1] = -1;
return SOCKET_ERROR;
}
} // namespace _
namespace {
// =======================================================================================
static constexpr uint NEW_FD_FLAGS = LowLevelAsyncIoProvider::TAKE_OWNERSHIP;
class OwnedFd {
public:
OwnedFd(SOCKET fd, uint flags): fd(fd), flags(flags) {
// TODO(perf): Maybe use SetFileCompletionNotificationModes() to tell Windows not to bother
// delivering an event when the operation completes inline. Not currently implemented on
// Wine, though.
}
~OwnedFd() noexcept(false) {
if (flags & LowLevelAsyncIoProvider::TAKE_OWNERSHIP) {
KJ_WINSOCK(closesocket(fd)) { break; }
}
}
protected:
SOCKET fd;
private:
uint flags;
};
// =======================================================================================
class AsyncStreamFd: public OwnedFd, public AsyncIoStream {
public:
AsyncStreamFd(Win32EventPort& eventPort, SOCKET fd, uint flags)
: OwnedFd(fd, flags),
observer(eventPort.observeIo(reinterpret_cast<HANDLE>(fd))) {}
virtual ~AsyncStreamFd() noexcept(false) {}
Promise<size_t> read(void* buffer, size_t minBytes, size_t maxBytes) override {
return tryRead(buffer, minBytes, maxBytes).then([=](size_t result) {
KJ_REQUIRE(result >= minBytes, "Premature EOF") {
// Pretend we read zeros from the input.
memset(reinterpret_cast<byte*>(buffer) + result, 0, minBytes - result);
return minBytes;
}
return result;
});
}
Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
auto bufs = heapArray<WSABUF>(1);
bufs[0].buf = reinterpret_cast<char*>(buffer);
bufs[0].len = maxBytes;
ArrayPtr<WSABUF> ref = bufs;
return tryReadInternal(ref, minBytes, 0).attach(kj::mv(bufs));
}
Promise<void> write(const void* buffer, size_t size) override {
auto bufs = heapArray<WSABUF>(1);
bufs[0].buf = const_cast<char*>(reinterpret_cast<const char*>(buffer));
bufs[0].len = size;
ArrayPtr<WSABUF> ref = bufs;
return writeInternal(ref).attach(kj::mv(bufs));
}
Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override {
auto bufs = heapArray<WSABUF>(pieces.size());
for (auto i: kj::indices(pieces)) {
bufs[i].buf = const_cast<char*>(pieces[i].asChars().begin());
bufs[i].len = pieces[i].size();
}
ArrayPtr<WSABUF> ref = bufs;
return writeInternal(ref).attach(kj::mv(bufs));
}
kj::Promise<void> connect(const struct sockaddr* addr, uint addrlen) {
// In order to connect asynchronously, we need the ConnectEx() function. Apparently, we have
// to query the socket for it dynamically, I guess because of the insanity in which winsock
// can be implemented in userspace and old implementations may not support it.
GUID guid = WSAID_CONNECTEX;
LPFN_CONNECTEX connectEx = nullptr;
DWORD n = 0;
KJ_WINSOCK(WSAIoctl(fd, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid, sizeof(guid),
&connectEx, sizeof(connectEx), &n, NULL, NULL)) {
goto fail; // avoid memory leak due to compiler bugs
}
if (false) {
fail:
return kj::READY_NOW;
}
// OK, phew, we now have our ConnectEx function pointer. Call it.
auto op = observer->newOperation(0);
if (!connectEx(fd, addr, addrlen, NULL, 0, NULL, op->getOverlapped())) {
DWORD error = WSAGetLastError();
if (error != ERROR_IO_PENDING) {
KJ_FAIL_WIN32("ConnectEx()", error) { break; }
return kj::READY_NOW;
}
}
return op->onComplete().then([this](Win32EventPort::IoResult result) {
if (result.errorCode != ERROR_SUCCESS) {
KJ_FAIL_WIN32("ConnectEx()", result.errorCode) { return; }
}
// Enable shutdown() to work.
setsockopt(SOL_SOCKET, SO_UPDATE_CONNECT_CONTEXT, NULL, 0);
});
}
Promise<void> whenWriteDisconnected() override {
// Windows IOCP does not provide a direct, documented way to detect when the socket disconnects
// without actually doing a read or write. However, there is an undocoumented-but-stable
// ioctl called IOCTL_AFD_POLL which can be used for this purpose. In fact, select() is
// implemented in terms of this ioctl -- performed synchronously -- but it's entirely possible
// to put only one socket into the list and perform the ioctl asynchronously. Here's the
// source code for select() in Windows 2000 (not sure how this became public...):
//
// https://github.com/pustladi/Windows-2000/blob/661d000d50637ed6fab2329d30e31775046588a9/private/net/sockets/winsock2/wsp/msafd/select.c#L59-L655
//
// And here's an interesting discussion: https://github.com/python-trio/trio/issues/52
//
// TODO(someday): Implement this with IOCTL_AFD_POLL. For now I'm leaving it unimplemented
// because I added this method for a Linux-only use case.
return NEVER_DONE;
}
void shutdownWrite() override {
// There's no legitimate way to get an AsyncStreamFd that isn't a socket through the
// Win32AsyncIoProvider interface.
KJ_WINSOCK(shutdown(fd, SD_SEND));
}
void abortRead() override {
// There's no legitimate way to get an AsyncStreamFd that isn't a socket through the
// Win32AsyncIoProvider interface.
KJ_WINSOCK(shutdown(fd, SD_RECEIVE));
}
void getsockopt(int level, int option, void* value, uint* length) override {
socklen_t socklen = *length;
KJ_WINSOCK(::getsockopt(fd, level, option,
reinterpret_cast<char*>(value), &socklen));
*length = socklen;
}
void setsockopt(int level, int option, const void* value, uint length) override {
KJ_WINSOCK(::setsockopt(fd, level, option,
reinterpret_cast<const char*>(value), length));
}
void getsockname(struct sockaddr* addr, uint* length) override {
socklen_t socklen = *length;
KJ_WINSOCK(::getsockname(fd, addr, &socklen));
*length = socklen;
}
void getpeername(struct sockaddr* addr, uint* length) override {
socklen_t socklen = *length;
KJ_WINSOCK(::getpeername(fd, addr, &socklen));
*length = socklen;
}
private:
Own<Win32EventPort::IoObserver> observer;
Promise<size_t> tryReadInternal(ArrayPtr<WSABUF> bufs, size_t minBytes, size_t alreadyRead) {
// `bufs` will remain valid until the promise completes and may be freely modified.
//
// `alreadyRead` is the number of bytes we have already received via previous reads -- minBytes
// and buffer have already been adjusted to account for them, but this count must be included
// in the final return value.
auto op = observer->newOperation(0);
DWORD flags = 0;
if (WSARecv(fd, bufs.begin(), bufs.size(), NULL, &flags,
op->getOverlapped(), NULL) == SOCKET_ERROR) {
DWORD error = WSAGetLastError();
if (error != WSA_IO_PENDING) {
KJ_FAIL_WIN32("WSARecv()", error) { break; }
return alreadyRead;
}
}
return op->onComplete()
.then([this,KJ_CPCAP(bufs),minBytes,alreadyRead](Win32IocpEventPort::IoResult result) mutable
-> Promise<size_t> {
if (result.errorCode != ERROR_SUCCESS) {
if (alreadyRead > 0) {
// Report what we already read.
return alreadyRead;
} else {
KJ_FAIL_WIN32("WSARecv()", result.errorCode) { break; }
return size_t(0);
}
}
if (result.bytesTransferred == 0) {
return alreadyRead;
}
alreadyRead += result.bytesTransferred;
if (result.bytesTransferred >= minBytes) {
// We can stop here.
return alreadyRead;
}
minBytes -= result.bytesTransferred;
while (result.bytesTransferred >= bufs[0].len) {
result.bytesTransferred -= bufs[0].len;
bufs = bufs.slice(1, bufs.size());
}
if (result.bytesTransferred > 0) {
bufs[0].buf += result.bytesTransferred;
bufs[0].len -= result.bytesTransferred;
}
return tryReadInternal(bufs, minBytes, alreadyRead);
}).attach(kj::mv(bufs));
}
Promise<void> writeInternal(ArrayPtr<WSABUF> bufs) {
// `bufs` will remain valid until the promise completes and may be freely modified.
auto op = observer->newOperation(0);
if (WSASend(fd, bufs.begin(), bufs.size(), NULL, 0,
op->getOverlapped(), NULL) == SOCKET_ERROR) {
DWORD error = WSAGetLastError();
if (error != WSA_IO_PENDING) {
KJ_FAIL_WIN32("WSASend()", error) { break; }
return kj::READY_NOW;
}
}
return op->onComplete()
.then([this,KJ_CPCAP(bufs)](Win32IocpEventPort::IoResult result) mutable -> Promise<void> {
if (result.errorCode != ERROR_SUCCESS) {
KJ_FAIL_WIN32("WSASend()", result.errorCode) { break; }
return kj::READY_NOW;
}
while (bufs.size() > 0 && result.bytesTransferred >= bufs[0].len) {
result.bytesTransferred -= bufs[0].len;
bufs = bufs.slice(1, bufs.size());
}
if (result.bytesTransferred > 0) {
bufs[0].buf += result.bytesTransferred;
bufs[0].len -= result.bytesTransferred;
}
if (bufs.size() > 0) {
return writeInternal(bufs);
} else {
return kj::READY_NOW;
}
}).attach(kj::mv(bufs));
}
};
// =======================================================================================
class SocketAddress {
public:
SocketAddress(const void* sockaddr, uint len): addrlen(len) {
KJ_REQUIRE(len <= sizeof(addr), "Sorry, your sockaddr is too big for me.");
memcpy(&addr.generic, sockaddr, len);
}
bool operator<(const SocketAddress& other) const {
// So we can use std::set<SocketAddress>... see DNS lookup code.
if (wildcard < other.wildcard) return true;
if (wildcard > other.wildcard) return false;
if (addrlen < other.addrlen) return true;
if (addrlen > other.addrlen) return false;
return memcmp(&addr.generic, &other.addr.generic, addrlen) < 0;
}
const struct sockaddr* getRaw() const { return &addr.generic; }
int getRawSize() const { return addrlen; }
SOCKET socket(int type) const {
bool isStream = type == SOCK_STREAM;
SOCKET result = ::socket(addr.generic.sa_family, type, 0);
if (result == INVALID_SOCKET) {
KJ_FAIL_WIN32("WSASocket()", WSAGetLastError()) { return INVALID_SOCKET; }
}
if (isStream && (addr.generic.sa_family == AF_INET ||
addr.generic.sa_family == AF_INET6)) {
// TODO(perf): As a hack for the 0.4 release we are always setting
// TCP_NODELAY because Nagle's algorithm pretty much kills Cap'n Proto's
// RPC protocol. Later, we should extend the interface to provide more
// control over this. Perhaps write() should have a flag which
// specifies whether to pass MSG_MORE.
BOOL one = TRUE;
KJ_WINSOCK(setsockopt(result, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(one)));
}
return result;
}
void bind(SOCKET sockfd) const {
if (wildcard) {
// Disable IPV6_V6ONLY because we want to handle both ipv4 and ipv6 on this socket. (The
// default value of this option varies across platforms.)
DWORD value = 0;
KJ_WINSOCK(setsockopt(sockfd, IPPROTO_IPV6, IPV6_V6ONLY,
reinterpret_cast<char*>(&value), sizeof(value)));
}
KJ_WINSOCK(::bind(sockfd, &addr.generic, addrlen), toString());
}
uint getPort() const {
switch (addr.generic.sa_family) {
case AF_INET: return ntohs(addr.inet4.sin_port);
case AF_INET6: return ntohs(addr.inet6.sin6_port);
default: return 0;
}
}
String toString() const {
if (wildcard) {
return str("*:", getPort());
}
switch (addr.generic.sa_family) {
case AF_INET: {
char buffer[16];
if (InetNtopA(addr.inet4.sin_family, const_cast<struct in_addr*>(&addr.inet4.sin_addr),
buffer, sizeof(buffer)) == nullptr) {
KJ_FAIL_WIN32("InetNtop", WSAGetLastError()) { break; }
return heapString("(inet_ntop error)");
}
return str(buffer, ':', ntohs(addr.inet4.sin_port));
}
case AF_INET6: {
char buffer[46];
if (InetNtopA(addr.inet6.sin6_family, const_cast<struct in6_addr*>(&addr.inet6.sin6_addr),
buffer, sizeof(buffer)) == nullptr) {
KJ_FAIL_WIN32("InetNtop", WSAGetLastError()) { break; }
return heapString("(inet_ntop error)");
}
return str('[', buffer, "]:", ntohs(addr.inet6.sin6_port));
}
default:
return str("(unknown address family ", addr.generic.sa_family, ")");
}
}
static Promise<Array<SocketAddress>> lookupHost(
LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint,
_::NetworkFilter& filter);
// Perform a DNS lookup.
static Promise<Array<SocketAddress>> parse(
LowLevelAsyncIoProvider& lowLevel, StringPtr str, uint portHint, _::NetworkFilter& filter) {
// TODO(someday): Allow commas in `str`.
SocketAddress result;
// Try to separate the address and port.
ArrayPtr<const char> addrPart;
Maybe<StringPtr> portPart;
int af;
if (str.startsWith("[")) {
// Address starts with a bracket, which is a common way to write an ip6 address with a port,
// since without brackets around the address part, the port looks like another segment of
// the address.
af = AF_INET6;
size_t closeBracket = KJ_ASSERT_NONNULL(str.findLast(']'),
"Unclosed '[' in address string.", str);
addrPart = str.slice(1, closeBracket);
if (str.size() > closeBracket + 1) {
KJ_REQUIRE(str.slice(closeBracket + 1).startsWith(":"),
"Expected port suffix after ']'.", str);
portPart = str.slice(closeBracket + 2);
}
} else {
KJ_IF_MAYBE(colon, str.findFirst(':')) {
if (str.slice(*colon + 1).findFirst(':') == nullptr) {
// There is exactly one colon and no brackets, so it must be an ip4 address with port.
af = AF_INET;
addrPart = str.slice(0, *colon);
portPart = str.slice(*colon + 1);
} else {
// There are two or more colons and no brackets, so the whole thing must be an ip6
// address with no port.
af = AF_INET6;
addrPart = str;
}
} else {
// No colons, so it must be an ip4 address without port.
af = AF_INET;
addrPart = str;
}
}
// Parse the port.
unsigned long port;
KJ_IF_MAYBE(portText, portPart) {
char* endptr;
port = strtoul(portText->cStr(), &endptr, 0);
if (portText->size() == 0 || *endptr != '\0') {
// Not a number. Maybe it's a service name. Fall back to DNS.
return lookupHost(lowLevel, kj::heapString(addrPart), kj::heapString(*portText), portHint,
filter);
}
KJ_REQUIRE(port < 65536, "Port number too large.");
} else {
port = portHint;
}
// Check for wildcard.
if (addrPart.size() == 1 && addrPart[0] == '*') {
result.wildcard = true;
// Create an ip6 socket and set IPV6_V6ONLY to 0 later.
result.addrlen = sizeof(addr.inet6);
result.addr.inet6.sin6_family = AF_INET6;
result.addr.inet6.sin6_port = htons(port);
auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(result);
return array.finish();
}
void* addrTarget;
if (af == AF_INET6) {
result.addrlen = sizeof(addr.inet6);
result.addr.inet6.sin6_family = AF_INET6;
result.addr.inet6.sin6_port = htons(port);
addrTarget = &result.addr.inet6.sin6_addr;
} else {
result.addrlen = sizeof(addr.inet4);
result.addr.inet4.sin_family = AF_INET;
result.addr.inet4.sin_port = htons(port);
addrTarget = &result.addr.inet4.sin_addr;
}
char buffer[64];
if (addrPart.size() < sizeof(buffer) - 1) {
// addrPart is not necessarily NUL-terminated so we have to make a copy. :(
memcpy(buffer, addrPart.begin(), addrPart.size());
buffer[addrPart.size()] = '\0';
// OK, parse it!
switch (InetPtonA(af, buffer, addrTarget)) {
case 1: {
// success.
if (!result.parseAllowedBy(filter)) {
KJ_FAIL_REQUIRE("address family blocked by restrictPeers()");
return Array<SocketAddress>();
}
auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(result);
return array.finish();
}
case 0:
// It's apparently not a simple address... fall back to DNS.
break;
default:
KJ_FAIL_WIN32("InetPton", WSAGetLastError(), af, addrPart);
}
}
return lookupHost(lowLevel, kj::heapString(addrPart), nullptr, port, filter);
}
static SocketAddress getLocalAddress(SOCKET sockfd) {
SocketAddress result;
result.addrlen = sizeof(addr);
KJ_WINSOCK(getsockname(sockfd, &result.addr.generic, &result.addrlen));
return result;
}
static SocketAddress getPeerAddress(SOCKET sockfd) {
SocketAddress result;
result.addrlen = sizeof(addr);
KJ_WINSOCK(getpeername(sockfd, &result.addr.generic, &result.addrlen));
return result;
}
bool allowedBy(LowLevelAsyncIoProvider::NetworkFilter& filter) {
return filter.shouldAllow(&addr.generic, addrlen);
}
bool parseAllowedBy(_::NetworkFilter& filter) {
return filter.shouldAllowParse(&addr.generic, addrlen);
}
static SocketAddress getWildcardForFamily(int family) {
SocketAddress result;
switch (family) {
case AF_INET:
result.addrlen = sizeof(addr.inet4);
result.addr.inet4.sin_family = AF_INET;
return result;
case AF_INET6:
result.addrlen = sizeof(addr.inet6);
result.addr.inet6.sin6_family = AF_INET6;
return result;
default:
KJ_FAIL_REQUIRE("unknown address family", family);
}
}
private:
SocketAddress(): addrlen(0) {
memset(&addr, 0, sizeof(addr));
}
socklen_t addrlen;
bool wildcard = false;
union {
struct sockaddr generic;
struct sockaddr_in inet4;
struct sockaddr_in6 inet6;
struct sockaddr_storage storage;
} addr;
struct LookupParams;
class LookupReader;
};
class SocketAddress::LookupReader {
// Reads SocketAddresses off of a pipe coming from another thread that is performing
// getaddrinfo.
public:
LookupReader(kj::Own<Thread>&& thread, kj::Own<AsyncInputStream>&& input,
_::NetworkFilter& filter)
: thread(kj::mv(thread)), input(kj::mv(input)), filter(filter) {}
~LookupReader() {
if (thread) thread->detach();
}
Promise<Array<SocketAddress>> read() {
return input->tryRead(&current, sizeof(current), sizeof(current)).then(
[this](size_t n) -> Promise<Array<SocketAddress>> {
if (n < sizeof(current)) {
thread = nullptr;
// getaddrinfo()'s docs seem to say it will never return an empty list, but let's check
// anyway.
KJ_REQUIRE(addresses.size() > 0, "DNS lookup returned no permitted addresses.") { break; }
return addresses.releaseAsArray();
} else {
// getaddrinfo() can return multiple copies of the same address for several reasons.
// A major one is that we don't give it a socket type (SOCK_STREAM vs. SOCK_DGRAM), so
// it may return two copies of the same address, one for each type, unless it explicitly
// knows that the service name given is specific to one type. But we can't tell it a type,
// because we don't actually know which one the user wants, and if we specify SOCK_STREAM
// while the user specified a UDP service name then they'll get a resolution error which
// is lame. (At least, I think that's how it works.)
//
// So we instead resort to de-duping results.
if (alreadySeen.insert(current).second) {
if (current.parseAllowedBy(filter)) {
addresses.add(current);
}
}
return read();
}
});
}
private:
kj::Own<Thread> thread;
kj::Own<AsyncInputStream> input;
_::NetworkFilter& filter;
SocketAddress current;
kj::Vector<SocketAddress> addresses;
std::set<SocketAddress> alreadySeen;
};
struct SocketAddress::LookupParams {
kj::String host;
kj::String service;
};
Promise<Array<SocketAddress>> SocketAddress::lookupHost(
LowLevelAsyncIoProvider& lowLevel, kj::String host, kj::String service, uint portHint,
_::NetworkFilter& filter) {
// This shitty function spawns a thread to run getaddrinfo(). Unfortunately, getaddrinfo() is
// the only cross-platform DNS API and it is blocking.
//
// TODO(perf): Use GetAddrInfoEx(). But there are problems:
// - Not implemented in Wine.
// - Doesn't seem compatible with I/O completion ports, in particular because it's not associated
// with a handle. Could signal completion as an APC instead, but that requires the IOCP code
// to use GetQueuedCompletionStatusEx() which it doesn't right now because it's not available
// in Wine.
// - Requires Unicode, for some reason. Only GetAddrInfoExW() supports async, according to the
// docs. Never mind that DNS itself is ASCII...
SOCKET fds[2];
KJ_WINSOCK(_::win32Socketpair(fds));
auto input = lowLevel.wrapInputFd(fds[0], NEW_FD_FLAGS);
int outFd = fds[1];
LookupParams params = { kj::mv(host), kj::mv(service) };
auto thread = heap<Thread>(kj::mvCapture(params, [outFd,portHint](LookupParams&& params) {
KJ_DEFER(closesocket(outFd));
addrinfo* list;
int status = getaddrinfo(
params.host == "*" ? nullptr : params.host.cStr(),
params.service == nullptr ? nullptr : params.service.cStr(),
nullptr, &list);
if (status == 0) {
KJ_DEFER(freeaddrinfo(list));
addrinfo* cur = list;
while (cur != nullptr) {
if (params.service == nullptr) {
switch (cur->ai_addr->sa_family) {
case AF_INET:
((struct sockaddr_in*)cur->ai_addr)->sin_port = htons(portHint);
break;
case AF_INET6:
((struct sockaddr_in6*)cur->ai_addr)->sin6_port = htons(portHint);
break;
default:
break;
}
}
SocketAddress addr;
memset(&addr, 0, sizeof(addr)); // mollify valgrind
if (params.host == "*") {
// Set up a wildcard SocketAddress. Only use the port number returned by getaddrinfo().
addr.wildcard = true;
addr.addrlen = sizeof(addr.addr.inet6);
addr.addr.inet6.sin6_family = AF_INET6;
switch (cur->ai_addr->sa_family) {
case AF_INET:
addr.addr.inet6.sin6_port = ((struct sockaddr_in*)cur->ai_addr)->sin_port;
break;
case AF_INET6:
addr.addr.inet6.sin6_port = ((struct sockaddr_in6*)cur->ai_addr)->sin6_port;
break;
default:
addr.addr.inet6.sin6_port = portHint;
break;
}
} else {
addr.addrlen = cur->ai_addrlen;
memcpy(&addr.addr.generic, cur->ai_addr, cur->ai_addrlen);
}
KJ_ASSERT_CAN_MEMCPY(SocketAddress);
const char* data = reinterpret_cast<const char*>(&addr);
size_t size = sizeof(addr);
while (size > 0) {
int n;
KJ_WINSOCK(n = send(outFd, data, size, 0));
data += n;
size -= n;
}
cur = cur->ai_next;
}
} else {
KJ_FAIL_WIN32("getaddrinfo()", status, params.host, params.service) {
return;
}
}
}));
auto reader = heap<LookupReader>(kj::mv(thread), kj::mv(input), filter);
return reader->read().attach(kj::mv(reader));
}
// =======================================================================================
class FdConnectionReceiver final: public ConnectionReceiver, public OwnedFd {
public:
FdConnectionReceiver(Win32EventPort& eventPort, SOCKET fd,
LowLevelAsyncIoProvider::NetworkFilter& filter, uint flags)
: OwnedFd(fd, flags), eventPort(eventPort), filter(filter),
observer(eventPort.observeIo(reinterpret_cast<HANDLE>(fd))),
address(SocketAddress::getLocalAddress(fd)) {
// In order to accept asynchronously, we need the AcceptEx() function. Apparently, we have
// to query the socket for it dynamically, I guess because of the insanity in which winsock
// can be implemented in userspace and old implementations may not support it.
GUID guid = WSAID_ACCEPTEX;
DWORD n = 0;
KJ_WINSOCK(WSAIoctl(fd, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid, sizeof(guid),
&acceptEx, sizeof(acceptEx), &n, NULL, NULL)) {
acceptEx = nullptr;
return;
}
}
Promise<Own<AsyncIoStream>> accept() override {
SOCKET newFd = address.socket(SOCK_STREAM);
KJ_ASSERT(newFd != INVALID_SOCKET);
auto result = heap<AsyncStreamFd>(eventPort, newFd, NEW_FD_FLAGS);
auto scratch = heapArray<byte>(256);
DWORD dummy;
auto op = observer->newOperation(0);
if (!acceptEx(fd, newFd, scratch.begin(), 0, 128, 128, &dummy, op->getOverlapped())) {
DWORD error = WSAGetLastError();
if (error != ERROR_IO_PENDING) {
KJ_FAIL_WIN32("AcceptEx()", error) { break; }
return Own<AsyncIoStream>(kj::mv(result)); // dummy, won't be used
}
}
return op->onComplete().then(mvCapture(result, mvCapture(scratch,
[this,newFd]
(Array<byte> scratch, Own<AsyncIoStream> stream, Win32EventPort::IoResult ioResult)
-> Promise<Own<AsyncIoStream>> {
if (ioResult.errorCode != ERROR_SUCCESS) {
KJ_FAIL_WIN32("AcceptEx()", ioResult.errorCode) { break; }
} else {
SOCKET me = fd;
stream->setsockopt(SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT,
reinterpret_cast<char*>(&me), sizeof(me));
}
// Supposedly, AcceptEx() places the local and peer addresses into the buffer (which we've
// named `scratch`). However, the format in which it writes these is undocumented, and
// doesn't even match between native Windows and WINE. Apparently it is useless. I don't know
// why they require the buffer to have space for it in the first place. We'll need to call
// getpeername() to get the address.
auto addr = SocketAddress::getPeerAddress(newFd);
if (addr.allowedBy(filter)) {
return kj::mv(stream);
} else {
return accept();
}
})));
}
uint getPort() override {
return address.getPort();
}
void getsockopt(int level, int option, void* value, uint* length) override {
socklen_t socklen = *length;
KJ_WINSOCK(::getsockopt(fd, level, option,
reinterpret_cast<char*>(value), &socklen));
*length = socklen;
}
void setsockopt(int level, int option, const void* value, uint length) override {
KJ_WINSOCK(::setsockopt(fd, level, option,
reinterpret_cast<const char*>(value), length));
}
void getsockname(struct sockaddr* addr, uint* length) override {
socklen_t socklen = *length;
KJ_WINSOCK(::getsockname(fd, addr, &socklen));
*length = socklen;
}
public:
Win32EventPort& eventPort;
LowLevelAsyncIoProvider::NetworkFilter& filter;
Own<Win32EventPort::IoObserver> observer;
LPFN_ACCEPTEX acceptEx = nullptr;
SocketAddress address;
};
// TODO(someday): DatagramPortImpl
class LowLevelAsyncIoProviderImpl final: public LowLevelAsyncIoProvider {
public:
LowLevelAsyncIoProviderImpl()
: eventLoop(eventPort), waitScope(eventLoop) {}
inline WaitScope& getWaitScope() { return waitScope; }
Own<AsyncInputStream> wrapInputFd(SOCKET fd, uint flags = 0) override {
return heap<AsyncStreamFd>(eventPort, fd, flags);
}
Own<AsyncOutputStream> wrapOutputFd(SOCKET fd, uint flags = 0) override {
return heap<AsyncStreamFd>(eventPort, fd, flags);
}
Own<AsyncIoStream> wrapSocketFd(SOCKET fd, uint flags = 0) override {
return heap<AsyncStreamFd>(eventPort, fd, flags);
}
Promise<Own<AsyncIoStream>> wrapConnectingSocketFd(
SOCKET fd, const struct sockaddr* addr, uint addrlen, uint flags = 0) override {
auto result = heap<AsyncStreamFd>(eventPort, fd, flags);
// ConnectEx requires that the socket be bound, for some reason. Bind to an arbitrary port.
SocketAddress::getWildcardForFamily(addr->sa_family).bind(fd);
auto connected = result->connect(addr, addrlen);
return connected.then(kj::mvCapture(result, [](Own<AsyncIoStream>&& result) {
return kj::mv(result);
}));
}
Own<ConnectionReceiver> wrapListenSocketFd(
SOCKET fd, NetworkFilter& filter, uint flags = 0) override {
return heap<FdConnectionReceiver>(eventPort, fd, filter, flags);
}
Timer& getTimer() override { return eventPort.getTimer(); }
Win32EventPort& getEventPort() { return eventPort; }
private:
Win32IocpEventPort eventPort;
EventLoop eventLoop;
WaitScope waitScope;
};
// =======================================================================================
class NetworkAddressImpl final: public NetworkAddress {
public:
NetworkAddressImpl(LowLevelAsyncIoProvider& lowLevel,
LowLevelAsyncIoProvider::NetworkFilter& filter,
Array<SocketAddress> addrs)
: lowLevel(lowLevel), filter(filter), addrs(kj::mv(addrs)) {}
Promise<Own<AsyncIoStream>> connect() override {
auto addrsCopy = heapArray(addrs.asPtr());
auto promise = connectImpl(lowLevel, filter, addrsCopy);
return promise.attach(kj::mv(addrsCopy));
}
Own<ConnectionReceiver> listen() override {
if (addrs.size() > 1) {
KJ_LOG(WARNING, "Bind address resolved to multiple addresses. Only the first address will "
"be used. If this is incorrect, specify the address numerically. This may be fixed "
"in the future.", addrs[0].toString());
}
int fd = addrs[0].socket(SOCK_STREAM);
{
KJ_ON_SCOPE_FAILURE(closesocket(fd));
// We always enable SO_REUSEADDR because having to take your server down for five minutes
// before it can restart really sucks.
int optval = 1;
KJ_WINSOCK(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR,
reinterpret_cast<char*>(&optval), sizeof(optval)));
addrs[0].bind(fd);
// TODO(someday): Let queue size be specified explicitly in string addresses.
KJ_WINSOCK(::listen(fd, SOMAXCONN));
}
return lowLevel.wrapListenSocketFd(fd, filter, NEW_FD_FLAGS);
}
Own<DatagramPort> bindDatagramPort() override {
if (addrs.size() > 1) {
KJ_LOG(WARNING, "Bind address resolved to multiple addresses. Only the first address will "
"be used. If this is incorrect, specify the address numerically. This may be fixed "
"in the future.", addrs[0].toString());
}
int fd = addrs[0].socket(SOCK_DGRAM);
{
KJ_ON_SCOPE_FAILURE(closesocket(fd));
// We always enable SO_REUSEADDR because having to take your server down for five minutes
// before it can restart really sucks.
int optval = 1;
KJ_WINSOCK(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR,
reinterpret_cast<char*>(&optval), sizeof(optval)));
addrs[0].bind(fd);
}
return lowLevel.wrapDatagramSocketFd(fd, filter, NEW_FD_FLAGS);
}
Own<NetworkAddress> clone() override {
return kj::heap<NetworkAddressImpl>(lowLevel, filter, kj::heapArray(addrs.asPtr()));
}
String toString() override {
return strArray(KJ_MAP(addr, addrs) { return addr.toString(); }, ",");
}
const SocketAddress& chooseOneAddress() {
KJ_REQUIRE(addrs.size() > 0, "No addresses available.");
return addrs[counter++ % addrs.size()];
}
private:
LowLevelAsyncIoProvider& lowLevel;
LowLevelAsyncIoProvider::NetworkFilter& filter;
Array<SocketAddress> addrs;
uint counter = 0;
static Promise<Own<AsyncIoStream>> connectImpl(
LowLevelAsyncIoProvider& lowLevel,
LowLevelAsyncIoProvider::NetworkFilter& filter,
ArrayPtr<SocketAddress> addrs) {
KJ_ASSERT(addrs.size() > 0);
int fd = addrs[0].socket(SOCK_STREAM);
return kj::evalNow([&]() -> Promise<Own<AsyncIoStream>> {
if (!addrs[0].allowedBy(filter)) {
return KJ_EXCEPTION(FAILED, "connect() blocked by restrictPeers()");
} else {
return lowLevel.wrapConnectingSocketFd(
fd, addrs[0].getRaw(), addrs[0].getRawSize(), NEW_FD_FLAGS);
}
}).then([](Own<AsyncIoStream>&& stream) -> Promise<Own<AsyncIoStream>> {
// Success, pass along.
return kj::mv(stream);
}, [&lowLevel,&filter,KJ_CPCAP(addrs)](Exception&& exception) mutable
-> Promise<Own<AsyncIoStream>> {
// Connect failed.
if (addrs.size() > 1) {
// Try the next address instead.
return connectImpl(lowLevel, filter, addrs.slice(1, addrs.size()));
} else {
// No more addresses to try, so propagate the exception.
return kj::mv(exception);
}
});
}
};
class SocketNetwork final: public Network {
public:
explicit SocketNetwork(LowLevelAsyncIoProvider& lowLevel): lowLevel(lowLevel) {}
explicit SocketNetwork(SocketNetwork& parent,
kj::ArrayPtr<const kj::StringPtr> allow,
kj::ArrayPtr<const kj::StringPtr> deny)
: lowLevel(parent.lowLevel), filter(allow, deny, parent.filter) {}
Promise<Own<NetworkAddress>> parseAddress(StringPtr addr, uint portHint = 0) override {
return evalLater(mvCapture(heapString(addr), [this,portHint](String&& addr) {
return SocketAddress::parse(lowLevel, addr, portHint, filter);
})).then([this](Array<SocketAddress> addresses) -> Own<NetworkAddress> {
return heap<NetworkAddressImpl>(lowLevel, filter, kj::mv(addresses));
});
}
Own<NetworkAddress> getSockaddr(const void* sockaddr, uint len) override {
auto array = kj::heapArrayBuilder<SocketAddress>(1);
array.add(SocketAddress(sockaddr, len));
KJ_REQUIRE(array[0].allowedBy(filter), "address blocked by restrictPeers()") { break; }
return Own<NetworkAddress>(heap<NetworkAddressImpl>(lowLevel, filter, array.finish()));
}
Own<Network> restrictPeers(
kj::ArrayPtr<const kj::StringPtr> allow,
kj::ArrayPtr<const kj::StringPtr> deny = nullptr) override {
return heap<SocketNetwork>(*this, allow, deny);
}
private:
LowLevelAsyncIoProvider& lowLevel;
_::NetworkFilter filter;
};
// =======================================================================================
class AsyncIoProviderImpl final: public AsyncIoProvider {
public:
AsyncIoProviderImpl(LowLevelAsyncIoProvider& lowLevel)
: lowLevel(lowLevel), network(lowLevel) {}
OneWayPipe newOneWayPipe() override {
SOCKET fds[2];
KJ_WINSOCK(_::win32Socketpair(fds));
auto in = lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS);
auto out = lowLevel.wrapOutputFd(fds[1], NEW_FD_FLAGS);
in->shutdownWrite();
return { kj::mv(in), kj::mv(out) };
}
TwoWayPipe newTwoWayPipe() override {
SOCKET fds[2];
KJ_WINSOCK(_::win32Socketpair(fds));
return TwoWayPipe { {
lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS),
lowLevel.wrapSocketFd(fds[1], NEW_FD_FLAGS)
} };
}
Network& getNetwork() override {
return network;
}
PipeThread newPipeThread(
Function<void(AsyncIoProvider&, AsyncIoStream&, WaitScope&)> startFunc) override {
SOCKET fds[2];
KJ_WINSOCK(_::win32Socketpair(fds));
int threadFd = fds[1];
KJ_ON_SCOPE_FAILURE(closesocket(threadFd));
auto pipe = lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS);
auto thread = heap<Thread>(kj::mvCapture(startFunc,
[threadFd](Function<void(AsyncIoProvider&, AsyncIoStream&, WaitScope&)>&& startFunc) {
LowLevelAsyncIoProviderImpl lowLevel;
auto stream = lowLevel.wrapSocketFd(threadFd, NEW_FD_FLAGS);
AsyncIoProviderImpl ioProvider(lowLevel);
startFunc(ioProvider, *stream, lowLevel.getWaitScope());
}));
return { kj::mv(thread), kj::mv(pipe) };
}
Timer& getTimer() override { return lowLevel.getTimer(); }
private:
LowLevelAsyncIoProvider& lowLevel;
SocketNetwork network;
};
} // namespace
Own<AsyncIoProvider> newAsyncIoProvider(LowLevelAsyncIoProvider& lowLevel) {
return kj::heap<AsyncIoProviderImpl>(lowLevel);
}
AsyncIoContext setupAsyncIo() {
_::initWinsockOnce();
auto lowLevel = heap<LowLevelAsyncIoProviderImpl>();
auto ioProvider = kj::heap<AsyncIoProviderImpl>(*lowLevel);
auto& waitScope = lowLevel->getWaitScope();
auto& eventPort = lowLevel->getEventPort();
return { kj::mv(lowLevel), kj::mv(ioProvider), waitScope, eventPort };
}
} // namespace kj
#endif // _WIN32