okhttp: Add ChannelCredentials
diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java
index 93ea57e..28e334a 100644
--- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java
+++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java
@@ -22,11 +22,18 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
+import io.grpc.CallCredentials;
+import io.grpc.ChannelCredentials;
import io.grpc.ChannelLogger;
+import io.grpc.ChoiceChannelCredentials;
+import io.grpc.CompositeCallCredentials;
+import io.grpc.CompositeChannelCredentials;
import io.grpc.ExperimentalApi;
import io.grpc.ForwardingChannelBuilder;
+import io.grpc.InsecureChannelCredentials;
import io.grpc.Internal;
import io.grpc.ManagedChannelBuilder;
+import io.grpc.TlsChannelCredentials;
import io.grpc.internal.AtomicBackoff;
import io.grpc.internal.ClientTransportFactory;
import io.grpc.internal.ConnectionClientTransport;
@@ -45,6 +52,8 @@
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.security.GeneralSecurityException;
+import java.util.EnumSet;
+import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@@ -113,6 +122,11 @@
return new OkHttpChannelBuilder(host, port);
}
+ /** Creates a new builder with the given host and port. */
+ public static OkHttpChannelBuilder forAddress(String host, int port, ChannelCredentials creds) {
+ return forTarget(GrpcUtil.authorityFromHostAndPort(host, port), creds);
+ }
+
/**
* Creates a new builder for the given target that will be resolved by
* {@link io.grpc.NameResolver}.
@@ -121,11 +135,24 @@
return new OkHttpChannelBuilder(target);
}
+ /**
+ * Creates a new builder for the given target that will be resolved by
+ * {@link io.grpc.NameResolver}.
+ */
+ public static OkHttpChannelBuilder forTarget(String target, ChannelCredentials creds) {
+ SslSocketFactoryResult result = sslSocketFactoryFrom(creds);
+ if (result.error != null) {
+ throw new IllegalArgumentException(result.error);
+ }
+ return new OkHttpChannelBuilder(target, result.factory, result.callCredentials);
+ }
+
private Executor transportExecutor;
private ScheduledExecutorService scheduledExecutorService;
private SocketFactory socketFactory;
private SSLSocketFactory sslSocketFactory;
+ private final boolean freezeSecurityConfiguration;
private HostnameVerifier hostnameVerifier;
private ConnectionSpec connectionSpec = INTERNAL_DEFAULT_CONNECTION_SPEC;
private NegotiationType negotiationType = NegotiationType.TLS;
@@ -147,23 +174,36 @@
}
private OkHttpChannelBuilder(String target) {
- final class OkHttpChannelTransportFactoryBuilder implements ClientTransportFactoryBuilder {
- @Override
- public ClientTransportFactory buildClientTransportFactory() {
- return buildTransportFactory();
- }
- }
-
- final class OkHttpChannelDefaultPortProvider implements ChannelBuilderDefaultPortProvider {
- @Override
- public int getDefaultPort() {
- return OkHttpChannelBuilder.this.getDefaultPort();
- }
- }
-
managedChannelImplBuilder = new ManagedChannelImplBuilder(target,
new OkHttpChannelTransportFactoryBuilder(),
new OkHttpChannelDefaultPortProvider());
+ this.freezeSecurityConfiguration = false;
+ }
+
+ OkHttpChannelBuilder(String target, @Nullable SSLSocketFactory factory,
+ @Nullable CallCredentials callCredentials) {
+ managedChannelImplBuilder = new ManagedChannelImplBuilder(target, callCredentials,
+ new OkHttpChannelTransportFactoryBuilder(),
+ new OkHttpChannelDefaultPortProvider());
+ this.sslSocketFactory = factory;
+ this.negotiationType = factory == null ? NegotiationType.PLAINTEXT : NegotiationType.TLS;
+ this.freezeSecurityConfiguration = true;
+ }
+
+ private final class OkHttpChannelTransportFactoryBuilder
+ implements ClientTransportFactoryBuilder {
+ @Override
+ public ClientTransportFactory buildClientTransportFactory() {
+ return buildTransportFactory();
+ }
+ }
+
+ private final class OkHttpChannelDefaultPortProvider
+ implements ChannelBuilderDefaultPortProvider {
+ @Override
+ public int getDefaultPort() {
+ return OkHttpChannelBuilder.this.getDefaultPort();
+ }
}
@Internal
@@ -214,6 +254,8 @@
*/
@Deprecated
public OkHttpChannelBuilder negotiationType(io.grpc.okhttp.NegotiationType type) {
+ Preconditions.checkState(!freezeSecurityConfiguration,
+ "Cannot change security when using ChannelCredentials");
Preconditions.checkNotNull(type, "type");
switch (type) {
case TLS:
@@ -284,6 +326,8 @@
* Override the default {@link SSLSocketFactory} and enable TLS negotiation.
*/
public OkHttpChannelBuilder sslSocketFactory(SSLSocketFactory factory) {
+ Preconditions.checkState(!freezeSecurityConfiguration,
+ "Cannot change security when using ChannelCredentials");
this.sslSocketFactory = factory;
negotiationType = NegotiationType.TLS;
return this;
@@ -310,6 +354,8 @@
*
*/
public OkHttpChannelBuilder hostnameVerifier(@Nullable HostnameVerifier hostnameVerifier) {
+ Preconditions.checkState(!freezeSecurityConfiguration,
+ "Cannot change security when using ChannelCredentials");
this.hostnameVerifier = hostnameVerifier;
return this;
}
@@ -328,6 +374,8 @@
*/
public OkHttpChannelBuilder connectionSpec(
com.squareup.okhttp.ConnectionSpec connectionSpec) {
+ Preconditions.checkState(!freezeSecurityConfiguration,
+ "Cannot change security when using ChannelCredentials");
Preconditions.checkArgument(connectionSpec.isTls(), "plaintext ConnectionSpec is not accepted");
this.connectionSpec = Utils.convertSpec(connectionSpec);
return this;
@@ -336,6 +384,8 @@
/** Sets the negotiation type for the HTTP/2 connection to plaintext. */
@Override
public OkHttpChannelBuilder usePlaintext() {
+ Preconditions.checkState(!freezeSecurityConfiguration,
+ "Cannot change security when using ChannelCredentials");
negotiationType = NegotiationType.PLAINTEXT;
return this;
}
@@ -345,11 +395,13 @@
*
* <p>With TLS enabled, a default {@link SSLSocketFactory} is created using the best {@link
* java.security.Provider} available and is NOT based on {@link SSLSocketFactory#getDefault}. To
- * more precisely control the TLS configuration call {@link #sslSocketFactory} to override the
- * socket factory used.
+ * more precisely control the TLS configuration call {@link #sslSocketFactory(SSLSocketFactory)}
+ * to override the socket factory used.
*/
@Override
public OkHttpChannelBuilder useTransportSecurity() {
+ Preconditions.checkState(!freezeSecurityConfiguration,
+ "Cannot change security when using ChannelCredentials");
negotiationType = NegotiationType.TLS;
return this;
}
@@ -469,6 +521,99 @@
}
}
+ private static final EnumSet<TlsChannelCredentials.Feature> understoodTlsFeatures =
+ EnumSet.noneOf(TlsChannelCredentials.Feature.class);
+
+ static SslSocketFactoryResult sslSocketFactoryFrom(ChannelCredentials creds) {
+ if (creds instanceof TlsChannelCredentials) {
+ TlsChannelCredentials tlsCreds = (TlsChannelCredentials) creds;
+ Set<TlsChannelCredentials.Feature> incomprehensible =
+ tlsCreds.incomprehensible(understoodTlsFeatures);
+ if (!incomprehensible.isEmpty()) {
+ return SslSocketFactoryResult.error(
+ "TLS features not understood: " + incomprehensible);
+ }
+ SSLSocketFactory sslSocketFactory;
+ try {
+ SSLContext sslContext = SSLContext.getInstance("Default", Platform.get().getProvider());
+ sslSocketFactory = sslContext.getSocketFactory();
+ } catch (GeneralSecurityException gse) {
+ throw new RuntimeException("TLS Provider failure", gse);
+ }
+ return SslSocketFactoryResult.factory(sslSocketFactory);
+
+ } else if (creds instanceof InsecureChannelCredentials) {
+ return SslSocketFactoryResult.plaintext();
+
+ } else if (creds instanceof CompositeChannelCredentials) {
+ CompositeChannelCredentials compCreds = (CompositeChannelCredentials) creds;
+ return sslSocketFactoryFrom(compCreds.getChannelCredentials())
+ .withCallCredentials(compCreds.getCallCredentials());
+
+ } else if (creds instanceof SslSocketFactoryChannelCredentials.ChannelCredentials) {
+ SslSocketFactoryChannelCredentials.ChannelCredentials factoryCreds =
+ (SslSocketFactoryChannelCredentials.ChannelCredentials) creds;
+ return SslSocketFactoryResult.factory(factoryCreds.getFactory());
+
+ } else if (creds instanceof ChoiceChannelCredentials) {
+ ChoiceChannelCredentials choiceCreds = (ChoiceChannelCredentials) creds;
+ StringBuilder error = new StringBuilder();
+ for (ChannelCredentials innerCreds : choiceCreds.getCredentialsList()) {
+ SslSocketFactoryResult result = sslSocketFactoryFrom(innerCreds);
+ if (result.error == null) {
+ return result;
+ }
+ error.append(", ");
+ error.append(result.error);
+ }
+ return SslSocketFactoryResult.error(error.substring(2));
+
+ } else {
+ return SslSocketFactoryResult.error(
+ "Unsupported credential type: " + creds.getClass().getName());
+ }
+ }
+
+ static final class SslSocketFactoryResult {
+ /** {@code null} implies plaintext if {@code error == null}. */
+ public final SSLSocketFactory factory;
+ public final CallCredentials callCredentials;
+ public final String error;
+
+ private SslSocketFactoryResult(SSLSocketFactory factory, CallCredentials creds, String error) {
+ this.factory = factory;
+ this.callCredentials = creds;
+ this.error = error;
+ }
+
+ public static SslSocketFactoryResult error(String error) {
+ return new SslSocketFactoryResult(
+ null, null, Preconditions.checkNotNull(error, "error"));
+ }
+
+ public static SslSocketFactoryResult plaintext() {
+ return new SslSocketFactoryResult(null, null, null);
+ }
+
+ public static SslSocketFactoryResult factory(
+ SSLSocketFactory factory) {
+ return new SslSocketFactoryResult(
+ Preconditions.checkNotNull(factory, "factory"), null, null);
+ }
+
+ public SslSocketFactoryResult withCallCredentials(CallCredentials callCreds) {
+ Preconditions.checkNotNull(callCreds, "callCreds");
+ if (error != null) {
+ return this;
+ }
+ if (this.callCredentials != null) {
+ callCreds = new CompositeCallCredentials(this.callCredentials, callCreds);
+ }
+ return new SslSocketFactoryResult(factory, callCreds, null);
+ }
+ }
+
+
/**
* Creates OkHttp transports. Exposed for internal use, as it should be private.
*/
diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java
index 036cc3a..f8caaea 100644
--- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java
+++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java
@@ -16,6 +16,7 @@
package io.grpc.okhttp;
+import io.grpc.ChannelCredentials;
import io.grpc.Internal;
import io.grpc.InternalServiceProviders;
import io.grpc.ManagedChannelProvider;
@@ -45,4 +46,15 @@
public OkHttpChannelBuilder builderForTarget(String target) {
return OkHttpChannelBuilder.forTarget(target);
}
+
+ @Override
+ public NewChannelBuilderResult newChannelBuilder(String target, ChannelCredentials creds) {
+ OkHttpChannelBuilder.SslSocketFactoryResult result =
+ OkHttpChannelBuilder.sslSocketFactoryFrom(creds);
+ if (result.error != null) {
+ return NewChannelBuilderResult.error(result.error);
+ }
+ return NewChannelBuilderResult.channelBuilder(new OkHttpChannelBuilder(
+ target, result.factory, result.callCredentials));
+ }
}
diff --git a/okhttp/src/main/java/io/grpc/okhttp/SslSocketFactoryChannelCredentials.java b/okhttp/src/main/java/io/grpc/okhttp/SslSocketFactoryChannelCredentials.java
new file mode 100644
index 0000000..3ee524c
--- /dev/null
+++ b/okhttp/src/main/java/io/grpc/okhttp/SslSocketFactoryChannelCredentials.java
@@ -0,0 +1,44 @@
+/*
+ * Copyright 2020 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.okhttp;
+
+import com.google.common.base.Preconditions;
+import io.grpc.ExperimentalApi;
+import javax.net.ssl.SSLSocketFactory;
+
+/** A credential with full control over the SSLSocketFactory. */
+@ExperimentalApi("There is no plan to make this API stable, given transport API instability")
+public final class SslSocketFactoryChannelCredentials {
+ private SslSocketFactoryChannelCredentials() {}
+
+ public static io.grpc.ChannelCredentials create(SSLSocketFactory factory) {
+ return new ChannelCredentials(factory);
+ }
+
+ // Hide implementation detail of how these credentials operate
+ static final class ChannelCredentials extends io.grpc.ChannelCredentials {
+ private final SSLSocketFactory factory;
+
+ private ChannelCredentials(SSLSocketFactory factory) {
+ this.factory = Preconditions.checkNotNull(factory, "factory");
+ }
+
+ public SSLSocketFactory getFactory() {
+ return factory;
+ }
+ }
+}
diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java
index e15bca3..4e51ce2 100644
--- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java
+++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelBuilderTest.java
@@ -16,15 +16,23 @@
package io.grpc.okhttp;
+import static com.google.common.truth.Truth.assertThat;
import static io.grpc.internal.GrpcUtil.TIMER_SERVICE;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
+import static org.mockito.Mockito.mock;
import com.squareup.okhttp.ConnectionSpec;
+import io.grpc.CallCredentials;
+import io.grpc.ChannelCredentials;
import io.grpc.ChannelLogger;
+import io.grpc.ChoiceChannelCredentials;
+import io.grpc.CompositeChannelCredentials;
+import io.grpc.InsecureChannelCredentials;
import io.grpc.ManagedChannel;
+import io.grpc.TlsChannelCredentials;
import io.grpc.internal.ClientTransportFactory;
import io.grpc.internal.FakeClock;
import io.grpc.internal.GrpcUtil;
@@ -35,6 +43,8 @@
import java.net.Socket;
import java.util.concurrent.ScheduledExecutorService;
import javax.net.SocketFactory;
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLSocketFactory;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
@@ -113,6 +123,101 @@
}
@Test
+ public void sslSocketFactoryFrom_unknown() {
+ OkHttpChannelBuilder.SslSocketFactoryResult result =
+ OkHttpChannelBuilder.sslSocketFactoryFrom(new ChannelCredentials() {});
+ assertThat(result.error).isNotNull();
+ assertThat(result.callCredentials).isNull();
+ assertThat(result.factory).isNull();
+ }
+
+ @Test
+ public void sslSocketFactoryFrom_tls() {
+ OkHttpChannelBuilder.SslSocketFactoryResult result =
+ OkHttpChannelBuilder.sslSocketFactoryFrom(TlsChannelCredentials.create());
+ assertThat(result.error).isNull();
+ assertThat(result.callCredentials).isNull();
+ assertThat(result.factory).isNotNull();
+ }
+
+ @Test
+ public void sslSocketFactoryFrom_unsupportedTls() {
+ OkHttpChannelBuilder.SslSocketFactoryResult result = OkHttpChannelBuilder.sslSocketFactoryFrom(
+ TlsChannelCredentials.newBuilder().requireFakeFeature().build());
+ assertThat(result.error).contains("FAKE");
+ assertThat(result.callCredentials).isNull();
+ assertThat(result.factory).isNull();
+ }
+
+ @Test
+ public void sslSocketFactoryFrom_insecure() {
+ OkHttpChannelBuilder.SslSocketFactoryResult result =
+ OkHttpChannelBuilder.sslSocketFactoryFrom(InsecureChannelCredentials.create());
+ assertThat(result.error).isNull();
+ assertThat(result.callCredentials).isNull();
+ assertThat(result.factory).isNull();
+ }
+
+ @Test
+ public void sslSocketFactoryFrom_composite() {
+ CallCredentials callCredentials = mock(CallCredentials.class);
+ OkHttpChannelBuilder.SslSocketFactoryResult result =
+ OkHttpChannelBuilder.sslSocketFactoryFrom(CompositeChannelCredentials.create(
+ TlsChannelCredentials.create(), callCredentials));
+ assertThat(result.error).isNull();
+ assertThat(result.callCredentials).isSameInstanceAs(callCredentials);
+ assertThat(result.factory).isNotNull();
+
+ result = OkHttpChannelBuilder.sslSocketFactoryFrom(CompositeChannelCredentials.create(
+ InsecureChannelCredentials.create(), callCredentials));
+ assertThat(result.error).isNull();
+ assertThat(result.callCredentials).isSameInstanceAs(callCredentials);
+ assertThat(result.factory).isNull();
+ }
+
+ @Test
+ public void sslSocketFactoryFrom_okHttp() throws Exception {
+ SSLContext sslContext = SSLContext.getInstance("TLS");
+ sslContext.init(null, null, null);
+ SSLSocketFactory sslSocketFactory = sslContext.getSocketFactory();
+ OkHttpChannelBuilder.SslSocketFactoryResult result = OkHttpChannelBuilder.sslSocketFactoryFrom(
+ SslSocketFactoryChannelCredentials.create(sslSocketFactory));
+ assertThat(result.error).isNull();
+ assertThat(result.callCredentials).isNull();
+ assertThat(result.factory).isSameInstanceAs(sslSocketFactory);
+ }
+
+ @Test
+ public void sslSocketFactoryFrom_choice() {
+ OkHttpChannelBuilder.SslSocketFactoryResult result =
+ OkHttpChannelBuilder.sslSocketFactoryFrom(ChoiceChannelCredentials.create(
+ new ChannelCredentials() {},
+ TlsChannelCredentials.create(),
+ InsecureChannelCredentials.create()));
+ assertThat(result.error).isNull();
+ assertThat(result.callCredentials).isNull();
+ assertThat(result.factory).isNotNull();
+
+ result = OkHttpChannelBuilder.sslSocketFactoryFrom(ChoiceChannelCredentials.create(
+ InsecureChannelCredentials.create(),
+ new ChannelCredentials() {},
+ TlsChannelCredentials.create()));
+ assertThat(result.error).isNull();
+ assertThat(result.callCredentials).isNull();
+ assertThat(result.factory).isNull();
+ }
+
+ @Test
+ public void sslSocketFactoryFrom_choice_unknown() {
+ OkHttpChannelBuilder.SslSocketFactoryResult result =
+ OkHttpChannelBuilder.sslSocketFactoryFrom(ChoiceChannelCredentials.create(
+ new ChannelCredentials() {}));
+ assertThat(result.error).isNotNull();
+ assertThat(result.callCredentials).isNull();
+ assertThat(result.factory).isNull();
+ }
+
+ @Test
public void failForUsingClearTextSpecDirectly() {
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("plaintext ConnectionSpec is not accepted");
diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelProviderTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelProviderTest.java
index 3069eb1..363f11e 100644
--- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelProviderTest.java
+++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelProviderTest.java
@@ -16,13 +16,16 @@
package io.grpc.okhttp;
+import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import io.grpc.InternalServiceProviders;
import io.grpc.ManagedChannelProvider;
+import io.grpc.ManagedChannelProvider.NewChannelBuilderResult;
import io.grpc.ManagedChannelRegistryAccessor;
+import io.grpc.TlsChannelCredentials;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -63,4 +66,23 @@
public void builderIsAOkHttpBuilder() {
assertSame(OkHttpChannelBuilder.class, provider.builderForAddress("localhost", 443).getClass());
}
+
+ @Test
+ public void builderForTarget() {
+ assertThat(provider.builderForTarget("localhost:443")).isInstanceOf(OkHttpChannelBuilder.class);
+ }
+
+ @Test
+ public void newChannelBuilder_success() {
+ NewChannelBuilderResult result =
+ provider.newChannelBuilder("localhost:443", TlsChannelCredentials.create());
+ assertThat(result.getChannelBuilder()).isInstanceOf(OkHttpChannelBuilder.class);
+ }
+
+ @Test
+ public void newChannelBuilder_fail() {
+ NewChannelBuilderResult result = provider.newChannelBuilder("localhost:443",
+ TlsChannelCredentials.newBuilder().requireFakeFeature().build());
+ assertThat(result.getError()).contains("FAKE");
+ }
}