okhttp: Add ChannelCredentials
diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java
index 93ea57e..28e334a 100644
--- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java
+++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java
@@ -22,11 +22,18 @@
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
+import io.grpc.CallCredentials;
+import io.grpc.ChannelCredentials;
 import io.grpc.ChannelLogger;
+import io.grpc.ChoiceChannelCredentials;
+import io.grpc.CompositeCallCredentials;
+import io.grpc.CompositeChannelCredentials;
 import io.grpc.ExperimentalApi;
 import io.grpc.ForwardingChannelBuilder;
+import io.grpc.InsecureChannelCredentials;
 import io.grpc.Internal;
 import io.grpc.ManagedChannelBuilder;
+import io.grpc.TlsChannelCredentials;
 import io.grpc.internal.AtomicBackoff;
 import io.grpc.internal.ClientTransportFactory;
 import io.grpc.internal.ConnectionClientTransport;
@@ -45,6 +52,8 @@
 import java.net.InetSocketAddress;
 import java.net.SocketAddress;
 import java.security.GeneralSecurityException;
+import java.util.EnumSet;
+import java.util.Set;
 import java.util.concurrent.Executor;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
@@ -113,6 +122,11 @@
     return new OkHttpChannelBuilder(host, port);
   }
 
+  /** Creates a new builder with the given host and port. */
+  public static OkHttpChannelBuilder forAddress(String host, int port, ChannelCredentials creds) {
+    return forTarget(GrpcUtil.authorityFromHostAndPort(host, port), creds);
+  }
+
   /**
    * Creates a new builder for the given target that will be resolved by
    * {@link io.grpc.NameResolver}.
@@ -121,11 +135,24 @@
     return new OkHttpChannelBuilder(target);
   }
 
+  /**
+   * Creates a new builder for the given target that will be resolved by
+   * {@link io.grpc.NameResolver}.
+   */
+  public static OkHttpChannelBuilder forTarget(String target, ChannelCredentials creds) {
+    SslSocketFactoryResult result = sslSocketFactoryFrom(creds);
+    if (result.error != null) {
+      throw new IllegalArgumentException(result.error);
+    }
+    return new OkHttpChannelBuilder(target, result.factory, result.callCredentials);
+  }
+
   private Executor transportExecutor;
   private ScheduledExecutorService scheduledExecutorService;
 
   private SocketFactory socketFactory;
   private SSLSocketFactory sslSocketFactory;
+  private final boolean freezeSecurityConfiguration;
   private HostnameVerifier hostnameVerifier;
   private ConnectionSpec connectionSpec = INTERNAL_DEFAULT_CONNECTION_SPEC;
   private NegotiationType negotiationType = NegotiationType.TLS;
@@ -147,23 +174,36 @@
   }
 
   private OkHttpChannelBuilder(String target) {
-    final class OkHttpChannelTransportFactoryBuilder implements ClientTransportFactoryBuilder {
-      @Override
-      public ClientTransportFactory buildClientTransportFactory() {
-        return buildTransportFactory();
-      }
-    }
-
-    final class OkHttpChannelDefaultPortProvider implements ChannelBuilderDefaultPortProvider {
-      @Override
-      public int getDefaultPort() {
-        return OkHttpChannelBuilder.this.getDefaultPort();
-      }
-    }
-
     managedChannelImplBuilder = new ManagedChannelImplBuilder(target,
         new OkHttpChannelTransportFactoryBuilder(),
         new OkHttpChannelDefaultPortProvider());
+    this.freezeSecurityConfiguration = false;
+  }
+
+  OkHttpChannelBuilder(String target, @Nullable SSLSocketFactory factory,
+      @Nullable CallCredentials callCredentials) {
+    managedChannelImplBuilder = new ManagedChannelImplBuilder(target, callCredentials,
+        new OkHttpChannelTransportFactoryBuilder(),
+        new OkHttpChannelDefaultPortProvider());
+    this.sslSocketFactory = factory;
+    this.negotiationType = factory == null ? NegotiationType.PLAINTEXT : NegotiationType.TLS;
+    this.freezeSecurityConfiguration = true;
+  }
+
+  private final class OkHttpChannelTransportFactoryBuilder
+      implements ClientTransportFactoryBuilder {
+    @Override
+    public ClientTransportFactory buildClientTransportFactory() {
+      return buildTransportFactory();
+    }
+  }
+
+  private final class OkHttpChannelDefaultPortProvider
+      implements ChannelBuilderDefaultPortProvider {
+    @Override
+    public int getDefaultPort() {
+      return OkHttpChannelBuilder.this.getDefaultPort();
+    }
   }
 
   @Internal
