okhttp: Add missing server support for TLS ClientAuth (#9711)

diff --git a/okhttp/build.gradle b/okhttp/build.gradle
index a044503..439abaa 100644
--- a/okhttp/build.gradle
+++ b/okhttp/build.gradle
@@ -21,6 +21,7 @@
     testImplementation project(':grpc-core').sourceSets.test.output,
             project(':grpc-api').sourceSets.test.output,
             project(':grpc-testing'),
+            project(':grpc-testing-proto'),
             libraries.netty.codec.http2,
             libraries.okhttp
     signature libraries.signature.java
diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java
index f0e8bf4..45d6b9e 100644
--- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java
+++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java
@@ -40,7 +40,10 @@
 import io.grpc.internal.SharedResourcePool;
 import io.grpc.internal.TransportTracer;
 import io.grpc.okhttp.internal.Platform;
+import java.io.IOException;
+import java.net.InetAddress;
 import java.net.InetSocketAddress;
+import java.net.Socket;
 import java.net.SocketAddress;
 import java.security.GeneralSecurityException;
 import java.util.EnumSet;
@@ -54,6 +57,8 @@
 import javax.net.ServerSocketFactory;
 import javax.net.ssl.KeyManager;
 import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLSocket;
+import javax.net.ssl.SSLSocketFactory;
 import javax.net.ssl.TrustManager;
 
 /**
@@ -422,9 +427,26 @@
       } catch (GeneralSecurityException gse) {
         throw new RuntimeException("TLS Provider failure", gse);
       }
+      SSLSocketFactory sslSocketFactory = sslContext.getSocketFactory();
+      switch (tlsCreds.getClientAuth()) {
+        case OPTIONAL:
+          sslSocketFactory = new ClientCertRequestingSocketFactory(sslSocketFactory, false);
+          break;
+
+        case REQUIRE:
+          sslSocketFactory = new ClientCertRequestingSocketFactory(sslSocketFactory, true);
+          break;
+
+        case NONE:
+          // NOOP; this is the SSLContext default
+          break;
+
+        default:
+          return HandshakerSocketFactoryResult.error(
+              "Unknown TlsServerCredentials.ClientAuth value: " + tlsCreds.getClientAuth());
+      }
       return HandshakerSocketFactoryResult.factory(new TlsServerHandshakerSocketFactory(
-          new SslSocketFactoryServerCredentials.ServerCredentials(
-              sslContext.getSocketFactory())));
+          new SslSocketFactoryServerCredentials.ServerCredentials(sslSocketFactory)));
 
     } else if (creds instanceof InsecureServerCredentials) {
       return HandshakerSocketFactoryResult.factory(new PlaintextHandshakerSocketFactory());
@@ -473,4 +495,59 @@
           Preconditions.checkNotNull(factory, "factory"), null);
     }
   }
+
+  static final class ClientCertRequestingSocketFactory extends SSLSocketFactory {
+    private final SSLSocketFactory socketFactory;
+    private final boolean required;
+
+    public ClientCertRequestingSocketFactory(SSLSocketFactory socketFactory, boolean required) {
+      this.socketFactory = Preconditions.checkNotNull(socketFactory, "socketFactory");
+      this.required = required;
+    }
+
+    private Socket apply(Socket s) throws IOException {
+      if (!(s instanceof SSLSocket)) {
+        throw new IOException(
+            "SocketFactory " + socketFactory + " did not produce an SSLSocket: " + s.getClass());
+      }
+      SSLSocket sslSocket = (SSLSocket) s;
+      if (required) {
+        sslSocket.setNeedClientAuth(true);
+      } else {
+        sslSocket.setWantClientAuth(true);
+      }
+      return sslSocket;
+    }
+
+    @Override public Socket createSocket(Socket s, String host, int port, boolean autoClose)
+        throws IOException {
+      return apply(socketFactory.createSocket(s, host, port, autoClose));
+    }
+
+    @Override public Socket createSocket(String host, int port) throws IOException {
+      return apply(socketFactory.createSocket(host, port));
+    }
+
+    @Override public Socket createSocket(
+        String host, int port, InetAddress localHost, int localPort) throws IOException {
+      return apply(socketFactory.createSocket(host, port, localHost, localPort));
+    }
+
+    @Override public Socket createSocket(InetAddress host, int port) throws IOException {
+      return apply(socketFactory.createSocket(host, port));
+    }
+
+    @Override public Socket createSocket(
+        InetAddress host, int port, InetAddress localAddress, int localPort) throws IOException {
+      return apply(socketFactory.createSocket(host, port, localAddress, localPort));
+    }
+
+    @Override public String[] getDefaultCipherSuites() {
+      return socketFactory.getDefaultCipherSuites();
+    }
+
+    @Override public String[] getSupportedCipherSuites() {
+      return socketFactory.getSupportedCipherSuites();
+    }
+  }
 }
diff --git a/okhttp/src/test/java/io/grpc/okhttp/TlsTest.java b/okhttp/src/test/java/io/grpc/okhttp/TlsTest.java
new file mode 100644
index 0000000..cc86c81
--- /dev/null
+++ b/okhttp/src/test/java/io/grpc/okhttp/TlsTest.java
@@ -0,0 +1,271 @@
+/*
+ * Copyright 2015 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.okhttp;
+
+import static com.google.common.truth.Truth.assertThat;
+import static com.google.common.truth.Truth.assertWithMessage;
+
+import com.google.common.base.Throwables;
+import io.grpc.ChannelCredentials;
+import io.grpc.ConnectivityState;
+import io.grpc.ManagedChannel;
+import io.grpc.ManagedChannelBuilder;
+import io.grpc.Server;
+import io.grpc.ServerCredentials;
+import io.grpc.Status;
+import io.grpc.StatusRuntimeException;
+import io.grpc.TlsChannelCredentials;
+import io.grpc.TlsServerCredentials;
+import io.grpc.internal.testing.TestUtils;
+import io.grpc.stub.StreamObserver;
+import io.grpc.testing.GrpcCleanupRule;
+import io.grpc.testing.TlsTesting;
+import io.grpc.testing.protobuf.SimpleRequest;
+import io.grpc.testing.protobuf.SimpleResponse;
+import io.grpc.testing.protobuf.SimpleServiceGrpc;
+import java.io.IOException;
+import java.io.InputStream;
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLEngine;
+import org.junit.Assume;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Verify OkHttp's TLS integration. */
+@RunWith(JUnit4.class)
+public class TlsTest {
+  @Rule
+  public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule();
+
+  @Before
+  public void checkForAlpnApi() throws Exception {
+    // This checks for the "Java 9 ALPN API" which was backported to Java 8u252. The Kokoro Windows
+    // CI is on too old of a JDK for us to assume this is available.
+    SSLContext context = SSLContext.getInstance("TLS");
+    context.init(null, null, null);
+    SSLEngine engine = context.createSSLEngine();
+    try {
+      SSLEngine.class.getMethod("getApplicationProtocol").invoke(engine);
+    } catch (NoSuchMethodException | UnsupportedOperationException ex) {
+      Assume.assumeNoException(ex);
+    }
+  }
+
+  @Test
+  public void mtls_succeeds() throws Exception {
+    ServerCredentials serverCreds;
+    try (InputStream serverCert = TlsTesting.loadCert("server1.pem");
+         InputStream serverPrivateKey = TlsTesting.loadCert("server1.key");
+         InputStream caCert = TlsTesting.loadCert("ca.pem")) {
+      serverCreds = TlsServerCredentials.newBuilder()
+          .keyManager(serverCert, serverPrivateKey)
+          .trustManager(caCert)
+          .clientAuth(TlsServerCredentials.ClientAuth.REQUIRE)
+          .build();
+    }
+    ChannelCredentials channelCreds;
+    try (InputStream clientCertChain = TlsTesting.loadCert("client.pem");
+         InputStream clientPrivateKey = TlsTesting.loadCert("client.key");
+         InputStream caCert = TlsTesting.loadCert("ca.pem")) {
+      channelCreds = TlsChannelCredentials.newBuilder()
+          .keyManager(clientCertChain, clientPrivateKey)
+          .trustManager(caCert)
+          .build();
+    }
+    Server server = grpcCleanupRule.register(server(serverCreds));
+    ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds));
+
+    SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance());
+  }
+
+  @Test
+  public void untrustedClient_fails() throws Exception {
+    ServerCredentials serverCreds;
+    try (InputStream serverCert = TlsTesting.loadCert("server1.pem");
+         InputStream serverPrivateKey = TlsTesting.loadCert("server1.key");
+         InputStream caCert = TlsTesting.loadCert("ca.pem")) {
+      serverCreds = TlsServerCredentials.newBuilder()
+          .keyManager(serverCert, serverPrivateKey)
+          .trustManager(caCert)
+          .clientAuth(TlsServerCredentials.ClientAuth.REQUIRE)
+          .build();
+    }
+    ChannelCredentials channelCreds;
+    try (InputStream clientCertChain = TlsTesting.loadCert("badclient.pem");
+         InputStream clientPrivateKey = TlsTesting.loadCert("badclient.key");
+         InputStream caCert = TlsTesting.loadCert("ca.pem")) {
+      channelCreds = TlsChannelCredentials.newBuilder()
+          .keyManager(clientCertChain, clientPrivateKey)
+          .trustManager(caCert)
+          .build();
+    }
+    Server server = grpcCleanupRule.register(server(serverCreds));
+    ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds));
+
+    assertRpcFails(channel);
+  }
+
+  @Test
+  public void missingOptionalClientCert_succeeds() throws Exception {
+    ServerCredentials serverCreds;
+    try (InputStream serverCert = TlsTesting.loadCert("server1.pem");
+         InputStream serverPrivateKey = TlsTesting.loadCert("server1.key");
+         InputStream caCert = TlsTesting.loadCert("ca.pem")) {
+      serverCreds = TlsServerCredentials.newBuilder()
+          .keyManager(serverCert, serverPrivateKey)
+          .trustManager(caCert)
+          .clientAuth(TlsServerCredentials.ClientAuth.OPTIONAL)
+          .build();
+    }
+    ChannelCredentials channelCreds;
+    try (InputStream caCert = TlsTesting.loadCert("ca.pem")) {
+      channelCreds = TlsChannelCredentials.newBuilder()
+          .trustManager(caCert)
+          .build();
+    }
+    Server server = grpcCleanupRule.register(server(serverCreds));
+    ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds));
+
+    SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance());
+  }
+
+  @Test
+  public void missingRequiredClientCert_fails() throws Exception {
+    ServerCredentials serverCreds;
+    try (InputStream serverCert = TlsTesting.loadCert("server1.pem");
+         InputStream serverPrivateKey = TlsTesting.loadCert("server1.key");
+         InputStream caCert = TlsTesting.loadCert("ca.pem")) {
+      serverCreds = TlsServerCredentials.newBuilder()
+          .keyManager(serverCert, serverPrivateKey)
+          .trustManager(caCert)
+          .clientAuth(TlsServerCredentials.ClientAuth.REQUIRE)
+          .build();
+    }
+    ChannelCredentials channelCreds;
+    try (InputStream caCert = TlsTesting.loadCert("ca.pem")) {
+      channelCreds = TlsChannelCredentials.newBuilder()
+          .trustManager(caCert)
+          .build();
+    }
+    Server server = grpcCleanupRule.register(server(serverCreds));
+    ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds));
+
+    assertRpcFails(channel);
+  }
+
+  @Test
+  public void untrustedServer_fails() throws Exception {
+    ServerCredentials serverCreds;
+    try (InputStream serverCert = TlsTesting.loadCert("badserver.pem");
+         InputStream serverPrivateKey = TlsTesting.loadCert("badserver.key");
+         InputStream caCert = TlsTesting.loadCert("ca.pem")) {
+      serverCreds = TlsServerCredentials.newBuilder()
+          .keyManager(serverCert, serverPrivateKey)
+          .trustManager(caCert)
+          .build();
+    }
+    ChannelCredentials channelCreds;
+    try (InputStream clientCertChain = TlsTesting.loadCert("client.pem");
+         InputStream clientPrivateKey = TlsTesting.loadCert("client.key");
+         InputStream caCert = TlsTesting.loadCert("ca.pem")) {
+      channelCreds = TlsChannelCredentials.newBuilder()
+          .keyManager(clientCertChain, clientPrivateKey)
+          .trustManager(caCert)
+          .build();
+    }
+    Server server = grpcCleanupRule.register(server(serverCreds));
+    ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds));
+
+    assertRpcFails(channel);
+  }
+
+  @Test
+  public void unmatchedServerSubjectAlternativeNames_fails() throws Exception {
+    ServerCredentials serverCreds;
+    try (InputStream serverCert = TlsTesting.loadCert("server1.pem");
+         InputStream serverPrivateKey = TlsTesting.loadCert("server1.key");
+         InputStream caCert = TlsTesting.loadCert("ca.pem")) {
+      serverCreds = TlsServerCredentials.newBuilder()
+          .keyManager(serverCert, serverPrivateKey)
+          .trustManager(caCert)
+          .build();
+    }
+    ChannelCredentials channelCreds;
+    try (InputStream clientCertChain = TlsTesting.loadCert("client.pem");
+         InputStream clientPrivateKey = TlsTesting.loadCert("client.key");
+         InputStream caCert = TlsTesting.loadCert("ca.pem")) {
+      channelCreds = TlsChannelCredentials.newBuilder()
+          .keyManager(clientCertChain, clientPrivateKey)
+          .trustManager(caCert)
+          .build();
+    }
+    Server server = grpcCleanupRule.register(server(serverCreds));
+    ManagedChannel channel = grpcCleanupRule.register(clientChannelBuilder(server, channelCreds)
+        .overrideAuthority("notgonnamatch.example.com")
+        .build());
+
+    assertRpcFails(channel);
+  }
+
+  private static Server server(ServerCredentials creds) throws IOException {
+    return OkHttpServerBuilder.forPort(0, creds)
+        .directExecutor()
+        .addService(new SimpleServiceImpl())
+        .build()
+        .start();
+  }
+
+  private static ManagedChannelBuilder<?> clientChannelBuilder(
+      Server server, ChannelCredentials creds) {
+    return OkHttpChannelBuilder.forAddress("localhost", server.getPort(), creds)
+        .directExecutor()
+        .overrideAuthority(TestUtils.TEST_SERVER_HOST);
+  }
+
+  private static ManagedChannel clientChannel(Server server, ChannelCredentials creds) {
+    return clientChannelBuilder(server, creds).build();
+  }
+
+  private static void assertRpcFails(ManagedChannel channel) {
+    SimpleServiceGrpc.SimpleServiceBlockingStub stub = SimpleServiceGrpc.newBlockingStub(channel);
+    try {
+      stub.unaryRpc(SimpleRequest.getDefaultInstance());
+      assertWithMessage("TLS handshake should have failed, but didn't; received RPC response")
+          .fail();
+    } catch (StatusRuntimeException e) {
+      assertWithMessage(Throwables.getStackTraceAsString(e))
+          .that(e.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE);
+    }
+    // We really want to see TRANSIENT_FAILURE here, but if the test runs slowly the 1s backoff
+    // may be exceeded by the time the failure happens (since it counts from the start of the
+    // attempt). Even so, CONNECTING is a strong indicator that the handshake failed; otherwise we'd
+    // expect READY or IDLE.
+    assertThat(channel.getState(false))
+        .isAnyOf(ConnectivityState.TRANSIENT_FAILURE, ConnectivityState.CONNECTING);
+  }
+
+  private static final class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase {
+    @Override
+    public void unaryRpc(SimpleRequest req, StreamObserver<SimpleResponse> respOb) {
+      respOb.onNext(SimpleResponse.getDefaultInstance());
+      respOb.onCompleted();
+    }
+  }
+}