feat: auto-enable mTLS when supported certificates are detected (#2686)

* feat: auto-enable mTLS when supported certificates are detected

Signed-off-by: Radhika Agrawal <[email protected]>

* feat: Add docstring, update version check and lint errors fix

Signed-off-by: Radhika Agrawal <[email protected]>

* chore: Update the testcases to check against google auth version number and skip for unsupported version number

Signed-off-by: Radhika Agrawal <[email protected]>

* fix: fix the import for parse_version_to_tuple

Signed-off-by: Radhika Agrawal <[email protected]>

* fix: Fix the tests to add the parse_version_to_tuple function

Signed-off-by: Radhika Agrawal <[email protected]>

* fix: Minor fix for parse version function

Signed-off-by: Radhika Agrawal <[email protected]>

---------

Signed-off-by: Radhika Agrawal <[email protected]>
diff --git a/googleapiclient/discovery.py b/googleapiclient/discovery.py
index f7bbd77..62d243b 100644
--- a/googleapiclient/discovery.py
+++ b/googleapiclient/discovery.py
@@ -649,16 +649,23 @@
 
         # Obtain client cert and create mTLS http channel if cert exists.
         client_cert_to_use = None
-        use_client_cert = os.getenv(GOOGLE_API_USE_CLIENT_CERTIFICATE, "false")
-        if not use_client_cert in ("true", "false"):
-            raise MutualTLSChannelError(
-                "Unsupported GOOGLE_API_USE_CLIENT_CERTIFICATE value. Accepted values: true, false"
+        if hasattr(mtls, "should_use_client_cert"):
+            use_client_cert = mtls.should_use_client_cert()
+        else:
+            # if unsupported, fallback to reading from env var
+            use_client_cert_str = os.getenv(
+                "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"
+            ).lower()
+            use_client_cert = use_client_cert_str == "true"
+            if use_client_cert_str not in ("true", "false"):
+                raise MutualTLSChannelError(
+                    "Unsupported GOOGLE_API_USE_CLIENT_CERTIFICATE value. Accepted values: true, false"
             )
         if client_options and client_options.client_cert_source:
             raise MutualTLSChannelError(
                 "ClientOptions.client_cert_source is not supported, please use ClientOptions.client_encrypted_cert_source."
             )
-        if use_client_cert == "true":
+        if use_client_cert:
             if (
                 client_options
                 and hasattr(client_options, "client_encrypted_cert_source")
diff --git a/tests/test_discovery.py b/tests/test_discovery.py
index 9bf7cf4..f60f84c 100644
--- a/tests/test_discovery.py
+++ b/tests/test_discovery.py
@@ -32,6 +32,7 @@
 import json
 import os
 import pickle
+import pytest
 import re
 import sys
 import unittest
@@ -40,6 +41,7 @@
 
 import google.api_core.exceptions
 import google.auth.credentials
+from google.auth import __version__ as auth_version
 from google.auth.exceptions import MutualTLSChannelError
 import google_auth_httplib2
 import httplib2
@@ -62,46 +64,29 @@
     HAS_UNIVERSE = False
 
 from googleapiclient import _helpers as util
-from googleapiclient.discovery import (
-    DISCOVERY_URI,
-    MEDIA_BODY_PARAMETER_DEFAULT_VALUE,
-    MEDIA_MIME_TYPE_PARAMETER_DEFAULT_VALUE,
-    STACK_QUERY_PARAMETER_DEFAULT_VALUE,
-    STACK_QUERY_PARAMETERS,
-    V1_DISCOVERY_URI,
-    V2_DISCOVERY_URI,
-    APICoreVersionError,
-    ResourceMethodParameters,
-    _fix_up_media_path_base_url,
-    _fix_up_media_upload,
-    _fix_up_method_description,
-    _fix_up_parameters,
-    _urljoin,
-    build,
-    build_from_document,
-    key2param,
-)
+from googleapiclient.discovery import (DISCOVERY_URI,
+                                       MEDIA_BODY_PARAMETER_DEFAULT_VALUE,
+                                       MEDIA_MIME_TYPE_PARAMETER_DEFAULT_VALUE,
+                                       STACK_QUERY_PARAMETER_DEFAULT_VALUE,
+                                       STACK_QUERY_PARAMETERS,
+                                       V1_DISCOVERY_URI, V2_DISCOVERY_URI,
+                                       APICoreVersionError,
+                                       ResourceMethodParameters,
+                                       _fix_up_media_path_base_url,
+                                       _fix_up_media_upload,
+                                       _fix_up_method_description,
+                                       _fix_up_parameters, _urljoin, build,
+                                       build_from_document, key2param)
 from googleapiclient.discovery_cache import DISCOVERY_DOC_MAX_AGE
 from googleapiclient.discovery_cache.base import Cache
-from googleapiclient.errors import (
-    HttpError,
-    InvalidJsonError,
-    MediaUploadSizeError,
-    ResumableUploadError,
-    UnacceptableMimeTypeError,
-    UnknownApiNameOrVersion,
-    UnknownFileType,
-)
-from googleapiclient.http import (
-    HttpMock,
-    HttpMockSequence,
-    MediaFileUpload,
-    MediaIoBaseUpload,
-    MediaUpload,
-    MediaUploadProgress,
-    build_http,
-    tunnel_patch,
-)
+from googleapiclient.errors import (HttpError, InvalidJsonError,
+                                    MediaUploadSizeError, ResumableUploadError,
+                                    UnacceptableMimeTypeError,
+                                    UnknownApiNameOrVersion, UnknownFileType)
+from googleapiclient.http import (HttpMock, HttpMockSequence, MediaFileUpload,
+                                  MediaIoBaseUpload, MediaUpload,
+                                  MediaUploadProgress, build_http,
+                                  tunnel_patch)
 from googleapiclient.model import JsonModel
 from googleapiclient.schema import Schemas
 
@@ -156,6 +141,28 @@
     with open(datafile(filename), mode=mode) as f:
         return f.read()
 
+def parse_version_to_tuple(version_string):
+    """Safely converts a semantic version string to a comparable tuple of integers.
+
+    Example: "4.25.8" -> (4, 25, 8)
+    Ignores non-numeric parts and handles common version formats.
+
+    Args:
+        version_string: Version string in the format "x.y.z" or "x.y.z<suffix>"
+
+    Returns:
+        Tuple of integers for the parsed version string.
+    """
+    parts = []
+    for part in version_string.split("."):
+        try:
+            parts.append(int(part))
+        except ValueError:
+            # If it's a non-numeric part (e.g., '1.0.0b1' -> 'b1'), stop here.
+            # This is a simplification compared to 'packaging.parse_version', but sufficient
+            # for comparing strictly numeric semantic versions.
+            break
+    return tuple(parts)
 
 class SetupHttplib2(unittest.TestCase):
     def test_retries(self):
@@ -778,7 +785,19 @@
 
 REGULAR_ENDPOINT = "https://www.googleapis.com/plus/v1/"
 MTLS_ENDPOINT = "https://www.mtls.googleapis.com/plus/v1/"
-
+CONFIG_DATA_WITH_WORKLOAD = {
+    "version": 1,
+    "cert_configs": {
+        "workload": {
+            "cert_path": "path/to/cert/file",
+            "key_path": "path/to/key/file",
+        }
+    },
+}
+CONFIG_DATA_WITHOUT_WORKLOAD = {
+    "version": 1,
+    "cert_configs": {},
+}
 
 class DiscoveryFromDocumentMutualTLS(unittest.TestCase):
     MOCK_CREDENTIALS = mock.Mock(spec=google.auth.credentials.Credentials)
@@ -886,6 +905,55 @@
 
     @parameterized.expand(
         [
+            ("never", "", CONFIG_DATA_WITH_WORKLOAD , REGULAR_ENDPOINT),
+            ("auto", "", CONFIG_DATA_WITH_WORKLOAD, MTLS_ENDPOINT),
+            ("always", "", CONFIG_DATA_WITH_WORKLOAD, MTLS_ENDPOINT),
+            ("never", "", CONFIG_DATA_WITHOUT_WORKLOAD, REGULAR_ENDPOINT),
+            ("auto", "", CONFIG_DATA_WITHOUT_WORKLOAD, REGULAR_ENDPOINT),
+            ("always", "", CONFIG_DATA_WITHOUT_WORKLOAD, MTLS_ENDPOINT),
+        ]
+    )
+    @pytest.mark.skipif(
+        parse_version_to_tuple(auth_version) < (2,43,0),
+        reason="automatic mtls enablement when supported certs present only"
+        "enabled in google-auth<=2.43.0"
+    )
+    def test_mtls_with_provided_client_cert_unset_environment_variable(
+        self, use_mtls_env, use_client_cert, config_data, base_url
+    ):
+        """Tests that mTLS is correctly handled when a client certificate is provided.
+
+        This test case verifies that when a client certificate is explicitly provided
+        via `client_options` and GOOGLE_API_USE_CLIENT_CERTIFICATE is unset, the 
+        discovery document build process correctly configures the base URL for mTLS 
+        or regular endpoints based on the `GOOGLE_API_USE_MTLS_ENDPOINT` environment variable.
+        """
+        if hasattr(google.auth.transport.mtls, "should_use_client_cert"):
+            discovery = read_datafile("plus.json")
+            config_filename = "mock_certificate_config.json"
+            config_file_content = json.dumps(config_data)
+            m = mock.mock_open(read_data=config_file_content)
+
+            with mock.patch.dict(
+                "os.environ", {"GOOGLE_API_USE_MTLS_ENDPOINT": use_mtls_env}
+            ):
+                with mock.patch.dict(
+                    "os.environ", {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert}
+                ):
+                    with mock.patch("builtins.open", m):
+                        with mock.patch.dict("os.environ", {"GOOGLE_API_CERTIFICATE_CONFIG": config_filename}):
+                            plus = build_from_document(
+                                discovery,
+                                credentials=self.MOCK_CREDENTIALS,
+                                client_options={
+                                    "client_encrypted_cert_source": self.client_encrypted_cert_source
+                                },
+                            )
+                            self.assertIsNotNone(plus)
+                            self.assertEqual(plus._baseUrl, base_url)
+
+    @parameterized.expand(
+        [
             ("never", "true"),
             ("auto", "true"),
             ("always", "true"),
@@ -961,6 +1029,71 @@
                 self.assertIsNotNone(plus)
                 self.check_http_client_cert(plus, has_client_cert=use_client_cert)
                 self.assertEqual(plus._baseUrl, base_url)
+    @parameterized.expand(
+        [
+            ("never", "", CONFIG_DATA_WITH_WORKLOAD, REGULAR_ENDPOINT),
+            ("auto", "", CONFIG_DATA_WITH_WORKLOAD, MTLS_ENDPOINT),
+            ("always", "", CONFIG_DATA_WITH_WORKLOAD, MTLS_ENDPOINT),
+            ("never", "", CONFIG_DATA_WITHOUT_WORKLOAD, REGULAR_ENDPOINT),
+            ("auto", "", CONFIG_DATA_WITHOUT_WORKLOAD, REGULAR_ENDPOINT),
+            ("always", "", CONFIG_DATA_WITHOUT_WORKLOAD, MTLS_ENDPOINT),
+        ]
+    )
+    @mock.patch(
+        "google.auth.transport.mtls.has_default_client_cert_source", autospec=True
+    )
+    @mock.patch(
+        "google.auth.transport.mtls.default_client_encrypted_cert_source", autospec=True
+    )
+    @pytest.mark.skipif(
+        parse_version_to_tuple(auth_version) < (2,43,0),
+        reason="automatic mtls enablement when supported certs present only"
+        "enabled in google-auth<=2.43.0"
+    )
+    def test_mtls_with_default_client_cert_with_unset_environment_variable(
+        self,
+        use_mtls_env,
+        use_client_cert,
+        config_data,
+        base_url,
+        default_client_encrypted_cert_source,
+        has_default_client_cert_source,
+    ):
+        """Tests mTLS handling when falling back to a default client certificate.
+
+        This test simulates the scenario where no client certificate is explicitly
+        provided, and the library successfully finds and uses a default client
+        certificate when GOOGLE_API_USE_CLIENT_CERTIFICATE is unset. It mocks the
+        default certificate discovery process and checks that the base URL is
+        correctly set for mTLS or regular endpoints depending on the
+        `GOOGLE_API_USE_MTLS_ENDPOINT` environment variable.
+        """
+        if hasattr(google.auth.transport.mtls, "should_use_client_cert"):
+            has_default_client_cert_source.return_value = True
+            default_client_encrypted_cert_source.return_value = (
+                self.client_encrypted_cert_source
+            )
+            discovery = read_datafile("plus.json")
+            config_filename = "mock_certificate_config.json"
+            config_file_content = json.dumps(config_data)
+            m = mock.mock_open(read_data=config_file_content)
+
+            with mock.patch.dict(
+                "os.environ", {"GOOGLE_API_USE_MTLS_ENDPOINT": use_mtls_env}
+            ):
+                with mock.patch.dict(
+                    "os.environ", {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert}
+                ):
+                    with mock.patch("builtins.open", m):
+                        with mock.patch.dict("os.environ", {"GOOGLE_API_CERTIFICATE_CONFIG": config_filename}):
+                            plus = build_from_document(
+                            discovery,
+                            credentials=self.MOCK_CREDENTIALS,
+                            adc_cert_path=self.ADC_CERT_PATH,
+                            adc_key_path=self.ADC_KEY_PATH,
+                            )
+                            self.assertIsNotNone(plus)
+                            self.assertEqual(plus._baseUrl, base_url)
 
     @parameterized.expand(
         [