@@ -214,6 +254,8 @@
    */
   @Deprecated
   public OkHttpChannelBuilder negotiationType(io.grpc.okhttp.NegotiationType type) {
+    Preconditions.checkState(!freezeSecurityConfiguration,
+        "Cannot change security when using ChannelCredentials");
     Preconditions.checkNotNull(type, "type");
     switch (type) {
       case TLS:
@@ -284,6 +326,8 @@
    * Override the default {@link SSLSocketFactory} and enable TLS negotiation.
    */
   public OkHttpChannelBuilder sslSocketFactory(SSLSocketFactory factory) {
+    Preconditions.checkState(!freezeSecurityConfiguration,
+        "Cannot change security when using ChannelCredentials");
     this.sslSocketFactory = factory;
     negotiationType = NegotiationType.TLS;
     return this;
@@ -310,6 +354,8 @@
    *
    */
   public OkHttpChannelBuilder hostnameVerifier(@Nullable HostnameVerifier hostnameVerifier) {
+    Preconditions.checkState(!freezeSecurityConfiguration,
+        "Cannot change security when using ChannelCredentials");
     this.hostnameVerifier = hostnameVerifier;
     return this;
   }
@@ -328,6 +374,8 @@
    */
   public OkHttpChannelBuilder connectionSpec(
       com.squareup.okhttp.ConnectionSpec connectionSpec) {
+    Preconditions.checkState(!freezeSecurityConfiguration,
+        "Cannot change security when using ChannelCredentials");
     Preconditions.checkArgument(connectionSpec.isTls(), "plaintext ConnectionSpec is not accepted");
     this.connectionSpec = Utils.convertSpec(connectionSpec);
     return this;
@@ -336,6 +384,8 @@
   /** Sets the negotiation type for the HTTP/2 connection to plaintext. */
   @Override
   public OkHttpChannelBuilder usePlaintext() {
+    Preconditions.checkState(!freezeSecurityConfiguration,
+        "Cannot change security when using ChannelCredentials");
     negotiationType = NegotiationType.PLAINTEXT;
     return this;
   }
@@ -345,11 +395,13 @@
    *
    * <p>With TLS enabled, a default {@link SSLSocketFactory} is created using the best {@link
    * java.security.Provider} available and is NOT based on {@link SSLSocketFactory#getDefault}. To
-   * more precisely control the TLS configuration call {@link #sslSocketFactory} to override the
-   * socket factory used.
+   * more precisely control the TLS configuration call {@link #sslSocketFactory(SSLSocketFactory)}
+   * to override the socket factory used.
    */
   @Override
   public OkHttpChannelBuilder useTransportSecurity() {
+    Preconditions.checkState(!freezeSecurityConfiguration,
+        "Cannot change security when using ChannelCredentials");
     negotiationType = NegotiationType.TLS;
     return this;
   }
@@ -469,6 +521,99 @@
     }
   }
 
+  private static final EnumSet<TlsChannelCredentials.Feature> understoodTlsFeatures =
+      EnumSet.noneOf(TlsChannelCredentials.Feature.class);
+
+  static SslSocketFactoryResult sslSocketFactoryFrom(ChannelCredentials creds) {
+    if (creds instanceof TlsChannelCredentials) {
+      TlsChannelCredentials tlsCreds = (TlsChannelCredentials) creds;
+      Set<TlsChannelCredentials.Feature> incomprehensible =
+          tlsCreds.incomprehensible(understoodTlsFeatures);
+      if (!incomprehensible.isEmpty()) {
+        return SslSocketFactoryResult.error(
+            "TLS features not understood: " + incomprehensible);
+      }
+      SSLSocketFactory sslSocketFactory;
+      try {
+        SSLContext sslContext = SSLContext.getInstance("Default", Platform.get().getProvider());
+        sslSocketFactory = sslContext.getSocketFactory();
+      } catch (GeneralSecurityException gse) {
+        throw new RuntimeException("TLS Provider failure", gse);
+      }
+      return SslSocketFactoryResult.factory(sslSocketFactory);
+
+    } else if (creds instanceof InsecureChannelCredentials) {
+      return SslSocketFactoryResult.plaintext();
+
+    } else if (creds instanceof CompositeChannelCredentials) {
+      CompositeChannelCredentials compCreds = (CompositeChannelCredentials) creds;
+      return sslSocketFactoryFrom(compCreds.getChannelCredentials())
+          .withCallCredentials(compCreds.getCallCredentials());
+
+    } else if (creds instanceof SslSocketFactoryChannelCredentials.ChannelCredentials) {
+      SslSocketFactoryChannelCredentials.ChannelCredentials factoryCreds =
+          (SslSocketFactoryChannelCredentials.ChannelCredentials) creds;
+      return SslSocketFactoryResult.factory(factoryCreds.getFactory());
+
+    } else if (creds instanceof ChoiceChannelCredentials) {
+      ChoiceChannelCredentials choiceCreds = (ChoiceChannelCredentials) creds;
+      StringBuilder error = new StringBuilder();
+      for (ChannelCredentials innerCreds : choiceCreds.getCredentialsList()) {
+        SslSocketFactoryResult result = sslSocketFactoryFrom(innerCreds);
+        if (result.error == null) {
+          return result;
+        }
+        error.append(", ");
+        error.append(result.error);
+      }
+      return SslSocketFactoryResult.error(error.substring(2));
+
+    } else {
+      return SslSocketFactoryResult.error(
+          "Unsupported credential type: " + creds.getClass().getName());
+    }
+  }
+
+  static final class SslSocketFactoryResult {
+    /** {@code null} implies plaintext if {@code error == null}. */
+    public final SSLSocketFactory factory;
+    public final CallCredentials callCredentials;
+    public final String error;
+
+    private SslSocketFactoryResult(SSLSocketFactory factory, CallCredentials creds, String error) {
+      this.factory = factory;
+      this.callCredentials = creds;
+      this.error = error;
+    }
+
+    public static SslSocketFactoryResult error(String error) {
+      return new SslSocketFactoryResult(
+          null, null, Preconditions.checkNotNull(error, "error"));
+    }
+
+    public static SslSocketFactoryResult plaintext() {
+      return new SslSocketFactoryResult(null, null, null);
+    }
+
+    public static SslSocketFactoryResult factory(
+        SSLSocketFactory factory) {
+      return new SslSocketFactoryResult(
+          Preconditions.checkNotNull(factory, "factory"), null, null);
+    }
+
+    public SslSocketFactoryResult withCallCredentials(CallCredentials callCreds) {
+      Preconditions.checkNotNull(callCreds, "callCreds");
+      if (error != null) {
+        return this;
+      }
+      if (this.callCredentials != null) {
+        callCreds = new CompositeCallCredentials(this.callCredentials, callCreds);
+      }
+      return new SslSocketFactoryResult(factory, callCreds, null);
+    }
+  }
+
+
   /**
    * Creates OkHttp transports. Exposed for internal use, as it should be private.
    */
diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java
index 036cc3a..f8caaea 100644
--- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java
+++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java
@@ -16,6 +16,7 @@
 
 package io.grpc.okhttp;
 
+import io.grpc.ChannelCredentials;
 import io.grpc.Internal;
 import io.grpc.InternalServiceProviders;
 import io.grpc.ManagedChannelProvider;
@@ -45,4 +46,15 @@
   public OkHttpChannelBuilder builderForTarget(String target) {
     return OkHttpChannelBuilder.forTarget(target);
   }
+
+  @Override
+  public NewChannelBuilderResult newChannelBuilder(String target, ChannelCredentials creds) {
+    OkHttpChannelBuilder.SslSocketFactoryResult result =
+        OkHttpChannelBuilder.sslSocketFactoryFrom(creds);
+    if (result.error != null) {
+      return NewChannelBuilderResult.error(result.error);
+    }
+    return NewChannelBuilderResult.channelBuilder(new OkHttpChannelBuilder(
+        target, result.factory, result.callCredentials));
+  }
 }
diff --git a/okhttp/src/main/java/io/grpc/okhttp/SslSocketFactoryChannelCredentials.java b/okhttp/src/main/java/io/grpc/okhttp/SslSocketFactoryChannelCredentials.java
new file mode 100644
index 0000000..3ee524c
--- /dev/null
+++ b/okhttp/src/main/java/io/grpc/okhttp/SslSocketFactoryChannelCredentials.java
@@ -0,0 +1,44 @@
+/*
+ * Copyright 2020 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.okhttp;
+
+import com.google.common.base.Preconditions;
+import io.grpc.ExperimentalApi;
+import javax.net.ssl.SSLSocketFactory;
+
+/** A credential with full control over the SSLSocketFactory. */
+@ExperimentalApi("There is no plan to make this API stable, given transport API instability")
+public final class SslSocketFactoryChannelCredentials {
+  private SslSocketFactoryChannelCredentials() {}
+
+  public static io.grpc.ChannelCredentials create(SSLSocketFactory factory) {
+    return new ChannelCredentials(factory);
+  }
+
+  // Hide implementation detail of how these credentials operate
+  static final class ChannelCredentials extends io.grpc.ChannelCredentials {
+    private final SSLSocketFactory factory;
+
+    private ChannelCredentials(SSLSocketFactory factory) {
+      this.factory = Preconditions.checkNotNull(factory, "factory");
+    }
+
+    public SSLSocketFactory getFactory() {
+      return factory;
+    }
+  }
+}
diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java
index e15bca3..4e51ce2 100644
--- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java
+++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java
@@ -16,15 +16,23 @@
 
 package io.grpc.okhttp;
 
+import static com.google.common.truth.Truth.assertThat;
 import static io.grpc.internal.GrpcUtil.TIMER_SERVICE;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertSame;
+import static org.mockito.Mockito.mock;
 
 import com.squareup.okhttp.ConnectionSpec;
+import io.grpc.CallCredentials;
+import io.grpc.ChannelCredentials;
 import io.grpc.ChannelLogger;
+import io.grpc.ChoiceChannelCredentials;
+import io.grpc.CompositeChannelCredentials;
+import io.grpc.InsecureChannelCredentials;
 import io.grpc.ManagedChannel;
+import io.grpc.TlsChannelCredentials;
 import io.grpc.internal.ClientTransportFactory;
 import io.grpc.internal.FakeClock;
 import io.grpc.internal.GrpcUtil;
@@ -35,6 +43,8 @@
 import java.net.Socket;
 import java.util.concurrent.ScheduledExecutorService;
 import javax.net.SocketFactory;
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLSocketFactory;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
@@ -113,6 +123,101 @@
   }
 
   @Test
+  public void sslSocketFactoryFrom_unknown() {
+    OkHttpChannelBuilder.SslSocketFactoryResult result =
+        OkHttpChannelBuilder.sslSocketFactoryFrom(new ChannelCredentials() {});
+    assertThat(result.error).isNotNull();
+    assertThat(result.callCredentials).isNull();
+    assertThat(result.factory).isNull();
+  }
+
+  @Test
+  public void sslSocketFactoryFrom_tls() {
+    OkHttpChannelBuilder.SslSocketFactoryResult result =
+        OkHttpChannelBuilder.sslSocketFactoryFrom(TlsChannelCredentials.create());
+    assertThat(result.error).isNull();
+    assertThat(result.callCredentials).isNull();
+    assertThat(result.factory).isNotNull();
+  }
+
+  @Test
+  public void sslSocketFactoryFrom_unsupportedTls() {
+    OkHttpChannelBuilder.SslSocketFactoryResult result = OkHttpChannelBuilder.sslSocketFactoryFrom(
+        TlsChannelCredentials.newBuilder().requireFakeFeature().build());
+    assertThat(result.error).contains("FAKE");
+    assertThat(result.callCredentials).isNull();
+    assertThat(result.factory).isNull();
+  }
+
+  @Test
+  public void sslSocketFactoryFrom_insecure() {
+    OkHttpChannelBuilder.SslSocketFactoryResult result =
+        OkHttpChannelBuilder.sslSocketFactoryFrom(InsecureChannelCredentials.create());
+    assertThat(result.error).isNull();
+    assertThat(result.callCredentials).isNull();
+    assertThat(result.factory).isNull();
+  }
+
+  @Test
+  public void sslSocketFactoryFrom_composite() {
+    CallCredentials callCredentials = mock(CallCredentials.class);
+    OkHttpChannelBuilder.SslSocketFactoryResult result =
+        OkHttpChannelBuilder.sslSocketFactoryFrom(CompositeChannelCredentials.create(
+          TlsChannelCredentials.create(), callCredentials));
+    assertThat(result.error).isNull();
+    assertThat(result.callCredentials).isSameInstanceAs(callCredentials);
+    assertThat(result.factory).isNotNull();
+
+    result = OkHttpChannelBuilder.sslSocketFactoryFrom(CompositeChannelCredentials.create(
+          InsecureChannelCredentials.create(), callCredentials));
+    assertThat(result.error).isNull();
+    assertThat(result.callCredentials).isSameInstanceAs(callCredentials);
+    assertThat(result.factory).isNull();
+  }
+
+  @Test
+  public void sslSocketFactoryFrom_okHttp() throws Exception {
+    SSLContext sslContext = SSLContext.getInstance("TLS");
+    sslContext.init(null, null, null);
+    SSLSocketFactory sslSocketFactory = sslContext.getSocketFactory();
+    OkHttpChannelBuilder.SslSocketFactoryResult result = OkHttpChannelBuilder.sslSocketFactoryFrom(
+        SslSocketFactoryChannelCredentials.create(sslSocketFactory));
+    assertThat(result.error).isNull();
+    assertThat(result.callCredentials).isNull();
+    assertThat(result.factory).isSameInstanceAs(sslSocketFactory);
+  }
+
+  @Test
+  public void sslSocketFactoryFrom_choice() {
+    OkHttpChannelBuilder.SslSocketFactoryResult result =
+        OkHttpChannelBuilder.sslSocketFactoryFrom(ChoiceChannelCredentials.create(
+          new ChannelCredentials() {},
+          TlsChannelCredentials.create(),
+          InsecureChannelCredentials.create()));
+    assertThat(result.error).isNull();
+    assertThat(result.callCredentials).isNull();
+    assertThat(result.factory).isNotNull();
+
+    result = OkHttpChannelBuilder.sslSocketFactoryFrom(ChoiceChannelCredentials.create(
+          InsecureChannelCredentials.create(),
+          new ChannelCredentials() {},
+          TlsChannelCredentials.create()));
+    assertThat(result.error).isNull();
+    assertThat(result.callCredentials).isNull();
+    assertThat(result.factory).isNull();
+  }
+
+  @Test
+  public void sslSocketFactoryFrom_choice_unknown() {
+    OkHttpChannelBuilder.SslSocketFactoryResult result =
+        OkHttpChannelBuilder.sslSocketFactoryFrom(ChoiceChannelCredentials.create(
+          new ChannelCredentials() {}));
+    assertThat(result.error).isNotNull();
+    assertThat(result.callCredentials).isNull();
+    assertThat(result.factory).isNull();
+  }
+
+  @Test
   public void failForUsingClearTextSpecDirectly() {
     thrown.expect(IllegalArgumentException.class);
     thrown.expectMessage("plaintext ConnectionSpec is not accepted");
diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelProviderTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelProviderTest.java
index 3069eb1..363f11e 100644
--- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelProviderTest.java
+++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelProviderTest.java
@@ -16,13 +16,16 @@
 
 package io.grpc.okhttp;
 
+import static com.google.common.truth.Truth.assertThat;
 import static org.junit.Assert.assertSame;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
 import io.grpc.InternalServiceProviders;
 import io.grpc.ManagedChannelProvider;
+import io.grpc.ManagedChannelProvider.NewChannelBuilderResult;
 import io.grpc.ManagedChannelRegistryAccessor;
+import io.grpc.TlsChannelCredentials;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -63,4 +66,23 @@
   public void builderIsAOkHttpBuilder() {
     assertSame(OkHttpChannelBuilder.class, provider.builderForAddress("localhost", 443).getClass());
   }
+
+  @Test
+  public void builderForTarget() {
+    assertThat(provider.builderForTarget("localhost:443")).isInstanceOf(OkHttpChannelBuilder.class);
+  }
+
+  @Test
+  public void newChannelBuilder_success() {
+    NewChannelBuilderResult result =
+        provider.newChannelBuilder("localhost:443", TlsChannelCredentials.create());
+    assertThat(result.getChannelBuilder()).isInstanceOf(OkHttpChannelBuilder.class);
+  }
+
+  @Test
+  public void newChannelBuilder_fail() {
+    NewChannelBuilderResult result = provider.newChannelBuilder("localhost:443",
+        TlsChannelCredentials.newBuilder().requireFakeFeature().build());
+    assertThat(result.getError()).contains("FAKE");
+  }
 }