netty: fix server and client handlers to check the correct alpn list (#6603)

diff --git a/netty/src/main/java/io/grpc/netty/GrpcSslContexts.java b/netty/src/main/java/io/grpc/netty/GrpcSslContexts.java
index 72701ca..9be8f9e 100644
--- a/netty/src/main/java/io/grpc/netty/GrpcSslContexts.java
+++ b/netty/src/main/java/io/grpc/netty/GrpcSslContexts.java
@@ -56,7 +56,7 @@
   /*
    * List of ALPN/NPN protocols in order of preference.
    */
-  static final List<String> NEXT_PROTOCOL_VERSIONS =
+  private static final List<String> NEXT_PROTOCOL_VERSIONS =
       Collections.unmodifiableList(Arrays.asList(HTTP2_VERSION));
 
   /*
diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java
index b35ab1f..a420d20 100644
--- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java
+++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java
@@ -18,7 +18,6 @@
 
 import static com.google.common.base.Preconditions.checkNotNull;
 import static com.google.common.base.Preconditions.checkState;
-import static io.grpc.netty.GrpcSslContexts.NEXT_PROTOCOL_VERSIONS;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
@@ -190,7 +189,8 @@
           return;
         }
         SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
-        if (!NEXT_PROTOCOL_VERSIONS.contains(sslHandler.applicationProtocol())) {
+        if (!sslContext.applicationProtocolNegotiator().protocols().contains(
+                sslHandler.applicationProtocol())) {
           logSslEngineDetails(Level.FINE, ctx, "TLS negotiation failed for new client.", null);
           ctx.fireExceptionCaught(unavailableException(
               "Failed protocol negotiation: Unable to find compatible protocol"));
@@ -359,7 +359,8 @@
         SslHandshakeCompletionEvent handshakeEvent = (SslHandshakeCompletionEvent) evt;
         if (handshakeEvent.isSuccess()) {
           SslHandler handler = ctx.pipeline().get(SslHandler.class);
-          if (NEXT_PROTOCOL_VERSIONS.contains(handler.applicationProtocol())) {
+          if (sslContext.applicationProtocolNegotiator().protocols()
+              .contains(handler.applicationProtocol())) {
             // Successfully negotiated the protocol.
             logSslEngineDetails(Level.FINER, ctx, "TLS negotiation succeeded.", null);
             propagateTlsComplete(ctx, handler.engine().getSession());
diff --git a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java
index 5747e74..92b95ee 100644
--- a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java
+++ b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java
@@ -35,6 +35,7 @@
 import io.grpc.SecurityLevel;
 import io.grpc.internal.GrpcAttributes;
 import io.grpc.internal.testing.TestUtils;
+import io.grpc.netty.ProtocolNegotiators.ClientTlsHandler;
 import io.grpc.netty.ProtocolNegotiators.ClientTlsProtocolNegotiator;
 import io.grpc.netty.ProtocolNegotiators.HostPort;
 import io.grpc.netty.ProtocolNegotiators.ServerTlsHandler;
@@ -76,6 +77,7 @@
 import io.netty.handler.codec.http2.Http2ServerUpgradeCodec;
 import io.netty.handler.codec.http2.Http2Settings;
 import io.netty.handler.proxy.ProxyConnectException;
+import io.netty.handler.ssl.ApplicationProtocolConfig;
 import io.netty.handler.ssl.SslContext;
 import io.netty.handler.ssl.SslContextBuilder;
 import io.netty.handler.ssl.SslHandler;
@@ -85,6 +87,8 @@
 import java.io.File;
 import java.net.InetSocketAddress;
 import java.net.SocketAddress;
+import java.util.Arrays;
+import java.util.List;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
@@ -350,6 +354,186 @@
   }
 
   @Test
+  public void serverTlsHandler_userEventTriggeredSslEvent_supportedProtocolCustom()
+      throws Exception {
+    SslHandler goodSslHandler = new SslHandler(engine, false) {
+      @Override
+      public String applicationProtocol() {
+        return "managed_mtls";
+      }
+    };
+
+    File serverCert = TestUtils.loadCert("server1.pem");
+    File key = TestUtils.loadCert("server1.key");
+    List<String> alpnList = Arrays.asList("managed_mtls", "h2");
+    ApplicationProtocolConfig apn = new ApplicationProtocolConfig(
+        ApplicationProtocolConfig.Protocol.ALPN,
+        ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE,
+        ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT,
+        alpnList);
+
+    sslContext = GrpcSslContexts.forServer(serverCert, key)
+        .ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE)
+        .applicationProtocolConfig(apn).build();
+
+    ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext, null);
+    pipeline.addLast(handler);
+
+    pipeline.replace(SslHandler.class, null, goodSslHandler);
+    channelHandlerCtx = pipeline.context(handler);
+    Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
+
+    pipeline.fireUserEventTriggered(sslEvent);
+
+    assertTrue(channel.isOpen());
+    ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
+    assertNotNull(grpcHandlerCtx);
+  }
+
+  @Test
+  public void serverTlsHandler_userEventTriggeredSslEvent_unsupportedProtocolCustom()
+      throws Exception {
+    SslHandler badSslHandler = new SslHandler(engine, false) {
+      @Override
+      public String applicationProtocol() {
+        return "badprotocol";
+      }
+    };
+
+    File serverCert = TestUtils.loadCert("server1.pem");
+    File key = TestUtils.loadCert("server1.key");
+    List<String> alpnList = Arrays.asList("managed_mtls", "h2");
+    ApplicationProtocolConfig apn = new ApplicationProtocolConfig(
+        ApplicationProtocolConfig.Protocol.ALPN,
+        ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE,
+        ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT,
+        alpnList);
+
+    sslContext = GrpcSslContexts.forServer(serverCert, key)
+        .ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE)
+        .applicationProtocolConfig(apn).build();
+    ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext, null);
+    pipeline.addLast(handler);
+
+    final AtomicReference<Throwable> error = new AtomicReference<>();
+    ChannelHandler errorCapture = new ChannelInboundHandlerAdapter() {
+      @Override
+      public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
+        error.set(cause);
+      }
+    };
+
+    pipeline.addLast(errorCapture);
+
+    pipeline.replace(SslHandler.class, null, badSslHandler);
+    channelHandlerCtx = pipeline.context(handler);
+    Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
+
+    pipeline.fireUserEventTriggered(sslEvent);
+
+    // No h2 protocol was specified, so there should be an error, (normally handled by WBAEH)
+    assertThat(error.get()).hasMessageThat().contains("Unable to find compatible protocol");
+    ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
+    assertNull(grpcHandlerCtx);
+  }
+
+  @Test
+  public void clientTlsHandler_userEventTriggeredSslEvent_supportedProtocolH2() throws Exception {
+    SslHandler goodSslHandler = new SslHandler(engine, false) {
+      @Override
+      public String applicationProtocol() {
+        return "h2";
+      }
+    };
+    DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
+
+    ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, "authority", elg);
+    pipeline.addLast(handler);
+    pipeline.replace(SslHandler.class, null, goodSslHandler);
+    pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
+    channelHandlerCtx = pipeline.context(handler);
+    Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
+
+    pipeline.fireUserEventTriggered(sslEvent);
+
+    ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
+    assertNotNull(grpcHandlerCtx);
+  }
+
+  @Test
+  public void clientTlsHandler_userEventTriggeredSslEvent_supportedProtocolCustom()
+      throws Exception {
+    SslHandler goodSslHandler = new SslHandler(engine, false) {
+      @Override
+      public String applicationProtocol() {
+        return "managed_mtls";
+      }
+    };
+    DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
+
+    File clientCert = TestUtils.loadCert("client.pem");
+    File key = TestUtils.loadCert("client.key");
+    List<String> alpnList = Arrays.asList("managed_mtls", "h2");
+    ApplicationProtocolConfig apn = new ApplicationProtocolConfig(
+        ApplicationProtocolConfig.Protocol.ALPN,
+        ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE,
+        ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT,
+        alpnList);
+
+    sslContext = GrpcSslContexts.forClient()
+        .keyManager(clientCert, key)
+        .ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE)
+        .applicationProtocolConfig(apn).build();
+
+    ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, "authority", elg);
+    pipeline.addLast(handler);
+    pipeline.replace(SslHandler.class, null, goodSslHandler);
+    pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
+    channelHandlerCtx = pipeline.context(handler);
+    Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
+
+    pipeline.fireUserEventTriggered(sslEvent);
+
+    ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
+    assertNotNull(grpcHandlerCtx);
+  }
+
+  @Test
+  public void clientTlsHandler_userEventTriggeredSslEvent_unsupportedProtocol() throws Exception {
+    SslHandler goodSslHandler = new SslHandler(engine, false) {
+      @Override
+      public String applicationProtocol() {
+        return "badproto";
+      }
+    };
+    DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
+
+    ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, "authority", elg);
+    pipeline.addLast(handler);
+
+    final AtomicReference<Throwable> error = new AtomicReference<>();
+    ChannelHandler errorCapture = new ChannelInboundHandlerAdapter() {
+      @Override
+      public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
+        error.set(cause);
+      }
+    };
+
+    pipeline.addLast(errorCapture);
+    pipeline.replace(SslHandler.class, null, goodSslHandler);
+    pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
+    channelHandlerCtx = pipeline.context(handler);
+    Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
+
+    pipeline.fireUserEventTriggered(sslEvent);
+
+    // Bad protocol was specified, so there should be an error, (normally handled by WBAEH)
+    assertThat(error.get()).hasMessageThat().contains("Unable to find compatible protocol");
+    ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
+    assertNull(grpcHandlerCtx);
+  }
+
+  @Test
   public void engineLog() {
     ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext, null);
     pipeline.addLast(handler);