blob: a35528ef8839eac916bf684e368f34ceb711db93 [file] [log] [blame]
// Copyright (c) 2017 Sandstorm Development Group, Inc. and contributors
// Licensed under the MIT License:
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#include "http.h"
#include "url.h"
#include <kj/debug.h>
#include <kj/parse/char.h>
#include <unordered_map>
#include <stdlib.h>
#include <kj/encoding.h>
#include <deque>
#include <queue>
#include <map>
namespace kj {
// =======================================================================================
// SHA-1 implementation from https://github.com/clibs/sha1
//
// The WebSocket standard depends on SHA-1. ARRRGGGHHHHH.
//
// Any old checksum would have served the purpose, or hell, even just returning the header
// verbatim. But NO, they decided to throw a whole complicated hash algorithm in there, AND
// THEY CHOSE A BROKEN ONE THAT WE OTHERWISE WOULDN'T NEED ANYMORE.
//
// TODO(cleanup): Move this to a shared hashing library. Maybe. Or maybe don't, because no one
// should be using SHA-1 anymore.
//
// THIS USAGE IS NOT SECURITY SENSITIVE. IF YOU REPORT A SECURITY ISSUE BECAUSE YOU SAW SHA1 IN THE
// SOURCE CODE I WILL MAKE FUN OF YOU.
/*
SHA-1 in C
By Steve Reid <[email protected]>
100% Public Domain
Test Vectors (from FIPS PUB 180-1)
"abc"
A9993E36 4706816A BA3E2571 7850C26C 9CD0D89D
"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq"
84983E44 1C3BD26E BAAE4AA1 F95129E5 E54670F1
A million repetitions of "a"
34AA973C D4C4DAA4 F61EEB2B DBAD2731 6534016F
*/
/* #define LITTLE_ENDIAN * This should be #define'd already, if true. */
/* #define SHA1HANDSOFF * Copies data before messing with it. */
#define SHA1HANDSOFF
typedef struct
{
uint32_t state[5];
uint32_t count[2];
unsigned char buffer[64];
} SHA1_CTX;
#define rol(value, bits) (((value) << (bits)) | ((value) >> (32 - (bits))))
/* blk0() and blk() perform the initial expand. */
/* I got the idea of expanding during the round function from SSLeay */
#if BYTE_ORDER == LITTLE_ENDIAN
#define blk0(i) (block->l[i] = (rol(block->l[i],24)&0xFF00FF00) \
|(rol(block->l[i],8)&0x00FF00FF))
#elif BYTE_ORDER == BIG_ENDIAN
#define blk0(i) block->l[i]
#else
#error "Endianness not defined!"
#endif
#define blk(i) (block->l[i&15] = rol(block->l[(i+13)&15]^block->l[(i+8)&15] \
^block->l[(i+2)&15]^block->l[i&15],1))
/* (R0+R1), R2, R3, R4 are the different operations used in SHA1 */
#define R0(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk0(i)+0x5A827999+rol(v,5);w=rol(w,30);
#define R1(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk(i)+0x5A827999+rol(v,5);w=rol(w,30);
#define R2(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0x6ED9EBA1+rol(v,5);w=rol(w,30);
#define R3(v,w,x,y,z,i) z+=(((w|x)&y)|(w&x))+blk(i)+0x8F1BBCDC+rol(v,5);w=rol(w,30);
#define R4(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0xCA62C1D6+rol(v,5);w=rol(w,30);
/* Hash a single 512-bit block. This is the core of the algorithm. */
void SHA1Transform(
uint32_t state[5],
const unsigned char buffer[64]
)
{
uint32_t a, b, c, d, e;
typedef union
{
unsigned char c[64];
uint32_t l[16];
} CHAR64LONG16;
#ifdef SHA1HANDSOFF
CHAR64LONG16 block[1]; /* use array to appear as a pointer */
memcpy(block, buffer, 64);
#else
/* The following had better never be used because it causes the
* pointer-to-const buffer to be cast into a pointer to non-const.
* And the result is written through. I threw a "const" in, hoping
* this will cause a diagnostic.
*/
CHAR64LONG16 *block = (const CHAR64LONG16 *) buffer;
#endif
/* Copy context->state[] to working vars */
a = state[0];
b = state[1];
c = state[2];
d = state[3];
e = state[4];
/* 4 rounds of 20 operations each. Loop unrolled. */
R0(a, b, c, d, e, 0);
R0(e, a, b, c, d, 1);
R0(d, e, a, b, c, 2);
R0(c, d, e, a, b, 3);
R0(b, c, d, e, a, 4);
R0(a, b, c, d, e, 5);
R0(e, a, b, c, d, 6);
R0(d, e, a, b, c, 7);
R0(c, d, e, a, b, 8);
R0(b, c, d, e, a, 9);
R0(a, b, c, d, e, 10);
R0(e, a, b, c, d, 11);
R0(d, e, a, b, c, 12);
R0(c, d, e, a, b, 13);
R0(b, c, d, e, a, 14);
R0(a, b, c, d, e, 15);
R1(e, a, b, c, d, 16);
R1(d, e, a, b, c, 17);
R1(c, d, e, a, b, 18);
R1(b, c, d, e, a, 19);
R2(a, b, c, d, e, 20);
R2(e, a, b, c, d, 21);
R2(d, e, a, b, c, 22);
R2(c, d, e, a, b, 23);
R2(b, c, d, e, a, 24);
R2(a, b, c, d, e, 25);
R2(e, a, b, c, d, 26);
R2(d, e, a, b, c, 27);
R2(c, d, e, a, b, 28);
R2(b, c, d, e, a, 29);
R2(a, b, c, d, e, 30);
R2(e, a, b, c, d, 31);
R2(d, e, a, b, c, 32);
R2(c, d, e, a, b, 33);
R2(b, c, d, e, a, 34);
R2(a, b, c, d, e, 35);
R2(e, a, b, c, d, 36);
R2(d, e, a, b, c, 37);
R2(c, d, e, a, b, 38);
R2(b, c, d, e, a, 39);
R3(a, b, c, d, e, 40);
R3(e, a, b, c, d, 41);
R3(d, e, a, b, c, 42);
R3(c, d, e, a, b, 43);
R3(b, c, d, e, a, 44);
R3(a, b, c, d, e, 45);
R3(e, a, b, c, d, 46);
R3(d, e, a, b, c, 47);
R3(c, d, e, a, b, 48);
R3(b, c, d, e, a, 49);
R3(a, b, c, d, e, 50);
R3(e, a, b, c, d, 51);
R3(d, e, a, b, c, 52);
R3(c, d, e, a, b, 53);
R3(b, c, d, e, a, 54);
R3(a, b, c, d, e, 55);
R3(e, a, b, c, d, 56);
R3(d, e, a, b, c, 57);
R3(c, d, e, a, b, 58);
R3(b, c, d, e, a, 59);
R4(a, b, c, d, e, 60);
R4(e, a, b, c, d, 61);
R4(d, e, a, b, c, 62);
R4(c, d, e, a, b, 63);
R4(b, c, d, e, a, 64);
R4(a, b, c, d, e, 65);
R4(e, a, b, c, d, 66);
R4(d, e, a, b, c, 67);
R4(c, d, e, a, b, 68);
R4(b, c, d, e, a, 69);
R4(a, b, c, d, e, 70);
R4(e, a, b, c, d, 71);
R4(d, e, a, b, c, 72);
R4(c, d, e, a, b, 73);
R4(b, c, d, e, a, 74);
R4(a, b, c, d, e, 75);
R4(e, a, b, c, d, 76);
R4(d, e, a, b, c, 77);
R4(c, d, e, a, b, 78);
R4(b, c, d, e, a, 79);
/* Add the working vars back into context.state[] */
state[0] += a;
state[1] += b;
state[2] += c;
state[3] += d;
state[4] += e;
/* Wipe variables */
a = b = c = d = e = 0;
#ifdef SHA1HANDSOFF
memset(block, '\0', sizeof(block));
#endif
}
/* SHA1Init - Initialize new context */
void SHA1Init(
SHA1_CTX * context
)
{
/* SHA1 initialization constants */
context->state[0] = 0x67452301;
context->state[1] = 0xEFCDAB89;
context->state[2] = 0x98BADCFE;
context->state[3] = 0x10325476;
context->state[4] = 0xC3D2E1F0;
context->count[0] = context->count[1] = 0;
}
/* Run your data through this. */
void SHA1Update(
SHA1_CTX * context,
const unsigned char *data,
uint32_t len
)
{
uint32_t i;
uint32_t j;
j = context->count[0];
if ((context->count[0] += len << 3) < j)
context->count[1]++;
context->count[1] += (len >> 29);
j = (j >> 3) & 63;
if ((j + len) > 63)
{
memcpy(&context->buffer[j], data, (i = 64 - j));
SHA1Transform(context->state, context->buffer);
for (; i + 63 < len; i += 64)
{
SHA1Transform(context->state, &data[i]);
}
j = 0;
}
else
i = 0;
memcpy(&context->buffer[j], &data[i], len - i);
}
/* Add padding and return the message digest. */
void SHA1Final(
unsigned char digest[20],
SHA1_CTX * context
)
{
unsigned i;
unsigned char finalcount[8];
unsigned char c;
#if 0 /* untested "improvement" by DHR */
/* Convert context->count to a sequence of bytes
* in finalcount. Second element first, but
* big-endian order within element.
* But we do it all backwards.
*/
unsigned char *fcp = &finalcount[8];
for (i = 0; i < 2; i++)
{
uint32_t t = context->count[i];
int j;
for (j = 0; j < 4; t >>= 8, j++)
*--fcp = (unsigned char) t}
#else
for (i = 0; i < 8; i++)
{
finalcount[i] = (unsigned char) ((context->count[(i >= 4 ? 0 : 1)] >> ((3 - (i & 3)) * 8)) & 255); /* Endian independent */
}
#endif
c = 0200;
SHA1Update(context, &c, 1);
while ((context->count[0] & 504) != 448)
{
c = 0000;
SHA1Update(context, &c, 1);
}
SHA1Update(context, finalcount, 8); /* Should cause a SHA1Transform() */
for (i = 0; i < 20; i++)
{
digest[i] = (unsigned char)
((context->state[i >> 2] >> ((3 - (i & 3)) * 8)) & 255);
}
/* Wipe variables */
memset(context, '\0', sizeof(*context));
memset(&finalcount, '\0', sizeof(finalcount));
}
// End SHA-1 implementation.
// =======================================================================================
static const char* METHOD_NAMES[] = {
#define METHOD_NAME(id) #id,
KJ_HTTP_FOR_EACH_METHOD(METHOD_NAME)
#undef METHOD_NAME
};
kj::StringPtr KJ_STRINGIFY(HttpMethod method) {
return METHOD_NAMES[static_cast<uint>(method)];
}
static kj::Maybe<HttpMethod> consumeHttpMethod(char*& ptr) {
char* p = ptr;
#define EXPECT_REST(prefix, suffix) \
if (strncmp(p, #suffix, sizeof(#suffix)-1) == 0) { \
ptr = p + (sizeof(#suffix)-1); \
return HttpMethod::prefix##suffix; \
} else { \
return nullptr; \
}
switch (*p++) {
case 'A': EXPECT_REST(A,CL)
case 'C':
switch (*p++) {
case 'H': EXPECT_REST(CH,ECKOUT)
case 'O': EXPECT_REST(CO,PY)
default: return nullptr;
}
case 'D': EXPECT_REST(D,ELETE)
case 'G': EXPECT_REST(G,ET)
case 'H': EXPECT_REST(H,EAD)
case 'L': EXPECT_REST(L,OCK)
case 'M':
switch (*p++) {
case 'E': EXPECT_REST(ME,RGE)
case 'K':
switch (*p++) {
case 'A': EXPECT_REST(MKA,CTIVITY)
case 'C': EXPECT_REST(MKC,OL)
default: return nullptr;
}
case 'O': EXPECT_REST(MO,VE)
case 'S': EXPECT_REST(MS,EARCH)
default: return nullptr;
}
case 'N': EXPECT_REST(N,OTIFY)
case 'O': EXPECT_REST(O,PTIONS)
case 'P':
switch (*p++) {
case 'A': EXPECT_REST(PA,TCH)
case 'O': EXPECT_REST(PO,ST)
case 'R':
if (*p++ != 'O' || *p++ != 'P') return nullptr;
switch (*p++) {
case 'F': EXPECT_REST(PROPF,IND)
case 'P': EXPECT_REST(PROPP,ATCH)
default: return nullptr;
}
case 'U':
switch (*p++) {
case 'R': EXPECT_REST(PUR,GE)
case 'T': EXPECT_REST(PUT,)
default: return nullptr;
}
default: return nullptr;
}
case 'R': EXPECT_REST(R,EPORT)
case 'S':
switch (*p++) {
case 'E': EXPECT_REST(SE,ARCH)
case 'U': EXPECT_REST(SU,BSCRIBE)
default: return nullptr;
}
case 'T': EXPECT_REST(T,RACE)
case 'U':
if (*p++ != 'N') return nullptr;
switch (*p++) {
case 'L': EXPECT_REST(UNL,OCK)
case 'S': EXPECT_REST(UNS,UBSCRIBE)
default: return nullptr;
}
default: return nullptr;
}
#undef EXPECT_REST
}
kj::Maybe<HttpMethod> tryParseHttpMethod(kj::StringPtr name) {
// const_cast OK because we don't actually access it. consumeHttpMethod() is also called by some
// code later than explicitly needs to use a non-const pointer.
char* ptr = const_cast<char*>(name.begin());
auto result = consumeHttpMethod(ptr);
if (*ptr == '\0') {
return result;
} else {
return nullptr;
}
}
// =======================================================================================
namespace {
constexpr char WEBSOCKET_GUID[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
// From RFC6455.
static kj::String generateWebSocketAccept(kj::StringPtr key) {
// WebSocket demands we do a SHA-1 here. ARRGHH WHY SHA-1 WHYYYYYY?
SHA1_CTX ctx;
byte digest[20];
SHA1Init(&ctx);
SHA1Update(&ctx, key.asBytes().begin(), key.size());
SHA1Update(&ctx, reinterpret_cast<const byte*>(WEBSOCKET_GUID), strlen(WEBSOCKET_GUID));
SHA1Final(digest, &ctx);
return kj::encodeBase64(digest);
}
constexpr auto HTTP_SEPARATOR_CHARS = kj::parse::anyOfChars("()<>@,;:\\\"/[]?={} \t");
// RFC2616 section 2.2: https://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2
constexpr auto HTTP_TOKEN_CHARS =
kj::parse::controlChar.orChar('\x7f')
.orGroup(kj::parse::whitespaceChar)
.orGroup(HTTP_SEPARATOR_CHARS)
.invert();
// RFC2616 section 2.2: https://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2
constexpr auto HTTP_HEADER_NAME_CHARS = HTTP_TOKEN_CHARS;
// RFC2616 section 4.2: https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
static void requireValidHeaderName(kj::StringPtr name) {
for (char c: name) {
KJ_REQUIRE(HTTP_HEADER_NAME_CHARS.contains(c), "invalid header name", name);
}
}
static void requireValidHeaderValue(kj::StringPtr value) {
KJ_REQUIRE(HttpHeaders::isValidHeaderValue(value), "invalid header value",
kj::encodeCEscape(value));
}
static const char* BUILTIN_HEADER_NAMES[] = {
// Indexed by header ID, which includes connection headers, so we include those names too.
#define HEADER_NAME(id, name) name,
KJ_HTTP_FOR_EACH_BUILTIN_HEADER(HEADER_NAME)
#undef HEADER_NAME
};
} // namespace
#define HEADER_ID(id, name) constexpr uint HttpHeaders::BuiltinIndices::id;
KJ_HTTP_FOR_EACH_BUILTIN_HEADER(HEADER_ID)
#undef HEADER_ID
#define DEFINE_HEADER(id, name) \
const HttpHeaderId HttpHeaderId::id(nullptr, HttpHeaders::BuiltinIndices::id);
KJ_HTTP_FOR_EACH_BUILTIN_HEADER(DEFINE_HEADER)
#undef DEFINE_HEADER
kj::StringPtr HttpHeaderId::toString() const {
if (table == nullptr) {
KJ_ASSERT(id < kj::size(BUILTIN_HEADER_NAMES));
return BUILTIN_HEADER_NAMES[id];
} else {
return table->idToString(*this);
}
}
namespace {
struct HeaderNameHash {
size_t operator()(kj::StringPtr s) const {
size_t result = 5381;
for (byte b: s.asBytes()) {
// Masking bit 0x20 makes our hash case-insensitive while conveniently avoiding any
// collisions that would matter for header names.
result = ((result << 5) + result) ^ (b & ~0x20);
}
return result;
}
bool operator()(kj::StringPtr a, kj::StringPtr b) const {
// TODO(perf): I wonder if we can beat strcasecmp() by masking bit 0x20 from each byte. We'd
// need to prohibit one of the technically-legal characters '^' or '~' from header names
// since they'd otherwise be ambiguous, but otherwise there is no ambiguity.
#if _MSC_VER
return _stricmp(a.cStr(), b.cStr()) == 0;
#else
return strcasecmp(a.cStr(), b.cStr()) == 0;
#endif
}
};
} // namespace
struct HttpHeaderTable::IdsByNameMap {
// TODO(perf): If we were cool we could maybe use a perfect hash here, since our hashtable is
// static once built.
std::unordered_map<kj::StringPtr, uint, HeaderNameHash, HeaderNameHash> map;
};
HttpHeaderTable::Builder::Builder()
: table(kj::heap<HttpHeaderTable>()) {}
HttpHeaderId HttpHeaderTable::Builder::add(kj::StringPtr name) {
requireValidHeaderName(name);
auto insertResult = table->idsByName->map.insert(std::make_pair(name, table->namesById.size()));
if (insertResult.second) {
table->namesById.add(name);
}
return HttpHeaderId(table, insertResult.first->second);
}
HttpHeaderTable::HttpHeaderTable()
: idsByName(kj::heap<IdsByNameMap>()) {
#define ADD_HEADER(id, name) \
namesById.add(name); \
idsByName->map.insert(std::make_pair(name, HttpHeaders::BuiltinIndices::id));
KJ_HTTP_FOR_EACH_BUILTIN_HEADER(ADD_HEADER);
#undef ADD_HEADER
}
HttpHeaderTable::~HttpHeaderTable() noexcept(false) {}
kj::Maybe<HttpHeaderId> HttpHeaderTable::stringToId(kj::StringPtr name) const {
auto iter = idsByName->map.find(name);
if (iter == idsByName->map.end()) {
return nullptr;
} else {
return HttpHeaderId(this, iter->second);
}
}
// =======================================================================================
bool HttpHeaders::isValidHeaderValue(kj::StringPtr value) {
for (char c: value) {
// While the HTTP spec suggests that only printable ASCII characters are allowed in header
// values, reality has a different opinion. See: https://github.com/httpwg/http11bis/issues/19
// We follow the browsers' lead.
if (c == '\0' || c == '\r' || c == '\n') {
return false;
}
}
return true;
}
HttpHeaders::HttpHeaders(const HttpHeaderTable& table)
: table(&table),
indexedHeaders(kj::heapArray<kj::StringPtr>(table.idCount())) {}
void HttpHeaders::clear() {
for (auto& header: indexedHeaders) {
header = nullptr;
}
unindexedHeaders.clear();
}
size_t HttpHeaders::size() const {
size_t result = unindexedHeaders.size();
for (auto i: kj::indices(indexedHeaders)) {
if (indexedHeaders[i] != nullptr) {
++result;
}
}
return result;
}
HttpHeaders HttpHeaders::clone() const {
HttpHeaders result(*table);
for (auto i: kj::indices(indexedHeaders)) {
if (indexedHeaders[i] != nullptr) {
result.indexedHeaders[i] = result.cloneToOwn(indexedHeaders[i]);
}
}
result.unindexedHeaders.resize(unindexedHeaders.size());
for (auto i: kj::indices(unindexedHeaders)) {
result.unindexedHeaders[i].name = result.cloneToOwn(unindexedHeaders[i].name);
result.unindexedHeaders[i].value = result.cloneToOwn(unindexedHeaders[i].value);
}
return result;
}
HttpHeaders HttpHeaders::cloneShallow() const {
HttpHeaders result(*table);
for (auto i: kj::indices(indexedHeaders)) {
if (indexedHeaders[i] != nullptr) {
result.indexedHeaders[i] = indexedHeaders[i];
}
}
result.unindexedHeaders.resize(unindexedHeaders.size());
for (auto i: kj::indices(unindexedHeaders)) {
result.unindexedHeaders[i] = unindexedHeaders[i];
}
return result;
}
kj::StringPtr HttpHeaders::cloneToOwn(kj::StringPtr str) {
auto copy = kj::heapString(str);
kj::StringPtr result = copy;
ownedStrings.add(copy.releaseArray());
return result;
}
namespace {
template <char... chars>
constexpr bool fastCaseCmp(const char* actual);
} // namespace
bool HttpHeaders::isWebSocket() const {
return fastCaseCmp<'w', 'e', 'b', 's', 'o', 'c', 'k', 'e', 't'>(
get(HttpHeaderId::UPGRADE).orDefault(nullptr).cStr());
}
void HttpHeaders::set(HttpHeaderId id, kj::StringPtr value) {
id.requireFrom(*table);
requireValidHeaderValue(value);
indexedHeaders[id.id] = value;
}
void HttpHeaders::set(HttpHeaderId id, kj::String&& value) {
set(id, kj::StringPtr(value));
takeOwnership(kj::mv(value));
}
void HttpHeaders::add(kj::StringPtr name, kj::StringPtr value) {
requireValidHeaderName(name);
requireValidHeaderValue(value);
addNoCheck(name, value);
}
void HttpHeaders::add(kj::StringPtr name, kj::String&& value) {
add(name, kj::StringPtr(value));
takeOwnership(kj::mv(value));
}
void HttpHeaders::add(kj::String&& name, kj::String&& value) {
add(kj::StringPtr(name), kj::StringPtr(value));
takeOwnership(kj::mv(name));
takeOwnership(kj::mv(value));
}
void HttpHeaders::addNoCheck(kj::StringPtr name, kj::StringPtr value) {
KJ_IF_MAYBE(id, table->stringToId(name)) {
if (indexedHeaders[id->id] == nullptr) {
indexedHeaders[id->id] = value;
} else {
// Duplicate HTTP headers are equivalent to the values being separated by a comma.
#if _MSC_VER
if (_stricmp(name.cStr(), "set-cookie") == 0) {
#else
if (strcasecmp(name.cStr(), "set-cookie") == 0) {
#endif
// Uh-oh, Set-Cookie will be corrupted if we try to concatenate it. We'll make it an
// unindexed header, which is weird, but the alternative is guaranteed corruption, so...
// TODO(cleanup): Maybe HttpHeaders should just special-case set-cookie in general?
unindexedHeaders.add(Header {name, value});
} else {
auto concat = kj::str(indexedHeaders[id->id], ", ", value);
indexedHeaders[id->id] = concat;
ownedStrings.add(concat.releaseArray());
}
}
} else {
unindexedHeaders.add(Header {name, value});
}
}
void HttpHeaders::takeOwnership(kj::String&& string) {
ownedStrings.add(string.releaseArray());
}
void HttpHeaders::takeOwnership(kj::Array<char>&& chars) {
ownedStrings.add(kj::mv(chars));
}
void HttpHeaders::takeOwnership(HttpHeaders&& otherHeaders) {
for (auto& str: otherHeaders.ownedStrings) {
ownedStrings.add(kj::mv(str));
}
otherHeaders.ownedStrings.clear();
}
// -----------------------------------------------------------------------------
static inline char* skipSpace(char* p) {
for (;;) {
switch (*p) {
case '\t':
case ' ':
++p;
break;
default:
return p;
}
}
}
static kj::Maybe<kj::StringPtr> consumeWord(char*& ptr) {
char* start = skipSpace(ptr);
char* p = start;
for (;;) {
switch (*p) {
case '\0':
ptr = p;
return kj::StringPtr(start, p);
case '\t':
case ' ': {
char* end = p++;
ptr = p;
*end = '\0';
return kj::StringPtr(start, end);
}
case '\n':
case '\r':
// Not expecting EOL!
return nullptr;
default:
++p;
break;
}
}
}
static kj::Maybe<uint> consumeNumber(char*& ptr) {
char* start = skipSpace(ptr);
char* p = start;
uint result = 0;
for (;;) {
char c = *p;
if ('0' <= c && c <= '9') {
result = result * 10 + (c - '0');
++p;
} else {
if (p == start) return nullptr;
ptr = p;
return result;
}
}
}
static kj::StringPtr consumeLine(char*& ptr) {
char* start = skipSpace(ptr);
char* p = start;
for (;;) {
switch (*p) {
case '\0':
ptr = p;
return kj::StringPtr(start, p);
case '\r': {
char* end = p++;
if (*p == '\n') ++p;
if (*p == ' ' || *p == '\t') {
// Whoa, continuation line. These are deprecated, but historically a line starting with
// a space was treated as a continuation of the previous line. The behavior should be
// the same as if the \r\n were replaced with spaces, so let's do that here to prevent
// confusion later.
*end = ' ';
p[-1] = ' ';
break;
}
ptr = p;
*end = '\0';
return kj::StringPtr(start, end);
}
case '\n': {
char* end = p++;
if (*p == ' ' || *p == '\t') {
// Whoa, continuation line. These are deprecated, but historically a line starting with
// a space was treated as a continuation of the previous line. The behavior should be
// the same as if the \n were replaced with spaces, so let's do that here to prevent
// confusion later.
*end = ' ';
break;
}
ptr = p;
*end = '\0';
return kj::StringPtr(start, end);
}
default:
++p;
break;
}
}
}
static kj::Maybe<kj::StringPtr> consumeHeaderName(char*& ptr) {
// Do NOT skip spaces before the header name. Leading spaces indicate a continuation line; they
// should have been handled in consumeLine().
char* p = ptr;
char* start = p;
while (HTTP_HEADER_NAME_CHARS.contains(*p)) ++p;
char* end = p;
p = skipSpace(p);
if (end == start || *p != ':') return nullptr;
++p;
p = skipSpace(p);
*end = '\0';
ptr = p;
return kj::StringPtr(start, end);
}
static char* trimHeaderEnding(kj::ArrayPtr<char> content) {
// Trim off the trailing \r\n from a header blob.
if (content.size() < 2) return nullptr;
// Remove trailing \r\n\r\n and replace with \0 sentinel char.
char* end = content.end();
if (end[-1] != '\n') return nullptr;
--end;
if (end[-1] == '\r') --end;
*end = '\0';
return end;
}
HttpHeaders::RequestOrProtocolError HttpHeaders::tryParseRequest(kj::ArrayPtr<char> content) {
char* end = trimHeaderEnding(content);
if (end == nullptr) {
return ProtocolError { 400, "Bad Request",
"Request headers have no terminal newline.", content };
}
char* ptr = content.begin();
HttpHeaders::Request request;
KJ_IF_MAYBE(method, consumeHttpMethod(ptr)) {
request.method = *method;
if (*ptr != ' ' && *ptr != '\t') {
return ProtocolError { 501, "Not Implemented",
"Unrecognized request method.", content };
}
++ptr;
} else {
return ProtocolError { 501, "Not Implemented",
"Unrecognized request method.", content };
}
KJ_IF_MAYBE(path, consumeWord(ptr)) {
request.url = *path;
} else {
return ProtocolError { 400, "Bad Request",
"Invalid request line.", content };
}
// Ignore rest of line. Don't care about "HTTP/1.1" or whatever.
consumeLine(ptr);
if (!parseHeaders(ptr, end)) {
return ProtocolError { 400, "Bad Request",
"The headers sent by your client are not valid.", content };
}
return request;
}
HttpHeaders::ResponseOrProtocolError HttpHeaders::tryParseResponse(kj::ArrayPtr<char> content) {
char* end = trimHeaderEnding(content);
if (end == nullptr) {
return ProtocolError { 502, "Bad Gateway",
"Response headers have no terminal newline.", content };
}
char* ptr = content.begin();
HttpHeaders::Response response;
KJ_IF_MAYBE(version, consumeWord(ptr)) {
if (!version->startsWith("HTTP/")) {
return ProtocolError { 502, "Bad Gateway",
"Invalid response status line (invalid protocol).", content };
}
} else {
return ProtocolError { 502, "Bad Gateway",
"Invalid response status line (no spaces).", content };
}
KJ_IF_MAYBE(code, consumeNumber(ptr)) {
response.statusCode = *code;
} else {
return ProtocolError { 502, "Bad Gateway",
"Invalid response status line (invalid status code).", content };
}
response.statusText = consumeLine(ptr);
if (!parseHeaders(ptr, end)) {
return ProtocolError { 502, "Bad Gateway",
"The headers sent by the server are not valid.", content };
}
return response;
}
bool HttpHeaders::tryParse(kj::ArrayPtr<char> content) {
char* end = trimHeaderEnding(content);
if (end == nullptr) return false;
char* ptr = content.begin();
return parseHeaders(ptr, end);
}
bool HttpHeaders::parseHeaders(char* ptr, char* end) {
while (*ptr != '\0') {
KJ_IF_MAYBE(name, consumeHeaderName(ptr)) {
kj::StringPtr line = consumeLine(ptr);
addNoCheck(*name, line);
} else {
return false;
}
}
return ptr == end;
}
// -----------------------------------------------------------------------------
kj::String HttpHeaders::serializeRequest(
HttpMethod method, kj::StringPtr url,
kj::ArrayPtr<const kj::StringPtr> connectionHeaders) const {
return serialize(kj::toCharSequence(method), url, kj::StringPtr("HTTP/1.1"), connectionHeaders);
}
kj::String HttpHeaders::serializeResponse(
uint statusCode, kj::StringPtr statusText,
kj::ArrayPtr<const kj::StringPtr> connectionHeaders) const {
auto statusCodeStr = kj::toCharSequence(statusCode);
return serialize(kj::StringPtr("HTTP/1.1"), statusCodeStr, statusText, connectionHeaders);
}
kj::String HttpHeaders::serialize(kj::ArrayPtr<const char> word1,
kj::ArrayPtr<const char> word2,
kj::ArrayPtr<const char> word3,
kj::ArrayPtr<const kj::StringPtr> connectionHeaders) const {
const kj::StringPtr space = " ";
const kj::StringPtr newline = "\r\n";
const kj::StringPtr colon = ": ";
size_t size = 2; // final \r\n
if (word1 != nullptr) {
size += word1.size() + word2.size() + word3.size() + 4;
}
KJ_ASSERT(connectionHeaders.size() <= indexedHeaders.size());
for (auto i: kj::indices(indexedHeaders)) {
kj::StringPtr value = i < connectionHeaders.size() ? connectionHeaders[i] : indexedHeaders[i];
if (value != nullptr) {
size += table->idToString(HttpHeaderId(table, i)).size() + value.size() + 4;
}
}
for (auto& header: unindexedHeaders) {
size += header.name.size() + header.value.size() + 4;
}
String result = heapString(size);
char* ptr = result.begin();
if (word1 != nullptr) {
ptr = kj::_::fill(ptr, word1, space, word2, space, word3, newline);
}
for (auto i: kj::indices(indexedHeaders)) {
kj::StringPtr value = i < connectionHeaders.size() ? connectionHeaders[i] : indexedHeaders[i];
if (value != nullptr) {
ptr = kj::_::fill(ptr, table->idToString(HttpHeaderId(table, i)), colon, value, newline);
}
}
for (auto& header: unindexedHeaders) {
ptr = kj::_::fill(ptr, header.name, colon, header.value, newline);
}
ptr = kj::_::fill(ptr, newline);
KJ_ASSERT(ptr == result.end());
return result;
}
kj::String HttpHeaders::toString() const {
return serialize(nullptr, nullptr, nullptr, nullptr);
}
// =======================================================================================
namespace {
static constexpr size_t MIN_BUFFER = 4096;
static constexpr size_t MAX_BUFFER = 128 * 1024;
static constexpr size_t MAX_CHUNK_HEADER_SIZE = 32;
class HttpInputStreamImpl final: public HttpInputStream {
public:
explicit HttpInputStreamImpl(AsyncInputStream& inner, const HttpHeaderTable& table)
: inner(inner), headerBuffer(kj::heapArray<char>(MIN_BUFFER)), headers(table) {
}
explicit HttpInputStreamImpl(AsyncInputStream& inner,
kj::Array<char> headerBufferParam,
kj::ArrayPtr<char> leftoverParam,
HttpMethod method,
kj::StringPtr url,
HttpHeaders headers)
: inner(inner),
headerBuffer(kj::mv(headerBufferParam)),
// Initialize `messageHeaderEnd` to a safe value, we'll adjust it below.
messageHeaderEnd(leftoverParam.begin() - headerBuffer.begin()),
leftover(leftoverParam),
headers(kj::mv(headers)),
resumingRequest(HttpHeaders::Request { method, url }) {
// Constructor used for resuming a SuspendedRequest.
// We expect headerBuffer to look like this:
// <method> <url> <headers> [CR] LF <leftover>
// We initialized `messageHeaderEnd` to the beginning of `leftover`, but we want to point it at
// the CR (or LF if there's no CR).
KJ_REQUIRE(messageHeaderEnd >= 2 && leftover.end() <= headerBuffer.end(),
"invalid SuspendedRequest - leftover buffer not where it should be");
KJ_REQUIRE(leftover.begin()[-1] == '\n', "invalid SuspendedRequest - missing LF");
messageHeaderEnd -= 1 + (leftover.begin()[-2] == '\r');
// We're in the middle of a message, so set up our state as such. Note that the only way to
// resume a SuspendedRequest is via an HttpServer, but HttpServers never call
// `awaitNextMessage()` before fully reading request bodies, meaning we expect that
// `messageReadQueue` will never be used.
++pendingMessageCount;
auto paf = kj::newPromiseAndFulfiller<void>();
onMessageDone = kj::mv(paf.fulfiller);
messageReadQueue = kj::mv(paf.promise);
}
bool canReuse() {
return !broken && pendingMessageCount == 0;
}
bool canSuspend() {
// We are at a suspendable point if we've parsed the headers, but haven't consumed anything
// beyond that.
//
// TODO(cleanup): This is a silly check; we need a more defined way to track the state of the
// stream.
bool messageHeaderEndLooksRight =
(leftover.begin() - (headerBuffer.begin() + messageHeaderEnd) == 2 &&
leftover.begin()[-1] == '\n' && leftover.begin()[-2] == '\r')
|| (leftover.begin() - (headerBuffer.begin() + messageHeaderEnd) == 1 &&
leftover.begin()[-1] == '\n');
return !broken && headerBuffer.size() > 0 && messageHeaderEndLooksRight;
}
// ---------------------------------------------------------------------------
// public interface
kj::Promise<Request> readRequest() override {
return readRequestHeaders()
.then([this](HttpHeaders::RequestOrProtocolError&& requestOrProtocolError)
-> HttpInputStream::Request {
auto request = KJ_REQUIRE_NONNULL(
requestOrProtocolError.tryGet<HttpHeaders::Request>(), "bad request");
auto body = getEntityBody(HttpInputStreamImpl::REQUEST, request.method, 0, headers);
return { request.method, request.url, headers, kj::mv(body) };
});
}
kj::Promise<Response> readResponse(HttpMethod requestMethod) override {
return readResponseHeaders()
.then([this,requestMethod](HttpHeaders::ResponseOrProtocolError&& responseOrProtocolError)
-> HttpInputStream::Response {
auto response = KJ_REQUIRE_NONNULL(
responseOrProtocolError.tryGet<HttpHeaders::Response>(), "bad response");
auto body = getEntityBody(HttpInputStreamImpl::RESPONSE, requestMethod,
response.statusCode, headers);
return { response.statusCode, response.statusText, headers, kj::mv(body) };
});
}
kj::Promise<Message> readMessage() override {
return readMessageHeaders()
.then([this](kj::ArrayPtr<char> text) -> HttpInputStream::Message {
headers.clear();
KJ_REQUIRE(headers.tryParse(text), "bad message");
auto body = getEntityBody(HttpInputStreamImpl::RESPONSE, HttpMethod::GET, 0, headers);
return { headers, kj::mv(body) };
});
}
// ---------------------------------------------------------------------------
// Stream locking: While an entity-body is being read, the body stream "locks" the underlying
// HTTP stream. Once the entity-body is complete, we can read the next pipelined message.
void finishRead() {
// Called when entire request has been read.
KJ_REQUIRE_NONNULL(onMessageDone)->fulfill();
onMessageDone = nullptr;
--pendingMessageCount;
}
void abortRead() {
// Called when a body input stream was destroyed without reading to the end.
KJ_REQUIRE_NONNULL(onMessageDone)->reject(KJ_EXCEPTION(FAILED,
"application did not finish reading previous HTTP response body",
"can't read next pipelined request/response"));
onMessageDone = nullptr;
broken = true;
}
// ---------------------------------------------------------------------------
kj::Promise<bool> awaitNextMessage() override {
// Waits until more data is available, but doesn't consume it. Returns false on EOF.
//
// Used on the server after a request is handled, to check for pipelined requests.
//
// Used on the client to detect when idle connections are closed from the server end. (In this
// case, the promise always returns false or is canceled.)
if (resumingRequest != nullptr) {
// We're resuming a request, so report that we have a message.
return true;
}
if (onMessageDone != nullptr) {
// We're still working on reading the previous body.
auto fork = messageReadQueue.fork();
messageReadQueue = fork.addBranch();
return fork.addBranch().then([this]() {
return awaitNextMessage();
});
}
snarfBufferedLineBreak();
if (!lineBreakBeforeNextHeader && leftover != nullptr) {
return true;
}
return inner.tryRead(headerBuffer.begin(), 1, headerBuffer.size())
.then([this](size_t amount) -> kj::Promise<bool> {
if (amount > 0) {
leftover = headerBuffer.slice(0, amount);
return awaitNextMessage();
} else {
return false;
}
});
}
bool isCleanDrain() {
// Returns whether we can cleanly drain the stream at this point.
if (onMessageDone != nullptr) return false;
snarfBufferedLineBreak();
return !lineBreakBeforeNextHeader && leftover == nullptr;
}
kj::Promise<kj::ArrayPtr<char>> readMessageHeaders() {
++pendingMessageCount;
auto paf = kj::newPromiseAndFulfiller<void>();
auto promise = messageReadQueue
.then(kj::mvCapture(paf.fulfiller, [this](kj::Own<kj::PromiseFulfiller<void>> fulfiller) {
onMessageDone = kj::mv(fulfiller);
return readHeader(HeaderType::MESSAGE, 0, 0);
}));
messageReadQueue = kj::mv(paf.promise);
return promise;
}
kj::Promise<uint64_t> readChunkHeader() {
KJ_REQUIRE(onMessageDone != nullptr);
// We use the portion of the header after the end of message headers.
return readHeader(HeaderType::CHUNK, messageHeaderEnd, messageHeaderEnd)
.then([](kj::ArrayPtr<char> text) -> uint64_t {
KJ_REQUIRE(text.size() > 0) { break; }
uint64_t value = 0;
for (char c: text) {
if ('0' <= c && c <= '9') {
value = value * 16 + (c - '0');
} else if ('a' <= c && c <= 'f') {
value = value * 16 + (c - 'a' + 10);
} else if ('A' <= c && c <= 'F') {
value = value * 16 + (c - 'A' + 10);
} else {
KJ_FAIL_REQUIRE("invalid HTTP chunk size", text, text.asBytes()) { break; }
return value;
}
}
return value;
});
}
inline kj::Promise<HttpHeaders::RequestOrProtocolError> readRequestHeaders() {
KJ_IF_MAYBE(resuming, resumingRequest) {
KJ_DEFER(resumingRequest = nullptr);
return HttpHeaders::RequestOrProtocolError(*resuming);
}
return readMessageHeaders().then([this](kj::ArrayPtr<char> text) {
headers.clear();
return headers.tryParseRequest(text);
});
}
inline kj::Promise<HttpHeaders::ResponseOrProtocolError> readResponseHeaders() {
// Note: readResponseHeaders() could be called multiple times concurrently when pipelining
// requests. readMessageHeaders() will serialize these, but it's important not to mess with
// state (like calling headers.clear()) before said serialization has taken place.
return readMessageHeaders().then([this](kj::ArrayPtr<char> text) {
headers.clear();
return headers.tryParseResponse(text);
});
}
inline const HttpHeaders& getHeaders() const { return headers; }
Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) {
// Read message body data.
KJ_REQUIRE(onMessageDone != nullptr);
if (leftover == nullptr) {
// No leftovers. Forward directly to inner stream.
return inner.tryRead(buffer, minBytes, maxBytes);
} else if (leftover.size() >= maxBytes) {
// Didn't even read the entire leftover buffer.
memcpy(buffer, leftover.begin(), maxBytes);
leftover = leftover.slice(maxBytes, leftover.size());
return maxBytes;
} else {
// Read the entire leftover buffer, plus some.
memcpy(buffer, leftover.begin(), leftover.size());
size_t copied = leftover.size();
leftover = nullptr;
if (copied >= minBytes) {
// Got enough to stop here.
return copied;
} else {
// Read the rest from the underlying stream.
return inner.tryRead(reinterpret_cast<byte*>(buffer) + copied,
minBytes - copied, maxBytes - copied)
.then([copied](size_t n) { return n + copied; });
}
}
}
enum RequestOrResponse {
REQUEST,
RESPONSE
};
kj::Own<kj::AsyncInputStream> getEntityBody(
RequestOrResponse type, HttpMethod method, uint statusCode,
const kj::HttpHeaders& headers);
struct ReleasedBuffer {
kj::Array<byte> buffer;
kj::ArrayPtr<byte> leftover;
};
ReleasedBuffer releaseBuffer() {
return { headerBuffer.releaseAsBytes(), leftover.asBytes() };
}
private:
AsyncInputStream& inner;
kj::Array<char> headerBuffer;
size_t messageHeaderEnd = 0;
// Position in headerBuffer where the message headers end -- further buffer space can
// be used for chunk headers.
kj::ArrayPtr<char> leftover;
// Data in headerBuffer that comes immediately after the header content, if any.
HttpHeaders headers;
// Parsed headers, after a call to parseAwaited*().
kj::Maybe<HttpHeaders::Request> resumingRequest;
// Non-null if we're resuming a SuspendedRequest.
bool lineBreakBeforeNextHeader = false;
// If true, the next await should expect to start with a spurious '\n' or '\r\n'. This happens
// as a side-effect of HTTP chunked encoding, where such a newline is added to the end of each
// chunk, for no good reason.
bool broken = false;
// Becomes true if the caller failed to read the whole entity-body before closing the stream.
uint pendingMessageCount = 0;
// Number of reads we have queued up.
kj::Promise<void> messageReadQueue = kj::READY_NOW;
kj::Maybe<kj::Own<kj::PromiseFulfiller<void>>> onMessageDone;
// Fulfill once the current message has been completely read. Unblocks reading of the next
// message headers.
enum class HeaderType {
MESSAGE,
CHUNK
};
kj::Promise<kj::ArrayPtr<char>> readHeader(
HeaderType type, size_t bufferStart, size_t bufferEnd) {
// Reads the HTTP message header or a chunk header (as in transfer-encoding chunked) and
// returns the buffer slice containing it.
//
// The main source of complication here is that we want to end up with one continuous buffer
// containing the result, and that the input is delimited by newlines rather than by an upfront
// length.
kj::Promise<size_t> readPromise = nullptr;
// Figure out where we're reading from.
if (leftover != nullptr) {
// Some data is still left over from the previous message, so start with that.
// This can only happen if this is the initial call to readHeader() (not recursive).
KJ_ASSERT(bufferStart == bufferEnd);
// OK, set bufferStart and bufferEnd to both point to the start of the leftover, and then
// fake a read promise as if we read the bytes from the leftover.
bufferStart = leftover.begin() - headerBuffer.begin();
bufferEnd = bufferStart;
readPromise = leftover.size();
leftover = nullptr;
} else {
// Need to read more data from the underlying stream.
if (bufferEnd == headerBuffer.size()) {
// Out of buffer space.
// Maybe we can move bufferStart backwards to make more space at the end?
size_t minStart = type == HeaderType::MESSAGE ? 0 : messageHeaderEnd;
if (bufferStart > minStart) {
// Move to make space.
memmove(headerBuffer.begin() + minStart, headerBuffer.begin() + bufferStart,
bufferEnd - bufferStart);
bufferEnd = bufferEnd - bufferStart + minStart;
bufferStart = minStart;
} else {
// Really out of buffer space. Grow the buffer.
if (type != HeaderType::MESSAGE) {
// Can't grow because we'd invalidate the HTTP headers.
return KJ_EXCEPTION(FAILED, "invalid HTTP chunk size");
}
KJ_REQUIRE(headerBuffer.size() < MAX_BUFFER, "request headers too large");
auto newBuffer = kj::heapArray<char>(headerBuffer.size() * 2);
memcpy(newBuffer.begin(), headerBuffer.begin(), headerBuffer.size());
headerBuffer = kj::mv(newBuffer);
}
}
// How many bytes will we read?
size_t maxBytes = headerBuffer.size() - bufferEnd;
if (type == HeaderType::CHUNK) {
// Roughly limit the amount of data we read to MAX_CHUNK_HEADER_SIZE.
// TODO(perf): This is mainly to avoid copying a lot of body data into our buffer just to
// copy it again when it is read. But maybe the copy would be cheaper than overhead of
// extra event loop turns?
KJ_REQUIRE(bufferEnd - bufferStart <= MAX_CHUNK_HEADER_SIZE, "invalid HTTP chunk size");
maxBytes = kj::min(maxBytes, MAX_CHUNK_HEADER_SIZE);
}
readPromise = inner.read(headerBuffer.begin() + bufferEnd, 1, maxBytes);
}
return readPromise.then([this,type,bufferStart,bufferEnd](size_t amount) mutable
-> kj::Promise<kj::ArrayPtr<char>> {
if (lineBreakBeforeNextHeader) {
// Hackily deal with expected leading line break.
if (bufferEnd == bufferStart && headerBuffer[bufferEnd] == '\r') {
++bufferEnd;
--amount;
}
if (amount > 0 && headerBuffer[bufferEnd] == '\n') {
lineBreakBeforeNextHeader = false;
++bufferEnd;
--amount;
// Cut the leading line break out of the buffer entirely.
bufferStart = bufferEnd;
}
if (amount == 0) {
return readHeader(type, bufferStart, bufferEnd);
}
}
size_t pos = bufferEnd;
size_t newEnd = pos + amount;
for (;;) {
// Search for next newline.
char* nl = reinterpret_cast<char*>(
memchr(headerBuffer.begin() + pos, '\n', newEnd - pos));
if (nl == nullptr) {
// No newline found. Wait for more data.
return readHeader(type, bufferStart, newEnd);
}
// Is this newline which we found the last of the header? For a chunk header, always. For
// a message header, we search for two newlines in a row. We accept either "\r\n" or just
// "\n" as a newline sequence (though the standard requires "\r\n").
if (type == HeaderType::CHUNK ||
(nl - headerBuffer.begin() >= 4 &&
((nl[-1] == '\r' && nl[-2] == '\n') || (nl[-1] == '\n')))) {
// OK, we've got all the data!
size_t endIndex = nl + 1 - headerBuffer.begin();
size_t leftoverStart = endIndex;
// Strip off the last newline from end.
endIndex -= 1 + (nl[-1] == '\r');
if (type == HeaderType::MESSAGE) {
if (headerBuffer.size() - newEnd < MAX_CHUNK_HEADER_SIZE) {
// Ugh, there's not enough space for the secondary await buffer. Grow once more.
auto newBuffer = kj::heapArray<char>(headerBuffer.size() * 2);
memcpy(newBuffer.begin(), headerBuffer.begin(), headerBuffer.size());
headerBuffer = kj::mv(newBuffer);
}
messageHeaderEnd = endIndex;
} else {
// For some reason, HTTP specifies that there will be a line break after each chunk.
lineBreakBeforeNextHeader = true;
}
auto result = headerBuffer.slice(bufferStart, endIndex);
leftover = headerBuffer.slice(leftoverStart, newEnd);
return result;
} else {
pos = nl - headerBuffer.begin() + 1;
}
}
});
}
void snarfBufferedLineBreak() {
// Slightly-crappy code to snarf the expected line break. This will actually eat the leading
// regex /\r*\n?/.
while (lineBreakBeforeNextHeader && leftover.size() > 0) {
if (leftover[0] == '\r') {
leftover = leftover.slice(1, leftover.size());
} else if (leftover[0] == '\n') {
leftover = leftover.slice(1, leftover.size());
lineBreakBeforeNextHeader = false;
} else {
// Err, missing line break, whatever.
lineBreakBeforeNextHeader = false;
}
}
}
};
// -----------------------------------------------------------------------------
class HttpEntityBodyReader: public kj::AsyncInputStream {
public:
HttpEntityBodyReader(HttpInputStreamImpl& inner): inner(inner) {}
~HttpEntityBodyReader() noexcept(false) {
if (!finished) {
inner.abortRead();
}
}
protected:
HttpInputStreamImpl& inner;
void doneReading() {
KJ_REQUIRE(!finished);
finished = true;
inner.finishRead();
}
inline bool alreadyDone() { return finished; }
private:
bool finished = false;
};
class HttpNullEntityReader final: public HttpEntityBodyReader {
// Stream for an entity-body which is not present. Always returns EOF on read, but tryGetLength()
// may indicate non-zero in the special case of a response to a HEAD request.
public:
HttpNullEntityReader(HttpInputStreamImpl& inner, kj::Maybe<uint64_t> length)
: HttpEntityBodyReader(inner), length(length) {
// `length` is what to return from tryGetLength(). For a response to a HEAD request, this may
// be non-zero.
doneReading();
}
Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
return size_t(0);
}
Maybe<uint64_t> tryGetLength() override {
return length;
}
private:
kj::Maybe<uint64_t> length;
};
class HttpConnectionCloseEntityReader final: public HttpEntityBodyReader {
// Stream which reads until EOF.
public:
HttpConnectionCloseEntityReader(HttpInputStreamImpl& inner)
: HttpEntityBodyReader(inner) {}
Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
if (alreadyDone()) return size_t(0);
return inner.tryRead(buffer, minBytes, maxBytes)
.then([=](size_t amount) {
if (amount < minBytes) {
doneReading();
}
return amount;
});
}
};
class HttpFixedLengthEntityReader final: public HttpEntityBodyReader {
// Stream which reads only up to a fixed length from the underlying stream, then emulates EOF.
public:
HttpFixedLengthEntityReader(HttpInputStreamImpl& inner, size_t length)
: HttpEntityBodyReader(inner), length(length) {
if (length == 0) doneReading();
}
Maybe<uint64_t> tryGetLength() override {
return length;
}
Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
return tryReadInternal(buffer, minBytes, maxBytes, 0);
}
private:
size_t length;
Promise<size_t> tryReadInternal(void* buffer, size_t minBytes, size_t maxBytes,
size_t alreadyRead) {
if (length == 0) return size_t(0);
// We have to set minBytes to 1 here so that if we read any data at all, we update our
// counter immediately, so that we still know where we are in case of cancellation.
return inner.tryRead(buffer, 1, kj::min(maxBytes, length))
.then([=](size_t amount) -> kj::Promise<size_t> {
length -= amount;
if (length > 0) {
// We haven't reached the end of the entity body yet.
if (amount == 0) {
kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED,
"premature EOF in HTTP entity body; did not reach Content-Length"));
} else if (amount < minBytes) {
// We requested a minimum 1 byte above, but our own caller actually set a larger minimum
// which has not yet been reached. Keep trying until we reach it.
return tryReadInternal(reinterpret_cast<byte*>(buffer) + amount,
minBytes - amount, maxBytes - amount, alreadyRead + amount);
}
} else if (length == 0) {
doneReading();
}
return amount + alreadyRead;
});
}
};
class HttpChunkedEntityReader final: public HttpEntityBodyReader {
// Stream which reads a Transfer-Encoding: Chunked stream.
public:
HttpChunkedEntityReader(HttpInputStreamImpl& inner)
: HttpEntityBodyReader(inner) {}
Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
return tryReadInternal(buffer, minBytes, maxBytes, 0);
}
private:
size_t chunkSize = 0;
Promise<size_t> tryReadInternal(void* buffer, size_t minBytes, size_t maxBytes,
size_t alreadyRead) {
if (alreadyDone()) {
return alreadyRead;
} else if (chunkSize == 0) {
// Read next chunk header.
return inner.readChunkHeader().then([=](uint64_t nextChunkSize) {
if (nextChunkSize == 0) {
doneReading();
}
chunkSize = nextChunkSize;
return tryReadInternal(buffer, minBytes, maxBytes, alreadyRead);
});
} else {
// Read current chunk.
// We have to set minBytes to 1 here so that if we read any data at all, we update our
// counter immediately, so that we still know where we are in case of cancellation.
return inner.tryRead(buffer, 1, kj::min(maxBytes, chunkSize))
.then([=](size_t amount) -> kj::Promise<size_t> {
chunkSize -= amount;
if (amount == 0) {
kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "premature EOF in HTTP chunk"));
} else if (amount < minBytes) {
// We requested a minimum 1 byte above, but our own caller actually set a larger minimum
// which has not yet been reached. Keep trying until we reach it.
return tryReadInternal(reinterpret_cast<byte*>(buffer) + amount,
minBytes - amount, maxBytes - amount, alreadyRead + amount);
}
return alreadyRead + amount;
});
}
}
};
template <char...>
struct FastCaseCmp;
template <char first, char... rest>
struct FastCaseCmp<first, rest...> {
static constexpr bool apply(const char* actual) {
return
('a' <= first && first <= 'z') || ('A' <= first && first <= 'Z')
? (*actual | 0x20) == (first | 0x20) && FastCaseCmp<rest...>::apply(actual + 1)
: *actual == first && FastCaseCmp<rest...>::apply(actual + 1);
}
};
template <>
struct FastCaseCmp<> {
static constexpr bool apply(const char* actual) {
return *actual == '\0';
}
};
template <char... chars>
constexpr bool fastCaseCmp(const char* actual) {
return FastCaseCmp<chars...>::apply(actual);
}
// Tests
static_assert(fastCaseCmp<'f','O','o','B','1'>("FooB1"), "");
static_assert(!fastCaseCmp<'f','O','o','B','2'>("FooB1"), "");
static_assert(!fastCaseCmp<'n','O','o','B','1'>("FooB1"), "");
static_assert(!fastCaseCmp<'f','O','o','B'>("FooB1"), "");
static_assert(!fastCaseCmp<'f','O','o','B','1','a'>("FooB1"), "");
kj::Own<kj::AsyncInputStream> HttpInputStreamImpl::getEntityBody(
RequestOrResponse type, HttpMethod method, uint statusCode,
const kj::HttpHeaders& headers) {
KJ_REQUIRE(headerBuffer.size() > 0, "Cannot get entity body after header buffer release.");
// Rules to determine how HTTP entity-body is delimited:
// https://tools.ietf.org/html/rfc7230#section-3.3.3
// #1
if (type == RESPONSE) {
if (method == HttpMethod::HEAD) {
// Body elided.
kj::Maybe<uint64_t> length;
KJ_IF_MAYBE(cl, headers.get(HttpHeaderId::CONTENT_LENGTH)) {
length = strtoull(cl->cStr(), nullptr, 10);
} else if (headers.get(HttpHeaderId::TRANSFER_ENCODING) == nullptr) {
// HACK: Neither Content-Length nor Transfer-Encoding header in response to HEAD request.
// Propagate this fact with a 0 expected body length.
length = uint64_t(0);
}
return kj::heap<HttpNullEntityReader>(*this, length);
} else if (statusCode == 204 || statusCode == 304) {
// No body.
return kj::heap<HttpNullEntityReader>(*this, uint64_t(0));
}
}
// #2 deals with the CONNECT method which is handled separately.
// #3
KJ_IF_MAYBE(te, headers.get(HttpHeaderId::TRANSFER_ENCODING)) {
// TODO(someday): Support plugable transfer encodings? Or at least gzip?
// TODO(someday): Support stacked transfer encodings, e.g. "gzip, chunked".
// NOTE: #3¶3 is ambiguous about what should happen if Transfer-Encoding and Content-Length are
// both present. It says that Transfer-Encoding takes precedence, but also that the request
// "ought to be handled as an error", and that proxies "MUST" drop the Content-Length before
// forwarding. We ignore the vague "ought to" part and implement the other two. (The
// dropping of Content-Length will happen naturally if/when the message is sent back out to
// the network.)
if (fastCaseCmp<'c','h','u','n','k','e','d'>(te->cStr())) {
// #3¶1
return kj::heap<HttpChunkedEntityReader>(*this);
} else if (fastCaseCmp<'i','d','e','n','t','i','t','y'>(te->cStr())) {
// #3¶2
KJ_REQUIRE(type != REQUEST, "request body cannot have Transfer-Encoding other than chunked");
return kj::heap<HttpConnectionCloseEntityReader>(*this);
} else {
KJ_FAIL_REQUIRE("unknown transfer encoding", *te) { break; }
}
}
// #4 and #5
KJ_IF_MAYBE(cl, headers.get(HttpHeaderId::CONTENT_LENGTH)) {
// NOTE: By spec, multiple Content-Length values are allowed as long as they are the same, e.g.
// "Content-Length: 5, 5, 5". Hopefully no one actually does that...
char* end;
uint64_t length = strtoull(cl->cStr(), &end, 10);
if (end > cl->begin() && *end == '\0') {
// #5
return kj::heap<HttpFixedLengthEntityReader>(*this, length);
} else {
// #4 (bad content-length)
KJ_FAIL_REQUIRE("invalid Content-Length header value", *cl);
}
}
// #6
if (type == REQUEST) {
// Lack of a Content-Length or Transfer-Encoding means no body for requests.
return kj::heap<HttpNullEntityReader>(*this, uint64_t(0));
}
// RFC 2616 permitted "multipart/byteranges" responses to be self-delimiting, but this was
// mercifully removed in RFC 7230, and new exceptions of this type are disallowed:
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.4
// https://tools.ietf.org/html/rfc7230#page-81
// To be extra-safe, we'll reject a multipart/byteranges response that lacks transfer-encoding
// and content-length.
KJ_IF_MAYBE(type, headers.get(HttpHeaderId::CONTENT_TYPE)) {
if (type->startsWith("multipart/byteranges")) {
KJ_FAIL_REQUIRE(
"refusing to handle multipart/byteranges response without transfer-encoding nor "
"content-length due to ambiguity between RFC 2616 vs RFC 7230.");
}
}
// #7
return kj::heap<HttpConnectionCloseEntityReader>(*this);
}
} // namespace
kj::Own<HttpInputStream> newHttpInputStream(
kj::AsyncInputStream& input, const HttpHeaderTable& table) {
return kj::heap<HttpInputStreamImpl>(input, table);
}
// =======================================================================================
namespace {
class HttpOutputStream {
public:
HttpOutputStream(AsyncOutputStream& inner): inner(inner) {}
bool isInBody() {
return inBody;
}
bool canReuse() {
return !inBody && !broken && !writeInProgress;
}
bool canWriteBodyData() {
return !writeInProgress && inBody;
}
bool isBroken() {
return broken;
}
void writeHeaders(String content) {
// Writes some header content and begins a new entity body.
KJ_REQUIRE(!writeInProgress, "concurrent write()s not allowed") { return; }
KJ_REQUIRE(!inBody, "previous HTTP message body incomplete; can't write more messages");
inBody = true;
queueWrite(kj::mv(content));
}
void writeBodyData(kj::String content) {
KJ_REQUIRE(!writeInProgress, "concurrent write()s not allowed") { return; }
KJ_REQUIRE(inBody) { return; }
queueWrite(kj::mv(content));
}
kj::Promise<void> writeBodyData(const void* buffer, size_t size) {
KJ_REQUIRE(!writeInProgress, "concurrent write()s not allowed") { return kj::READY_NOW; }
KJ_REQUIRE(inBody) { return kj::READY_NOW; }
writeInProgress = true;
auto fork = writeQueue.fork();
writeQueue = fork.addBranch();
return fork.addBranch().then([this,buffer,size]() {
return inner.write(buffer, size);
}).then([this]() {
writeInProgress = false;
});
}
kj::Promise<void> writeBodyData(kj::ArrayPtr<const kj::ArrayPtr<const byte>> pieces) {
KJ_REQUIRE(!writeInProgress, "concurrent write()s not allowed") { return kj::READY_NOW; }
KJ_REQUIRE(inBody) { return kj::READY_NOW; }
writeInProgress = true;
auto fork = writeQueue.fork();
writeQueue = fork.addBranch();
return fork.addBranch().then([this,pieces]() {
return inner.write(pieces);
}).then([this]() {
writeInProgress = false;
});
}
Promise<uint64_t> pumpBodyFrom(AsyncInputStream& input, uint64_t amount) {
KJ_REQUIRE(!writeInProgress, "concurrent write()s not allowed") { return uint64_t(0); }
KJ_REQUIRE(inBody) { return uint64_t(0); }
writeInProgress = true;
auto fork = writeQueue.fork();
writeQueue = fork.addBranch();
return fork.addBranch().then([this,&input,amount]() {
return input.pumpTo(inner, amount);
}).then([this](uint64_t actual) {
writeInProgress = false;
return actual;
});
}
void finishBody() {
// Called when entire body was written.
KJ_REQUIRE(inBody) { return; }
inBody = false;
if (writeInProgress) {
// It looks like the last write never completed -- possibly because it was canceled or threw
// an exception. We must treat this equivalent to abortBody().
broken = true;
// Cancel any writes that are still queued.
writeQueue = KJ_EXCEPTION(FAILED,
"previous HTTP message body incomplete; can't write more messages");
}
}
void abortBody() {
// Called if the application failed to write all expected body bytes.
KJ_REQUIRE(inBody) { return; }
inBody = false;
broken = true;
// Cancel any writes that are still queued.
writeQueue = KJ_EXCEPTION(FAILED,
"previous HTTP message body incomplete; can't write more messages");
}
kj::Promise<void> flush() {
auto fork = writeQueue.fork();
writeQueue = fork.addBranch();
return fork.addBranch();
}
Promise<void> whenWriteDisconnected() {
return inner.whenWriteDisconnected();
}
bool isWriteInProgress() { return writeInProgress; }
private:
AsyncOutputStream& inner;
kj::Promise<void> writeQueue = kj::READY_NOW;
bool inBody = false;
bool broken = false;
bool writeInProgress = false;
// True if a write method has been called and has not completed successfully. In the case that
// a write throws an exception or is canceled, this remains true forever. In these cases, the
// underlying stream is in an inconsistent state and cannot be reused.
void queueWrite(kj::String content) {
// We only use queueWrite() in cases where we can take ownership of the write buffer, and where
// it is convenient if we can return `void` rather than a promise. In particular, this is used
// to write headers and chunk boundaries. Writes of application data do not go into
// `writeQueue` because this would prevent cancellation. Instead, they wait until `writeQueue`
// is empty, then they make the write directly, using `writeInProgress` to detect and block
// concurrent writes.
writeQueue = writeQueue.then(kj::mvCapture(content, [this](kj::String&& content) {
auto promise = inner.write(content.begin(), content.size());
return promise.attach(kj::mv(content));
}));
}
};
class HttpNullEntityWriter final: public kj::AsyncOutputStream {
public:
Promise<void> write(const void* buffer, size_t size) override {
return KJ_EXCEPTION(FAILED, "HTTP message has no entity-body; can't write()");
}
Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override {
return KJ_EXCEPTION(FAILED, "HTTP message has no entity-body; can't write()");
}
Promise<void> whenWriteDisconnected() override {
return kj::NEVER_DONE;
}
};
class HttpDiscardingEntityWriter final: public kj::AsyncOutputStream {
public:
Promise<void> write(const void* buffer, size_t size) override {
return kj::READY_NOW;
}
Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override {
return kj::READY_NOW;
}
Promise<void> whenWriteDisconnected() override {
return kj::NEVER_DONE;
}
};
class HttpFixedLengthEntityWriter final: public kj::AsyncOutputStream {
public:
HttpFixedLengthEntityWriter(HttpOutputStream& inner, uint64_t length)
: inner(inner), length(length) {
if (length == 0) inner.finishBody();
}
~HttpFixedLengthEntityWriter() noexcept(false) {
if (length > 0 || inner.isWriteInProgress()) {
inner.abortBody();
}
}
Promise<void> write(const void* buffer, size_t size) override {
if (size == 0) return kj::READY_NOW;
KJ_REQUIRE(size <= length, "overwrote Content-Length");
length -= size;
return maybeFinishAfter(inner.writeBodyData(buffer, size));
}
Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override {
uint64_t size = 0;
for (auto& piece: pieces) size += piece.size();
if (size == 0) return kj::READY_NOW;
KJ_REQUIRE(size <= length, "overwrote Content-Length");
length -= size;
return maybeFinishAfter(inner.writeBodyData(pieces));
}
Maybe<Promise<uint64_t>> tryPumpFrom(AsyncInputStream& input, uint64_t amount) override {
if (amount == 0) return Promise<uint64_t>(uint64_t(0));
bool overshot = amount > length;
if (overshot) {
// Hmm, the requested amount was too large, but it's common to specify kj::max as the amount
// to pump, in which case we pump to EOF. Let's try to verify whether EOF is where we
// expect it to be.
KJ_IF_MAYBE(available, input.tryGetLength()) {
// Great, the stream knows how large it is. If it's indeed larger than the space available
// then let's abort.
KJ_REQUIRE(*available <= length, "overwrote Content-Length");
} else {
// OK, we have no idea how large the input is, so we'll have to check later.
}
}
amount = kj::min(amount, length);
length -= amount;
auto promise = amount == 0
? kj::Promise<uint64_t>(amount)
: inner.pumpBodyFrom(input, amount).then([this,amount](uint64_t actual) {
// Adjust for bytes not written.
length += amount - actual;
if (length == 0) inner.finishBody();
return actual;
});
if (overshot) {
promise = promise.then([amount,&input](uint64_t actual) -> kj::Promise<uint64_t> {
if (actual == amount) {
// We read exactly the amount expected. In order to detect an overshoot, we have to
// try reading one more byte. Ugh.
static byte junk;
return input.tryRead(&junk, 1, 1).then([actual](size_t extra) {
KJ_REQUIRE(extra == 0, "overwrote Content-Length");
return actual;
});
} else {
// We actually read less data than requested so we couldn't have overshot. In fact, we
// undershot.
return actual;
}
});
}
return kj::mv(promise);
}
Promise<void> whenWriteDisconnected() override {
return inner.whenWriteDisconnected();
}
private:
HttpOutputStream& inner;
uint64_t length;
kj::Promise<void> maybeFinishAfter(kj::Promise<void> promise) {
if (length == 0) {
return promise.then([this]() { inner.finishBody(); });
} else {
return kj::mv(promise);
}
}
};
class HttpChunkedEntityWriter final: public kj::AsyncOutputStream {
public:
HttpChunkedEntityWriter(HttpOutputStream& inner)
: inner(inner) {}
~HttpChunkedEntityWriter() noexcept(false) {
if (inner.canWriteBodyData()) {
inner.writeBodyData(kj::str("0\r\n\r\n"));
inner.finishBody();
} else {
inner.abortBody();
}
}
Promise<void> write(const void* buffer, size_t size) override {
if (size == 0) return kj::READY_NOW; // can't encode zero-size chunk since it indicates EOF.
auto header = kj::str(kj::hex(size), "\r\n");
auto parts = kj::heapArray<ArrayPtr<const byte>>(3);
parts[0] = header.asBytes();
parts[1] = kj::arrayPtr(reinterpret_cast<const byte*>(buffer), size);
parts[2] = kj::StringPtr("\r\n").asBytes();
auto promise = inner.writeBodyData(parts.asPtr());
return promise.attach(kj::mv(header), kj::mv(parts));
}
Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override {
uint64_t size = 0;
for (auto& piece: pieces) size += piece.size();
if (size == 0) return kj::READY_NOW; // can't encode zero-size chunk since it indicates EOF.
auto header = kj::str(kj::hex(size), "\r\n");
auto partsBuilder = kj::heapArrayBuilder<ArrayPtr<const byte>>(pieces.size() + 2);
partsBuilder.add(header.asBytes());
for (auto& piece: pieces) {
partsBuilder.add(piece);
}
partsBuilder.add(kj::StringPtr("\r\n").asBytes());
auto parts = partsBuilder.finish();
auto promise = inner.writeBodyData(parts.asPtr());
return promise.attach(kj::mv(header), kj::mv(parts));
}
Maybe<Promise<uint64_t>> tryPumpFrom(AsyncInputStream& input, uint64_t amount) override {
KJ_IF_MAYBE(l, input.tryGetLength()) {
// Hey, we know exactly how large the input is, so we can write just one chunk.
uint64_t length = kj::min(amount, *l);
inner.writeBodyData(kj::str(kj::hex(length), "\r\n"));
return inner.pumpBodyFrom(input, length)
.then([this,length](uint64_t actual) {
if (actual < length) {
inner.abortBody();
KJ_FAIL_REQUIRE(
"value returned by input.tryGetLength() was greater than actual bytes transferred") {
break;
}
}
inner.writeBodyData(kj::str("\r\n"));
return actual;
});
} else {
// Need to use naive read/write loop.
return nullptr;
}
}
Promise<void> whenWriteDisconnected() override {
return inner.whenWriteDisconnected();
}
private:
HttpOutputStream& inner;
};
// =======================================================================================
class WebSocketImpl final: public WebSocket {
public:
WebSocketImpl(kj::Own<kj::AsyncIoStream> stream,
kj::Maybe<EntropySource&> maskKeyGenerator,
kj::Array<byte> buffer = kj::heapArray<byte>(4096),
kj::ArrayPtr<byte> leftover = nullptr,
kj::Maybe<kj::Promise<void>> waitBeforeSend = nullptr)
: stream(kj::mv(stream)), maskKeyGenerator(maskKeyGenerator),
sendingPong(kj::mv(waitBeforeSend)),
recvBuffer(kj::mv(buffer)), recvData(leftover) {}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
return sendImpl(OPCODE_BINARY, message);
}
kj::Promise<void> send(kj::ArrayPtr<const char> message) override {
return sendImpl(OPCODE_TEXT, message.asBytes());
}
kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override {
kj::Array<byte> payload;
if (code == 1005) {
KJ_REQUIRE(reason.size() == 0, "WebSocket close code 1005 cannot have a reason");
// code 1005 -- leave payload empty
} else {
payload = heapArray<byte>(reason.size() + 2);
payload[0] = code >> 8;
payload[1] = code;
memcpy(payload.begin() + 2, reason.begin(), reason.size());
}
auto promise = sendImpl(OPCODE_CLOSE, payload);
return promise.attach(kj::mv(payload));
}
kj::Promise<void> disconnect() override {
KJ_REQUIRE(!currentlySending, "another message send is already in progress");
KJ_IF_MAYBE(p, sendingPong) {
// We recently sent a pong, make sure it's finished before proceeding.
currentlySending = true;
auto promise = p->then([this]() {
currentlySending = false;
return disconnect();
});
sendingPong = nullptr;
return promise;
}
disconnected = true;
stream->shutdownWrite();
return kj::READY_NOW;
}
void abort() override {
queuedPong = nullptr;
sendingPong = nullptr;
disconnected = true;
stream->abortRead();
stream->shutdownWrite();
}
kj::Promise<void> whenAborted() override {
return stream->whenWriteDisconnected();
}
kj::Promise<Message> receive(size_t maxSize) override {
size_t headerSize = Header::headerSize(recvData.begin(), recvData.size());
if (headerSize > recvData.size()) {
if (recvData.begin() != recvBuffer.begin()) {
// Move existing data to front of buffer.
if (recvData.size() > 0) {
memmove(recvBuffer.begin(), recvData.begin(), recvData.size());
}
recvData = recvBuffer.slice(0, recvData.size());
}
return stream->tryRead(recvData.end(), 1, recvBuffer.end() - recvData.end())
.then([this,maxSize](size_t actual) -> kj::Promise<Message> {
receivedBytes += actual;
if (actual == 0) {
if (recvData.size() > 0) {
return KJ_EXCEPTION(DISCONNECTED, "WebSocket EOF in frame header");
} else {
// It's incorrect for the WebSocket to disconnect without sending `Close`.
return KJ_EXCEPTION(DISCONNECTED,
"WebSocket disconnected between frames without sending `Close`.");
}
}
recvData = recvBuffer.slice(0, recvData.size() + actual);
return receive(maxSize);
});
}
auto& recvHeader = *reinterpret_cast<Header*>(recvData.begin());
recvData = recvData.slice(headerSize, recvData.size());
size_t payloadLen = recvHeader.getPayloadLen();
KJ_REQUIRE(payloadLen < maxSize, "WebSocket message is too large");
auto opcode = recvHeader.getOpcode();
bool isData = opcode < OPCODE_FIRST_CONTROL;
if (opcode == OPCODE_CONTINUATION) {
KJ_REQUIRE(!fragments.empty(), "unexpected continuation frame in WebSocket");
opcode = fragmentOpcode;
} else if (isData) {
KJ_REQUIRE(fragments.empty(), "expected continuation frame in WebSocket");
}
bool isFin = recvHeader.isFin();
kj::Array<byte> message; // space to allocate
byte* payloadTarget; // location into which to read payload (size is payloadLen)
if (isFin) {
// Add space for NUL terminator when allocating text message.
size_t amountToAllocate = payloadLen + (opcode == OPCODE_TEXT && isFin);
if (isData && !fragments.empty()) {
// Final frame of a fragmented message. Gather the fragments.
size_t offset = 0;
for (auto& fragment: fragments) offset += fragment.size();
message = kj::heapArray<byte>(offset + amountToAllocate);
offset = 0;
for (auto& fragment: fragments) {
memcpy(message.begin() + offset, fragment.begin(), fragment.size());
offset += fragment.size();
}
payloadTarget = message.begin() + offset;
fragments.clear();
fragmentOpcode = 0;
} else {
// Single-frame message.
message = kj::heapArray<byte>(amountToAllocate);
payloadTarget = message.begin();
}
} else {
// Fragmented message, and this isn't the final fragment.
KJ_REQUIRE(isData, "WebSocket control frame cannot be fragmented");
message = kj::heapArray<byte>(payloadLen);
payloadTarget = message.begin();
if (fragments.empty()) {
// This is the first fragment, so set the opcode.
fragmentOpcode = opcode;
}
}
Mask mask = recvHeader.getMask();
auto handleMessage = kj::mvCapture(message,
[this,opcode,payloadTarget,payloadLen,mask,isFin,maxSize]
(kj::Array<byte>&& message) -> kj::Promise<Message> {
if (!mask.isZero()) {
mask.apply(kj::arrayPtr(payloadTarget, payloadLen));
}
if (!isFin) {
// Add fragment to the list and loop.
auto newMax = maxSize - message.size();
fragments.add(kj::mv(message));
return receive(newMax);
}
switch (opcode) {
case OPCODE_CONTINUATION:
// Shouldn't get here; handled above.
KJ_UNREACHABLE;
case OPCODE_TEXT:
message.back() = '\0';
return Message(kj::String(message.releaseAsChars()));
case OPCODE_BINARY:
return Message(message.releaseAsBytes());
case OPCODE_CLOSE:
if (message.size() < 2) {
return Message(Close { 1005, nullptr });
} else {
uint16_t status = (static_cast<uint16_t>(message[0]) << 8)
| (static_cast<uint16_t>(message[1]) );
return Message(Close {
status, kj::heapString(message.slice(2, message.size()).asChars())
});
}
case OPCODE_PING:
// Send back a pong.
queuePong(kj::mv(message));
return receive(maxSize);
case OPCODE_PONG:
// Unsolicited pong. Ignore.
return receive(maxSize);
default:
KJ_FAIL_REQUIRE("unknown WebSocket opcode", opcode);
}
});
if (payloadLen <= recvData.size()) {
// All data already received.
memcpy(payloadTarget, recvData.begin(), payloadLen);
recvData = recvData.slice(payloadLen, recvData.size());
return handleMessage();
} else {
// Need to read more data.
memcpy(payloadTarget, recvData.begin(), recvData.size());
size_t remaining = payloadLen - recvData.size();
auto promise = stream->tryRead(payloadTarget + recvData.size(), remaining, remaining)
.then([this, remaining](size_t amount) {
receivedBytes += amount;
if (amount < remaining) {
kj::throwRecoverableException(KJ_EXCEPTION(DISCONNECTED, "WebSocket EOF in message"));
}
});
recvData = nullptr;
return promise.then(kj::mv(handleMessage));
}
}
kj::Maybe<kj::Promise<void>> tryPumpFrom(WebSocket& other) override {
KJ_IF_MAYBE(optOther, kj::dynamicDowncastIfAvailable<WebSocketImpl>(other)) {
// Both WebSockets are raw WebSockets, so we can pump the streams directly rather than read
// whole messages.
if ((maskKeyGenerator == nullptr) == (optOther->maskKeyGenerator == nullptr)) {
// Oops, it appears that we either believe we are the client side of both sockets, or we
// are the server side of both sockets. Since clients must "mask" their outgoing frames but
// servers must *not* do so, we can't direct-pump. Sad.
return nullptr;
}
// Check same error conditions as with sendImpl().
KJ_REQUIRE(!disconnected, "WebSocket can't send after disconnect()");
KJ_REQUIRE(!currentlySending, "another message send is already in progress");
currentlySending = true;
// If the application chooses to pump messages out, but receives incoming messages normally
// with `receive()`, then we will receive pings and attempt to send pongs. But we can't
// safely insert a pong in the middle of a pumped stream. We kind of don't have a choice
// except to drop them on the floor, which is what will happen if we set `hasSentClose` true.
// Hopefully most apps that set up a pump do so in both directions at once, and so pings will
// flow through and pongs will flow back.
hasSentClose = true;
return optOther->optimizedPumpTo(*this);
}
return nullptr;
}
uint64_t sentByteCount() override { return sentBytes; }
uint64_t receivedByteCount() override { return receivedBytes; }
private:
class Mask {
public:
Mask(): maskBytes { 0, 0, 0, 0 } {}
Mask(const byte* ptr) { memcpy(maskBytes, ptr, 4); }
Mask(kj::Maybe<EntropySource&> generator) {
KJ_IF_MAYBE(g, generator) {
g->generate(maskBytes);
} else {
memset(maskBytes, 0, 4);
}
}
void apply(kj::ArrayPtr<byte> bytes) const {
apply(bytes.begin(), bytes.size());
}
void copyTo(byte* output) const {
memcpy(output, maskBytes, 4);
}
bool isZero() const {
return (maskBytes[0] | maskBytes[1] | maskBytes[2] | maskBytes[3]) == 0;
}
private:
byte maskBytes[4];
void apply(byte* __restrict__ bytes, size_t size) const {
for (size_t i = 0; i < size; i++) {
bytes[i] ^= maskBytes[i % 4];
}
}
};
class Header {
public:
kj::ArrayPtr<const byte> compose(bool fin, byte opcode, uint64_t payloadLen, Mask mask) {
bytes[0] = (fin ? FIN_MASK : 0) | opcode;
bool hasMask = !mask.isZero();
size_t fill;
if (payloadLen < 126) {
bytes[1] = (hasMask ? USE_MASK_MASK : 0) | payloadLen;
if (hasMask) {
mask.copyTo(bytes + 2);
fill = 6;
} else {
fill = 2;
}
} else if (payloadLen < 65536) {
bytes[1] = (hasMask ? USE_MASK_MASK : 0) | 126;
bytes[2] = static_cast<byte>(payloadLen >> 8);
bytes[3] = static_cast<byte>(payloadLen );
if (hasMask) {
mask.copyTo(bytes + 4);
fill = 8;
} else {
fill = 4;
}
} else {
bytes[1] = (hasMask ? USE_MASK_MASK : 0) | 127;
bytes[2] = static_cast<byte>(payloadLen >> 56);
bytes[3] = static_cast<byte>(payloadLen >> 48);
bytes[4] = static_cast<byte>(payloadLen >> 40);
bytes[5] = static_cast<byte>(payloadLen >> 42);
bytes[6] = static_cast<byte>(payloadLen >> 24);
bytes[7] = static_cast<byte>(payloadLen >> 16);
bytes[8] = static_cast<byte>(payloadLen >> 8);
bytes[9] = static_cast<byte>(payloadLen );
if (hasMask) {
mask.copyTo(bytes + 10);
fill = 14;
} else {
fill = 10;
}
}
return arrayPtr(bytes, fill);
}
bool isFin() const {
return bytes[0] & FIN_MASK;
}
bool hasRsv() const {
return bytes[0] & RSV_MASK;
}
byte getOpcode() const {
return bytes[0] & OPCODE_MASK;
}
uint64_t getPayloadLen() const {
byte payloadLen = bytes[1] & PAYLOAD_LEN_MASK;
if (payloadLen == 127) {
return (static_cast<uint64_t>(bytes[2]) << 56)
| (static_cast<uint64_t>(bytes[3]) << 48)
| (static_cast<uint64_t>(bytes[4]) << 40)
| (static_cast<uint64_t>(bytes[5]) << 32)
| (static_cast<uint64_t>(bytes[6]) << 24)
| (static_cast<uint64_t>(bytes[7]) << 16)
| (static_cast<uint64_t>(bytes[8]) << 8)
| (static_cast<uint64_t>(bytes[9]) );
} else if (payloadLen == 126) {
return (static_cast<uint64_t>(bytes[2]) << 8)
| (static_cast<uint64_t>(bytes[3]) );
} else {
return payloadLen;
}
}
Mask getMask() const {
if (bytes[1] & USE_MASK_MASK) {
byte payloadLen = bytes[1] & PAYLOAD_LEN_MASK;
if (payloadLen == 127) {
return Mask(bytes + 10);
} else if (payloadLen == 126) {
return Mask(bytes + 4);
} else {
return Mask(bytes + 2);
}
} else {
return Mask();
}
}
static size_t headerSize(byte const* bytes, size_t sizeSoFar) {
if (sizeSoFar < 2) return 2;
size_t required = 2;
if (bytes[1] & USE_MASK_MASK) {
required += 4;
}
byte payloadLen = bytes[1] & PAYLOAD_LEN_MASK;
if (payloadLen == 127) {
required += 8;
} else if (payloadLen == 126) {
required += 2;
}
return required;
}
private:
byte bytes[14];
static constexpr byte FIN_MASK = 0x80;
static constexpr byte RSV_MASK = 0x70;
static constexpr byte OPCODE_MASK = 0x0f;
static constexpr byte USE_MASK_MASK = 0x80;
static constexpr byte PAYLOAD_LEN_MASK = 0x7f;
};
static constexpr byte OPCODE_CONTINUATION = 0;
static constexpr byte OPCODE_TEXT = 1;
static constexpr byte OPCODE_BINARY = 2;
static constexpr byte OPCODE_CLOSE = 8;
static constexpr byte OPCODE_PING = 9;
static constexpr byte OPCODE_PONG = 10;
static constexpr byte OPCODE_FIRST_CONTROL = 8;
// ---------------------------------------------------------------------------
kj::Own<kj::AsyncIoStream> stream;
kj::Maybe<EntropySource&> maskKeyGenerator;
bool hasSentClose = false;
bool disconnected = false;
bool currentlySending = false;
Header sendHeader;
kj::ArrayPtr<const byte> sendParts[2];
kj::Maybe<kj::Array<byte>> queuedPong;
// If a Ping is received while currentlySending is true, then queuedPong is set to the body of
// a pong message that should be sent once the current send is complete.
kj::Maybe<kj::Promise<void>> sendingPong;
// If a Pong is being sent asynchronously in response to a Ping, this is a promise for the
// completion of that send.
//
// Additionally, this member is used if we need to block our first send on WebSocket startup,
// e.g. because we need to wait for HTTP handshake writes to flush before we can start sending
// WebSocket data. `sendingPong` was overloaded for this use case because the logic is the same.
// Perhaps it should be renamed to `blockSend` or `writeQueue`.
uint fragmentOpcode = 0;
kj::Vector<kj::Array<byte>> fragments;
// If `fragments` is non-empty, we've already received some fragments of a message.
// `fragmentOpcode` is the original opcode.
kj::Array<byte> recvBuffer;
kj::ArrayPtr<byte> recvData;
uint64_t sentBytes = 0;
uint64_t receivedBytes = 0;
kj::Promise<void> sendImpl(byte opcode, kj::ArrayPtr<const byte> message) {
KJ_REQUIRE(!disconnected, "WebSocket can't send after disconnect()");
KJ_REQUIRE(!currentlySending, "another message send is already in progress");
currentlySending = true;
KJ_IF_MAYBE(p, sendingPong) {
// We recently sent a pong, make sure it's finished before proceeding.
auto promise = p->then([this, opcode, message]() {
currentlySending = false;
return sendImpl(opcode, message);
});
sendingPong = nullptr;
return promise;
}
// We don't stop the application from sending further messages after close() -- this is the
// application's error to make. But, we do want to make sure we don't send any PONGs after a
// close, since that would be our error. So we stack whether we closed for that reason.
hasSentClose = hasSentClose || opcode == OPCODE_CLOSE;
Mask mask(maskKeyGenerator);
kj::Array<byte> ownMessage;
if (!mask.isZero()) {
// Sadness, we have to make a copy to apply the mask.
ownMessage = kj::heapArray(message);
mask.apply(ownMessage);
message = ownMessage;
}
sendParts[0] = sendHeader.compose(true, opcode, message.size(), mask);
sendParts[1] = message;
auto promise = stream->write(sendParts);
if (!mask.isZero()) {
promise = promise.attach(kj::mv(ownMessage));
}
return promise.then([this, size = sendParts[0].size() + sendParts[1].size()]() {
currentlySending = false;
// Send queued pong if needed.
KJ_IF_MAYBE(q, queuedPong) {
kj::Array<byte> payload = kj::mv(*q);
queuedPong = nullptr;
queuePong(kj::mv(payload));
}
sentBytes += size;
});
}
void queuePong(kj::Array<byte> payload) {
if (currentlySending) {
// There is a message-send in progress, so we cannot write to the stream now.
//
// Note: According to spec, if the server receives a second ping before responding to the
// previous one, it can opt to respond only to the last ping. So we don't have to check if
// queuedPong is already non-null.
queuedPong = kj::mv(payload);
} else KJ_IF_MAYBE(promise, sendingPong) {
// We're still sending a previous pong. Wait for it to finish before sending ours.
sendingPong = promise->then(kj::mvCapture(payload, [this](kj::Array<byte> payload) mutable {
return sendPong(kj::mv(payload));
}));
} else {
// We're not sending any pong currently.
sendingPong = sendPong(kj::mv(payload));
}
}
kj::Promise<void> sendPong(kj::Array<byte> payload) {
if (hasSentClose || disconnected) {
return kj::READY_NOW;
}
sendParts[0] = sendHeader.compose(true, OPCODE_PONG, payload.size(), Mask(maskKeyGenerator));
sendParts[1] = payload;
return stream->write(sendParts).attach(kj::mv(payload));
}
kj::Promise<void> optimizedPumpTo(WebSocketImpl& other) {
KJ_IF_MAYBE(p, other.sendingPong) {
// We recently sent a pong, make sure it's finished before proceeding.
auto promise = p->then([this, &other]() {
return optimizedPumpTo(other);
});
other.sendingPong = nullptr;
return promise;
}
if (recvData.size() > 0) {
// We have some data buffered. Write it first.
return other.stream->write(recvData.begin(), recvData.size())
.then([this, &other, size = recvData.size()]() {
recvData = nullptr;
other.sentBytes += size;
return optimizedPumpTo(other);
});
}
auto cancelPromise = other.stream->whenWriteDisconnected()
.then([this]() -> kj::Promise<void> {
this->abort();
return KJ_EXCEPTION(DISCONNECTED,
"destination of WebSocket pump disconnected prematurely");
});
// There's no buffered incoming data, so start pumping stream now.
return stream->pumpTo(*other.stream).then([this, &other](size_t s) -> kj::Promise<void> {
// WebSocket pumps are expected to include end-of-stream.
other.disconnected = true;
other.stream->shutdownWrite();
receivedBytes += s;
other.sentBytes += s;
return kj::READY_NOW;
}, [&other](kj::Exception&& e) -> kj::Promise<void> {
// We don't know if it was a read or a write that threw. If it was a read that threw, we need
// to send a disconnect on the destination. If it was the destination that threw, it
// shouldn't hurt to disconnect() it again, but we'll catch and squelch any exceptions.
other.disconnected = true;
kj::runCatchingExceptions([&other]() { other.stream->shutdownWrite(); });
return kj::mv(e);
}).exclusiveJoin(kj::mv(cancelPromise));
}
};
kj::Own<WebSocket> upgradeToWebSocket(
kj::Own<kj::AsyncIoStream> stream, HttpInputStreamImpl& httpInput, HttpOutputStream& httpOutput,
kj::Maybe<EntropySource&> maskKeyGenerator) {
// Create a WebSocket upgraded from an HTTP stream.
auto releasedBuffer = httpInput.releaseBuffer();
return kj::heap<WebSocketImpl>(kj::mv(stream), maskKeyGenerator,
kj::mv(releasedBuffer.buffer), releasedBuffer.leftover,
httpOutput.flush());
}
} // namespace
kj::Own<WebSocket> newWebSocket(kj::Own<kj::AsyncIoStream> stream,
kj::Maybe<EntropySource&> maskKeyGenerator) {
return kj::heap<WebSocketImpl>(kj::mv(stream), maskKeyGenerator);
}
static kj::Promise<void> pumpWebSocketLoop(WebSocket& from, WebSocket& to) {
return from.receive().then([&from,&to](WebSocket::Message&& message) {
KJ_SWITCH_ONEOF(message) {
KJ_CASE_ONEOF(text, kj::String) {
return to.send(text)
.attach(kj::mv(text))
.then([&from,&to]() { return pumpWebSocketLoop(from, to); });
}
KJ_CASE_ONEOF(data, kj::Array<byte>) {
return to.send(data)
.attach(kj::mv(data))
.then([&from,&to]() { return pumpWebSocketLoop(from, to); });
}
KJ_CASE_ONEOF(close, WebSocket::Close) {
// Once a close has passed through, the pump is complete.
return to.close(close.code, close.reason)
.attach(kj::mv(close));
}
}
KJ_UNREACHABLE;
}, [&to](kj::Exception&& e) {
if (e.getType() == kj::Exception::Type::DISCONNECTED) {
return to.disconnect();
} else {
return to.close(1002, e.getDescription());
}
});
}
kj::Promise<void> WebSocket::pumpTo(WebSocket& other) {
KJ_IF_MAYBE(p, other.tryPumpFrom(*this)) {
// Yay, optimized pump!
return kj::mv(*p);
} else {
// Fall back to default implementation.
return kj::evalNow([&]() {
auto cancelPromise = other.whenAborted().then([this]() -> kj::Promise<void> {
this->abort();
return KJ_EXCEPTION(DISCONNECTED,
"destination of WebSocket pump disconnected prematurely");
});
return pumpWebSocketLoop(*this, other).exclusiveJoin(kj::mv(cancelPromise));
});
}
}
kj::Maybe<kj::Promise<void>> WebSocket::tryPumpFrom(WebSocket& other) {
return nullptr;
}
namespace {
class WebSocketPipeImpl final: public WebSocket, public kj::Refcounted {
// Represents one direction of a WebSocket pipe.
//
// This class behaves as a "loopback" WebSocket: a message sent using send() is received using
// receive(), on the same object. This is *not* how WebSocket implementations usually behave.
// But, this object is actually used to implement only one direction of a bidirectional pipe. At
// another layer above this, the pipe is actually composed of two WebSocketPipeEnd instances,
// which layer on top of two WebSocketPipeImpl instances representing the two directions. So,
// send() calls on a WebSocketPipeImpl instance always come from one of the two WebSocketPipeEnds
// while receive() calls come from the other end.
public:
~WebSocketPipeImpl() noexcept(false) {
KJ_REQUIRE(state == nullptr || ownState.get() != nullptr,
"destroying WebSocketPipe with operation still in-progress; probably going to segfault") {
// Don't std::terminate().
break;
}
}
void abort() override {
KJ_IF_MAYBE(s, state) {
s->abort();
} else {
ownState = heap<Aborted>();
state = *ownState;
aborted = true;
KJ_IF_MAYBE(f, abortedFulfiller) {
f->get()->fulfill();
abortedFulfiller = nullptr;
}
}
}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
KJ_IF_MAYBE(s, state) {
return s->send(message).then([&, size = message.size()]() { transferredBytes += size; });
} else {
return newAdaptedPromise<void, BlockedSend>(*this, MessagePtr(message))
.then([&, size = message.size()]() { transferredBytes += size; });
}
}
kj::Promise<void> send(kj::ArrayPtr<const char> message) override {
KJ_IF_MAYBE(s, state) {
return s->send(message).then([&, size = message.size()]() { transferredBytes += size; });
} else {
return newAdaptedPromise<void, BlockedSend>(*this, MessagePtr(message))
.then([&, size = message.size()]() { transferredBytes += size; });
}
}
kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override {
KJ_IF_MAYBE(s, state) {
return s->close(code, reason)
.then([&, size = reason.size()]() { transferredBytes += (2 +size); });
} else {
return newAdaptedPromise<void, BlockedSend>(*this, MessagePtr(ClosePtr { code, reason }))
.then([&, size = reason.size()]() { transferredBytes += (2 +size); });
}
}
kj::Promise<void> disconnect() override {
KJ_IF_MAYBE(s, state) {
return s->disconnect();
} else {
ownState = heap<Disconnected>();
state = *ownState;
return kj::READY_NOW;
}
}
kj::Promise<void> whenAborted() override {
if (aborted) {
return kj::READY_NOW;
} else KJ_IF_MAYBE(p, abortedPromise) {
return p->addBranch();
} else {
auto paf = newPromiseAndFulfiller<void>();
abortedFulfiller = kj::mv(paf.fulfiller);
auto fork = paf.promise.fork();
auto result = fork.addBranch();
abortedPromise = kj::mv(fork);
return result;
}
}
kj::Maybe<kj::Promise<void>> tryPumpFrom(WebSocket& other) override {
KJ_IF_MAYBE(s, state) {
return s->tryPumpFrom(other);
} else {
return newAdaptedPromise<void, BlockedPumpFrom>(*this, other);
}
}
kj::Promise<Message> receive(size_t maxSize) override {
KJ_IF_MAYBE(s, state) {
return s->receive(maxSize);
} else {
return newAdaptedPromise<Message, BlockedReceive>(*this, maxSize);
}
}
kj::Promise<void> pumpTo(WebSocket& other) override {
KJ_IF_MAYBE(s, state) {
auto before = other.receivedByteCount();
return s->pumpTo(other).attach(kj::defer([this, &other, before]() {
transferredBytes += other.receivedByteCount() - before;
}));
} else {
return newAdaptedPromise<void, BlockedPumpTo>(*this, other);
}
}
uint64_t sentByteCount() override {
return transferredBytes;
}
uint64_t receivedByteCount() override {
return transferredBytes;
}
private:
kj::Maybe<WebSocket&> state;
// Object-oriented state! If any method call is blocked waiting on activity from the other end,
// then `state` is non-null and method calls should be forwarded to it. If no calls are
// outstanding, `state` is null.
kj::Own<WebSocket> ownState;
uint64_t transferredBytes = 0;
bool aborted = false;
Maybe<Own<PromiseFulfiller<void>>> abortedFulfiller = nullptr;
Maybe<ForkedPromise<void>> abortedPromise = nullptr;
void endState(WebSocket& obj) {
KJ_IF_MAYBE(s, state) {
if (s == &obj) {
state = nullptr;
}
}
}
struct ClosePtr {
uint16_t code;
kj::StringPtr reason;
};
typedef kj::OneOf<kj::ArrayPtr<const char>, kj::ArrayPtr<const byte>, ClosePtr> MessagePtr;
class BlockedSend final: public WebSocket {
public:
BlockedSend(kj::PromiseFulfiller<void>& fulfiller, WebSocketPipeImpl& pipe, MessagePtr message)
: fulfiller(fulfiller), pipe(pipe), message(kj::mv(message)) {
KJ_REQUIRE(pipe.state == nullptr);
pipe.state = *this;
}
~BlockedSend() noexcept(false) {
pipe.endState(*this);
}
void abort() override {
canceler.cancel("other end of WebSocketPipe was destroyed");
fulfiller.reject(KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed"));
pipe.endState(*this);
pipe.abort();
}
kj::Promise<void> whenAborted() override {
KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl");
}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
KJ_FAIL_ASSERT("another message send is already in progress");
}
kj::Promise<void> send(kj::ArrayPtr<const char> message) override {
KJ_FAIL_ASSERT("another message send is already in progress");
}
kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override {
KJ_FAIL_ASSERT("another message send is already in progress");
}
kj::Promise<void> disconnect() override {
KJ_FAIL_ASSERT("another message send is already in progress");
}
kj::Maybe<kj::Promise<void>> tryPumpFrom(WebSocket& other) override {
KJ_FAIL_ASSERT("another message send is already in progress");
}
kj::Promise<Message> receive(size_t maxSize) override {
KJ_REQUIRE(canceler.isEmpty(), "already pumping");
fulfiller.fulfill();
pipe.endState(*this);
KJ_SWITCH_ONEOF(message) {
KJ_CASE_ONEOF(arr, kj::ArrayPtr<const char>) {
return Message(kj::str(arr));
}
KJ_CASE_ONEOF(arr, kj::ArrayPtr<const byte>) {
auto copy = kj::heapArray<byte>(arr.size());
memcpy(copy.begin(), arr.begin(), arr.size());
return Message(kj::mv(copy));
}
KJ_CASE_ONEOF(close, ClosePtr) {
return Message(Close { close.code, kj::str(close.reason) });
}
}
KJ_UNREACHABLE;
}
kj::Promise<void> pumpTo(WebSocket& other) override {
KJ_REQUIRE(canceler.isEmpty(), "already pumping");
kj::Promise<void> promise = nullptr;
KJ_SWITCH_ONEOF(message) {
KJ_CASE_ONEOF(arr, kj::ArrayPtr<const char>) {
promise = other.send(arr);
}
KJ_CASE_ONEOF(arr, kj::ArrayPtr<const byte>) {
promise = other.send(arr);
}
KJ_CASE_ONEOF(close, ClosePtr) {
promise = other.close(close.code, close.reason);
}
}
return canceler.wrap(promise.then([this,&other]() {
canceler.release();
fulfiller.fulfill();
pipe.endState(*this);
return pipe.pumpTo(other);
}, [this](kj::Exception&& e) -> kj::Promise<void> {
canceler.release();
fulfiller.reject(kj::cp(e));
pipe.endState(*this);
return kj::mv(e);
}));
}
uint64_t sentByteCount() override {
KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl.");
}
uint64_t receivedByteCount() override {
KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl.");
}
private:
kj::PromiseFulfiller<void>& fulfiller;
WebSocketPipeImpl& pipe;
MessagePtr message;
Canceler canceler;
};
class BlockedPumpFrom final: public WebSocket {
public:
BlockedPumpFrom(kj::PromiseFulfiller<void>& fulfiller, WebSocketPipeImpl& pipe,
WebSocket& input)
: fulfiller(fulfiller), pipe(pipe), input(input) {
KJ_REQUIRE(pipe.state == nullptr);
pipe.state = *this;
}
~BlockedPumpFrom() noexcept(false) {
pipe.endState(*this);
}
void abort() override {
canceler.cancel("other end of WebSocketPipe was destroyed");
fulfiller.reject(KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed"));
pipe.endState(*this);
pipe.abort();
}
kj::Promise<void> whenAborted() override {
KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl");
}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
KJ_FAIL_ASSERT("another message send is already in progress");
}
kj::Promise<void> send(kj::ArrayPtr<const char> message) override {
KJ_FAIL_ASSERT("another message send is already in progress");
}
kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override {
KJ_FAIL_ASSERT("another message send is already in progress");
}
kj::Promise<void> disconnect() override {
KJ_FAIL_ASSERT("another message send is already in progress");
}
kj::Maybe<kj::Promise<void>> tryPumpFrom(WebSocket& other) override {
KJ_FAIL_ASSERT("another message send is already in progress");
}
kj::Promise<Message> receive(size_t maxSize) override {
KJ_REQUIRE(canceler.isEmpty(), "another message receive is already in progress");
return canceler.wrap(input.receive(maxSize)
.then([this](Message message) {
if (message.is<Close>()) {
canceler.release();
fulfiller.fulfill();
pipe.endState(*this);
}
return kj::mv(message);
}, [this](kj::Exception&& e) -> Message {
canceler.release();
fulfiller.reject(kj::cp(e));
pipe.endState(*this);
kj::throwRecoverableException(kj::mv(e));
return Message(kj::String());
}));
}
kj::Promise<void> pumpTo(WebSocket& other) override {
KJ_REQUIRE(canceler.isEmpty(), "another message receive is already in progress");
return canceler.wrap(input.pumpTo(other)
.then([this]() {
canceler.release();
fulfiller.fulfill();
pipe.endState(*this);
}, [this](kj::Exception&& e) {
canceler.release();
fulfiller.reject(kj::cp(e));
pipe.endState(*this);
kj::throwRecoverableException(kj::mv(e));
}));
}
uint64_t sentByteCount() override {
KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl.");
}
uint64_t receivedByteCount() override {
KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl.");
}
private:
kj::PromiseFulfiller<void>& fulfiller;
WebSocketPipeImpl& pipe;
WebSocket& input;
Canceler canceler;
};
class BlockedReceive final: public WebSocket {
public:
BlockedReceive(kj::PromiseFulfiller<Message>& fulfiller, WebSocketPipeImpl& pipe,
size_t maxSize)
: fulfiller(fulfiller), pipe(pipe), maxSize(maxSize) {
KJ_REQUIRE(pipe.state == nullptr);
pipe.state = *this;
}
~BlockedReceive() noexcept(false) {
pipe.endState(*this);
}
void abort() override {
canceler.cancel("other end of WebSocketPipe was destroyed");
fulfiller.reject(KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed"));
pipe.endState(*this);
pipe.abort();
}
kj::Promise<void> whenAborted() override {
KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl");
}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
KJ_REQUIRE(canceler.isEmpty(), "already pumping");
auto copy = kj::heapArray<byte>(message.size());
memcpy(copy.begin(), message.begin(), message.size());
fulfiller.fulfill(Message(kj::mv(copy)));
pipe.endState(*this);
return kj::READY_NOW;
}
kj::Promise<void> send(kj::ArrayPtr<const char> message) override {
KJ_REQUIRE(canceler.isEmpty(), "already pumping");
fulfiller.fulfill(Message(kj::str(message)));
pipe.endState(*this);
return kj::READY_NOW;
}
kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override {
KJ_REQUIRE(canceler.isEmpty(), "already pumping");
fulfiller.fulfill(Message(Close { code, kj::str(reason) }));
pipe.endState(*this);
return kj::READY_NOW;
}
kj::Promise<void> disconnect() override {
KJ_REQUIRE(canceler.isEmpty(), "already pumping");
fulfiller.reject(KJ_EXCEPTION(DISCONNECTED, "WebSocket disconnected"));
pipe.endState(*this);
return pipe.disconnect();
}
kj::Maybe<kj::Promise<void>> tryPumpFrom(WebSocket& other) override {
KJ_REQUIRE(canceler.isEmpty(), "already pumping");
return canceler.wrap(other.receive(maxSize).then([this,&other](Message message) {
canceler.release();
fulfiller.fulfill(kj::mv(message));
pipe.endState(*this);
return other.pumpTo(pipe);
}, [this](kj::Exception&& e) -> kj::Promise<void> {
canceler.release();
fulfiller.reject(kj::cp(e));
pipe.endState(*this);
return kj::mv(e);
}));
}
kj::Promise<Message> receive(size_t maxSize) override {
KJ_FAIL_ASSERT("another message receive is already in progress");
}
kj::Promise<void> pumpTo(WebSocket& other) override {
KJ_FAIL_ASSERT("another message receive is already in progress");
}
uint64_t sentByteCount() override {
KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl.");
}
uint64_t receivedByteCount() override {
KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl.");
}
private:
kj::PromiseFulfiller<Message>& fulfiller;
WebSocketPipeImpl& pipe;
size_t maxSize;
Canceler canceler;
};
class BlockedPumpTo final: public WebSocket {
public:
BlockedPumpTo(kj::PromiseFulfiller<void>& fulfiller, WebSocketPipeImpl& pipe, WebSocket& output)
: fulfiller(fulfiller), pipe(pipe), output(output) {
KJ_REQUIRE(pipe.state == nullptr);
pipe.state = *this;
}
~BlockedPumpTo() noexcept(false) {
pipe.endState(*this);
}
void abort() override {
canceler.cancel("other end of WebSocketPipe was destroyed");
// abort() is called when the pipe end is dropped. This should be treated as disconnecting,
// so pumpTo() should complete normally.
fulfiller.fulfill();
pipe.endState(*this);
pipe.abort();
}
kj::Promise<void> whenAborted() override {
KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl");
}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
KJ_REQUIRE(canceler.isEmpty(), "another message send is already in progress");
return canceler.wrap(output.send(message));
}
kj::Promise<void> send(kj::ArrayPtr<const char> message) override {
KJ_REQUIRE(canceler.isEmpty(), "another message send is already in progress");
return canceler.wrap(output.send(message));
}
kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override {
KJ_REQUIRE(canceler.isEmpty(), "another message send is already in progress");
return canceler.wrap(output.close(code, reason).then([this]() {
// A pump is expected to end upon seeing a Close message.
canceler.release();
pipe.endState(*this);
fulfiller.fulfill();
}));
}
kj::Promise<void> disconnect() override {
KJ_REQUIRE(canceler.isEmpty(), "another message send is already in progress");
return canceler.wrap(output.disconnect().then([this]() {
canceler.release();
pipe.endState(*this);
fulfiller.fulfill();
return pipe.disconnect();
}));
}
kj::Maybe<kj::Promise<void>> tryPumpFrom(WebSocket& other) override {
KJ_REQUIRE(canceler.isEmpty(), "another message send is already in progress");
return canceler.wrap(other.pumpTo(output).then([this]() {
canceler.release();
pipe.endState(*this);
fulfiller.fulfill();
}));
}
kj::Promise<Message> receive(size_t maxSize) override {
KJ_FAIL_ASSERT("another message receive is already in progress");
}
kj::Promise<void> pumpTo(WebSocket& other) override {
KJ_FAIL_ASSERT("another message receive is already in progress");
}
uint64_t sentByteCount() override {
KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl.");
}
uint64_t receivedByteCount() override {
KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl.");
}
private:
kj::PromiseFulfiller<void>& fulfiller;
WebSocketPipeImpl& pipe;
WebSocket& output;
Canceler canceler;
};
class Disconnected final: public WebSocket {
public:
void abort() override {
// can ignore
}
kj::Promise<void> whenAborted() override {
KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl");
}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
KJ_FAIL_REQUIRE("can't send() after disconnect()");
}
kj::Promise<void> send(kj::ArrayPtr<const char> message) override {
KJ_FAIL_REQUIRE("can't send() after disconnect()");
}
kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override {
KJ_FAIL_REQUIRE("can't close() after disconnect()");
}
kj::Promise<void> disconnect() override {
return kj::READY_NOW;
}
kj::Maybe<kj::Promise<void>> tryPumpFrom(WebSocket& other) override {
KJ_FAIL_REQUIRE("can't tryPumpFrom() after disconnect()");
}
kj::Promise<Message> receive(size_t maxSize) override {
return KJ_EXCEPTION(DISCONNECTED, "WebSocket disconnected");
}
kj::Promise<void> pumpTo(WebSocket& other) override {
return kj::READY_NOW;
}
uint64_t sentByteCount() override {
KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl.");
}
uint64_t receivedByteCount() override {
KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl.");
}
};
class Aborted final: public WebSocket {
public:
void abort() override {
// can ignore
}
kj::Promise<void> whenAborted() override {
KJ_FAIL_ASSERT("can't get here -- implemented by WebSocketPipeImpl");
}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
return KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed");
}
kj::Promise<void> send(kj::ArrayPtr<const char> message) override {
return KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed");
}
kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override {
return KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed");
}
kj::Promise<void> disconnect() override {
return KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed");
}
kj::Maybe<kj::Promise<void>> tryPumpFrom(WebSocket& other) override {
return kj::Promise<void>(KJ_EXCEPTION(DISCONNECTED,
"other end of WebSocketPipe was destroyed"));
}
kj::Promise<Message> receive(size_t maxSize) override {
return KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed");
}
kj::Promise<void> pumpTo(WebSocket& other) override {
return KJ_EXCEPTION(DISCONNECTED, "other end of WebSocketPipe was destroyed");
}
uint64_t sentByteCount() override {
KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl.");
}
uint64_t receivedByteCount() override {
KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl.");
}
};
};
class WebSocketPipeEnd final: public WebSocket {
public:
WebSocketPipeEnd(kj::Own<WebSocketPipeImpl> in, kj::Own<WebSocketPipeImpl> out)
: in(kj::mv(in)), out(kj::mv(out)) {}
~WebSocketPipeEnd() noexcept(false) {
in->abort();
out->abort();
}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
return out->send(message);
}
kj::Promise<void> send(kj::ArrayPtr<const char> message) override {
return out->send(message);
}
kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override {
return out->close(code, reason);
}
kj::Promise<void> disconnect() override {
return out->disconnect();
}
void abort() override {
in->abort();
out->abort();
}
kj::Promise<void> whenAborted() override {
return out->whenAborted();
}
kj::Maybe<kj::Promise<void>> tryPumpFrom(WebSocket& other) override {
return out->tryPumpFrom(other);
}
kj::Promise<Message> receive(size_t maxSize) override {
return in->receive(maxSize);
}
kj::Promise<void> pumpTo(WebSocket& other) override {
return in->pumpTo(other);
}
uint64_t sentByteCount() override { return out->sentByteCount(); }
uint64_t receivedByteCount() override { return in->sentByteCount(); }
private:
kj::Own<WebSocketPipeImpl> in;
kj::Own<WebSocketPipeImpl> out;
};
} // namespace
WebSocketPipe newWebSocketPipe() {
auto pipe1 = kj::refcounted<WebSocketPipeImpl>();
auto pipe2 = kj::refcounted<WebSocketPipeImpl>();
auto end1 = kj::heap<WebSocketPipeEnd>(kj::addRef(*pipe1), kj::addRef(*pipe2));
auto end2 = kj::heap<WebSocketPipeEnd>(kj::mv(pipe2), kj::mv(pipe1));
return { { kj::mv(end1), kj::mv(end2) } };
}
// =======================================================================================
namespace {
class HttpClientImpl final: public HttpClient,
private HttpClientErrorHandler {
public:
HttpClientImpl(const HttpHeaderTable& responseHeaderTable, kj::Own<kj::AsyncIoStream> rawStream,
HttpClientSettings settings)
: httpInput(*rawStream, responseHeaderTable),
httpOutput(*rawStream),
ownStream(kj::mv(rawStream)),
settings(kj::mv(settings)) {}
bool canReuse() {
// Returns true if we can immediately reuse this HttpClient for another message (so all
// previous messages have been fully read).
return !upgraded && !closed && httpInput.canReuse() && httpOutput.canReuse();
}
Request request(HttpMethod method, kj::StringPtr url, const HttpHeaders& headers,
kj::Maybe<uint64_t> expectedBodySize = nullptr) override {
KJ_REQUIRE(!upgraded,
"can't make further requests on this HttpClient because it has been or is in the process "
"of being upgraded");
KJ_REQUIRE(!closed,
"this HttpClient's connection has been closed by the server or due to an error");
KJ_REQUIRE(httpOutput.canReuse(),
"can't start new request until previous request body has been fully written");
closeWatcherTask = nullptr;
kj::StringPtr connectionHeaders[HttpHeaders::CONNECTION_HEADERS_COUNT];
kj::String lengthStr;
bool isGet = method == HttpMethod::GET || method == HttpMethod::HEAD;
bool hasBody;
KJ_IF_MAYBE(s, expectedBodySize) {
if (isGet && *s == 0) {
// GET with empty body; don't send any Content-Length.
hasBody = false;
} else {
lengthStr = kj::str(*s);
connectionHeaders[HttpHeaders::BuiltinIndices::CONTENT_LENGTH] = lengthStr;
hasBody = true;
}
} else {
if (isGet && headers.get(HttpHeaderId::TRANSFER_ENCODING) == nullptr) {
// GET with empty body; don't send any Transfer-Encoding.
hasBody = false;
} else {
// HACK: Normally GET requests shouldn't have bodies. But, if the caller set a
// Transfer-Encoding header on a GET, we use this as a special signal that it might
// actually want to send a body. This allows pass-through of a GET request with a chunked
// body to "just work". We strongly discourage writing any new code that sends
// full-bodied GETs.
connectionHeaders[HttpHeaders::BuiltinIndices::TRANSFER_ENCODING] = "chunked";
hasBody = true;
}
}
httpOutput.writeHeaders(headers.serializeRequest(method, url, connectionHeaders));
kj::Own<kj::AsyncOutputStream> bodyStream;
if (!hasBody) {
// No entity-body.
httpOutput.finishBody();
bodyStream = heap<HttpNullEntityWriter>();
} else KJ_IF_MAYBE(s, expectedBodySize) {
bodyStream = heap<HttpFixedLengthEntityWriter>(httpOutput, *s);
} else {
bodyStream = heap<HttpChunkedEntityWriter>(httpOutput);
}
auto id = ++counter;
auto responsePromise = httpInput.readResponseHeaders().then(
[this,method,id](HttpHeaders::ResponseOrProtocolError&& responseOrProtocolError)
-> HttpClient::Response {
KJ_SWITCH_ONEOF(responseOrProtocolError) {
KJ_CASE_ONEOF(response, HttpHeaders::Response) {
auto& responseHeaders = httpInput.getHeaders();
HttpClient::Response result {
response.statusCode,
response.statusText,
&responseHeaders,
httpInput.getEntityBody(
HttpInputStreamImpl::RESPONSE, method, response.statusCode, responseHeaders)
};
if (fastCaseCmp<'c', 'l', 'o', 's', 'e'>(
responseHeaders.get(HttpHeaderId::CONNECTION).orDefault(nullptr).cStr())) {
closed = true;
} else if (counter == id) {
watchForClose();
} else {
// Another request was already queued after this one, so we don't want to watch for
// stream closure because we're fully expecting another response.
}
return result;
}
KJ_CASE_ONEOF(protocolError, HttpHeaders::ProtocolError) {
closed = true;
return settings.errorHandler.orDefault(*this).handleProtocolError(
kj::mv(protocolError));
}
}
KJ_UNREACHABLE;
});
return { kj::mv(bodyStream), kj::mv(responsePromise) };
}
kj::Promise<WebSocketResponse> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers) override {
KJ_REQUIRE(!upgraded,
"can't make further requests on this HttpClient because it has been or is in the process "
"of being upgraded");
KJ_REQUIRE(!closed,
"this HttpClient's connection has been closed by the server or due to an error");
closeWatcherTask = nullptr;
// Mark upgraded for now, even though the upgrade could fail, because we can't allow pipelined
// requests in the meantime.
upgraded = true;
byte keyBytes[16];
KJ_ASSERT_NONNULL(settings.entropySource,
"can't use openWebSocket() because no EntropySource was provided when creating the "
"HttpClient").generate(keyBytes);
auto keyBase64 = kj::encodeBase64(keyBytes);
kj::StringPtr connectionHeaders[HttpHeaders::WEBSOCKET_CONNECTION_HEADERS_COUNT];
connectionHeaders[HttpHeaders::BuiltinIndices::CONNECTION] = "Upgrade";
connectionHeaders[HttpHeaders::BuiltinIndices::UPGRADE] = "websocket";
connectionHeaders[HttpHeaders::BuiltinIndices::SEC_WEBSOCKET_VERSION] = "13";
connectionHeaders[HttpHeaders::BuiltinIndices::SEC_WEBSOCKET_KEY] = keyBase64;
httpOutput.writeHeaders(headers.serializeRequest(HttpMethod::GET, url, connectionHeaders));
// No entity-body.
httpOutput.finishBody();
auto id = ++counter;
return httpInput.readResponseHeaders()
.then([this,id,keyBase64 = kj::mv(keyBase64)](
HttpHeaders::ResponseOrProtocolError&& responseOrProtocolError)
-> HttpClient::WebSocketResponse {
KJ_SWITCH_ONEOF(responseOrProtocolError) {
KJ_CASE_ONEOF(response, HttpHeaders::Response) {
auto& responseHeaders = httpInput.getHeaders();
if (response.statusCode == 101) {
if (!fastCaseCmp<'w', 'e', 'b', 's', 'o', 'c', 'k', 'e', 't'>(
responseHeaders.get(HttpHeaderId::UPGRADE).orDefault(nullptr).cStr())) {
kj::String ownMessage;
kj::StringPtr message;
KJ_IF_MAYBE(actual, responseHeaders.get(HttpHeaderId::UPGRADE)) {
ownMessage = kj::str(
"Server failed WebSocket handshake: incorrect Upgrade header: "
"expected 'websocket', got '", *actual, "'.");
message = ownMessage;
} else {
message = "Server failed WebSocket handshake: missing Upgrade header.";
}
return settings.errorHandler.orDefault(*this).handleWebSocketProtocolError({
502, "Bad Gateway", message, nullptr
});
}
auto expectedAccept = generateWebSocketAccept(keyBase64);
if (responseHeaders.get(HttpHeaderId::SEC_WEBSOCKET_ACCEPT).orDefault(nullptr)
!= expectedAccept) {
kj::String ownMessage;
kj::StringPtr message;
KJ_IF_MAYBE(actual, responseHeaders.get(HttpHeaderId::SEC_WEBSOCKET_ACCEPT)) {
ownMessage = kj::str(
"Server failed WebSocket handshake: incorrect Sec-WebSocket-Accept header: "
"expected '", expectedAccept, "', got '", *actual, "'.");
message = ownMessage;
} else {
message = "Server failed WebSocket handshake: missing Upgrade header.";
}
return settings.errorHandler.orDefault(*this).handleWebSocketProtocolError({
502, "Bad Gateway", message, nullptr
});
}
return {
response.statusCode,
response.statusText,
&httpInput.getHeaders(),
upgradeToWebSocket(kj::mv(ownStream), httpInput, httpOutput, settings.entropySource),
};
} else {
upgraded = false;
HttpClient::WebSocketResponse result {
response.statusCode,
response.statusText,
&responseHeaders,
httpInput.getEntityBody(HttpInputStreamImpl::RESPONSE, HttpMethod::GET,
response.statusCode, responseHeaders)
};
if (fastCaseCmp<'c', 'l', 'o', 's', 'e'>(
responseHeaders.get(HttpHeaderId::CONNECTION).orDefault(nullptr).cStr())) {
closed = true;
} else if (counter == id) {
watchForClose();
} else {
// Another request was already queued after this one, so we don't want to watch for
// stream closure because we're fully expecting another response.
}
return result;
}
}
KJ_CASE_ONEOF(protocolError, HttpHeaders::ProtocolError) {
return settings.errorHandler.orDefault(*this).handleWebSocketProtocolError(
kj::mv(protocolError));
}
}
KJ_UNREACHABLE;
});
}
private:
HttpInputStreamImpl httpInput;
HttpOutputStream httpOutput;
kj::Own<AsyncIoStream> ownStream;
HttpClientSettings settings;
kj::Maybe<kj::Promise<void>> closeWatcherTask;
bool upgraded = false;
bool closed = false;
uint counter = 0;
// Counts requests for the sole purpose of detecting if more requests have been made after some
// point in history.
void watchForClose() {
closeWatcherTask = httpInput.awaitNextMessage()
.then([this](bool hasData) -> kj::Promise<void> {
if (hasData) {
// Uhh... The server sent some data before we asked for anything. Perhaps due to properties
// of this application, the server somehow already knows what the next request will be, and
// it is trying to optimize. Or maybe this is some sort of test and the server is just
// replaying a script. In any case, we will humor it -- leave the data in the buffer and
// let it become the response to the next request.
return kj::READY_NOW;
} else {
// EOF -- server disconnected.
closed = true;
if (httpOutput.isInBody()) {
// Huh, the application is still sending a request. We should let it finish. We do not
// need to proactively free the socket in this case because we know that we're not
// sitting in a reusable connection pool, because we know the application is still
// actively using the connection.
return kj::READY_NOW;
} else {
return httpOutput.flush().then([this]() {
// We might be sitting in NetworkAddressHttpClient's `availableClients` pool. We don't
// have a way to notify it to remove this client from the pool; instead, when it tries
// to pull this client from the pool later, it will notice the client is dead and will
// discard it then. But, we would like to avoid holding on to a socket forever. So,
// destroy the socket now.
// TODO(cleanup): Maybe we should arrange to proactively remove ourselves? Seems
// like the code will be awkward.
ownStream = nullptr;
});
}
}
}).eagerlyEvaluate(nullptr);
}
};
} // namespace
kj::Promise<HttpClient::WebSocketResponse> HttpClient::openWebSocket(
kj::StringPtr url, const HttpHeaders& headers) {
return request(HttpMethod::GET, url, headers, nullptr)
.response.then([](HttpClient::Response&& response) -> WebSocketResponse {
kj::OneOf<kj::Own<kj::AsyncInputStream>, kj::Own<WebSocket>> body;
body.init<kj::Own<kj::AsyncInputStream>>(kj::mv(response.body));
return {
response.statusCode,
response.statusText,
response.headers,
kj::mv(body)
};
});
}
kj::Promise<kj::Own<kj::AsyncIoStream>> HttpClient::connect(kj::StringPtr host) {
KJ_UNIMPLEMENTED("CONNECT is not implemented by this HttpClient");
}
kj::Own<HttpClient> newHttpClient(
const HttpHeaderTable& responseHeaderTable, kj::AsyncIoStream& stream,
HttpClientSettings settings) {
return kj::heap<HttpClientImpl>(responseHeaderTable,
kj::Own<kj::AsyncIoStream>(&stream, kj::NullDisposer::instance),
kj::mv(settings));
}
HttpClient::Response HttpClientErrorHandler::handleProtocolError(
HttpHeaders::ProtocolError protocolError) {
KJ_FAIL_REQUIRE(protocolError.description) { break; }
return HttpClient::Response();
}
HttpClient::WebSocketResponse HttpClientErrorHandler::handleWebSocketProtocolError(
HttpHeaders::ProtocolError protocolError) {
auto response = handleProtocolError(protocolError);
return HttpClient::WebSocketResponse {
response.statusCode, response.statusText, response.headers, kj::mv(response.body)
};
}
// =======================================================================================
namespace {
class NetworkAddressHttpClient final: public HttpClient {
public:
NetworkAddressHttpClient(kj::Timer& timer, const HttpHeaderTable& responseHeaderTable,
kj::Own<kj::NetworkAddress> address, HttpClientSettings settings)
: timer(timer),
responseHeaderTable(responseHeaderTable),
address(kj::mv(address)),
settings(kj::mv(settings)) {}
bool isDrained() {
// Returns true if there are no open connections.
return activeConnectionCount == 0 && availableClients.empty();
}
kj::Promise<void> onDrained() {
// Returns a promise which resolves the next time isDrained() transitions from false to true.
auto paf = kj::newPromiseAndFulfiller<void>();
drainedFulfiller = kj::mv(paf.fulfiller);
return kj::mv(paf.promise);
}
Request request(HttpMethod method, kj::StringPtr url, const HttpHeaders& headers,
kj::Maybe<uint64_t> expectedBodySize = nullptr) override {
auto refcounted = getClient();
auto result = refcounted->client->request(method, url, headers, expectedBodySize);
result.body = result.body.attach(kj::addRef(*refcounted));
result.response = result.response.then(kj::mvCapture(refcounted,
[](kj::Own<RefcountedClient>&& refcounted, Response&& response) {
response.body = response.body.attach(kj::mv(refcounted));
return kj::mv(response);
}));
return result;
}
kj::Promise<WebSocketResponse> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers) override {
auto refcounted = getClient();
auto result = refcounted->client->openWebSocket(url, headers);
return result.then(kj::mvCapture(refcounted,
[](kj::Own<RefcountedClient>&& refcounted, WebSocketResponse&& response) {
KJ_SWITCH_ONEOF(response.webSocketOrBody) {
KJ_CASE_ONEOF(body, kj::Own<kj::AsyncInputStream>) {
response.webSocketOrBody = body.attach(kj::mv(refcounted));
}
KJ_CASE_ONEOF(ws, kj::Own<WebSocket>) {
// The only reason we need to attach the client to the WebSocket is because otherwise
// the response headers will be deleted prematurely. Otherwise, the WebSocket has taken
// ownership of the connection.
//
// TODO(perf): Maybe we could transfer ownership of the response headers specifically?
response.webSocketOrBody = ws.attach(kj::mv(refcounted));
}
}
return kj::mv(response);
}));
}
private:
kj::Timer& timer;
const HttpHeaderTable& responseHeaderTable;
kj::Own<kj::NetworkAddress> address;
HttpClientSettings settings;
kj::Maybe<kj::Own<kj::PromiseFulfiller<void>>> drainedFulfiller;
uint activeConnectionCount = 0;
bool timeoutsScheduled = false;
kj::Promise<void> timeoutTask = nullptr;
struct AvailableClient {
kj::Own<HttpClientImpl> client;
kj::TimePoint expires;
};
std::deque<AvailableClient> availableClients;
struct RefcountedClient final: public kj::Refcounted {
RefcountedClient(NetworkAddressHttpClient& parent, kj::Own<HttpClientImpl> client)
: parent(parent), client(kj::mv(client)) {
++parent.activeConnectionCount;
}
~RefcountedClient() noexcept(false) {
--parent.activeConnectionCount;
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
parent.returnClientToAvailable(kj::mv(client));
})) {
KJ_LOG(ERROR, *exception);
}
}
NetworkAddressHttpClient& parent;
kj::Own<HttpClientImpl> client;
};
kj::Own<RefcountedClient> getClient() {
for (;;) {
if (availableClients.empty()) {
auto stream = newPromisedStream(address->connect());
return kj::refcounted<RefcountedClient>(*this,
kj::heap<HttpClientImpl>(responseHeaderTable, kj::mv(stream), settings));
} else {
auto client = kj::mv(availableClients.back().client);
availableClients.pop_back();
if (client->canReuse()) {
return kj::refcounted<RefcountedClient>(*this, kj::mv(client));
}
// Whoops, this client's connection was closed by the server at some point. Discard.
}
}
}
void returnClientToAvailable(kj::Own<HttpClientImpl> client) {
// Only return the connection to the pool if it is reusable and if our settings indicate we
// should reuse connections.
if (client->canReuse() && settings.idleTimeout > 0 * kj::SECONDS) {
availableClients.push_back(AvailableClient {
kj::mv(client), timer.now() + settings.idleTimeout
});
}
// Call this either way because it also signals onDrained().
if (!timeoutsScheduled) {
timeoutsScheduled = true;
timeoutTask = applyTimeouts();
}
}
kj::Promise<void> applyTimeouts() {
if (availableClients.empty()) {
timeoutsScheduled = false;
if (activeConnectionCount == 0) {
KJ_IF_MAYBE(f, drainedFulfiller) {
f->get()->fulfill();
drainedFulfiller = nullptr;
}
}
return kj::READY_NOW;
} else {
auto time = availableClients.front().expires;
return timer.atTime(time).then([this,time]() {
while (!availableClients.empty() && availableClients.front().expires <= time) {
availableClients.pop_front();
}
return applyTimeouts();
});
}
}
};
class PromiseNetworkAddressHttpClient final: public HttpClient {
// An HttpClient which waits for a promise to resolve then forwards all calls to the promised
// client.
public:
PromiseNetworkAddressHttpClient(kj::Promise<kj::Own<NetworkAddressHttpClient>> promise)
: promise(promise.then([this](kj::Own<NetworkAddressHttpClient>&& client) {
this->client = kj::mv(client);
}).fork()) {}
bool isDrained() {
KJ_IF_MAYBE(c, client) {
return c->get()->isDrained();
} else {
return failed;
}
}
kj::Promise<void> onDrained() {
KJ_IF_MAYBE(c, client) {
return c->get()->onDrained();
} else {
return promise.addBranch().then([this]() {
return KJ_ASSERT_NONNULL(client)->onDrained();
}, [this](kj::Exception&& e) {
// Connecting failed. Treat as immediately drained.
failed = true;
return kj::READY_NOW;
});
}
}
Request request(HttpMethod method, kj::StringPtr url, const HttpHeaders& headers,
kj::Maybe<uint64_t> expectedBodySize = nullptr) override {
KJ_IF_MAYBE(c, client) {
return c->get()->request(method, url, headers, expectedBodySize);
} else {
// This gets complicated since request() returns a pair of a stream and a promise.
auto urlCopy = kj::str(url);
auto headersCopy = headers.clone();
auto combined = promise.addBranch().then(kj::mvCapture(urlCopy, kj::mvCapture(headersCopy,
[this,method,expectedBodySize](HttpHeaders&& headers, kj::String&& url)
-> kj::Tuple<kj::Own<kj::AsyncOutputStream>, kj::Promise<Response>> {
auto req = KJ_ASSERT_NONNULL(client)->request(method, url, headers, expectedBodySize);
return kj::tuple(kj::mv(req.body), kj::mv(req.response));
})));
auto split = combined.split();
return {
newPromisedStream(kj::mv(kj::get<0>(split))),
kj::mv(kj::get<1>(split))
};
}
}
kj::Promise<WebSocketResponse> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers) override {
KJ_IF_MAYBE(c, client) {
return c->get()->openWebSocket(url, headers);
} else {
auto urlCopy = kj::str(url);
auto headersCopy = headers.clone();
return promise.addBranch().then(kj::mvCapture(urlCopy, kj::mvCapture(headersCopy,
[this](HttpHeaders&& headers, kj::String&& url) {
return KJ_ASSERT_NONNULL(client)->openWebSocket(url, headers);
})));
}
}
private:
kj::ForkedPromise<void> promise;
kj::Maybe<kj::Own<NetworkAddressHttpClient>> client;
bool failed = false;
};
class NetworkHttpClient final: public HttpClient, private kj::TaskSet::ErrorHandler {
public:
NetworkHttpClient(kj::Timer& timer, const HttpHeaderTable& responseHeaderTable,
kj::Network& network, kj::Maybe<kj::Network&> tlsNetwork,
HttpClientSettings settings)
: timer(timer),
responseHeaderTable(responseHeaderTable),
network(network),
tlsNetwork(tlsNetwork),
settings(kj::mv(settings)),
tasks(*this) {}
Request request(HttpMethod method, kj::StringPtr url, const HttpHeaders& headers,
kj::Maybe<uint64_t> expectedBodySize = nullptr) override {
// We need to parse the proxy-style URL to convert it to host-style.
// Use URL parsing options that avoid unnecessary rewrites.
Url::Options urlOptions;
urlOptions.allowEmpty = true;
urlOptions.percentDecode = false;
auto parsed = Url::parse(url, Url::HTTP_PROXY_REQUEST, urlOptions);
auto path = parsed.toString(Url::HTTP_REQUEST);
auto headersCopy = headers.clone();
headersCopy.set(HttpHeaderId::HOST, parsed.host);
return getClient(parsed).request(method, path, headersCopy, expectedBodySize);
}
kj::Promise<WebSocketResponse> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers) override {
// We need to parse the proxy-style URL to convert it to host-style.
// Use URL parsing options that avoid unnecessary rewrites.
Url::Options urlOptions;
urlOptions.allowEmpty = true;
urlOptions.percentDecode = false;
auto parsed = Url::parse(url, Url::HTTP_PROXY_REQUEST, urlOptions);
auto path = parsed.toString(Url::HTTP_REQUEST);
auto headersCopy = headers.clone();
headersCopy.set(HttpHeaderId::HOST, parsed.host);
return getClient(parsed).openWebSocket(path, headersCopy);
}
private:
kj::Timer& timer;
const HttpHeaderTable& responseHeaderTable;
kj::Network& network;
kj::Maybe<kj::Network&> tlsNetwork;
HttpClientSettings settings;
struct Host {
kj::String name; // including port, if non-default
kj::Own<PromiseNetworkAddressHttpClient> client;
};
std::map<kj::StringPtr, Host> httpHosts;
std::map<kj::StringPtr, Host> httpsHosts;
struct RequestInfo {
HttpMethod method;
kj::String hostname;
kj::String path;
HttpHeaders headers;
kj::Maybe<uint64_t> expectedBodySize;
};
kj::TaskSet tasks;
HttpClient& getClient(kj::Url& parsed) {
bool isHttps = parsed.scheme == "https";
bool isHttp = parsed.scheme == "http";
KJ_REQUIRE(isHttp || isHttps);
auto& hosts = isHttps ? httpsHosts : httpHosts;
// Look for a cached client for this host.
// TODO(perf): It would be nice to recognize when different hosts have the same address and
// reuse the same connection pool, but:
// - We'd need a reliable way to compare NetworkAddresses, e.g. .equals() and .hashCode().
// It's very Java... ick.
// - Correctly handling TLS would be tricky: we'd need to verify that the new hostname is
// on the certificate. When SNI is in use we might have to request an additional
// certificate (is that possible?).
auto iter = hosts.find(parsed.host);
if (iter == hosts.end()) {
// Need to open a new connection.
kj::Network* networkToUse = &network;
if (isHttps) {
networkToUse = &KJ_REQUIRE_NONNULL(tlsNetwork, "this HttpClient doesn't support HTTPS");
}
auto promise = networkToUse->parseAddress(parsed.host, isHttps ? 443 : 80)
.then([this](kj::Own<kj::NetworkAddress> addr) {
return kj::heap<NetworkAddressHttpClient>(
timer, responseHeaderTable, kj::mv(addr), settings);
});
Host host {
kj::mv(parsed.host),
kj::heap<PromiseNetworkAddressHttpClient>(kj::mv(promise))
};
kj::StringPtr nameRef = host.name;
auto insertResult = hosts.insert(std::make_pair(nameRef, kj::mv(host)));
KJ_ASSERT(insertResult.second);
iter = insertResult.first;
tasks.add(handleCleanup(hosts, iter));
}
return *iter->second.client;
}
kj::Promise<void> handleCleanup(std::map<kj::StringPtr, Host>& hosts,
std::map<kj::StringPtr, Host>::iterator iter) {
return iter->second.client->onDrained()
.then([this,&hosts,iter]() -> kj::Promise<void> {
// Double-check that it's really drained to avoid race conditions.
if (iter->second.client->isDrained()) {
hosts.erase(iter);
return kj::READY_NOW;
} else {
return handleCleanup(hosts, iter);
}
});
}
void taskFailed(kj::Exception&& exception) override {
KJ_LOG(ERROR, exception);
}
};
} // namespace
kj::Own<HttpClient> newHttpClient(kj::Timer& timer, const HttpHeaderTable& responseHeaderTable,
kj::NetworkAddress& addr, HttpClientSettings settings) {
return kj::heap<NetworkAddressHttpClient>(timer, responseHeaderTable,
kj::Own<kj::NetworkAddress>(&addr, kj::NullDisposer::instance), kj::mv(settings));
}
kj::Own<HttpClient> newHttpClient(kj::Timer& timer, const HttpHeaderTable& responseHeaderTable,
kj::Network& network, kj::Maybe<kj::Network&> tlsNetwork,
HttpClientSettings settings) {
return kj::heap<NetworkHttpClient>(
timer, responseHeaderTable, network, tlsNetwork, kj::mv(settings));
}
// =======================================================================================
namespace {
class ConcurrencyLimitingHttpClient final: public HttpClient {
public:
ConcurrencyLimitingHttpClient(
kj::HttpClient& inner, uint maxConcurrentRequests,
kj::Function<void(uint runningCount, uint pendingCount)> countChangedCallback)
: inner(inner),
maxConcurrentRequests(maxConcurrentRequests),
countChangedCallback(kj::mv(countChangedCallback)) {}
Request request(HttpMethod method, kj::StringPtr url, const HttpHeaders& headers,
kj::Maybe<uint64_t> expectedBodySize = nullptr) override {
if (concurrentRequests < maxConcurrentRequests) {
auto counter = ConnectionCounter(*this);
auto request = inner.request(method, url, headers, expectedBodySize);
fireCountChanged();
auto promise = attachCounter(kj::mv(request.response), kj::mv(counter));
return { kj::mv(request.body), kj::mv(promise) };
}
auto paf = kj::newPromiseAndFulfiller<ConnectionCounter>();
auto urlCopy = kj::str(url);
auto headersCopy = headers.clone();
auto combined = paf.promise
.then([this,
method,
urlCopy = kj::mv(urlCopy),
headersCopy = kj::mv(headersCopy),
expectedBodySize](ConnectionCounter&& counter) mutable {
auto req = inner.request(method, urlCopy, headersCopy, expectedBodySize);
return kj::tuple(kj::mv(req.body), attachCounter(kj::mv(req.response), kj::mv(counter)));
});
auto split = combined.split();
pendingRequests.push(kj::mv(paf.fulfiller));
fireCountChanged();
return { newPromisedStream(kj::mv(kj::get<0>(split))), kj::mv(kj::get<1>(split)) };
}
kj::Promise<WebSocketResponse> openWebSocket(
kj::StringPtr url, const kj::HttpHeaders& headers) override {
if (concurrentRequests < maxConcurrentRequests) {
auto counter = ConnectionCounter(*this);
auto response = inner.openWebSocket(url, headers);
fireCountChanged();
return attachCounter(kj::mv(response), kj::mv(counter));
}
auto paf = kj::newPromiseAndFulfiller<ConnectionCounter>();
auto urlCopy = kj::str(url);
auto headersCopy = headers.clone();
auto promise = paf.promise
.then([this,
urlCopy = kj::mv(urlCopy),
headersCopy = kj::mv(headersCopy)](ConnectionCounter&& counter) mutable {
return attachCounter(inner.openWebSocket(urlCopy, headersCopy), kj::mv(counter));
});
pendingRequests.push(kj::mv(paf.fulfiller));
fireCountChanged();
return kj::mv(promise);
}
private:
struct ConnectionCounter;
kj::HttpClient& inner;
uint maxConcurrentRequests;
uint concurrentRequests = 0;
kj::Function<void(uint runningCount, uint pendingCount)> countChangedCallback;
std::queue<kj::Own<kj::PromiseFulfiller<ConnectionCounter>>> pendingRequests;
// TODO(someday): want maximum cap on queue size?
struct ConnectionCounter final {
ConnectionCounter(ConcurrencyLimitingHttpClient& client) : parent(&client) {
++parent->concurrentRequests;
}
KJ_DISALLOW_COPY(ConnectionCounter);
~ConnectionCounter() noexcept(false) {
if (parent != nullptr) {
--parent->concurrentRequests;
parent->serviceQueue();
parent->fireCountChanged();
}
}
ConnectionCounter(ConnectionCounter&& other) : parent(other.parent) {
other.parent = nullptr;
}
ConnectionCounter& operator=(ConnectionCounter&& other) {
if (this != &other) {
this->parent = other.parent;
other.parent = nullptr;
}
return *this;
}
ConcurrencyLimitingHttpClient* parent;
};
void serviceQueue() {
while (concurrentRequests < maxConcurrentRequests && !pendingRequests.empty()) {
auto fulfiller = kj::mv(pendingRequests.front());
pendingRequests.pop();
// ConnectionCounter's destructor calls this function, so we can avoid unnecessary recursion
// if we only create a ConnectionCounter when we find a waiting fulfiller.
if (fulfiller->isWaiting()) {
fulfiller->fulfill(ConnectionCounter(*this));
}
}
}
void fireCountChanged() {
countChangedCallback(concurrentRequests, pendingRequests.size());
}
using WebSocketOrBody = kj::OneOf<kj::Own<kj::AsyncInputStream>, kj::Own<WebSocket>>;
static WebSocketOrBody attachCounter(WebSocketOrBody&& webSocketOrBody,
ConnectionCounter&& counter) {
KJ_SWITCH_ONEOF(webSocketOrBody) {
KJ_CASE_ONEOF(ws, kj::Own<WebSocket>) {
return ws.attach(kj::mv(counter));
}
KJ_CASE_ONEOF(body, kj::Own<kj::AsyncInputStream>) {
return body.attach(kj::mv(counter));
}
}
KJ_UNREACHABLE;
}
static kj::Promise<WebSocketResponse> attachCounter(kj::Promise<WebSocketResponse>&& promise,
ConnectionCounter&& counter) {
return promise.then([counter = kj::mv(counter)](WebSocketResponse&& response) mutable {
return WebSocketResponse {
response.statusCode,
response.statusText,
response.headers,
attachCounter(kj::mv(response.webSocketOrBody), kj::mv(counter))
};
});
}
static kj::Promise<Response> attachCounter(kj::Promise<Response>&& promise,
ConnectionCounter&& counter) {
return promise.then([counter = kj::mv(counter)](Response&& response) mutable {
return Response {
response.statusCode,
response.statusText,
response.headers,
response.body.attach(kj::mv(counter))
};
});
}
};
}
kj::Own<HttpClient> newConcurrencyLimitingHttpClient(
HttpClient& inner, uint maxConcurrentRequests,
kj::Function<void(uint runningCount, uint pendingCount)> countChangedCallback) {
return kj::heap<ConcurrencyLimitingHttpClient>(inner, maxConcurrentRequests,
kj::mv(countChangedCallback));
}
// =======================================================================================
namespace {
class NullInputStream final: public kj::AsyncInputStream {
public:
NullInputStream(kj::Maybe<size_t> expectedLength = size_t(0))
: expectedLength(expectedLength) {}
kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
return size_t(0);
}
kj::Maybe<uint64_t> tryGetLength() override {
return expectedLength;
}
kj::Promise<uint64_t> pumpTo(AsyncOutputStream& output, uint64_t amount) override {
return uint64_t(0);
}
private:
kj::Maybe<size_t> expectedLength;
};
class NullOutputStream final: public kj::AsyncOutputStream {
public:
Promise<void> write(const void* buffer, size_t size) override {
return kj::READY_NOW;
}
Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override {
return kj::READY_NOW;
}
Promise<void> whenWriteDisconnected() override {
return kj::NEVER_DONE;
}
// We can't really optimize tryPumpFrom() unless AsyncInputStream grows a skip() method.
};
class HttpClientAdapter final: public HttpClient {
public:
HttpClientAdapter(HttpService& service): service(service) {}
Request request(HttpMethod method, kj::StringPtr url, const HttpHeaders& headers,
kj::Maybe<uint64_t> expectedBodySize = nullptr) override {
// We have to clone the URL and headers because HttpService implementation are allowed to
// assume that they remain valid until the service handler completes whereas HttpClient callers
// are allowed to destroy them immediately after the call.
auto urlCopy = kj::str(url);
auto headersCopy = kj::heap(headers.clone());
auto pipe = newOneWayPipe(expectedBodySize);
// TODO(cleanup): The ownership relationships here are a mess. Can we do something better
// involving a PromiseAdapter, maybe?
auto paf = kj::newPromiseAndFulfiller<Response>();
auto responder = kj::refcounted<ResponseImpl>(method, kj::mv(paf.fulfiller));
auto requestPaf = kj::newPromiseAndFulfiller<kj::Promise<void>>();
responder->setPromise(kj::mv(requestPaf.promise));
auto promise = service.request(method, urlCopy, *headersCopy, *pipe.in, *responder)
.attach(kj::mv(pipe.in), kj::mv(urlCopy), kj::mv(headersCopy));
requestPaf.fulfiller->fulfill(kj::mv(promise));
return {
kj::mv(pipe.out),
paf.promise.attach(kj::mv(responder))
};
}
kj::Promise<WebSocketResponse> openWebSocket(
kj::StringPtr url, const HttpHeaders& headers) override {
// We have to clone the URL and headers because HttpService implementation are allowed to
// assume that they remain valid until the service handler completes whereas HttpClient callers
// are allowed to destroy them immediately after the call. Also we need to add
// `Upgrade: websocket` so that headers.isWebSocket() returns true on the service side.
auto urlCopy = kj::str(url);
auto headersCopy = kj::heap(headers.clone());
headersCopy->set(HttpHeaderId::UPGRADE, "websocket");
KJ_DASSERT(headersCopy->isWebSocket());
auto paf = kj::newPromiseAndFulfiller<WebSocketResponse>();
auto responder = kj::refcounted<WebSocketResponseImpl>(kj::mv(paf.fulfiller));
auto requestPaf = kj::newPromiseAndFulfiller<kj::Promise<void>>();
responder->setPromise(kj::mv(requestPaf.promise));
auto in = kj::heap<NullInputStream>();
auto promise = service.request(HttpMethod::GET, urlCopy, *headersCopy, *in, *responder)
.attach(kj::mv(in), kj::mv(urlCopy), kj::mv(headersCopy));
requestPaf.fulfiller->fulfill(kj::mv(promise));
return paf.promise.attach(kj::mv(responder));
}
kj::Promise<kj::Own<kj::AsyncIoStream>> connect(kj::StringPtr host) override {
return service.connect(kj::mv(host));
}
private:
HttpService& service;
class DelayedEofInputStream final: public kj::AsyncInputStream {
// An AsyncInputStream wrapper that, when it reaches EOF, delays the final read until some
// promise completes.
public:
DelayedEofInputStream(kj::Own<kj::AsyncInputStream> inner, kj::Promise<void> completionTask)
: inner(kj::mv(inner)), completionTask(kj::mv(completionTask)) {}
kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
return wrap(minBytes, inner->tryRead(buffer, minBytes, maxBytes));
}
kj::Maybe<uint64_t> tryGetLength() override {
return inner->tryGetLength();
}
kj::Promise<uint64_t> pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override {
return wrap(amount, inner->pumpTo(output, amount));
}
private:
kj::Own<kj::AsyncInputStream> inner;
kj::Maybe<kj::Promise<void>> completionTask;
template <typename T>
kj::Promise<T> wrap(T requested, kj::Promise<T> innerPromise) {
return innerPromise.then([this,requested](T actual) -> kj::Promise<T> {
if (actual < requested) {
// Must have reached EOF.
KJ_IF_MAYBE(t, completionTask) {
// Delay until completion.
auto result = t->then([actual]() { return actual; });
completionTask = nullptr;
return result;
} else {
// Must have called tryRead() again after we already signaled EOF. Fine.
return actual;
}
} else {
return actual;
}
}, [this](kj::Exception&& e) -> kj::Promise<T> {
// The stream threw an exception, but this exception is almost certainly just complaining
// that the other end of the stream was dropped. In all likelihood, the HttpService
// request() call itself will throw a much more interesting error -- we'd rather propagate
// that one, if so.
KJ_IF_MAYBE(t, completionTask) {
auto result = t->then([e = kj::mv(e)]() mutable -> kj::Promise<T> {
// Looks like the service didn't throw. I guess we should propagate the stream error
// after all.
return kj::mv(e);
});
completionTask = nullptr;
return result;
} else {
// Must have called tryRead() again after we already signaled EOF or threw. Fine.
return kj::mv(e);
}
});
}
};
class ResponseImpl final: public HttpService::Response, public kj::Refcounted {
public:
ResponseImpl(kj::HttpMethod method,
kj::Own<kj::PromiseFulfiller<HttpClient::Response>> fulfiller)
: method(method), fulfiller(kj::mv(fulfiller)) {}
void setPromise(kj::Promise<void> promise) {
task = promise.eagerlyEvaluate([this](kj::Exception&& exception) {
if (fulfiller->isWaiting()) {
fulfiller->reject(kj::mv(exception));
} else {
// We need to cause the response stream's read() to throw this, so we should propagate it.
kj::throwRecoverableException(kj::mv(exception));
}
});
}
kj::Own<kj::AsyncOutputStream> send(
uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers,
kj::Maybe<uint64_t> expectedBodySize = nullptr) override {
// The caller of HttpClient is allowed to assume that the statusText and headers remain
// valid until the body stream is dropped, but the HttpService implementation is allowed to
// send values that are only valid until send() returns, so we have to copy.
auto statusTextCopy = kj::str(statusText);
auto headersCopy = kj::heap(headers.clone());
if (method == kj::HttpMethod::HEAD || expectedBodySize.orDefault(1) == 0) {
// We're not expecting any body. We need to delay reporting completion to the client until
// the server side has actually returned from the service method, otherwise we may
// prematurely cancel it.
task = task.then([this,statusCode,statusTextCopy=kj::mv(statusTextCopy),
headersCopy=kj::mv(headersCopy),expectedBodySize]() mutable {
fulfiller->fulfill({
statusCode, statusTextCopy, headersCopy.get(),
kj::heap<NullInputStream>(expectedBodySize)
.attach(kj::mv(statusTextCopy), kj::mv(headersCopy))
});
}).eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); });
return kj::heap<NullOutputStream>();
} else {
auto pipe = newOneWayPipe(expectedBodySize);
// Wrap the stream in a wrapper that delays the last read (the one that signals EOF) until
// the service's request promise has finished.
auto wrapper = kj::heap<DelayedEofInputStream>(
kj::mv(pipe.in), task.attach(kj::addRef(*this)));
fulfiller->fulfill({
statusCode, statusTextCopy, headersCopy.get(),
wrapper.attach(kj::mv(statusTextCopy), kj::mv(headersCopy))
});
return kj::mv(pipe.out);
}
}
kj::Own<WebSocket> acceptWebSocket(const HttpHeaders& headers) override {
KJ_FAIL_REQUIRE("a WebSocket was not requested");
}
private:
kj::HttpMethod method;
kj::Own<kj::PromiseFulfiller<HttpClient::Response>> fulfiller;
kj::Promise<void> task = nullptr;
};
class DelayedCloseWebSocket final: public WebSocket {
// A WebSocket wrapper that, when it reaches Close (in both directions), delays the final close
// operation until some promise completes.
public:
DelayedCloseWebSocket(kj::Own<kj::WebSocket> inner, kj::Promise<void> completionTask)
: inner(kj::mv(inner)), completionTask(kj::mv(completionTask)) {}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
return inner->send(message);
}
kj::Promise<void> send(kj::ArrayPtr<const char> message) override {
return inner->send(message);
}
kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override {
return inner->close(code, reason)
.then([this]() {
return afterSendClosed();
});
}
kj::Promise<void> disconnect() override {
return inner->disconnect();
}
void abort() override {
// Don't need to worry about completion task in this case -- cancelling it is reasonable.
inner->abort();
}
kj::Promise<void> whenAborted() override {
return inner->whenAborted();
}
kj::Promise<Message> receive(size_t maxSize) override {
return inner->receive(maxSize).then([this](Message&& message) -> kj::Promise<Message> {
if (message.is<WebSocket::Close>()) {
return afterReceiveClosed()
.then([message = kj::mv(message)]() mutable { return kj::mv(message); });
}
return kj::mv(message);
});
}
kj::Promise<void> pumpTo(WebSocket& other) override {
return inner->pumpTo(other).then([this]() {
return afterReceiveClosed();
});
}
kj::Maybe<kj::Promise<void>> tryPumpFrom(WebSocket& other) override {
return other.pumpTo(*inner).then([this]() {
return afterSendClosed();
});
}
uint64_t sentByteCount() override { return inner->sentByteCount(); }
uint64_t receivedByteCount() override { return inner->receivedByteCount(); }
private:
kj::Own<kj::WebSocket> inner;
kj::Maybe<kj::Promise<void>> completionTask;
bool sentClose = false;
bool receivedClose = false;
kj::Promise<void> afterSendClosed() {
sentClose = true;
if (receivedClose) {
KJ_IF_MAYBE(t, completionTask) {
auto result = kj::mv(*t);
completionTask = nullptr;
return result;
}
}
return kj::READY_NOW;
}
kj::Promise<void> afterReceiveClosed() {
receivedClose = true;
if (sentClose) {
KJ_IF_MAYBE(t, completionTask) {
auto result = kj::mv(*t);
completionTask = nullptr;
return result;
}
}
return kj::READY_NOW;
}
};
class WebSocketResponseImpl final: public HttpService::Response, public kj::Refcounted {
public:
WebSocketResponseImpl(kj::Own<kj::PromiseFulfiller<HttpClient::WebSocketResponse>> fulfiller)
: fulfiller(kj::mv(fulfiller)) {}
void setPromise(kj::Promise<void> promise) {
task = promise.eagerlyEvaluate([this](kj::Exception&& exception) {
if (fulfiller->isWaiting()) {
fulfiller->reject(kj::mv(exception));
} else {
// We need to cause the client-side WebSocket to throw on close, so propagate the
// exception.
kj::throwRecoverableException(kj::mv(exception));
}
});
}
kj::Own<kj::AsyncOutputStream> send(
uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers,
kj::Maybe<uint64_t> expectedBodySize = nullptr) override {
// The caller of HttpClient is allowed to assume that the statusText and headers remain
// valid until the body stream is dropped, but the HttpService implementation is allowed to
// send values that are only valid until send() returns, so we have to copy.
auto statusTextCopy = kj::str(statusText);
auto headersCopy = kj::heap(headers.clone());
if (expectedBodySize.orDefault(1) == 0) {
// We're not expecting any body. We need to delay reporting completion to the client until
// the server side has actually returned from the service method, otherwise we may
// prematurely cancel it.
task = task.then([this,statusCode,statusTextCopy=kj::mv(statusTextCopy),
headersCopy=kj::mv(headersCopy),expectedBodySize]() mutable {
fulfiller->fulfill({
statusCode, statusTextCopy, headersCopy.get(),
kj::Own<AsyncInputStream>(kj::heap<NullInputStream>(expectedBodySize)
.attach(kj::mv(statusTextCopy), kj::mv(headersCopy)))
});
}).eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); });
return kj::heap<NullOutputStream>();
} else {
auto pipe = newOneWayPipe(expectedBodySize);
// Wrap the stream in a wrapper that delays the last read (the one that signals EOF) until
// the service's request promise has finished.
kj::Own<AsyncInputStream> wrapper =
kj::heap<DelayedEofInputStream>(kj::mv(pipe.in), task.attach(kj::addRef(*this)));
fulfiller->fulfill({
statusCode, statusTextCopy, headersCopy.get(),
wrapper.attach(kj::mv(statusTextCopy), kj::mv(headersCopy))
});
return kj::mv(pipe.out);
}
}
kj::Own<WebSocket> acceptWebSocket(const HttpHeaders& headers) override {
// The caller of HttpClient is allowed to assume that the headers remain valid until the body
// stream is dropped, but the HttpService implementation is allowed to send headers that are
// only valid until acceptWebSocket() returns, so we have to copy.
auto headersCopy = kj::heap(headers.clone());
auto pipe = newWebSocketPipe();
// Wrap the client-side WebSocket in a wrapper that delays clean close of the WebSocket until
// the service's request promise has finished.
kj::Own<WebSocket> wrapper =
kj::heap<DelayedCloseWebSocket>(kj::mv(pipe.ends[0]), task.attach(kj::addRef(*this)));
fulfiller->fulfill({
101, "Switching Protocols", headersCopy.get(),
wrapper.attach(kj::mv(headersCopy))
});
return kj::mv(pipe.ends[1]);
}
private:
kj::Own<kj::PromiseFulfiller<HttpClient::WebSocketResponse>> fulfiller;
kj::Promise<void> task = nullptr;
};
};
} // namespace
kj::Own<HttpClient> newHttpClient(HttpService& service) {
return kj::heap<HttpClientAdapter>(service);
}
// =======================================================================================
namespace {
class HttpServiceAdapter final: public HttpService {
public:
HttpServiceAdapter(HttpClient& client): client(client) {}
kj::Promise<void> request(
HttpMethod method, kj::StringPtr url, const HttpHeaders& headers,
kj::AsyncInputStream& requestBody, Response& response) override {
if (!headers.isWebSocket()) {
auto innerReq = client.request(method, url, headers, requestBody.tryGetLength());
auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2);
promises.add(requestBody.pumpTo(*innerReq.body).ignoreResult()
.attach(kj::mv(innerReq.body)).eagerlyEvaluate(nullptr));
promises.add(innerReq.response
.then([&response](HttpClient::Response&& innerResponse) {
auto out = response.send(
innerResponse.statusCode, innerResponse.statusText, *innerResponse.headers,
innerResponse.body->tryGetLength());
auto promise = innerResponse.body->pumpTo(*out);
return promise.ignoreResult().attach(kj::mv(out), kj::mv(innerResponse.body));
}));
return kj::joinPromises(promises.finish());
} else {
return client.openWebSocket(url, headers)
.then([&response](HttpClient::WebSocketResponse&& innerResponse) -> kj::Promise<void> {
KJ_SWITCH_ONEOF(innerResponse.webSocketOrBody) {
KJ_CASE_ONEOF(ws, kj::Own<WebSocket>) {
auto ws2 = response.acceptWebSocket(*innerResponse.headers);
auto promises = kj::heapArrayBuilder<kj::Promise<void>>(2);
promises.add(ws->pumpTo(*ws2));
promises.add(ws2->pumpTo(*ws));
return kj::joinPromises(promises.finish()).attach(kj::mv(ws), kj::mv(ws2));
}
KJ_CASE_ONEOF(body, kj::Own<kj::AsyncInputStream>) {
auto out = response.send(
innerResponse.statusCode, innerResponse.statusText, *innerResponse.headers,
body->tryGetLength());
auto promise = body->pumpTo(*out);
return promise.ignoreResult().attach(kj::mv(out), kj::mv(body));
}
}
KJ_UNREACHABLE;
});
}
}
kj::Promise<kj::Own<kj::AsyncIoStream>> connect(kj::StringPtr host) override {
return client.connect(kj::mv(host));
}
private:
HttpClient& client;
};
} // namespace
kj::Own<HttpService> newHttpService(HttpClient& client) {
return kj::heap<HttpServiceAdapter>(client);
}
// =======================================================================================
kj::Promise<void> HttpService::Response::sendError(
uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers) {
auto stream = send(statusCode, statusText, headers, statusText.size());
auto promise = stream->write(statusText.begin(), statusText.size());
return promise.attach(kj::mv(stream));
}
kj::Promise<void> HttpService::Response::sendError(
uint statusCode, kj::StringPtr statusText, const HttpHeaderTable& headerTable) {
return sendError(statusCode, statusText, HttpHeaders(headerTable));
}
kj::Promise<kj::Own<kj::AsyncIoStream>> HttpService::connect(kj::StringPtr host) {
KJ_UNIMPLEMENTED("CONNECT is not implemented by this HttpService");
}
class HttpServer::Connection final: private HttpService::Response,
private HttpServerErrorHandler {
public:
Connection(HttpServer& server, kj::AsyncIoStream& stream,
SuspendableHttpServiceFactory factory, kj::Maybe<SuspendedRequest> suspendedRequest)
: server(server),
stream(stream),
factory(kj::mv(factory)),
httpInput(makeHttpInput(stream, server.requestHeaderTable, kj::mv(suspendedRequest))),
httpOutput(stream) {
++server.connectionCount;
}
~Connection() noexcept(false) {
if (--server.connectionCount == 0) {
KJ_IF_MAYBE(f, server.zeroConnectionsFulfiller) {
f->get()->fulfill();
}
}
}
public:
kj::Promise<bool> startLoop(bool firstRequest) {
return loop(firstRequest).catch_([this](kj::Exception&& e) -> kj::Promise<bool> {
// Exception; report 5xx.
KJ_IF_MAYBE(p, webSocketError) {
// sendWebSocketError() was called. Finish sending and close the connection. Don't log
// the exception because it's probably a side-effect of this.
auto promise = kj::mv(*p);
webSocketError = nullptr;
return kj::mv(promise);
}
return sendError(kj::mv(e));
});
}
SuspendedRequest suspend(SuspendableRequest& suspendable) {
KJ_REQUIRE(httpInput.canSuspend(),
"suspend() may only be called before the request body is consumed");
KJ_DEFER(suspended = true);
auto released = httpInput.releaseBuffer();
return {
kj::mv(released.buffer),
released.leftover,
suspendable.method,
suspendable.url,
suspendable.headers.cloneShallow(),
};
}
private:
HttpServer& server;
kj::AsyncIoStream& stream;
SuspendableHttpServiceFactory factory;
// Creates a new kj::Own<HttpService> for each request we handle on this connection.
HttpInputStreamImpl httpInput;
HttpOutputStream httpOutput;
kj::Maybe<HttpMethod> currentMethod;
bool timedOut = false;
bool closed = false;
bool upgraded = false;
bool webSocketClosed = false;
bool closeAfterSend = false; // True if send() should set Connection: close.
bool suspended = false;
kj::Maybe<kj::Promise<bool>> webSocketError;
static HttpInputStreamImpl makeHttpInput(
kj::AsyncIoStream& stream,
const kj::HttpHeaderTable& table,
kj::Maybe<SuspendedRequest> suspendedRequest) {
// Constructor helper function to create our HttpInputStreamImpl.
KJ_IF_MAYBE(sr, suspendedRequest) {
return HttpInputStreamImpl(stream,
sr->buffer.releaseAsChars(),
sr->leftover.asChars(),
sr->method,
sr->url,
kj::mv(sr->headers));
}
return HttpInputStreamImpl(stream, table);
}
kj::Promise<bool> loop(bool firstRequest) {
if (!firstRequest && server.draining && httpInput.isCleanDrain()) {
// Don't call awaitNextMessage() in this case because that will initiate a read() which will
// immediately be canceled, losing data.
return true;
}
auto firstByte = httpInput.awaitNextMessage();
if (!firstRequest) {
// For requests after the first, require that the first byte arrive before the pipeline
// timeout, otherwise treat it like the connection was simply closed.
auto timeoutPromise = server.timer.afterDelay(server.settings.pipelineTimeout);
if (httpInput.isCleanDrain()) {
// If we haven't buffered any data, then we can safely drain here, so allow the wait to
// be canceled by the onDrain promise.
auto cleanDrainPromise = server.onDrain.addBranch()
.then([this]() -> kj::Promise<void> {
// This is a little tricky... drain() has been called, BUT we could have read some data
// into the buffer in the meantime, and we don't want to lose that. If any data has
// arrived, then we have no choice but to read the rest of the request and respond to
// it.
if (!httpInput.isCleanDrain()) {
return kj::NEVER_DONE;
}
// OK... As far as we know, no data has arrived in the buffer. However, unfortunately,
// we don't *really* know that, because read() is asynchronous. It may have already
// delivered some bytes, but we just haven't received the notification yet, because it's
// still queued on the event loop. As a horrible hack, we use evalLast(), so that any
// such pending notifications get a chance to be delivered.
// TODO(someday): Does this actually work on Windows, where the notification could also
// be queued on the IOCP?
return kj::evalLast([this]() -> kj::Promise<void> {
if (httpInput.isCleanDrain()) {
return kj::READY_NOW;
} else {
return kj::NEVER_DONE;
}
});
});
timeoutPromise = timeoutPromise.exclusiveJoin(kj::mv(cleanDrainPromise));
}
firstByte = firstByte.exclusiveJoin(timeoutPromise.then([this]() -> bool {
timedOut = true;
return false;
}));
}
auto receivedHeaders = firstByte
.then([this,firstRequest](bool hasData)
-> kj::Promise<HttpHeaders::RequestOrProtocolError> {
if (hasData) {
auto readHeaders = httpInput.readRequestHeaders();
if (!firstRequest) {
// On requests other than the first, the header timeout starts ticking when we receive
// the first byte of a pipeline response.
readHeaders = readHeaders.exclusiveJoin(
server.timer.afterDelay(server.settings.headerTimeout)
.then([this]() -> HttpHeaders::RequestOrProtocolError {
timedOut = true;
return HttpHeaders::ProtocolError {
408, "Request Timeout",
"Timed out waiting for next request headers.", nullptr
};
}));
}
return kj::mv(readHeaders);
} else {
// Client closed connection or pipeline timed out with no bytes received. This is not an
// error, so don't report one.
this->closed = true;
return HttpHeaders::RequestOrProtocolError(HttpHeaders::ProtocolError {
408, "Request Timeout",
"Client closed connection or connection timeout "
"while waiting for request headers.", nullptr
});
}
});
if (firstRequest) {
// On the first request, the header timeout starts ticking immediately upon request opening.
auto timeoutPromise = server.timer.afterDelay(server.settings.headerTimeout)
.exclusiveJoin(server.onDrain.addBranch())
.then([this]() -> HttpHeaders::RequestOrProtocolError {
timedOut = true;
return HttpHeaders::ProtocolError {
408, "Request Timeout",
"Timed out waiting for initial request headers.", nullptr
};
});
receivedHeaders = receivedHeaders.exclusiveJoin(kj::mv(timeoutPromise));
}
return receivedHeaders
.then([this](HttpHeaders::RequestOrProtocolError&& requestOrProtocolError)
-> kj::Promise<bool> {
if (timedOut) {
// Client took too long to send anything, so we're going to close the connection. In
// theory, we should send back an HTTP 408 error -- it is designed exactly for this
// purpose. Alas, in practice, Google Chrome does not have any special handling for 408
// errors -- it will assume the error is a response to the next request it tries to send,
// and will happily serve the error to the user. OTOH, if we simply close the connection,
// Chrome does the "right thing", apparently. (Though I'm not sure what happens if a
// request is in-flight when we close... if it's a GET, the browser should retry. But if
// it's a POST, retrying may be dangerous. This is why 408 exists -- it unambiguously
// tells the client that it should retry.)
//
// Also note that if we ever decide to send 408 again, we might want to send some other
// error in the case that the server is draining, which also sets timedOut = true; see
// above.
return httpOutput.flush().then([this]() {
return server.draining && httpInput.isCleanDrain();
});
}
if (closed) {
// Client closed connection. Close our end too.
return httpOutput.flush().then([]() { return false; });
}
KJ_SWITCH_ONEOF(requestOrProtocolError) {
KJ_CASE_ONEOF(request, HttpHeaders::Request) {
auto& headers = httpInput.getHeaders();
currentMethod = request.method;
SuspendableRequest suspendable(*this, request.method, request.url, headers);
auto maybeService = factory(suspendable);
if (suspended) {
return false;
}
auto service = KJ_ASSERT_NONNULL(kj::mv(maybeService),
"SuspendableHttpServiceFactory did not suspend, but returned nullptr.");
auto body = httpInput.getEntityBody(
HttpInputStreamImpl::REQUEST, request.method, 0, headers);
// TODO(perf): If the client disconnects, should we cancel the response? Probably, to
// prevent permanent deadlock. It's slightly weird in that arguably the client should
// be able to shutdown the upstream but still wait on the downstream, but I believe many
// other HTTP servers do similar things.
auto promise = service->request(
request.method, request.url, headers, *body, *this).attach(kj::mv(service));
return promise.then([this, body = kj::mv(body)]() mutable -> kj::Promise<bool> {
// Response done. Await next request.
KJ_IF_MAYBE(p, webSocketError) {
// sendWebSocketError() was called. Finish sending and close the connection.
auto promise = kj::mv(*p);
webSocketError = nullptr;
return kj::mv(promise);
}
if (upgraded) {
// We've upgraded to WebSocket, and by now we should have closed the WebSocket.
if (!webSocketClosed) {
// This is gonna segfault later so abort now instead.
KJ_LOG(FATAL, "Accepted WebSocket object must be destroyed before HttpService "
"request handler completes.");
abort();
}
// Once we start a WebSocket there's no going back to HTTP.
return false;
}
if (currentMethod != nullptr) {
return sendError();
}
if (httpOutput.isBroken()) {
// We started a response but didn't finish it. But HttpService returns success?
// Perhaps it decided that it doesn't want to finish this response. We'll have to
// disconnect here. If the response body is not complete (e.g. Content-Length not
// reached), the client should notice. We don't want to log an error because this
// condition might be intentional on the service's part.
return false;
}
return httpOutput.flush().then(
[this, body = kj::mv(body)]() mutable -> kj::Promise<bool> {
if (httpInput.canReuse()) {
// Things look clean. Go ahead and accept the next request.
// Note that we don't have to handle server.draining here because we'll take care of
// it the next time around the loop.
return loop(false);
} else {
// Apparently, the application did not read the request body. Maybe this is a bug,
// or maybe not: maybe the client tried to upload too much data and the application
// legitimately wants to cancel the upload without reading all it it.
//
// We have a problem, though: We did send a response, and we didn't send
// `Connection: close`, so the client may expect that it can send another request.
// Perhaps the client has even finished sending the previous request's body, in
// which case the moment it finishes receiving the response, it could be completely
// within its rights to start a new request. If we close the socket now, we might
// interrupt that new request.
//
// There's no way we can get out of this perfectly cleanly. HTTP just isn't good
// enough at connection management. The best we can do is give the client some grace
// period and then abort the connection.
auto dummy = kj::heap<HttpDiscardingEntityWriter>();
auto lengthGrace = body->pumpTo(*dummy, server.settings.canceledUploadGraceBytes)
.then([this](size_t amount) {
if (httpInput.canReuse()) {
// Success, we can continue.
return true;
} else {
// Still more data. Give up.
return false;
}
});
lengthGrace = lengthGrace.attach(kj::mv(dummy), kj::mv(body));
auto timeGrace = server.timer.afterDelay(server.settings.canceledUploadGracePeriod)
.then([]() { return false; });
return lengthGrace.exclusiveJoin(kj::mv(timeGrace))
.then([this](bool clean) -> kj::Promise<bool> {
if (clean) {
// We recovered. Continue loop.
return loop(false);
} else {
// Client still not done. Return broken.
return false;
}
});
}
});
});
}
KJ_CASE_ONEOF(protocolError, HttpHeaders::ProtocolError) {
// Bad request.
// sendError() uses Response::send(), which requires that we have a currentMethod, but we
// never read one. GET seems like the correct choice here.
currentMethod = HttpMethod::GET;
return sendError(kj::mv(protocolError));
}
}
KJ_UNREACHABLE;
});
}
kj::Own<kj::AsyncOutputStream> send(
uint statusCode, kj::StringPtr statusText, const HttpHeaders& headers,
kj::Maybe<uint64_t> expectedBodySize) override {
auto method = KJ_REQUIRE_NONNULL(currentMethod, "already called send()");
currentMethod = nullptr;
kj::StringPtr connectionHeaders[HttpHeaders::CONNECTION_HEADERS_COUNT];
kj::String lengthStr;
if (!closeAfterSend) {
// Check if application wants us to close connections.
KJ_IF_MAYBE(c, server.settings.callbacks) {
if (c->shouldClose()) {
closeAfterSend = true;
}
}
}
// TODO(0.10): If `server.draining`, we should probably set `closeAfterSend` -- UNLESS the
// connection was created using listenHttpCleanDrain(), in which case the application may
// intend to continue using the connection.
if (closeAfterSend) {
connectionHeaders[HttpHeaders::BuiltinIndices::CONNECTION] = "close";
}
if (statusCode == 204 || statusCode == 304) {
// No entity-body.
} else if (statusCode == 205) {
// Status code 205 also has no body, but unlike 204 and 304, it must explicitly encode an
// empty body, e.g. using content-length: 0. I'm guessing this is one of those things, where
// some early clients expected an explicit body while others assumed an empty body, and so
// the standard had to choose the common denominator.
//
// Spec: https://tools.ietf.org/html/rfc7231#section-6.3.6
connectionHeaders[HttpHeaders::BuiltinIndices::CONTENT_LENGTH] = "0";
} else KJ_IF_MAYBE(s, expectedBodySize) {
// HACK: We interpret a zero-length expected body length on responses to HEAD requests to mean
// "don't set a Content-Length header at all." This provides a way to omit a body header on
// HEAD responses with non-null-body status codes. This is a hack that *only* makes sense
// for HEAD responses.
if (method != HttpMethod::HEAD || *s > 0) {
lengthStr = kj::str(*s);
connectionHeaders[HttpHeaders::BuiltinIndices::CONTENT_LENGTH] = lengthStr;
}
} else {
connectionHeaders[HttpHeaders::BuiltinIndices::TRANSFER_ENCODING] = "chunked";
}
// For HEAD requests, if the application specified a Content-Length or Transfer-Encoding
// header, use that instead of whatever we decided above.
kj::ArrayPtr<kj::StringPtr> connectionHeadersArray = connectionHeaders;
if (method == HttpMethod::HEAD) {
if (headers.get(HttpHeaderId::CONTENT_LENGTH) != nullptr ||
headers.get(HttpHeaderId::TRANSFER_ENCODING) != nullptr) {
connectionHeadersArray = connectionHeadersArray
.slice(0, HttpHeaders::HEAD_RESPONSE_CONNECTION_HEADERS_COUNT);
}
}
httpOutput.writeHeaders(headers.serializeResponse(
statusCode, statusText, connectionHeadersArray));
kj::Own<kj::AsyncOutputStream> bodyStream;
if (method == HttpMethod::HEAD) {
// Ignore entity-body.
httpOutput.finishBody();
return heap<HttpDiscardingEntityWriter>();
} else if (statusCode == 204 || statusCode == 205 || statusCode == 304) {
// No entity-body.
httpOutput.finishBody();
return heap<HttpNullEntityWriter>();
} else KJ_IF_MAYBE(s, expectedBodySize) {
return heap<HttpFixedLengthEntityWriter>(httpOutput, *s);
} else {
return heap<HttpChunkedEntityWriter>(httpOutput);
}
}
kj::Own<WebSocket> acceptWebSocket(const HttpHeaders& headers) override {
auto& requestHeaders = httpInput.getHeaders();
KJ_REQUIRE(requestHeaders.isWebSocket(),
"can't call acceptWebSocket() if the request headers didn't have Upgrade: WebSocket");
auto method = KJ_REQUIRE_NONNULL(currentMethod, "already called send()");
// Unlike send(), we neither need nor want to null out currentMethod. The error cases below
// depend on it being non-null to allow error responses to be sent, and the happy path expects
// it to be GET.
if (method != HttpMethod::GET) {
return sendWebSocketError("WebSocket must be initiated with a GET request.");
}
if (requestHeaders.get(HttpHeaderId::SEC_WEBSOCKET_VERSION).orDefault(nullptr) != "13") {
return sendWebSocketError("The requested WebSocket version is not supported.");
}
kj::String key;
KJ_IF_MAYBE(k, requestHeaders.get(HttpHeaderId::SEC_WEBSOCKET_KEY)) {
key = kj::str(*k);
} else {
return sendWebSocketError("Missing Sec-WebSocket-Key");
}
auto websocketAccept = generateWebSocketAccept(key);
kj::StringPtr connectionHeaders[HttpHeaders::WEBSOCKET_CONNECTION_HEADERS_COUNT];
connectionHeaders[HttpHeaders::BuiltinIndices::SEC_WEBSOCKET_ACCEPT] = websocketAccept;
connectionHeaders[HttpHeaders::BuiltinIndices::UPGRADE] = "websocket";
connectionHeaders[HttpHeaders::BuiltinIndices::CONNECTION] = "Upgrade";
httpOutput.writeHeaders(headers.serializeResponse(
101, "Switching Protocols", connectionHeaders));
upgraded = true;
// We need to give the WebSocket an Own<AsyncIoStream>, but we only have a reference. This is
// safe because the application is expected to drop the WebSocket object before returning
// from the request handler. For some extra safety, we check that webSocketClosed has been
// set true when the handler returns.
auto deferNoteClosed = kj::defer([this]() { webSocketClosed = true; });
kj::Own<kj::AsyncIoStream> ownStream(&stream, kj::NullDisposer::instance);
return upgradeToWebSocket(ownStream.attach(kj::mv(deferNoteClosed)),
httpInput, httpOutput, nullptr);
}
kj::Promise<bool> sendError(HttpHeaders::ProtocolError protocolError) {
closeAfterSend = true;
// Client protocol errors always happen on request headers parsing, before we call into the
// HttpService, meaning no response has been sent and we can provide a Response object.
auto promise = server.settings.errorHandler.orDefault(*this).handleClientProtocolError(
kj::mv(protocolError), *this);
return promise.then([this]() { return httpOutput.flush(); })
.then([]() { return false; }); // loop ends after flush
}
kj::Promise<bool> sendError(kj::Exception&& exception) {
closeAfterSend = true;
// We only provide the Response object if we know we haven't already sent a response.
auto promise = server.settings.errorHandler.orDefault(*this).handleApplicationError(
kj::mv(exception), currentMethod.map([this](auto&&) -> Response& { return *this; }));
return promise.then([this]() { return httpOutput.flush(); })
.then([]() { return false; }); // loop ends after flush
}
kj::Promise<bool> sendError() {
closeAfterSend = true;
// We can provide a Response object, since none has already been sent.
auto promise = server.settings.errorHandler.orDefault(*this).handleNoResponse(*this);
return promise.then([this]() { return httpOutput.flush(); })
.then([]() { return false; }); // loop ends after flush
}
kj::Own<WebSocket> sendWebSocketError(StringPtr errorMessage) {
kj::Exception exception = KJ_EXCEPTION(FAILED,
"received bad WebSocket handshake", errorMessage);
webSocketError = sendError(
HttpHeaders::ProtocolError { 400, "Bad Request", errorMessage, nullptr });
kj::throwRecoverableException(kj::mv(exception));
// Fallback path when exceptions are disabled.
class BrokenWebSocket final: public WebSocket {
public:
BrokenWebSocket(kj::Exception exception): exception(kj::mv(exception)) {}
kj::Promise<void> send(kj::ArrayPtr<const byte> message) override {
return kj::cp(exception);
}
kj::Promise<void> send(kj::ArrayPtr<const char> message) override {
return kj::cp(exception);
}
kj::Promise<void> close(uint16_t code, kj::StringPtr reason) override {
return kj::cp(exception);
}
kj::Promise<void> disconnect() override {
return kj::cp(exception);
}
void abort() override {
kj::throwRecoverableException(kj::cp(exception));
}
kj::Promise<void> whenAborted() override {
return kj::cp(exception);
}
kj::Promise<Message> receive(size_t maxSize) override {
return kj::cp(exception);
}
uint64_t sentByteCount() override { KJ_FAIL_ASSERT("received bad WebSocket handshake"); }
uint64_t receivedByteCount() override { KJ_FAIL_ASSERT("received bad WebSocket handshake"); }
private:
kj::Exception exception;
};
return kj::heap<BrokenWebSocket>(KJ_EXCEPTION(FAILED,
"received bad WebSocket handshake", errorMessage));
}
};
HttpServer::HttpServer(kj::Timer& timer, const HttpHeaderTable& requestHeaderTable,
HttpService& service, Settings settings)
: HttpServer(timer, requestHeaderTable, &service, settings,
kj::newPromiseAndFulfiller<void>()) {}
HttpServer::HttpServer(kj::Timer& timer, const HttpHeaderTable& requestHeaderTable,
HttpServiceFactory serviceFactory, Settings settings)
: HttpServer(timer, requestHeaderTable, kj::mv(serviceFactory), settings,
kj::newPromiseAndFulfiller<void>()) {}
HttpServer::HttpServer(kj::Timer& timer, const HttpHeaderTable& requestHeaderTable,
kj::OneOf<HttpService*, HttpServiceFactory> service,
Settings settings, kj::PromiseFulfillerPair<void> paf)
: timer(timer), requestHeaderTable(requestHeaderTable), service(kj::mv(service)),
settings(settings), onDrain(paf.promise.fork()), drainFulfiller(kj::mv(paf.fulfiller)),
tasks(*this) {}
kj::Promise<void> HttpServer::drain() {
KJ_REQUIRE(!draining, "you can only call drain() once");
draining = true;
drainFulfiller->fulfill();
if (connectionCount == 0) {
return kj::READY_NOW;
} else {
auto paf = kj::newPromiseAndFulfiller<void>();
zeroConnectionsFulfiller = kj::mv(paf.fulfiller);
return kj::mv(paf.promise);
}
}
kj::Promise<void> HttpServer::listenHttp(kj::ConnectionReceiver& port) {
return listenLoop(port).exclusiveJoin(onDrain.addBranch());
}
kj::Promise<void> HttpServer::listenLoop(kj::ConnectionReceiver& port) {
return port.accept()
.then([this,&port](kj::Own<kj::AsyncIoStream>&& connection) -> kj::Promise<void> {
if (draining) {
// Can get here if we *just* started draining.
return kj::READY_NOW;
}
tasks.add(listenHttp(kj::mv(connection)));
return listenLoop(port);
});
}
kj::Promise<void> HttpServer::listenHttp(kj::Own<kj::AsyncIoStream> connection) {
auto promise = listenHttpCleanDrain(*connection).ignoreResult();
// eagerlyEvaluate() to maintain historical guarantee that this method eagerly closes the
// connection when done.
return promise.attach(kj::mv(connection)).eagerlyEvaluate(nullptr);
}
kj::Promise<bool> HttpServer::listenHttpCleanDrain(kj::AsyncIoStream& connection) {
kj::Own<HttpService> srv;
KJ_SWITCH_ONEOF(service) {
KJ_CASE_ONEOF(ptr, HttpService*) {
// Fake Own okay because we can assume the HttpService outlives this HttpServer, and we can
// assume `this` HttpServer outlives the returned `listenHttpCleanDrain()` promise, which will
// own the fake Own.
srv = kj::Own<HttpService>(ptr, kj::NullDisposer::instance);
}
KJ_CASE_ONEOF(func, HttpServiceFactory) {
srv = func(connection);
}
}
KJ_ASSERT_NONNULL(srv.get());
return listenHttpCleanDrain(connection, [srv = kj::mv(srv)](SuspendableRequest&) mutable {
// This factory function will be owned by the Connection object, meaning the Connection object
// will own the HttpService. We also know that the Connection object outlives all
// service.request() promises (service.request() is called from a Connection member function).
// The Owns we return from this function are attached to the service.request() promises,
// meaning this factory function will outlive all Owns we return. So, it's safe to return a fake
// Own.
return kj::Own<HttpService>(srv.get(), kj::NullDisposer::instance);
});
}
kj::Promise<bool> HttpServer::listenHttpCleanDrain(kj::AsyncIoStream& connection,
SuspendableHttpServiceFactory factory,
kj::Maybe<SuspendedRequest> suspendedRequest) {
auto obj = heap<Connection>(*this, connection, kj::mv(factory), kj::mv(suspendedRequest));
// Start reading requests and responding to them, but immediately cancel processing if the client
// disconnects.
auto promise = obj->startLoop(true)
.exclusiveJoin(connection.whenWriteDisconnected().then([]() {return false;}));
// Eagerly evaluate so that we drop the connection when the promise resolves, even if the caller
// doesn't eagerly evaluate.
return promise.attach(kj::mv(obj)).eagerlyEvaluate(nullptr);
}
namespace {
void defaultHandleListenLoopException(kj::Exception&& exception) {
KJ_LOG(ERROR, "unhandled exception in HTTP server", exception);
}
} // namespace
void HttpServer::taskFailed(kj::Exception&& exception) {
KJ_IF_MAYBE(handler, settings.errorHandler) {
handler->handleListenLoopException(kj::mv(exception));
} else {
defaultHandleListenLoopException(kj::mv(exception));
}
}
HttpServer::SuspendedRequest::SuspendedRequest(
kj::Array<byte> bufferParam, kj::ArrayPtr<byte> leftoverParam, HttpMethod method,
kj::StringPtr url, HttpHeaders headers)
: buffer(kj::mv(bufferParam)),
leftover(leftoverParam),
method(method),
url(url),
headers(kj::mv(headers)) {
if (leftover.size() > 0) {
// We have a `leftover`; make sure it is a slice of `buffer`.
KJ_ASSERT(leftover.begin() >= buffer.begin() && leftover.begin() <= buffer.end());
KJ_ASSERT(leftover.end() >= buffer.begin() && leftover.end() <= buffer.end());
} else {
// We have no `leftover`, but we still expect it to point into `buffer` somewhere. This is
// important so that `messageHeaderEnd` is initialized correctly in HttpInputStreamImpl's
// constructor.
KJ_ASSERT(leftover.begin() >= buffer.begin() && leftover.begin() <= buffer.end());
}
}
HttpServer::SuspendedRequest HttpServer::SuspendableRequest::suspend() {
return connection.suspend(*this);
}
kj::Promise<void> HttpServerErrorHandler::handleClientProtocolError(
HttpHeaders::ProtocolError protocolError, kj::HttpService::Response& response) {
// Default error handler implementation.
HttpHeaderTable headerTable {};
HttpHeaders headers(headerTable);
headers.set(HttpHeaderId::CONTENT_TYPE, "text/plain");
auto errorMessage = kj::str("ERROR: ", protocolError.description);
auto body = response.send(protocolError.statusCode, protocolError.statusMessage,
headers, errorMessage.size());
return body->write(errorMessage.begin(), errorMessage.size())
.attach(kj::mv(errorMessage), kj::mv(body));
}
kj::Promise<void> HttpServerErrorHandler::handleApplicationError(
kj::Exception exception, kj::Maybe<kj::HttpService::Response&> response) {
// Default error handler implementation.
if (exception.getType() == kj::Exception::Type::DISCONNECTED) {
// How do we tell an HTTP client that there was a transient network error, and it should
// try again immediately? There's no HTTP status code for this (503 is meant for "try
// again later, not now"). Here's an idea: Don't send any response; just close the
// connection, so that it looks like the connection between the HTTP client and server
// was dropped. A good client should treat this exactly the way we want.
//
// We also bail here to avoid logging the disconnection, which isn't very interesting.
return kj::READY_NOW;
}
KJ_IF_MAYBE(r, response) {
KJ_LOG(INFO, "threw exception while serving HTTP response", exception);
HttpHeaderTable headerTable {};
HttpHeaders headers(headerTable);
headers.set(HttpHeaderId::CONTENT_TYPE, "text/plain");
kj::String errorMessage;
kj::Own<AsyncOutputStream> body;
if (exception.getType() == kj::Exception::Type::OVERLOADED) {
errorMessage = kj::str(
"ERROR: The server is temporarily unable to handle your request. Details:\n\n", exception);
body = r->send(503, "Service Unavailable", headers, errorMessage.size());
} else if (exception.getType() == kj::Exception::Type::UNIMPLEMENTED) {
errorMessage = kj::str(
"ERROR: The server does not implement this operation. Details:\n\n", exception);
body = r->send(501, "Not Implemented", headers, errorMessage.size());
} else {
errorMessage = kj::str(
"ERROR: The server threw an exception. Details:\n\n", exception);
body = r->send(500, "Internal Server Error", headers, errorMessage.size());
}
return body->write(errorMessage.begin(), errorMessage.size())
.attach(kj::mv(errorMessage), kj::mv(body));
}
KJ_LOG(ERROR, "HttpService threw exception after generating a partial response",
"too late to report error to client", exception);
return kj::READY_NOW;
}
void HttpServerErrorHandler::handleListenLoopException(kj::Exception&& exception) {
defaultHandleListenLoopException(kj::mv(exception));
}
kj::Promise<void> HttpServerErrorHandler::handleNoResponse(kj::HttpService::Response& response) {
HttpHeaderTable headerTable {};
HttpHeaders headers(headerTable);
headers.set(HttpHeaderId::CONTENT_TYPE, "text/plain");
constexpr auto errorMessage = "ERROR: The HttpService did not generate a response."_kj;
auto body = response.send(500, "Internal Server Error", headers, errorMessage.size());
return body->write(errorMessage.begin(), errorMessage.size()).attach(kj::mv(body));
}
} // namespace kj