okhttp: Add missing server support for TLS ClientAuth (#9711)
diff --git a/okhttp/build.gradle b/okhttp/build.gradle
index a044503..439abaa 100644
--- a/okhttp/build.gradle
+++ b/okhttp/build.gradle
@@ -21,6 +21,7 @@
testImplementation project(':grpc-core').sourceSets.test.output,
project(':grpc-api').sourceSets.test.output,
project(':grpc-testing'),
+ project(':grpc-testing-proto'),
libraries.netty.codec.http2,
libraries.okhttp
signature libraries.signature.java
diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java
index f0e8bf4..45d6b9e 100644
--- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java
+++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java
@@ -40,7 +40,10 @@
import io.grpc.internal.SharedResourcePool;
import io.grpc.internal.TransportTracer;
import io.grpc.okhttp.internal.Platform;
+import java.io.IOException;
+import java.net.InetAddress;
import java.net.InetSocketAddress;
+import java.net.Socket;
import java.net.SocketAddress;
import java.security.GeneralSecurityException;
import java.util.EnumSet;
@@ -54,6 +57,8 @@
import javax.net.ServerSocketFactory;
import javax.net.ssl.KeyManager;
import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLSocket;
+import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
/**
@@ -422,9 +427,26 @@
} catch (GeneralSecurityException gse) {
throw new RuntimeException("TLS Provider failure", gse);
}
+ SSLSocketFactory sslSocketFactory = sslContext.getSocketFactory();
+ switch (tlsCreds.getClientAuth()) {
+ case OPTIONAL:
+ sslSocketFactory = new ClientCertRequestingSocketFactory(sslSocketFactory, false);
+ break;
+
+ case REQUIRE:
+ sslSocketFactory = new ClientCertRequestingSocketFactory(sslSocketFactory, true);
+ break;
+
+ case NONE:
+ // NOOP; this is the SSLContext default
+ break;
+
+ default:
+ return HandshakerSocketFactoryResult.error(
+ "Unknown TlsServerCredentials.ClientAuth value: " + tlsCreds.getClientAuth());
+ }
return HandshakerSocketFactoryResult.factory(new TlsServerHandshakerSocketFactory(
- new SslSocketFactoryServerCredentials.ServerCredentials(
- sslContext.getSocketFactory())));
+ new SslSocketFactoryServerCredentials.ServerCredentials(sslSocketFactory)));
} else if (creds instanceof InsecureServerCredentials) {
return HandshakerSocketFactoryResult.factory(new PlaintextHandshakerSocketFactory());
@@ -473,4 +495,59 @@
Preconditions.checkNotNull(factory, "factory"), null);
}
}
+
+ static final class ClientCertRequestingSocketFactory extends SSLSocketFactory {
+ private final SSLSocketFactory socketFactory;
+ private final boolean required;
+
+ public ClientCertRequestingSocketFactory(SSLSocketFactory socketFactory, boolean required) {
+ this.socketFactory = Preconditions.checkNotNull(socketFactory, "socketFactory");
+ this.required = required;
+ }
+
+ private Socket apply(Socket s) throws IOException {
+ if (!(s instanceof SSLSocket)) {
+ throw new IOException(
+ "SocketFactory " + socketFactory + " did not produce an SSLSocket: " + s.getClass());
+ }
+ SSLSocket sslSocket = (SSLSocket) s;
+ if (required) {
+ sslSocket.setNeedClientAuth(true);
+ } else {
+ sslSocket.setWantClientAuth(true);
+ }
+ return sslSocket;
+ }
+
+ @Override public Socket createSocket(Socket s, String host, int port, boolean autoClose)
+ throws IOException {
+ return apply(socketFactory.createSocket(s, host, port, autoClose));
+ }
+
+ @Override public Socket createSocket(String host, int port) throws IOException {
+ return apply(socketFactory.createSocket(host, port));
+ }
+
+ @Override public Socket createSocket(
+ String host, int port, InetAddress localHost, int localPort) throws IOException {
+ return apply(socketFactory.createSocket(host, port, localHost, localPort));
+ }
+
+ @Override public Socket createSocket(InetAddress host, int port) throws IOException {
+ return apply(socketFactory.createSocket(host, port));
+ }
+
+ @Override public Socket createSocket(
+ InetAddress host, int port, InetAddress localAddress, int localPort) throws IOException {
+ return apply(socketFactory.createSocket(host, port, localAddress, localPort));
+ }
+
+ @Override public String[] getDefaultCipherSuites() {
+ return socketFactory.getDefaultCipherSuites();
+ }
+
+ @Override public String[] getSupportedCipherSuites() {
+ return socketFactory.getSupportedCipherSuites();
+ }
+ }
}
diff --git a/okhttp/src/test/java/io/grpc/okhttp/TlsTest.java b/okhttp/src/test/java/io/grpc/okhttp/TlsTest.java
new file mode 100644
index 0000000..cc86c81
--- /dev/null
+++ b/okhttp/src/test/java/io/grpc/okhttp/TlsTest.java
@@ -0,0 +1,271 @@
+/*
+ * Copyright 2015 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.okhttp;
+
+import static com.google.common.truth.Truth.assertThat;
+import static com.google.common.truth.Truth.assertWithMessage;
+
+import com.google.common.base.Throwables;
+import io.grpc.ChannelCredentials;
+import io.grpc.ConnectivityState;
+import io.grpc.ManagedChannel;
+import io.grpc.ManagedChannelBuilder;
+import io.grpc.Server;
+import io.grpc.ServerCredentials;
+import io.grpc.Status;
+import io.grpc.StatusRuntimeException;
+import io.grpc.TlsChannelCredentials;
+import io.grpc.TlsServerCredentials;
+import io.grpc.internal.testing.TestUtils;
+import io.grpc.stub.StreamObserver;
+import io.grpc.testing.GrpcCleanupRule;
+import io.grpc.testing.TlsTesting;
+import io.grpc.testing.protobuf.SimpleRequest;
+import io.grpc.testing.protobuf.SimpleResponse;
+import io.grpc.testing.protobuf.SimpleServiceGrpc;
+import java.io.IOException;
+import java.io.InputStream;
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLEngine;
+import org.junit.Assume;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Verify OkHttp's TLS integration. */
+@RunWith(JUnit4.class)
+public class TlsTest {
+ @Rule
+ public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule();
+
+ @Before
+ public void checkForAlpnApi() throws Exception {
+ // This checks for the "Java 9 ALPN API" which was backported to Java 8u252. The Kokoro Windows
+ // CI is on too old of a JDK for us to assume this is available.
+ SSLContext context = SSLContext.getInstance("TLS");
+ context.init(null, null, null);
+ SSLEngine engine = context.createSSLEngine();
+ try {
+ SSLEngine.class.getMethod("getApplicationProtocol").invoke(engine);
+ } catch (NoSuchMethodException | UnsupportedOperationException ex) {
+ Assume.assumeNoException(ex);
+ }
+ }
+
+ @Test
+ public void mtls_succeeds() throws Exception {
+ ServerCredentials serverCreds;
+ try (InputStream serverCert = TlsTesting.loadCert("server1.pem");
+ InputStream serverPrivateKey = TlsTesting.loadCert("server1.key");
+ InputStream caCert = TlsTesting.loadCert("ca.pem")) {
+ serverCreds = TlsServerCredentials.newBuilder()
+ .keyManager(serverCert, serverPrivateKey)
+ .trustManager(caCert)
+ .clientAuth(TlsServerCredentials.ClientAuth.REQUIRE)
+ .build();
+ }
+ ChannelCredentials channelCreds;
+ try (InputStream clientCertChain = TlsTesting.loadCert("client.pem");
+ InputStream clientPrivateKey = TlsTesting.loadCert("client.key");
+ InputStream caCert = TlsTesting.loadCert("ca.pem")) {
+ channelCreds = TlsChannelCredentials.newBuilder()
+ .keyManager(clientCertChain, clientPrivateKey)
+ .trustManager(caCert)
+ .build();
+ }
+ Server server = grpcCleanupRule.register(server(serverCreds));
+ ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds));
+
+ SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance());
+ }
+
+ @Test
+ public void untrustedClient_fails() throws Exception {
+ ServerCredentials serverCreds;
+ try (InputStream serverCert = TlsTesting.loadCert("server1.pem");
+ InputStream serverPrivateKey = TlsTesting.loadCert("server1.key");
+ InputStream caCert = TlsTesting.loadCert("ca.pem")) {
+ serverCreds = TlsServerCredentials.newBuilder()
+ .keyManager(serverCert, serverPrivateKey)
+ .trustManager(caCert)
+ .clientAuth(TlsServerCredentials.ClientAuth.REQUIRE)
+ .build();
+ }
+ ChannelCredentials channelCreds;
+ try (InputStream clientCertChain = TlsTesting.loadCert("badclient.pem");
+ InputStream clientPrivateKey = TlsTesting.loadCert("badclient.key");
+ InputStream caCert = TlsTesting.loadCert("ca.pem")) {
+ channelCreds = TlsChannelCredentials.newBuilder()
+ .keyManager(clientCertChain, clientPrivateKey)
+ .trustManager(caCert)
+ .build();
+ }
+ Server server = grpcCleanupRule.register(server(serverCreds));
+ ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds));
+
+ assertRpcFails(channel);
+ }
+
+ @Test
+ public void missingOptionalClientCert_succeeds() throws Exception {
+ ServerCredentials serverCreds;
+ try (InputStream serverCert = TlsTesting.loadCert("server1.pem");
+ InputStream serverPrivateKey = TlsTesting.loadCert("server1.key");
+ InputStream caCert = TlsTesting.loadCert("ca.pem")) {
+ serverCreds = TlsServerCredentials.newBuilder()
+ .keyManager(serverCert, serverPrivateKey)
+ .trustManager(caCert)
+ .clientAuth(TlsServerCredentials.ClientAuth.OPTIONAL)
+ .build();
+ }
+ ChannelCredentials channelCreds;
+ try (InputStream caCert = TlsTesting.loadCert("ca.pem")) {
+ channelCreds = TlsChannelCredentials.newBuilder()
+ .trustManager(caCert)
+ .build();
+ }
+ Server server = grpcCleanupRule.register(server(serverCreds));
+ ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds));
+
+ SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance());
+ }
+
+ @Test
+ public void missingRequiredClientCert_fails() throws Exception {
+ ServerCredentials serverCreds;
+ try (InputStream serverCert = TlsTesting.loadCert("server1.pem");
+ InputStream serverPrivateKey = TlsTesting.loadCert("server1.key");
+ InputStream caCert = TlsTesting.loadCert("ca.pem")) {
+ serverCreds = TlsServerCredentials.newBuilder()
+ .keyManager(serverCert, serverPrivateKey)
+ .trustManager(caCert)
+ .clientAuth(TlsServerCredentials.ClientAuth.REQUIRE)
+ .build();
+ }
+ ChannelCredentials channelCreds;
+ try (InputStream caCert = TlsTesting.loadCert("ca.pem")) {
+ channelCreds = TlsChannelCredentials.newBuilder()
+ .trustManager(caCert)
+ .build();
+ }
+ Server server = grpcCleanupRule.register(server(serverCreds));
+ ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds));
+
+ assertRpcFails(channel);
+ }
+
+ @Test
+ public void untrustedServer_fails() throws Exception {
+ ServerCredentials serverCreds;
+ try (InputStream serverCert = TlsTesting.loadCert("badserver.pem");
+ InputStream serverPrivateKey = TlsTesting.loadCert("badserver.key");
+ InputStream caCert = TlsTesting.loadCert("ca.pem")) {
+ serverCreds = TlsServerCredentials.newBuilder()
+ .keyManager(serverCert, serverPrivateKey)
+ .trustManager(caCert)
+ .build();
+ }
+ ChannelCredentials channelCreds;
+ try (InputStream clientCertChain = TlsTesting.loadCert("client.pem");
+ InputStream clientPrivateKey = TlsTesting.loadCert("client.key");
+ InputStream caCert = TlsTesting.loadCert("ca.pem")) {
+ channelCreds = TlsChannelCredentials.newBuilder()
+ .keyManager(clientCertChain, clientPrivateKey)
+ .trustManager(caCert)
+ .build();
+ }
+ Server server = grpcCleanupRule.register(server(serverCreds));
+ ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds));
+
+ assertRpcFails(channel);
+ }
+
+ @Test
+ public void unmatchedServerSubjectAlternativeNames_fails() throws Exception {
+ ServerCredentials serverCreds;
+ try (InputStream serverCert = TlsTesting.loadCert("server1.pem");
+ InputStream serverPrivateKey = TlsTesting.loadCert("server1.key");
+ InputStream caCert = TlsTesting.loadCert("ca.pem")) {
+ serverCreds = TlsServerCredentials.newBuilder()
+ .keyManager(serverCert, serverPrivateKey)
+ .trustManager(caCert)
+ .build();
+ }
+ ChannelCredentials channelCreds;
+ try (InputStream clientCertChain = TlsTesting.loadCert("client.pem");
+ InputStream clientPrivateKey = TlsTesting.loadCert("client.key");
+ InputStream caCert = TlsTesting.loadCert("ca.pem")) {
+ channelCreds = TlsChannelCredentials.newBuilder()
+ .keyManager(clientCertChain, clientPrivateKey)
+ .trustManager(caCert)
+ .build();
+ }
+ Server server = grpcCleanupRule.register(server(serverCreds));
+ ManagedChannel channel = grpcCleanupRule.register(clientChannelBuilder(server, channelCreds)
+ .overrideAuthority("notgonnamatch.example.com")
+ .build());
+
+ assertRpcFails(channel);
+ }
+
+ private static Server server(ServerCredentials creds) throws IOException {
+ return OkHttpServerBuilder.forPort(0, creds)
+ .directExecutor()
+ .addService(new SimpleServiceImpl())
+ .build()
+ .start();
+ }
+
+ private static ManagedChannelBuilder<?> clientChannelBuilder(
+ Server server, ChannelCredentials creds) {
+ return OkHttpChannelBuilder.forAddress("localhost", server.getPort(), creds)
+ .directExecutor()
+ .overrideAuthority(TestUtils.TEST_SERVER_HOST);
+ }
+
+ private static ManagedChannel clientChannel(Server server, ChannelCredentials creds) {
+ return clientChannelBuilder(server, creds).build();
+ }
+
+ private static void assertRpcFails(ManagedChannel channel) {
+ SimpleServiceGrpc.SimpleServiceBlockingStub stub = SimpleServiceGrpc.newBlockingStub(channel);
+ try {
+ stub.unaryRpc(SimpleRequest.getDefaultInstance());
+ assertWithMessage("TLS handshake should have failed, but didn't; received RPC response")
+ .fail();
+ } catch (StatusRuntimeException e) {
+ assertWithMessage(Throwables.getStackTraceAsString(e))
+ .that(e.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE);
+ }
+ // We really want to see TRANSIENT_FAILURE here, but if the test runs slowly the 1s backoff
+ // may be exceeded by the time the failure happens (since it counts from the start of the
+ // attempt). Even so, CONNECTING is a strong indicator that the handshake failed; otherwise we'd
+ // expect READY or IDLE.
+ assertThat(channel.getState(false))
+ .isAnyOf(ConnectivityState.TRANSIENT_FAILURE, ConnectivityState.CONNECTING);
+ }
+
+ private static final class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase {
+ @Override
+ public void unaryRpc(SimpleRequest req, StreamObserver<SimpleResponse> respOb) {
+ respOb.onNext(SimpleResponse.getDefaultInstance());
+ respOb.onCompleted();
+ }
+ }
+}