xds: suppress hostname check but keep trust check in the delegated X509TrustManagerImpl (#6589)

diff --git a/xds/src/main/java/io/grpc/xds/sds/trust/SdsTrustManagerFactory.java b/xds/src/main/java/io/grpc/xds/sds/trust/SdsTrustManagerFactory.java
index 0ae906a..e3a2274 100644
--- a/xds/src/main/java/io/grpc/xds/sds/trust/SdsTrustManagerFactory.java
+++ b/xds/src/main/java/io/grpc/xds/sds/trust/SdsTrustManagerFactory.java
@@ -19,6 +19,7 @@
 import static com.google.common.base.Preconditions.checkNotNull;
 import static com.google.common.base.Preconditions.checkState;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Strings;
 import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
 import io.envoyproxy.envoy.api.v2.core.DataSource.SpecifierCase;
@@ -53,7 +54,7 @@
   public SdsTrustManagerFactory(CertificateValidationContext certificateValidationContext)
       throws CertificateException, IOException, CertStoreException {
     checkNotNull(certificateValidationContext, "certificateValidationContext");
-    createSdsX509TrustManager(
+    sdsX509TrustManager = createSdsX509TrustManager(
         getTrustedCaFromCertContext(certificateValidationContext), certificateValidationContext);
   }
 
@@ -76,7 +77,8 @@
     }
   }
 
-  private void createSdsX509TrustManager(
+  @VisibleForTesting
+  static SdsX509TrustManager createSdsX509TrustManager(
       X509Certificate[] certs, CertificateValidationContext certContext) throws CertStoreException {
     TrustManagerFactory tmf = null;
     try {
@@ -109,7 +111,7 @@
     if (myDelegate == null) {
       throw new CertStoreException("Native X509 TrustManager not found.");
     }
-    sdsX509TrustManager = new SdsX509TrustManager(certContext, myDelegate);
+    return new SdsX509TrustManager(certContext, myDelegate);
   }
 
   @Override
diff --git a/xds/src/main/java/io/grpc/xds/sds/trust/SdsX509TrustManager.java b/xds/src/main/java/io/grpc/xds/sds/trust/SdsX509TrustManager.java
index c2d4b28..ce843d9 100644
--- a/xds/src/main/java/io/grpc/xds/sds/trust/SdsX509TrustManager.java
+++ b/xds/src/main/java/io/grpc/xds/sds/trust/SdsX509TrustManager.java
@@ -30,6 +30,8 @@
 import java.util.Locale;
 import javax.annotation.Nullable;
 import javax.net.ssl.SSLEngine;
+import javax.net.ssl.SSLParameters;
+import javax.net.ssl.SSLSocket;
 import javax.net.ssl.X509ExtendedTrustManager;
 import javax.net.ssl.X509TrustManager;
 
@@ -256,6 +258,14 @@
   @Override
   public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket)
       throws CertificateException {
+    if (socket instanceof SSLSocket) {
+      SSLSocket sslSocket = (SSLSocket) socket;
+      SSLParameters sslParams = sslSocket.getSSLParameters();
+      if (sslParams != null) {
+        sslParams.setEndpointIdentificationAlgorithm(null);
+        sslSocket.setSSLParameters(sslParams);
+      }
+    }
     delegate.checkServerTrusted(chain, authType, socket);
     verifySubjectAltNameInChain(chain);
   }
@@ -263,6 +273,11 @@
   @Override
   public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine sslEngine)
       throws CertificateException {
+    SSLParameters sslParams = sslEngine.getSSLParameters();
+    if (sslParams != null) {
+      sslParams.setEndpointIdentificationAlgorithm(null);
+      sslEngine.setSSLParameters(sslParams);
+    }
     delegate.checkServerTrusted(chain, authType, sslEngine);
     verifySubjectAltNameInChain(chain);
   }
diff --git a/xds/src/test/java/io/grpc/xds/sds/trust/SdsX509TrustManagerTest.java b/xds/src/test/java/io/grpc/xds/sds/trust/SdsX509TrustManagerTest.java
index e627612..64b2fd1 100644
--- a/xds/src/test/java/io/grpc/xds/sds/trust/SdsX509TrustManagerTest.java
+++ b/xds/src/test/java/io/grpc/xds/sds/trust/SdsX509TrustManagerTest.java
@@ -17,15 +17,26 @@
 package io.grpc.xds.sds.trust;
 
 import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.CALLS_REAL_METHODS;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
 import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
 import io.grpc.internal.testing.TestUtils;
 import java.io.FileNotFoundException;
 import java.io.IOException;
+import java.security.cert.CertStoreException;
 import java.security.cert.CertificateException;
 import java.security.cert.X509Certificate;
