blob: 6d658a8f5ebf10d3f2a3465f01fd29d441cba140 [file] [log] [blame]
// Copyright (c) 2016 Sandstorm Development Group, Inc. and contributors
// Licensed under the MIT License:
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#if KJ_HAS_OPENSSL
#include "tls.h"
#include "readiness-io.h"
#include <openssl/bio.h>
#include <openssl/conf.h>
#include <openssl/err.h>
#include <openssl/evp.h>
#include <openssl/ssl.h>
#include <openssl/tls1.h>
#include <openssl/x509.h>
#include <openssl/x509v3.h>
#include <kj/async-queue.h>
#include <kj/debug.h>
#include <kj/vector.h>
#if OPENSSL_VERSION_NUMBER < 0x10100000L
#define BIO_set_init(x,v) (x->init=v)
#define BIO_get_data(x) (x->ptr)
#define BIO_set_data(x,v) (x->ptr=v)
#endif
namespace kj {
// =======================================================================================
// misc helpers
namespace {
KJ_NORETURN(void throwOpensslError());
void throwOpensslError() {
// Call when an OpenSSL function returns an error code to convert that into an exception and
// throw it.
kj::Vector<kj::String> lines;
while (unsigned long long error = ERR_get_error()) {
char message[1024];
ERR_error_string_n(error, message, sizeof(message));
lines.add(kj::heapString(message));
}
kj::String message = kj::strArray(lines, "\n");
KJ_FAIL_ASSERT("OpenSSL error", message);
}
#if OPENSSL_VERSION_NUMBER < 0x10100000L && !defined(OPENSSL_IS_BORINGSSL)
// Older versions of OpenSSL don't define _up_ref() functions.
void EVP_PKEY_up_ref(EVP_PKEY* pkey) {
CRYPTO_add(&pkey->references, 1, CRYPTO_LOCK_EVP_PKEY);
}
void X509_up_ref(X509* x509) {
CRYPTO_add(&x509->references, 1, CRYPTO_LOCK_X509);
}
#endif
#if OPENSSL_VERSION_NUMBER < 0x10100000L
class OpenSslInit {
// Initializes the OpenSSL library.
public:
OpenSslInit() {
SSL_library_init();
SSL_load_error_strings();
OPENSSL_config(nullptr);
}
};
void ensureOpenSslInitialized() {
// Initializes the OpenSSL library the first time it is called.
static OpenSslInit init;
}
#else
inline void ensureOpenSslInitialized() {
// As of 1.1.0, no initialization is needed.
}
#endif
} // namespace
// =======================================================================================
// Implementation of kj::AsyncIoStream that applies TLS on top of some other AsyncIoStream.
//
// TODO(perf): OpenSSL's I/O abstraction layer, "BIO", is readiness-based, but AsyncIoStream is
// completion-based. This forces us to use an intermediate buffer which wastes memory and incurs
// redundant copies. We could improve the situation by creating a way to detect if the underlying
// AsyncIoStream is simply wrapping a file descriptor (or other readiness-based stream?) and use
// that directly if so.
class TlsConnection final: public kj::AsyncIoStream {
public:
TlsConnection(kj::Own<kj::AsyncIoStream> stream, SSL_CTX* ctx)
: TlsConnection(*stream, ctx) {
ownInner = kj::mv(stream);
}
TlsConnection(kj::AsyncIoStream& stream, SSL_CTX* ctx)
: inner(stream), readBuffer(stream), writeBuffer(stream) {
ssl = SSL_new(ctx);
if (ssl == nullptr) {
throwOpensslError();
}
BIO* bio = BIO_new(const_cast<BIO_METHOD*>(getBioVtable()));
if (bio == nullptr) {
SSL_free(ssl);
throwOpensslError();
}
BIO_set_data(bio, this);
BIO_set_init(bio, 1);
SSL_set_bio(ssl, bio, bio);
}
kj::Promise<void> connect(kj::StringPtr expectedServerHostname) {
if (!SSL_set_tlsext_host_name(ssl, expectedServerHostname.cStr())) {
throwOpensslError();
}
X509_VERIFY_PARAM* verify = SSL_get0_param(ssl);
if (verify == nullptr) {
throwOpensslError();
}
if (X509_VERIFY_PARAM_set1_host(
verify, expectedServerHostname.cStr(), expectedServerHostname.size()) <= 0) {
throwOpensslError();
}
// As of OpenSSL 1.1.0, X509_V_FLAG_TRUSTED_FIRST is on by default. Turning it on for older
// versions -- as well as certain OpenSSL-compatible libraries -- fixes the problem described
// here: https://community.letsencrypt.org/t/openssl-client-compatibility-changes-for-let-s-encrypt-certificates/143816
//
// Otherwise, certificates issued by Let's Encrypt won't work as of September 30, 2021:
// https://letsencrypt.org/docs/dst-root-ca-x3-expiration-september-2021/
X509_VERIFY_PARAM_set_flags(verify, X509_V_FLAG_TRUSTED_FIRST);
return sslCall([this]() { return SSL_connect(ssl); }).then([this](size_t) {
X509* cert = SSL_get_peer_certificate(ssl);
KJ_REQUIRE(cert != nullptr, "TLS peer provided no certificate");
X509_free(cert);
auto result = SSL_get_verify_result(ssl);
if (result != X509_V_OK) {
const char* reason = X509_verify_cert_error_string(result);
KJ_FAIL_REQUIRE("TLS peer's certificate is not trusted", reason);
}
});
}
kj::Promise<void> accept() {
// We are the server. Set SSL options to prefer server's cipher choice.
SSL_set_options(ssl, SSL_OP_CIPHER_SERVER_PREFERENCE);
auto acceptPromise = sslCall([this]() {
return SSL_accept(ssl);
});
return acceptPromise.then([](size_t ret) {
if (ret == 0) {
kj::throwRecoverableException(
KJ_EXCEPTION(DISCONNECTED, "Client disconnected during SSL_accept()"));
}
});
}
kj::Own<TlsPeerIdentity> getIdentity(kj::Own<kj::PeerIdentity> inner) {
return kj::heap<TlsPeerIdentity>(SSL_get_peer_certificate(ssl), kj::mv(inner),
kj::Badge<TlsConnection>());
}
~TlsConnection() noexcept(false) {
SSL_free(ssl);
}
kj::Promise<size_t> tryRead(void* buffer, size_t minBytes, size_t maxBytes) override {
return tryReadInternal(buffer, minBytes, maxBytes, 0);
}
Promise<void> write(const void* buffer, size_t size) override {
return writeInternal(kj::arrayPtr(reinterpret_cast<const byte*>(buffer), size), nullptr);
}
Promise<void> write(ArrayPtr<const ArrayPtr<const byte>> pieces) override {
auto cork = writeBuffer.cork();
return writeInternal(pieces[0], pieces.slice(1, pieces.size())).attach(kj::mv(cork));
}
Promise<void> whenWriteDisconnected() override {
return inner.whenWriteDisconnected();
}
void shutdownWrite() override {
KJ_REQUIRE(shutdownTask == nullptr, "already called shutdownWrite()");
// TODO(0.10): shutdownWrite() is problematic because it doesn't return a promise. It was
// designed to assume that it would only be called after all writes are finished and that
// there was no reason to block at that point, but SSL sessions don't fit this since they
// actually have to send a shutdown message.
shutdownTask = sslCall([this]() {
// The first SSL_shutdown() call is expected to return 0 and may flag a misleading error.
int result = SSL_shutdown(ssl);
return result == 0 ? 1 : result;
}).ignoreResult().eagerlyEvaluate([](kj::Exception&& e) {
KJ_LOG(ERROR, e);
});
}
void abortRead() override {
inner.abortRead();
}
void getsockopt(int level, int option, void* value, uint* length) override {
inner.getsockopt(level, option, value, length);
}
void setsockopt(int level, int option, const void* value, uint length) override {
inner.setsockopt(level, option, value, length);
}
void getsockname(struct sockaddr* addr, uint* length) override {
inner.getsockname(addr, length);
}
void getpeername(struct sockaddr* addr, uint* length) override {
inner.getpeername(addr, length);
}
kj::Maybe<int> getFd() const override {
return inner.getFd();
}
private:
SSL* ssl;
kj::AsyncIoStream& inner;
kj::Own<kj::AsyncIoStream> ownInner;
bool disconnected = false;
kj::Maybe<kj::Promise<void>> shutdownTask;
ReadyInputStreamWrapper readBuffer;
ReadyOutputStreamWrapper writeBuffer;
kj::Promise<size_t> tryReadInternal(
void* buffer, size_t minBytes, size_t maxBytes, size_t alreadyDone) {
if (disconnected) return alreadyDone;
return sslCall([this,buffer,maxBytes]() { return SSL_read(ssl, buffer, maxBytes); })
.then([this,buffer,minBytes,maxBytes,alreadyDone](size_t n) -> kj::Promise<size_t> {
if (n >= minBytes || n == 0) {
return alreadyDone + n;
} else {
return tryReadInternal(reinterpret_cast<byte*>(buffer) + n,
minBytes - n, maxBytes - n, alreadyDone + n);
}
});
}
Promise<void> writeInternal(kj::ArrayPtr<const byte> first,
kj::ArrayPtr<const kj::ArrayPtr<const byte>> rest) {
KJ_REQUIRE(shutdownTask == nullptr, "already called shutdownWrite()");
// SSL_write() with a zero-sized input returns 0, but a 0 return is documented as indicating
// an error. So, we need to avoid zero-sized writes entirely.
while (first.size() == 0) {
if (rest.size() == 0) {
return kj::READY_NOW;
}
first = rest.front();
rest = rest.slice(1, rest.size());
}
return sslCall([this,first]() { return SSL_write(ssl, first.begin(), first.size()); })
.then([this,first,rest](size_t n) -> kj::Promise<void> {
if (n == 0) {
return KJ_EXCEPTION(DISCONNECTED, "ssl connection ended during write");
} else if (n < first.size()) {
return writeInternal(first.slice(n, first.size()), rest);
} else if (rest.size() > 0) {
return writeInternal(rest[0], rest.slice(1, rest.size()));
} else {
return kj::READY_NOW;
}
});
}
template <typename Func>
kj::Promise<size_t> sslCall(Func&& func) {
if (disconnected) return size_t(0);
auto result = func();
if (result > 0) {
return result;
} else {
int error = SSL_get_error(ssl, result);
switch (error) {
case SSL_ERROR_ZERO_RETURN:
disconnected = true;
return size_t(0);
case SSL_ERROR_WANT_READ:
return readBuffer.whenReady().then(kj::mvCapture(func,
[this](Func&& func) mutable { return sslCall(kj::fwd<Func>(func)); }));
case SSL_ERROR_WANT_WRITE:
return writeBuffer.whenReady().then(kj::mvCapture(func,
[this](Func&& func) mutable { return sslCall(kj::fwd<Func>(func)); }));
case SSL_ERROR_SSL:
throwOpensslError();
case SSL_ERROR_SYSCALL:
if (result == 0) {
disconnected = true;
return size_t(0);
} else {
// According to documentation we shouldn't get here, because our BIO never returns an
// "error". But in practice we do get here sometimes when the peer disconnects
// prematurely.
return KJ_EXCEPTION(DISCONNECTED, "SSL unable to continue I/O");
}
default:
KJ_FAIL_ASSERT("unexpected SSL error code", error);
}
}
}
static int bioRead(BIO* b, char* out, int outl) {
BIO_clear_retry_flags(b);
KJ_IF_MAYBE(n, reinterpret_cast<TlsConnection*>(BIO_get_data(b))->readBuffer
.read(kj::arrayPtr(out, outl).asBytes())) {
return *n;
} else {
BIO_set_retry_read(b);
return -1;
}
}
static int bioWrite(BIO* b, const char* in, int inl) {
BIO_clear_retry_flags(b);
KJ_IF_MAYBE(n, reinterpret_cast<TlsConnection*>(BIO_get_data(b))->writeBuffer
.write(kj::arrayPtr(in, inl).asBytes())) {
return *n;
} else {
BIO_set_retry_write(b);
return -1;
}
}
static long bioCtrl(BIO* b, int cmd, long num, void* ptr) {
switch (cmd) {
case BIO_CTRL_FLUSH:
return 1;
case BIO_CTRL_PUSH:
case BIO_CTRL_POP:
// Informational?
return 0;
default:
KJ_LOG(WARNING, "unimplemented bio_ctrl", cmd);
return 0;
}
}
static int bioCreate(BIO* b) {
BIO_set_data(b, nullptr);
return 1;
}
static int bioDestroy(BIO* b) {
// The BIO does NOT own the TlsConnection.
return 1;
}
#if OPENSSL_VERSION_NUMBER < 0x10100000L
static const BIO_METHOD* getBioVtable() {
static const BIO_METHOD VTABLE {
BIO_TYPE_SOURCE_SINK,
"KJ stream",
TlsConnection::bioWrite,
TlsConnection::bioRead,
nullptr, // puts
nullptr, // gets
TlsConnection::bioCtrl,
TlsConnection::bioCreate,
TlsConnection::bioDestroy,
nullptr
};
return &VTABLE;
}
#else
static const BIO_METHOD* getBioVtable() {
static const BIO_METHOD* const vtable = makeBioVtable();
return vtable;
}
static const BIO_METHOD* makeBioVtable() {
BIO_METHOD* vtable = BIO_meth_new(BIO_TYPE_SOURCE_SINK, "KJ stream");
BIO_meth_set_write(vtable, TlsConnection::bioWrite);
BIO_meth_set_read(vtable, TlsConnection::bioRead);
BIO_meth_set_ctrl(vtable, TlsConnection::bioCtrl);
BIO_meth_set_create(vtable, TlsConnection::bioCreate);
BIO_meth_set_destroy(vtable, TlsConnection::bioDestroy);
return vtable;
}
#endif
};
// =======================================================================================
// Implementations of ConnectionReceiver, NetworkAddress, and Network as wrappers adding TLS.
class TlsConnectionReceiver final: public ConnectionReceiver, public TaskSet::ErrorHandler {
public:
TlsConnectionReceiver(
TlsContext &tls, Own<ConnectionReceiver> inner,
kj::Maybe<TlsErrorHandler> acceptErrorHandler)
: tls(tls), inner(kj::mv(inner)),
acceptLoopTask(acceptLoop().eagerlyEvaluate([this](Exception &&e) {
onAcceptFailure(kj::mv(e));
})),
acceptErrorHandler(kj::mv(acceptErrorHandler)),
tasks(*this) {}
void taskFailed(Exception&& e) override {
KJ_IF_MAYBE(handler, acceptErrorHandler){
handler->operator()(kj::mv(e));
} else if (e.getType() != Exception::Type::DISCONNECTED) {
KJ_LOG(ERROR, "error accepting tls connection", kj::mv(e));
}
};
Promise<Own<AsyncIoStream>> accept() override {
return acceptAuthenticated().then([](AuthenticatedStream&& stream) {
return kj::mv(stream.stream);
});
}
Promise<AuthenticatedStream> acceptAuthenticated() override {
KJ_IF_MAYBE(e, maybeInnerException) {
// We've experienced an exception from the inner receiver, we consider this unrecoverable.
return Exception(*e);
}
return queue.pop();
}
uint getPort() override {
return inner->getPort();
}
void getsockopt(int level, int option, void* value, uint* length) override {
return inner->getsockopt(level, option, value, length);
}
void setsockopt(int level, int option, const void* value, uint length) override {
return inner->setsockopt(level, option, value, length);
}
private:
void onAcceptSuccess(AuthenticatedStream&& stream) {
// Queue this stream to go through SSL_accept.
auto acceptPromise = kj::evalNow([&] {
// Do the SSL acceptance procedure.
return tls.wrapServer(kj::mv(stream));
});
auto sslPromise = acceptPromise.then([this](auto&& stream) -> Promise<void> {
// This is only attached to the success path, thus the error handler will catch if our
// promise fails.
queue.push(kj::mv(stream));
return kj::READY_NOW;
});
tasks.add(kj::mv(sslPromise));
}
void onAcceptFailure(Exception&& e) {
// Store this exception to reject all future calls to accept() and reject any unfulfilled
// promises from the queue.
maybeInnerException = kj::mv(e);
queue.rejectAll(Exception(KJ_REQUIRE_NONNULL(maybeInnerException)));
}
Promise<void> acceptLoop() {
// Accept one connection and queue up the next accept on our TaskSet.
return inner->acceptAuthenticated().then(
[this](AuthenticatedStream&& stream) {
onAcceptSuccess(kj::mv(stream));
// Queue up the next accept loop immediately without waiting for SSL_accept()/wrapServer().
return acceptLoop();
});
}
TlsContext& tls;
Own<ConnectionReceiver> inner;
Promise<void> acceptLoopTask;
ProducerConsumerQueue<AuthenticatedStream> queue;
kj::Maybe<TlsErrorHandler> acceptErrorHandler;
TaskSet tasks;
Maybe<Exception> maybeInnerException;
};
class TlsNetworkAddress final: public kj::NetworkAddress {
public:
TlsNetworkAddress(TlsContext& tls, kj::String hostname, kj::Own<kj::NetworkAddress>&& inner)
: tls(tls), hostname(kj::mv(hostname)), inner(kj::mv(inner)) {}
Promise<Own<AsyncIoStream>> connect() override {
// Note: It's unfortunately pretty common for people to assume they can drop the NetworkAddress
// as soon as connect() returns, and this works with the native network implementation.
// So, we make some copies here.
auto& tlsRef = tls;
auto hostnameCopy = kj::str(hostname);
return inner->connect().then(kj::mvCapture(hostnameCopy,
[&tlsRef](kj::String&& hostname, Own<AsyncIoStream>&& stream) {
return tlsRef.wrapClient(kj::mv(stream), hostname);
}));
}
Promise<kj::AuthenticatedStream> connectAuthenticated() override {
// Note: It's unfortunately pretty common for people to assume they can drop the NetworkAddress
// as soon as connect() returns, and this works with the native network implementation.
// So, we make some copies here.
auto& tlsRef = tls;
auto hostnameCopy = kj::str(hostname);
return inner->connectAuthenticated().then(
[&tlsRef, hostname = kj::mv(hostnameCopy)](kj::AuthenticatedStream stream) {
return tlsRef.wrapClient(kj::mv(stream), hostname);
});
}
Own<ConnectionReceiver> listen() override {
return tls.wrapPort(inner->listen());
}
Own<NetworkAddress> clone() override {
return kj::heap<TlsNetworkAddress>(tls, kj::str(hostname), inner->clone());
}
String toString() override {
return kj::str("tls:", inner->toString());
}
private:
TlsContext& tls;
kj::String hostname;
kj::Own<kj::NetworkAddress> inner;
};
class TlsNetwork final: public kj::Network {
public:
TlsNetwork(TlsContext& tls, kj::Network& inner): tls(tls), inner(inner) {}
TlsNetwork(TlsContext& tls, kj::Own<kj::Network> inner)
: tls(tls), inner(*inner), ownInner(kj::mv(inner)) {}
Promise<Own<NetworkAddress>> parseAddress(StringPtr addr, uint portHint) override {
kj::String hostname;
KJ_IF_MAYBE(pos, addr.findFirst(':')) {
hostname = kj::heapString(addr.slice(0, *pos));
} else {
hostname = kj::heapString(addr);
}
return inner.parseAddress(addr, portHint)
.then(kj::mvCapture(hostname, [this](kj::String&& hostname, kj::Own<NetworkAddress>&& addr)
-> kj::Own<kj::NetworkAddress> {
return kj::heap<TlsNetworkAddress>(tls, kj::mv(hostname), kj::mv(addr));
}));
}
Own<NetworkAddress> getSockaddr(const void* sockaddr, uint len) override {
KJ_UNIMPLEMENTED("TLS does not implement getSockaddr() because it needs to know hostnames");
}
Own<Network> restrictPeers(
kj::ArrayPtr<const kj::StringPtr> allow,
kj::ArrayPtr<const kj::StringPtr> deny = nullptr) override {
// TODO(someday): Maybe we could implement the ability to specify CA or hostname restrictions?
// Or is it better to let people do that via the TlsContext? A neat thing about
// restrictPeers() is that it's easy to make user-configurable.
return kj::heap<TlsNetwork>(tls, inner.restrictPeers(allow, deny));
}
private:
TlsContext& tls;
kj::Network& inner;
kj::Own<kj::Network> ownInner;
};
// =======================================================================================
// class TlsContext
TlsContext::Options::Options()
: useSystemTrustStore(true),
verifyClients(false),
minVersion(TlsVersion::TLS_1_2),
cipherList("ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305") {}
// Cipher list is Mozilla's "intermediate" list, except with classic DH removed since we don't
// currently support setting dhparams. See:
// https://mozilla.github.io/server-side-tls/ssl-config-generator/
//
// Classic DH is arguably obsolete and will only become more so as time passes, so perhaps we'll
// never bother.
struct TlsContext::SniCallback {
// struct SniCallback exists only so that callback() can be declared in the .c++ file, since it
// references OpenSSL types.
static int callback(SSL* ssl, int* ad, void* arg);
};
TlsContext::TlsContext(Options options) {
ensureOpenSslInitialized();
#if OPENSSL_VERSION_NUMBER >= 0x10100000L || defined(OPENSSL_IS_BORINGSSL)
SSL_CTX* ctx = SSL_CTX_new(TLS_method());
#else
SSL_CTX* ctx = SSL_CTX_new(SSLv23_method());
#endif
if (ctx == nullptr) {
throwOpensslError();
}
KJ_ON_SCOPE_FAILURE(SSL_CTX_free(ctx));
// honor options.useSystemTrustStore
if (options.useSystemTrustStore) {
if (!SSL_CTX_set_default_verify_paths(ctx)) {
throwOpensslError();
}
}
// honor options.trustedCertificates
if (options.trustedCertificates.size() > 0) {
X509_STORE* store = SSL_CTX_get_cert_store(ctx);
if (store == nullptr) {
throwOpensslError();
}
for (auto& cert: options.trustedCertificates) {
if (!X509_STORE_add_cert(store, reinterpret_cast<X509*>(cert.chain[0]))) {
throwOpensslError();
}
}
}
if (options.verifyClients) {
SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, NULL);
}
// honor options.minVersion
long optionFlags = 0;
if (options.minVersion > TlsVersion::SSL_3) {
optionFlags |= SSL_OP_NO_SSLv3;
}
if (options.minVersion > TlsVersion::TLS_1_0) {
optionFlags |= SSL_OP_NO_TLSv1;
}
if (options.minVersion > TlsVersion::TLS_1_1) {
optionFlags |= SSL_OP_NO_TLSv1_1;
}
if (options.minVersion > TlsVersion::TLS_1_2) {
optionFlags |= SSL_OP_NO_TLSv1_2;
}
SSL_CTX_set_options(ctx, optionFlags); // note: never fails; returns new options bitmask
// honor options.cipherList
if (!SSL_CTX_set_cipher_list(ctx, options.cipherList.cStr())) {
throwOpensslError();
}
// honor options.defaultKeypair
KJ_IF_MAYBE(kp, options.defaultKeypair) {
if (!SSL_CTX_use_PrivateKey(ctx, reinterpret_cast<EVP_PKEY*>(kp->privateKey.pkey))) {
throwOpensslError();
}
if (!SSL_CTX_use_certificate(ctx, reinterpret_cast<X509*>(kp->certificate.chain[0]))) {
throwOpensslError();
}
for (size_t i = 1; i < kj::size(kp->certificate.chain); i++) {
X509* x509 = reinterpret_cast<X509*>(kp->certificate.chain[i]);
if (x509 == nullptr) break; // end of chain
if (!SSL_CTX_add_extra_chain_cert(ctx, x509)) {
throwOpensslError();
}
// SSL_CTX_add_extra_chain_cert() does NOT up the refcount itself.
X509_up_ref(x509);
}
}
// honor options.sniCallback
KJ_IF_MAYBE(sni, options.sniCallback) {
SSL_CTX_set_tlsext_servername_callback(ctx, &SniCallback::callback);
SSL_CTX_set_tlsext_servername_arg(ctx, sni);
}
KJ_IF_MAYBE(timeout, options.acceptTimeout) {
this->timer = KJ_REQUIRE_NONNULL(options.timer,
"acceptTimeout option requires that a timer is also provided");
this->acceptTimeout = *timeout;
}
this->acceptErrorHandler = kj::mv(options.acceptErrorHandler);
this->ctx = ctx;
}
int TlsContext::SniCallback::callback(SSL* ssl, int* ad, void* arg) {
// The third parameter is actually type TlsSniCallback*.
KJ_IF_MAYBE(exception, kj::runCatchingExceptions([&]() {
TlsSniCallback& sni = *reinterpret_cast<TlsSniCallback*>(arg);
const char* name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
if (name != nullptr) {
KJ_IF_MAYBE(kp, sni.getKey(name)) {
if (!SSL_use_PrivateKey(ssl, reinterpret_cast<EVP_PKEY*>(kp->privateKey.pkey))) {
throwOpensslError();
}
if (!SSL_use_certificate(ssl, reinterpret_cast<X509*>(kp->certificate.chain[0]))) {
throwOpensslError();
}
if (!SSL_clear_chain_certs(ssl)) {
throwOpensslError();
}
for (size_t i = 1; i < kj::size(kp->certificate.chain); i++) {
X509* x509 = reinterpret_cast<X509*>(kp->certificate.chain[i]);
if (x509 == nullptr) break; // end of chain
if (!SSL_add0_chain_cert(ssl, x509)) {
throwOpensslError();
}
// SSL_add0_chain_cert() does NOT up the refcount itself.
X509_up_ref(x509);
}
}
}
})) {
KJ_LOG(ERROR, "exception when invoking SNI callback", *exception);
*ad = SSL_AD_INTERNAL_ERROR;
return SSL_TLSEXT_ERR_ALERT_FATAL;
}
return SSL_TLSEXT_ERR_OK;
}
TlsContext::~TlsContext() noexcept(false) {
SSL_CTX_free(reinterpret_cast<SSL_CTX*>(ctx));
}
kj::Promise<kj::Own<kj::AsyncIoStream>> TlsContext::wrapClient(
kj::Own<kj::AsyncIoStream> stream, kj::StringPtr expectedServerHostname) {
auto conn = kj::heap<TlsConnection>(kj::mv(stream), reinterpret_cast<SSL_CTX*>(ctx));
auto promise = conn->connect(expectedServerHostname);
return promise.then(kj::mvCapture(conn, [](kj::Own<TlsConnection> conn)
-> kj::Own<kj::AsyncIoStream> {
return kj::mv(conn);
}));
}
kj::Promise<kj::Own<kj::AsyncIoStream>> TlsContext::wrapServer(kj::Own<kj::AsyncIoStream> stream) {
auto conn = kj::heap<TlsConnection>(kj::mv(stream), reinterpret_cast<SSL_CTX*>(ctx));
auto promise = conn->accept();
KJ_IF_MAYBE(timeout, acceptTimeout) {
promise = KJ_REQUIRE_NONNULL(timer).afterDelay(*timeout).then([]() -> kj::Promise<void> {
return KJ_EXCEPTION(DISCONNECTED, "timed out waiting for client during TLS handshake");
}).exclusiveJoin(kj::mv(promise));
}
return promise.then(kj::mvCapture(conn, [](kj::Own<TlsConnection> conn)
-> kj::Own<kj::AsyncIoStream> {
return kj::mv(conn);
}));
}
kj::Promise<kj::AuthenticatedStream> TlsContext::wrapClient(
kj::AuthenticatedStream stream, kj::StringPtr expectedServerHostname) {
auto conn = kj::heap<TlsConnection>(kj::mv(stream.stream), reinterpret_cast<SSL_CTX*>(ctx));
auto promise = conn->connect(expectedServerHostname);
return promise.then([conn=kj::mv(conn),innerId=kj::mv(stream.peerIdentity)]() mutable {
auto id = conn->getIdentity(kj::mv(innerId));
return kj::AuthenticatedStream { kj::mv(conn), kj::mv(id) };
});
}
kj::Promise<kj::AuthenticatedStream> TlsContext::wrapServer(kj::AuthenticatedStream stream) {
auto conn = kj::heap<TlsConnection>(kj::mv(stream.stream), reinterpret_cast<SSL_CTX*>(ctx));
auto promise = conn->accept();
KJ_IF_MAYBE(timeout, acceptTimeout) {
promise = KJ_REQUIRE_NONNULL(timer).afterDelay(*timeout).then([]() -> kj::Promise<void> {
return KJ_EXCEPTION(DISCONNECTED, "timed out waiting for client during TLS handshake");
}).exclusiveJoin(kj::mv(promise));
}
return promise.then([conn=kj::mv(conn),innerId=kj::mv(stream.peerIdentity)]() mutable {
auto id = conn->getIdentity(kj::mv(innerId));
return kj::AuthenticatedStream { kj::mv(conn), kj::mv(id) };
});
}
kj::Own<kj::ConnectionReceiver> TlsContext::wrapPort(kj::Own<kj::ConnectionReceiver> port) {
auto handler = acceptErrorHandler.map([](TlsErrorHandler& handler) {
return handler.reference();
});
return kj::heap<TlsConnectionReceiver>(*this, kj::mv(port), kj::mv(handler));
}
kj::Own<kj::Network> TlsContext::wrapNetwork(kj::Network& network) {
return kj::heap<TlsNetwork>(*this, network);
}
// =======================================================================================
// class TlsPrivateKey
TlsPrivateKey::TlsPrivateKey(kj::ArrayPtr<const byte> asn1) {
ensureOpenSslInitialized();
const byte* ptr = asn1.begin();
pkey = d2i_AutoPrivateKey(nullptr, &ptr, asn1.size());
if (pkey == nullptr) {
throwOpensslError();
}
}
TlsPrivateKey::TlsPrivateKey(kj::StringPtr pem, kj::Maybe<kj::StringPtr> password) {
ensureOpenSslInitialized();
// const_cast apparently needed for older versions of OpenSSL.
BIO* bio = BIO_new_mem_buf(const_cast<char*>(pem.begin()), pem.size());
KJ_DEFER(BIO_free(bio));
pkey = PEM_read_bio_PrivateKey(bio, nullptr, &passwordCallback, &password);
if (pkey == nullptr) {
throwOpensslError();
}
}
TlsPrivateKey::TlsPrivateKey(const TlsPrivateKey& other)
: pkey(other.pkey) {
if (pkey != nullptr) EVP_PKEY_up_ref(reinterpret_cast<EVP_PKEY*>(pkey));
}
TlsPrivateKey& TlsPrivateKey::operator=(const TlsPrivateKey& other) {
if (pkey != other.pkey) {
EVP_PKEY_free(reinterpret_cast<EVP_PKEY*>(pkey));
pkey = other.pkey;
if (pkey != nullptr) EVP_PKEY_up_ref(reinterpret_cast<EVP_PKEY*>(pkey));
}
return *this;
}
TlsPrivateKey::~TlsPrivateKey() noexcept(false) {
EVP_PKEY_free(reinterpret_cast<EVP_PKEY*>(pkey));
}
int TlsPrivateKey::passwordCallback(char* buf, int size, int rwflag, void* u) {
auto& password = *reinterpret_cast<kj::Maybe<kj::StringPtr>*>(u);
KJ_IF_MAYBE(p, password) {
int result = kj::min(p->size(), size);
memcpy(buf, p->begin(), result);
return result;
} else {
return 0;
}
}
// =======================================================================================
// class TlsCertificate
TlsCertificate::TlsCertificate(kj::ArrayPtr<const kj::ArrayPtr<const byte>> asn1) {
ensureOpenSslInitialized();
KJ_REQUIRE(asn1.size() > 0, "must provide at least one certificate in chain");
KJ_REQUIRE(asn1.size() <= kj::size(chain),
"exceeded maximum certificate chain length of 10");
memset(chain, 0, sizeof(chain));
for (auto i: kj::indices(asn1)) {
auto p = asn1[i].begin();
// "_AUX" apparently refers to some auxiliary information that can be appended to the
// certificate, but should only be trusted for your own certificate, not the whole chain??
// I don't really know, I'm just cargo-culting.
chain[i] = i == 0 ? d2i_X509_AUX(nullptr, &p, asn1[i].size())
: d2i_X509(nullptr, &p, asn1[i].size());
if (chain[i] == nullptr) {
for (size_t j = 0; j < i; j++) {
X509_free(reinterpret_cast<X509*>(chain[j]));
}
throwOpensslError();
}
}
}
TlsCertificate::TlsCertificate(kj::ArrayPtr<const byte> asn1)
: TlsCertificate(kj::arrayPtr(&asn1, 1)) {}
TlsCertificate::TlsCertificate(kj::StringPtr pem) {
ensureOpenSslInitialized();
memset(chain, 0, sizeof(chain));
// const_cast apparently needed for older versions of OpenSSL.
BIO* bio = BIO_new_mem_buf(const_cast<char*>(pem.begin()), pem.size());
KJ_DEFER(BIO_free(bio));
for (auto i: kj::indices(chain)) {
// "_AUX" apparently refers to some auxiliary information that can be appended to the
// certificate, but should only be trusted for your own certificate, not the whole chain??
// I don't really know, I'm just cargo-culting.
chain[i] = i == 0 ? PEM_read_bio_X509_AUX(bio, nullptr, nullptr, nullptr)
: PEM_read_bio_X509(bio, nullptr, nullptr, nullptr);
if (chain[i] == nullptr) {
auto error = ERR_peek_last_error();
if (i > 0 && ERR_GET_LIB(error) == ERR_LIB_PEM &&
ERR_GET_REASON(error) == PEM_R_NO_START_LINE) {
// EOF; we're done.
ERR_clear_error();
return;
} else {
for (size_t j = 0; j < i; j++) {
X509_free(reinterpret_cast<X509*>(chain[j]));
}
throwOpensslError();
}
}
}
// We reached the chain length limit. Try to read one more to verify that the chain ends here.
X509* dummy = PEM_read_bio_X509(bio, nullptr, nullptr, nullptr);
if (dummy != nullptr) {
X509_free(dummy);
for (auto i: kj::indices(chain)) {
X509_free(reinterpret_cast<X509*>(chain[i]));
}
KJ_FAIL_REQUIRE("exceeded maximum certificate chain length of 10");
}
}
TlsCertificate::TlsCertificate(const TlsCertificate& other) {
memcpy(chain, other.chain, sizeof(chain));
for (void* p: chain) {
if (p == nullptr) break; // end of chain; quit early
X509_up_ref(reinterpret_cast<X509*>(p));
}
}
TlsCertificate& TlsCertificate::operator=(const TlsCertificate& other) {
for (auto i: kj::indices(chain)) {
if (chain[i] != other.chain[i]) {
EVP_PKEY_free(reinterpret_cast<EVP_PKEY*>(chain[i]));
chain[i] = other.chain[i];
if (chain[i] != nullptr) X509_up_ref(reinterpret_cast<X509*>(chain[i]));
} else if (chain[i] == nullptr) {
// end of both chains; quit early
break;
}
}
return *this;
}
TlsCertificate::~TlsCertificate() noexcept(false) {
for (void* p: chain) {
if (p == nullptr) break; // end of chain; quit early
X509_free(reinterpret_cast<X509*>(p));
}
}
// =======================================================================================
// class TlsPeerIdentity
TlsPeerIdentity::~TlsPeerIdentity() noexcept(false) {
if (cert != nullptr) {
X509_free(reinterpret_cast<X509*>(cert));
}
}
kj::String TlsPeerIdentity::toString() {
if (hasCertificate()) {
return getCommonName();
} else {
return kj::str("(anonymous client)");
}
}
kj::String TlsPeerIdentity::getCommonName() {
if (cert == nullptr) {
KJ_FAIL_REQUIRE("client did not provide a certificate") { return nullptr; }
}
X509_NAME* subj = X509_get_subject_name(reinterpret_cast<X509*>(cert));
int index = X509_NAME_get_index_by_NID(subj, NID_commonName, -1);
KJ_ASSERT(index != -1, "certificate has no common name?");
X509_NAME_ENTRY* entry = X509_NAME_get_entry(subj, index);
KJ_ASSERT(entry != nullptr);
ASN1_STRING* data = X509_NAME_ENTRY_get_data(entry);
KJ_ASSERT(data != nullptr);
unsigned char* out = nullptr;
int len = ASN1_STRING_to_UTF8(&out, data);
KJ_ASSERT(len >= 0);
KJ_DEFER(OPENSSL_free(out));
return kj::heapString(reinterpret_cast<char*>(out), len);
}
} // namespace kj
#endif // KJ_HAS_OPENSSL