| # -*- coding: utf-8 -*- |
| # |
| # Copyright 2024 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| """Mutual TLS for Google Compute Engine metadata server.""" |
| |
| from dataclasses import dataclass, field |
| import enum |
| import logging |
| import os |
| from pathlib import Path |
| import ssl |
| from urllib.parse import urlparse, urlunparse |
| |
| import requests |
| from requests.adapters import HTTPAdapter |
| |
| from google.auth import environment_vars, exceptions |
| |
| |
| _LOGGER = logging.getLogger(__name__) |
| |
| _WINDOWS_OS_NAME = "nt" |
| |
| # MDS mTLS certificate paths based on OS. |
| # Documentation to well known locations can be found at: |
| # https://cloud.google.com/compute/docs/metadata/overview#https-mds-certificates |
| _WINDOWS_MTLS_COMPONENTS_BASE_PATH = Path("C:/ProgramData/Google/ComputeEngine") |
| _MTLS_COMPONENTS_BASE_PATH = Path("/run/google-mds-mtls") |
| |
| |
| def _get_mds_root_crt_path(): |
| if os.name == _WINDOWS_OS_NAME: |
| return _WINDOWS_MTLS_COMPONENTS_BASE_PATH / "mds-mtls-root.crt" |
| else: |
| return _MTLS_COMPONENTS_BASE_PATH / "root.crt" |
| |
| |
| def _get_mds_client_combined_cert_path(): |
| if os.name == _WINDOWS_OS_NAME: |
| return _WINDOWS_MTLS_COMPONENTS_BASE_PATH / "mds-mtls-client.key" |
| else: |
| return _MTLS_COMPONENTS_BASE_PATH / "client.key" |
| |
| |
| @dataclass |
| class MdsMtlsConfig: |
| ca_cert_path: Path = field( |
| default_factory=_get_mds_root_crt_path |
| ) # path to CA certificate |
| client_combined_cert_path: Path = field( |
| default_factory=_get_mds_client_combined_cert_path |
| ) # path to file containing client certificate and key |
| |
| |
| def _certs_exist(mds_mtls_config: MdsMtlsConfig): |
| """Checks if the mTLS certificates exist.""" |
| return os.path.exists(mds_mtls_config.ca_cert_path) and os.path.exists( |
| mds_mtls_config.client_combined_cert_path |
| ) |
| |
| |
| class MdsMtlsMode(enum.Enum): |
| """MDS mTLS mode. Used to configure connection behavior when connecting to MDS. |
| |
| STRICT: Always use HTTPS/mTLS. If certificates are not found locally, an error will be returned. |
| NONE: Never use mTLS. Requests will use regular HTTP. |
| DEFAULT: Use mTLS if certificates are found locally, otherwise use regular HTTP. |
| """ |
| |
| STRICT = "strict" |
| NONE = "none" |
| DEFAULT = "default" |
| |
| |
| def _parse_mds_mode(): |
| """Parses the GCE_METADATA_MTLS_MODE environment variable.""" |
| mode_str = os.environ.get( |
| environment_vars.GCE_METADATA_MTLS_MODE, "default" |
| ).lower() |
| try: |
| return MdsMtlsMode(mode_str) |
| except ValueError: |
| raise ValueError( |
| "Invalid value for GCE_METADATA_MTLS_MODE. Must be one of 'strict', 'none', or 'default'." |
| ) |
| |
| |
| def should_use_mds_mtls(mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig()): |
| """Determines if mTLS should be used for the metadata server.""" |
| mode = _parse_mds_mode() |
| if mode == MdsMtlsMode.STRICT: |
| if not _certs_exist(mds_mtls_config): |
| raise exceptions.MutualTLSChannelError( |
| "mTLS certificates not found in strict mode." |
| ) |
| return True |
| elif mode == MdsMtlsMode.NONE: |
| return False |
| else: # Default mode |
| return _certs_exist(mds_mtls_config) |
| |
| |
| class MdsMtlsAdapter(HTTPAdapter): |
| """An HTTP adapter that uses mTLS for the metadata server.""" |
| |
| def __init__( |
| self, mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig(), *args, **kwargs |
| ): |
| self.ssl_context = ssl.create_default_context() |
| self.ssl_context.load_verify_locations(cafile=mds_mtls_config.ca_cert_path) |
| self.ssl_context.load_cert_chain( |
| certfile=mds_mtls_config.client_combined_cert_path |
| ) |
| super(MdsMtlsAdapter, self).__init__(*args, **kwargs) |
| |
| def init_poolmanager(self, *args, **kwargs): |
| kwargs["ssl_context"] = self.ssl_context |
| return super(MdsMtlsAdapter, self).init_poolmanager(*args, **kwargs) |
| |
| def proxy_manager_for(self, *args, **kwargs): |
| kwargs["ssl_context"] = self.ssl_context |
| return super(MdsMtlsAdapter, self).proxy_manager_for(*args, **kwargs) |
| |
| def send(self, request, **kwargs): |
| # If we are in strict mode, always use mTLS (no HTTP fallback) |
| if _parse_mds_mode() == MdsMtlsMode.STRICT: |
| return super(MdsMtlsAdapter, self).send(request, **kwargs) |
| |
| # In default mode, attempt mTLS first, then fallback to HTTP on failure |
| try: |
| response = super(MdsMtlsAdapter, self).send(request, **kwargs) |
| response.raise_for_status() |
| return response |
| except ( |
| ssl.SSLError, |
| requests.exceptions.SSLError, |
| requests.exceptions.HTTPError, |
| ) as e: |
| _LOGGER.warning( |
| "mTLS connection to Compute Engine Metadata server failed. " |
| "Falling back to standard HTTP. Reason: %s", |
| e, |
| ) |
| # Fallback to standard HTTP |
| parsed_original_url = urlparse(request.url) |
| http_fallback_url = urlunparse(parsed_original_url._replace(scheme="http")) |
| request.url = http_fallback_url |
| |
| # Use a standard HTTPAdapter for the fallback |
| http_adapter = HTTPAdapter() |
| return http_adapter.send(request, **kwargs) |