+import javax.net.ssl.SSLEngine;
+import javax.net.ssl.SSLParameters;
+import javax.net.ssl.SSLSession;
+import javax.net.ssl.SSLSocket;
 import javax.net.ssl.X509ExtendedTrustManager;
-import org.junit.Assert;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -33,6 +44,7 @@
 import org.mockito.Mock;
 import org.mockito.junit.MockitoJUnit;
 import org.mockito.junit.MockitoRule;
+import sun.security.validator.ValidatorException;
 
 /**
  * Unit tests for {@link SdsX509TrustManager}.
@@ -40,21 +52,32 @@
 @RunWith(JUnit4.class)
 public class SdsX509TrustManagerTest {
 
+  /** Trust store cert. */
+  private static final String CA_PEM_FILE = "ca.pem";
+
   /** server1 has 4 SANs. */
   private static final String SERVER_1_PEM_FILE = "server1.pem";
 
   /** client has no SANs. */
   private static final String CLIENT_PEM_FILE = "client.pem";
 
+  /** Untrusted server. */
+  private static final String BAD_SERVER_PEM_FILE = "badserver.pem";
+
   @Rule
   public final MockitoRule mockitoRule = MockitoJUnit.rule();
 
   @Mock
   private X509ExtendedTrustManager mockDelegate;
 
+  @Mock
+  private SSLSession mockSession;
+
+  private SdsX509TrustManager trustManager;
+
   @Test
   public void nullCertContextTest() throws CertificateException, IOException {
-    SdsX509TrustManager trustManager = new SdsX509TrustManager(null, mockDelegate);
+    trustManager = new SdsX509TrustManager(null, mockDelegate);
     X509Certificate[] certs =
         CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE));
     trustManager.verifySubjectAltNameInChain(certs);
@@ -63,7 +86,7 @@
   @Test
   public void emptySanListContextTest() throws CertificateException, IOException {
     CertificateValidationContext certContext = CertificateValidationContext.getDefaultInstance();
-    SdsX509TrustManager trustManager = new SdsX509TrustManager(certContext, mockDelegate);
+    trustManager = new SdsX509TrustManager(certContext, mockDelegate);
     X509Certificate[] certs =
         CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE));
     trustManager.verifySubjectAltNameInChain(certs);
@@ -73,10 +96,10 @@
   public void missingPeerCerts() throws CertificateException, FileNotFoundException {
     CertificateValidationContext certContext =
         CertificateValidationContext.newBuilder().addVerifySubjectAltName("foo.com").build();
-    SdsX509TrustManager trustManager = new SdsX509TrustManager(certContext, mockDelegate);
+    trustManager = new SdsX509TrustManager(certContext, mockDelegate);
     try {
       trustManager.verifySubjectAltNameInChain(null);
-      Assert.fail("no exception thrown");
+      fail("no exception thrown");
     } catch (CertificateException expected) {
       assertThat(expected).hasMessageThat().isEqualTo("Peer certificate(s) missing");
     }
@@ -86,10 +109,10 @@
   public void emptyArrayPeerCerts() throws CertificateException, FileNotFoundException {
     CertificateValidationContext certContext =
         CertificateValidationContext.newBuilder().addVerifySubjectAltName("foo.com").build();
-    SdsX509TrustManager trustManager = new SdsX509TrustManager(certContext, mockDelegate);
+    trustManager = new SdsX509TrustManager(certContext, mockDelegate);
     try {
       trustManager.verifySubjectAltNameInChain(new X509Certificate[0]);
-      Assert.fail("no exception thrown");
+      fail("no exception thrown");
     } catch (CertificateException expected) {
       assertThat(expected).hasMessageThat().isEqualTo("Peer certificate(s) missing");
     }
@@ -99,12 +122,12 @@
   public void noSansInPeerCerts() throws CertificateException, IOException {
     CertificateValidationContext certContext =
         CertificateValidationContext.newBuilder().addVerifySubjectAltName("foo.com").build();
-    SdsX509TrustManager trustManager = new SdsX509TrustManager(certContext, mockDelegate);
+    trustManager = new SdsX509TrustManager(certContext, mockDelegate);
     X509Certificate[] certs =
         CertificateUtils.toX509Certificates(TestUtils.loadCert(CLIENT_PEM_FILE));
     try {
       trustManager.verifySubjectAltNameInChain(certs);
-      Assert.fail("no exception thrown");
+      fail("no exception thrown");
     } catch (CertificateException expected) {
       assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed");
     }
@@ -116,7 +139,7 @@
         CertificateValidationContext.newBuilder()
             .addVerifySubjectAltName("waterzooi.test.google.be")
             .build();
-    SdsX509TrustManager trustManager = new SdsX509TrustManager(certContext, mockDelegate);
+    trustManager = new SdsX509TrustManager(certContext, mockDelegate);
     X509Certificate[] certs =
         CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE));
     trustManager.verifySubjectAltNameInChain(certs);
@@ -130,7 +153,7 @@
             .addVerifySubjectAltName("x.foo.com")
             .addVerifySubjectAltName("waterzooi.test.google.be")
             .build();
-    SdsX509TrustManager trustManager = new SdsX509TrustManager(certContext, mockDelegate);
+    trustManager = new SdsX509TrustManager(certContext, mockDelegate);
     X509Certificate[] certs =
         CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE));
     trustManager.verifySubjectAltNameInChain(certs);
