netty: fix server and client handlers to check the correct alpn list (#6603)
diff --git a/netty/src/main/java/io/grpc/netty/GrpcSslContexts.java b/netty/src/main/java/io/grpc/netty/GrpcSslContexts.java
index 72701ca..9be8f9e 100644
--- a/netty/src/main/java/io/grpc/netty/GrpcSslContexts.java
+++ b/netty/src/main/java/io/grpc/netty/GrpcSslContexts.java
@@ -56,7 +56,7 @@
/*
* List of ALPN/NPN protocols in order of preference.
*/
- static final List<String> NEXT_PROTOCOL_VERSIONS =
+ private static final List<String> NEXT_PROTOCOL_VERSIONS =
Collections.unmodifiableList(Arrays.asList(HTTP2_VERSION));
/*
diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java
index b35ab1f..a420d20 100644
--- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java
+++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java
@@ -18,7 +18,6 @@
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
-import static io.grpc.netty.GrpcSslContexts.NEXT_PROTOCOL_VERSIONS;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
@@ -190,7 +189,8 @@
return;
}
SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
- if (!NEXT_PROTOCOL_VERSIONS.contains(sslHandler.applicationProtocol())) {
+ if (!sslContext.applicationProtocolNegotiator().protocols().contains(
+ sslHandler.applicationProtocol())) {
logSslEngineDetails(Level.FINE, ctx, "TLS negotiation failed for new client.", null);
ctx.fireExceptionCaught(unavailableException(
"Failed protocol negotiation: Unable to find compatible protocol"));
@@ -359,7 +359,8 @@
SslHandshakeCompletionEvent handshakeEvent = (SslHandshakeCompletionEvent) evt;
if (handshakeEvent.isSuccess()) {
SslHandler handler = ctx.pipeline().get(SslHandler.class);
- if (NEXT_PROTOCOL_VERSIONS.contains(handler.applicationProtocol())) {
+ if (sslContext.applicationProtocolNegotiator().protocols()
+ .contains(handler.applicationProtocol())) {
// Successfully negotiated the protocol.
logSslEngineDetails(Level.FINER, ctx, "TLS negotiation succeeded.", null);
propagateTlsComplete(ctx, handler.engine().getSession());
diff --git a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java
index 5747e74..92b95ee 100644
--- a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java
+++ b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java
@@ -35,6 +35,7 @@
import io.grpc.SecurityLevel;
import io.grpc.internal.GrpcAttributes;
import io.grpc.internal.testing.TestUtils;
+import io.grpc.netty.ProtocolNegotiators.ClientTlsHandler;
import io.grpc.netty.ProtocolNegotiators.ClientTlsProtocolNegotiator;
import io.grpc.netty.ProtocolNegotiators.HostPort;
import io.grpc.netty.ProtocolNegotiators.ServerTlsHandler;
@@ -76,6 +77,7 @@
import io.netty.handler.codec.http2.Http2ServerUpgradeCodec;
import io.netty.handler.codec.http2.Http2Settings;
import io.netty.handler.proxy.ProxyConnectException;
+import io.netty.handler.ssl.ApplicationProtocolConfig;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslHandler;
@@ -85,6 +87,8 @@
import java.io.File;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
+import java.util.Arrays;
+import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
@@ -350,6 +354,186 @@
}
@Test
+ public void serverTlsHandler_userEventTriggeredSslEvent_supportedProtocolCustom()
+ throws Exception {
+ SslHandler goodSslHandler = new SslHandler(engine, false) {
+ @Override
+ public String applicationProtocol() {
+ return "managed_mtls";
+ }
+ };
+
+ File serverCert = TestUtils.loadCert("server1.pem");
+ File key = TestUtils.loadCert("server1.key");
+ List<String> alpnList = Arrays.asList("managed_mtls", "h2");
+ ApplicationProtocolConfig apn = new ApplicationProtocolConfig(
+ ApplicationProtocolConfig.Protocol.ALPN,
+ ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE,
+ ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT,
+ alpnList);
+
+ sslContext = GrpcSslContexts.forServer(serverCert, key)
+ .ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE)
+ .applicationProtocolConfig(apn).build();
+
+ ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext, null);
+ pipeline.addLast(handler);
+
+ pipeline.replace(SslHandler.class, null, goodSslHandler);
+ channelHandlerCtx = pipeline.context(handler);
+ Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
+
+ pipeline.fireUserEventTriggered(sslEvent);
+
+ assertTrue(channel.isOpen());
+ ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
+ assertNotNull(grpcHandlerCtx);
+ }
+
+ @Test
+ public void serverTlsHandler_userEventTriggeredSslEvent_unsupportedProtocolCustom()
+ throws Exception {
+ SslHandler badSslHandler = new SslHandler(engine, false) {
+ @Override
+ public String applicationProtocol() {
+ return "badprotocol";
+ }
+ };
+
+ File serverCert = TestUtils.loadCert("server1.pem");
+ File key = TestUtils.loadCert("server1.key");
+ List<String> alpnList = Arrays.asList("managed_mtls", "h2");
+ ApplicationProtocolConfig apn = new ApplicationProtocolConfig(
+ ApplicationProtocolConfig.Protocol.ALPN,
+ ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE,
+ ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT,
+ alpnList);
+
+ sslContext = GrpcSslContexts.forServer(serverCert, key)
+ .ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE)
+ .applicationProtocolConfig(apn).build();
+ ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext, null);
+ pipeline.addLast(handler);
+
+ final AtomicReference<Throwable> error = new AtomicReference<>();
+ ChannelHandler errorCapture = new ChannelInboundHandlerAdapter() {
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
+ error.set(cause);
+ }
+ };
+
+ pipeline.addLast(errorCapture);
+
+ pipeline.replace(SslHandler.class, null, badSslHandler);
+ channelHandlerCtx = pipeline.context(handler);
+ Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
+
+ pipeline.fireUserEventTriggered(sslEvent);
+
+ // No h2 protocol was specified, so there should be an error, (normally handled by WBAEH)
+ assertThat(error.get()).hasMessageThat().contains("Unable to find compatible protocol");
+ ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
+ assertNull(grpcHandlerCtx);
+ }
+
+ @Test
+ public void clientTlsHandler_userEventTriggeredSslEvent_supportedProtocolH2() throws Exception {
+ SslHandler goodSslHandler = new SslHandler(engine, false) {
+ @Override
+ public String applicationProtocol() {
+ return "h2";
+ }
+ };
+ DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
+
+ ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, "authority", elg);
+ pipeline.addLast(handler);
+ pipeline.replace(SslHandler.class, null, goodSslHandler);
+ pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
+ channelHandlerCtx = pipeline.context(handler);
+ Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
+
+ pipeline.fireUserEventTriggered(sslEvent);
+
+ ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
+ assertNotNull(grpcHandlerCtx);
+ }
+
+ @Test
+ public void clientTlsHandler_userEventTriggeredSslEvent_supportedProtocolCustom()
+ throws Exception {
+ SslHandler goodSslHandler = new SslHandler(engine, false) {
+ @Override
+ public String applicationProtocol() {
+ return "managed_mtls";
+ }
+ };
+ DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
+
+ File clientCert = TestUtils.loadCert("client.pem");
+ File key = TestUtils.loadCert("client.key");
+ List<String> alpnList = Arrays.asList("managed_mtls", "h2");
+ ApplicationProtocolConfig apn = new ApplicationProtocolConfig(
+ ApplicationProtocolConfig.Protocol.ALPN,
+ ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE,
+ ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT,
+ alpnList);
+
+ sslContext = GrpcSslContexts.forClient()
+ .keyManager(clientCert, key)
+ .ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE)
+ .applicationProtocolConfig(apn).build();
+
+ ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, "authority", elg);
+ pipeline.addLast(handler);
+ pipeline.replace(SslHandler.class, null, goodSslHandler);
+ pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
+ channelHandlerCtx = pipeline.context(handler);
+ Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
+
+ pipeline.fireUserEventTriggered(sslEvent);
+
+ ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
+ assertNotNull(grpcHandlerCtx);
+ }
+
+ @Test
+ public void clientTlsHandler_userEventTriggeredSslEvent_unsupportedProtocol() throws Exception {
+ SslHandler goodSslHandler = new SslHandler(engine, false) {
+ @Override
+ public String applicationProtocol() {
+ return "badproto";
+ }
+ };
+ DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
+
+ ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, "authority", elg);
+ pipeline.addLast(handler);
+
+ final AtomicReference<Throwable> error = new AtomicReference<>();
+ ChannelHandler errorCapture = new ChannelInboundHandlerAdapter() {
+ @Override
+ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
+ error.set(cause);
+ }
+ };
+
+ pipeline.addLast(errorCapture);
+ pipeline.replace(SslHandler.class, null, goodSslHandler);
+ pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
+ channelHandlerCtx = pipeline.context(handler);
+ Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
+
+ pipeline.fireUserEventTriggered(sslEvent);
+
+ // Bad protocol was specified, so there should be an error, (normally handled by WBAEH)
+ assertThat(error.get()).hasMessageThat().contains("Unable to find compatible protocol");
+ ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
+ assertNull(grpcHandlerCtx);
+ }
+
+ @Test
public void engineLog() {
ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext, null);
pipeline.addLast(handler);