@@ -141,12 +164,12 @@
       throws CertificateException, IOException {
     CertificateValidationContext certContext =
         CertificateValidationContext.newBuilder().addVerifySubjectAltName("x.foo.com").build();
-    SdsX509TrustManager trustManager = new SdsX509TrustManager(certContext, mockDelegate);
+    trustManager = new SdsX509TrustManager(certContext, mockDelegate);
     X509Certificate[] certs =
         CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE));
     try {
       trustManager.verifySubjectAltNameInChain(certs);
-      Assert.fail("no exception thrown");
+      fail("no exception thrown");
     } catch (CertificateException expected) {
       assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed");
     }
@@ -160,7 +183,7 @@
             .addVerifySubjectAltName("x.foo.com")
             .addVerifySubjectAltName("abc.test.youtube.com") // should match *.test.youtube.com
             .build();
-    SdsX509TrustManager trustManager = new SdsX509TrustManager(certContext, mockDelegate);
+    trustManager = new SdsX509TrustManager(certContext, mockDelegate);
     X509Certificate[] certs =
         CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE));
     trustManager.verifySubjectAltNameInChain(certs);
@@ -174,7 +197,7 @@
             .addVerifySubjectAltName("x.foo.com")
             .addVerifySubjectAltName("abc.test.google.fr") // should match *.test.google.fr
             .build();
-    SdsX509TrustManager trustManager = new SdsX509TrustManager(certContext, mockDelegate);
+    trustManager = new SdsX509TrustManager(certContext, mockDelegate);
     X509Certificate[] certs =
         CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE));
     trustManager.verifySubjectAltNameInChain(certs);
@@ -190,12 +213,12 @@
         CertificateValidationContext.newBuilder()
             .addVerifySubjectAltName("sub.abc.test.youtube.com")
             .build();
-    SdsX509TrustManager trustManager = new SdsX509TrustManager(certContext, mockDelegate);
+    trustManager = new SdsX509TrustManager(certContext, mockDelegate);
     X509Certificate[] certs =
         CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE));
     try {
       trustManager.verifySubjectAltNameInChain(certs);
-      Assert.fail("no exception thrown");
+      fail("no exception thrown");
     } catch (CertificateException expected) {
       assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed");
     }
@@ -208,7 +231,7 @@
             .addVerifySubjectAltName("x.foo.com")
             .addVerifySubjectAltName("192.168.1.3")
             .build();
-    SdsX509TrustManager trustManager = new SdsX509TrustManager(certContext, mockDelegate);
+    trustManager = new SdsX509TrustManager(certContext, mockDelegate);
     X509Certificate[] certs =
         CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE));
     trustManager.verifySubjectAltNameInChain(certs);
@@ -221,14 +244,132 @@
             .addVerifySubjectAltName("x.foo.com")
             .addVerifySubjectAltName("192.168.2.3")
             .build();
-    SdsX509TrustManager trustManager = new SdsX509TrustManager(certContext, mockDelegate);
+    trustManager = new SdsX509TrustManager(certContext, mockDelegate);
     X509Certificate[] certs =
         CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE));
     try {
       trustManager.verifySubjectAltNameInChain(certs);
-      Assert.fail("no exception thrown");
+      fail("no exception thrown");
     } catch (CertificateException expected) {
       assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed");
     }
   }
+
+  @Test
+  public void checkServerTrustedSslEngine()
+      throws CertificateException, IOException, CertStoreException {
+    TestSslEngine sslEngine = buildTrustManagerAndGetSslEngine();
+    X509Certificate[] serverCerts =
+        CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE));
+    trustManager.checkServerTrusted(serverCerts, "ECDHE_ECDSA", sslEngine);
+    verify(sslEngine, times(1)).getHandshakeSession();
+  }
+
+  @Test
+  public void checkServerTrustedSslEngine_untrustedServer_expectException()
+      throws CertificateException, IOException, CertStoreException {
+    TestSslEngine sslEngine = buildTrustManagerAndGetSslEngine();
+    X509Certificate[] badServerCert =
+        CertificateUtils.toX509Certificates(TestUtils.loadCert(BAD_SERVER_PEM_FILE));
+    try {
+      trustManager.checkServerTrusted(badServerCert, "ECDHE_ECDSA", sslEngine);
+      fail("exception expected");
+    } catch (ValidatorException expected) {
+      assertThat(expected).hasMessageThat()
+          .endsWith("unable to find valid certification path to requested target");
+    }
+    verify(sslEngine, times(1)).getHandshakeSession();
+  }
+
+  @Test
+  public void checkServerTrustedSslSocket()
+      throws CertificateException, IOException, CertStoreException {
+    TestSslSocket sslSocket = buildTrustManagerAndGetSslSocket();
+    X509Certificate[] serverCerts =
+        CertificateUtils.toX509Certificates(TestUtils.loadCert(SERVER_1_PEM_FILE));
+    trustManager.checkServerTrusted(serverCerts, "ECDHE_ECDSA", sslSocket);
+    verify(sslSocket, times(1)).isConnected();
+    verify(sslSocket, times(1)).getHandshakeSession();
+  }
+
+  @Test
+  public void checkServerTrustedSslSocket_untrustedServer_expectException()
+      throws CertificateException, IOException, CertStoreException {
+    TestSslSocket sslSocket = buildTrustManagerAndGetSslSocket();
+    X509Certificate[] badServerCert =
+        CertificateUtils.toX509Certificates(TestUtils.loadCert(BAD_SERVER_PEM_FILE));
+    try {
+      trustManager.checkServerTrusted(badServerCert, "ECDHE_ECDSA", sslSocket);
+      fail("exception expected");
+    } catch (ValidatorException expected) {
+      assertThat(expected).hasMessageThat()
+          .endsWith("unable to find valid certification path to requested target");
+    }
+    verify(sslSocket, times(1)).isConnected();
+    verify(sslSocket, times(1)).getHandshakeSession();
+  }
+
+  private TestSslEngine buildTrustManagerAndGetSslEngine()
+      throws CertificateException, IOException, CertStoreException {
+    SSLParameters sslParams = buildTrustManagerAndGetSslParameters();
+
+    TestSslEngine sslEngine = mock(TestSslEngine.class, CALLS_REAL_METHODS);
+    sslEngine.setSSLParameters(sslParams);
+    doReturn(mockSession).when(sslEngine).getHandshakeSession();
+    return sslEngine;
+  }
+
+  private TestSslSocket buildTrustManagerAndGetSslSocket()
+      throws CertificateException, IOException, CertStoreException {
+    SSLParameters sslParams = buildTrustManagerAndGetSslParameters();
+
+    TestSslSocket sslSocket = mock(TestSslSocket.class, CALLS_REAL_METHODS);
+    sslSocket.setSSLParameters(sslParams);
+    doReturn(true).when(sslSocket).isConnected();
+    doReturn(mockSession).when(sslSocket).getHandshakeSession();
+    return sslSocket;
+  }
+
+  private SSLParameters buildTrustManagerAndGetSslParameters()
+      throws CertificateException, IOException, CertStoreException {
+    X509Certificate[] caCerts =
+        CertificateUtils.toX509Certificates(TestUtils.loadCert(CA_PEM_FILE));
+    trustManager = SdsTrustManagerFactory.createSdsX509TrustManager(caCerts,
+        null);
+    when(mockSession.getProtocol()).thenReturn("TLSv1.2");
+    when(mockSession.getPeerHost()).thenReturn("peer-host-from-mock");
+    SSLParameters sslParams = new SSLParameters();
+    sslParams.setEndpointIdentificationAlgorithm("HTTPS");
+    return sslParams;
+  }
+
+  private abstract static class TestSslSocket extends SSLSocket {
+
+    @Override
+    public SSLParameters getSSLParameters() {
+      return sslParameters;
+    }
+
+    @Override
+    public void setSSLParameters(SSLParameters sslParameters) {
+      this.sslParameters = sslParameters;
+    }
+
+    private SSLParameters sslParameters;
+  }
+
+  private abstract static class TestSslEngine extends SSLEngine {
+
+    @Override
+    public SSLParameters getSSLParameters() {
+      return sslParameters;
+    }
+
+    @Override
+    public void setSSLParameters(SSLParameters sslParameters) {
+      this.sslParameters = sslParameters;
+    }
+
+    private SSLParameters sslParameters;
+  }
 }