diff --git a/packages/google-auth/google/auth/compute_engine/_metadata.py b/packages/google-auth/google/auth/compute_engine/_metadata.py index aae724ab18ee..bab526c3a902 100644 --- a/packages/google-auth/google/auth/compute_engine/_metadata.py +++ b/packages/google-auth/google/auth/compute_engine/_metadata.py @@ -22,7 +22,7 @@ import json import logging import os -from urllib.parse import urljoin +from urllib.parse import urljoin, urlparse import requests @@ -52,42 +52,52 @@ ) -def _validate_gce_mds_configured_environment(): - """Validates the GCE metadata server environment configuration for mTLS. +def _mount_mds_adapter_and_get_url( + request, + root, +) -> str: + """Prepares the metadata server root URL based on the mTLS configuration and environment. - mTLS is only supported when connecting to the default metadata server hosts. - If we are in strict mode (which requires mTLS), ensure that the metadata host - has not been overridden to a custom value (which means mTLS will fail). + This method mounts the mTLS adapter to the request if needed. + It also determines the appropriate URL scheme (HTTP vs HTTPS) to use when connecting to the metadata server based on the mTLS configuration and environment, and performes appropriate checks. + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. If mTLS is enabled and the request supports sessions, the mTLS adapter will be mounted to the request's session within this method. + root (str): The root URL to use for the metadata server. + Returns: + str: The metadata server root URL. The URL will use HTTPS if mTLS is enabled or required, and HTTP otherwise. Raises: - google.auth.exceptions.MutualTLSChannelError: if the environment - configuration is invalid for mTLS. + google.auth.exceptions.MutualTLSChannelError: if using mtls and the environment configuration is invalid for mTLS (for example, the metadata host + has been overridden in strict mTLS mode). """ - mode = _mtls._parse_mds_mode() - if mode == _mtls.MdsMtlsMode.STRICT: + mds_mtls_config = _mtls.MdsMtlsConfig() + + should_mount_adapter = _mtls.should_use_mds_mtls(mds_mtls_config=mds_mtls_config) + + mds_mtls_adapter_mounted = False + if should_mount_adapter: + mds_mtls_adapter_mounted = _try_mount_mds_mtls_adapter(request, mds_mtls_config) + + use_https = mds_mtls_adapter_mounted or ( + mds_mtls_config.mode == _mtls.MdsMtlsMode.STRICT + ) + scheme = "https" if use_https else "http" + + mds_mtls_root = "{}://{}/computeMetadata/v1/".format(scheme, root) + + if mds_mtls_config.mode == _mtls.MdsMtlsMode.STRICT: # mTLS is only supported when connecting to the default metadata host. # Raise an exception if we are in strict mode (which requires mTLS) # but the metadata host has been overridden to a custom MDS. (which means mTLS will fail) - if _GCE_METADATA_HOST not in _GCE_DEFAULT_MDS_HOSTS: + parsed = urlparse(mds_mtls_root) + if parsed.hostname not in _GCE_DEFAULT_MDS_HOSTS: raise exceptions.MutualTLSChannelError( "Mutual TLS is required, but the metadata host has been overridden. " "mTLS is only supported when connecting to the default metadata host." ) - -def _get_metadata_root(use_mtls: bool): - """Returns the metadata server root URL.""" - - scheme = "https" if use_mtls else "http" - return "{}://{}/computeMetadata/v1/".format(scheme, _GCE_METADATA_HOST) - - -def _get_metadata_ip_root(use_mtls: bool): - """Returns the metadata server IP root URL.""" - scheme = "https" if use_mtls else "http" - return "{}://{}".format( - scheme, os.getenv(environment_vars.GCE_METADATA_IP, _GCE_DEFAULT_MDS_IP) - ) + return mds_mtls_root _METADATA_FLAVOR_HEADER = "metadata-flavor" @@ -159,30 +169,33 @@ def detect_gce_residency_linux(): return content.startswith(_GOOGLE) -def _prepare_request_for_mds(request, use_mtls=False) -> None: - """Prepares a request for the metadata server. - - This will check if mTLS should be used and mount the mTLS adapter if needed. +def _try_mount_mds_mtls_adapter(request, mds_mtls_config: _mtls.MdsMtlsConfig) -> bool: + """Tries to mount the mTLS adapter to the request's session. Args: request (google.auth.transport.Request): A callable used to make - HTTP requests. If mTLS is enabled, and the request supports sessions, - the request will have the mTLS adapter mounted. Otherwise, there - will be no change. - use_mtls (bool): Whether to use mTLS for the request. + HTTP requests. If the request supports sessions, the mTLS adapter will be mounted to the request's session within this method. + mds_mtls_config (_mtls.MdsMtlsConfig): The mTLS configuration containing the paths to the CA and client certificates. + Returns: + bool: True if the mTLS adapter was mounted, False otherwise. + """ + if not hasattr(request, "session"): + # If the request does not support sessions, we cannot mount the mTLS adapter. + return False - """ # Only modify the request if mTLS is enabled, and request supports sessions. - if use_mtls and hasattr(request, "session"): - # Ensure the request has a session to mount the adapter to. - if not request.session: - request.session = requests.Session() + # Ensure the request has a session to mount the adapter to. + if not request.session: + request.session = requests.Session() - adapter = _mtls.MdsMtlsAdapter() - # Mount the adapter for all default GCE metadata hosts. - for host in _GCE_DEFAULT_MDS_HOSTS: - request.session.mount(f"https://{host}/", adapter) + adapter = _mtls.MdsMtlsAdapter(mds_mtls_config=mds_mtls_config) + + # Mount the adapter for all default GCE metadata hosts. + for host in _GCE_DEFAULT_MDS_HOSTS: + request.session.mount(f"https://{host}/", adapter) + + return True def ping( @@ -199,9 +212,16 @@ def ping( Returns: bool: True if the metadata server is reachable, False otherwise. + + Raises: + google.auth.exceptions.MutualTLSChannelError: if using mtls and the environment + configuration is invalid for mTLS (for example, the metadata host + has been overridden in strict mTLS mode). """ - use_mtls = _mtls.should_use_mds_mtls() - _prepare_request_for_mds(request, use_mtls=use_mtls) + mds_ip_url = _mount_mds_adapter_and_get_url( + request, root=os.getenv(environment_vars.GCE_METADATA_IP, _GCE_DEFAULT_MDS_IP) + ) + # NOTE: The explicit ``timeout`` is a workaround. The underlying # issue is that resolving an unknown host on some networks will take # 20-30 seconds; making this timeout short fixes the issue, but @@ -216,7 +236,7 @@ def ping( for attempt in backoff: try: response = request( - url=_get_metadata_ip_root(use_mtls), + url=mds_ip_url, method="GET", headers=headers, timeout=timeout, @@ -285,20 +305,11 @@ def get( has been overridden in strict mTLS mode). """ - use_mtls = _mtls.should_use_mds_mtls() - # Prepare the request object for mTLS if needed. - # This will create a new request object with the mTLS session. - _prepare_request_for_mds(request, use_mtls=use_mtls) - - if root is None: - root = _get_metadata_root(use_mtls) - - # mTLS is only supported when connecting to the default metadata host. - # If we are in strict mode (which requires mTLS), ensure that the metadata host - # has not been overridden to a non-default host value (which means mTLS will fail). - _validate_gce_mds_configured_environment() + mds_url = _mount_mds_adapter_and_get_url( + request, root=(root if root else _GCE_METADATA_HOST) + ) - base_url = urljoin(root, path) + base_url = urljoin(mds_url, path) query_params = {} if params is None else params headers_to_use = _METADATA_HEADERS.copy() diff --git a/packages/google-auth/google/auth/compute_engine/_mtls.py b/packages/google-auth/google/auth/compute_engine/_mtls.py index 6525dd03e1bd..16ec6d15887f 100644 --- a/packages/google-auth/google/auth/compute_engine/_mtls.py +++ b/packages/google-auth/google/auth/compute_engine/_mtls.py @@ -16,7 +16,6 @@ # """Mutual TLS for Google Compute Engine metadata server.""" -from dataclasses import dataclass, field import enum import logging import os @@ -29,7 +28,6 @@ from google.auth import environment_vars, exceptions - _LOGGER = logging.getLogger(__name__) _WINDOWS_OS_NAME = "nt" @@ -41,37 +39,6 @@ _MTLS_COMPONENTS_BASE_PATH = Path("/run/google-mds-mtls") -def _get_mds_root_crt_path(): - if os.name == _WINDOWS_OS_NAME: - return _WINDOWS_MTLS_COMPONENTS_BASE_PATH / "mds-mtls-root.crt" - else: - return _MTLS_COMPONENTS_BASE_PATH / "root.crt" - - -def _get_mds_client_combined_cert_path(): - if os.name == _WINDOWS_OS_NAME: - return _WINDOWS_MTLS_COMPONENTS_BASE_PATH / "mds-mtls-client.key" - else: - return _MTLS_COMPONENTS_BASE_PATH / "client.key" - - -@dataclass -class MdsMtlsConfig: - ca_cert_path: Path = field( - default_factory=_get_mds_root_crt_path - ) # path to CA certificate - client_combined_cert_path: Path = field( - default_factory=_get_mds_client_combined_cert_path - ) # path to file containing client certificate and key - - -def _certs_exist(mds_mtls_config: MdsMtlsConfig): - """Checks if the mTLS certificates exist.""" - return os.path.exists(mds_mtls_config.ca_cert_path) and os.path.exists( - mds_mtls_config.client_combined_cert_path - ) - - class MdsMtlsMode(enum.Enum): """MDS mTLS mode. Used to configure connection behavior when connecting to MDS. @@ -85,40 +52,95 @@ class MdsMtlsMode(enum.Enum): DEFAULT = "default" -def _parse_mds_mode(): - """Parses the GCE_METADATA_MTLS_MODE environment variable.""" - mode_str = os.environ.get( - environment_vars.GCE_METADATA_MTLS_MODE, "default" - ).lower() - try: - return MdsMtlsMode(mode_str) - except ValueError: - raise ValueError( - "Invalid value for GCE_METADATA_MTLS_MODE. Must be one of 'strict', 'none', or 'default'." +class MdsMtlsConfig: + def __init__( + self, + ca_cert_path: Path = None, + client_combined_cert_path: Path = None, + mode: MdsMtlsMode = None, + ): + self.ca_cert_path = ca_cert_path or self._get_default_mds_root_crt_path() + self.client_combined_cert_path = ( + client_combined_cert_path + or self._get_default_mds_client_combined_cert_path() ) + self.mode = mode or self._parse_mds_mode() + + def _get_default_mds_root_crt_path(self): + """Returns the default path to the CA certificate, based on the OS.""" + + if os.name == _WINDOWS_OS_NAME: + return _WINDOWS_MTLS_COMPONENTS_BASE_PATH / "mds-mtls-root.crt" + else: + return _MTLS_COMPONENTS_BASE_PATH / "root.crt" + + def _get_default_mds_client_combined_cert_path(self): + """Returns the default path to the client certificate and key combined file, based on the OS.""" + if os.name == _WINDOWS_OS_NAME: + return _WINDOWS_MTLS_COMPONENTS_BASE_PATH / "mds-mtls-client.key" + else: + return _MTLS_COMPONENTS_BASE_PATH / "client.key" -def should_use_mds_mtls(mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig()): - """Determines if mTLS should be used for the metadata server.""" - mode = _parse_mds_mode() - if mode == MdsMtlsMode.STRICT: - if not _certs_exist(mds_mtls_config): + def _parse_mds_mode(self) -> MdsMtlsMode: + """Parses the GCE_METADATA_MTLS_MODE environment variable.""" + + mode_str = os.environ.get( + environment_vars.GCE_METADATA_MTLS_MODE, "default" + ).lower() + try: + return MdsMtlsMode(mode_str) + except ValueError: + raise ValueError( + "Invalid value for GCE_METADATA_MTLS_MODE. Must be one of 'strict', 'none', or 'default'." + ) + + +def mds_mtls_certificates_exist(mds_mtls_config: MdsMtlsConfig): + """Checks if the mTLS certificates exist. + + Args: + mds_mtls_config (MdsMtlsConfig): The mTLS configuration containing the + paths to the CA and client certificates. + + Returns: + bool: True if both certificates exist, False otherwise. + """ + return os.path.exists(mds_mtls_config.ca_cert_path) and os.path.exists( + mds_mtls_config.client_combined_cert_path + ) + + +def should_use_mds_mtls(mds_mtls_config: MdsMtlsConfig) -> bool: + """Determines if mTLS should be used for the metadata server. + + Args: + mds_mtls_config (MdsMtlsConfig): The mTLS configuration containing the + paths to the CA and client certificates as well as the mTLS mode. + + Returns: + bool: True if mTLS should be used, False otherwise. + Raises: + google.auth.exceptions.MutualTLSChannelError: if mode is strict but certificates do not exist. + """ + + mds_mtls_mode = mds_mtls_config.mode + if mds_mtls_mode == MdsMtlsMode.STRICT: + if not mds_mtls_certificates_exist(mds_mtls_config): raise exceptions.MutualTLSChannelError( "mTLS certificates not found in strict mode." ) return True - elif mode == MdsMtlsMode.NONE: + if mds_mtls_mode == MdsMtlsMode.NONE: return False - else: # Default mode - return _certs_exist(mds_mtls_config) + return mds_mtls_certificates_exist(mds_mtls_config) class MdsMtlsAdapter(HTTPAdapter): """An HTTP adapter that uses mTLS for the metadata server.""" - def __init__( - self, mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig(), *args, **kwargs - ): + def __init__(self, mds_mtls_config: MdsMtlsConfig, *args, **kwargs): + self.mds_mtls_config = mds_mtls_config self.ssl_context = ssl.create_default_context() self.ssl_context.load_verify_locations(cafile=mds_mtls_config.ca_cert_path) self.ssl_context.load_cert_chain( @@ -136,7 +158,7 @@ def proxy_manager_for(self, *args, **kwargs): def send(self, request, **kwargs): # If we are in strict mode, always use mTLS (no HTTP fallback) - if _parse_mds_mode() == MdsMtlsMode.STRICT: + if self.mds_mtls_config.mode == MdsMtlsMode.STRICT: return super(MdsMtlsAdapter, self).send(request, **kwargs) # In default mode, attempt mTLS first, then fallback to HTTP on failure diff --git a/packages/google-auth/tests/compute_engine/test__metadata.py b/packages/google-auth/tests/compute_engine/test__metadata.py index 35996ab24b92..34046f503f9e 100644 --- a/packages/google-auth/tests/compute_engine/test__metadata.py +++ b/packages/google-auth/tests/compute_engine/test__metadata.py @@ -143,7 +143,7 @@ def test_ping_success(mock_metrics_header_value): request.assert_called_once_with( method="GET", - url="http://169.254.169.254", + url="http://169.254.169.254/computeMetadata/v1/", headers=MDS_PING_REQUEST_HEADER, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -157,7 +157,7 @@ def test_ping_success_retry(mock_metrics_header_value): request.assert_called_with( method="GET", - url="http://169.254.169.254", + url="http://169.254.169.254/computeMetadata/v1/", headers=MDS_PING_REQUEST_HEADER, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -195,7 +195,7 @@ def test_ping_success_custom_root(mock_metrics_header_value): request.assert_called_once_with( method="GET", - url="http://" + fake_ip, + url="http://" + fake_ip + "/computeMetadata/v1/", headers=MDS_PING_REQUEST_HEADER, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -388,13 +388,13 @@ def test_get_success_custom_root_old_variable(): def test_get_success_custom_root(): request = make_request("{}", headers={"content-type": "application/json"}) - fake_root = "http://another.metadata.service" + fake_root = "another.metadata.service" _metadata.get(request, PATH, root=fake_root) request.assert_called_once_with( method="GET", - url="{}/{}".format(fake_root, PATH), + url="http://another.metadata.service/computeMetadata/v1/{}".format(PATH), headers=_metadata._METADATA_HEADERS, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) @@ -638,12 +638,18 @@ def test_get_universe_domain_other_error(): ) +@mock.patch( + "google.auth._agent_identity_utils.get_and_parse_agent_identity_certificate", + return_value=None, +) @mock.patch( "google.auth.metrics.token_request_access_token_mds", return_value=ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE, ) @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) -def test_get_service_account_token(utcnow, mock_metrics_header_value): +def test_get_service_account_token( + utcnow, mock_metrics_header_value, mock_get_agent_cert +): ttl = 500 request = make_request( json.dumps({"access_token": "token", "expires_in": ttl}), @@ -665,12 +671,18 @@ def test_get_service_account_token(utcnow, mock_metrics_header_value): assert expiry == utcnow() + datetime.timedelta(seconds=ttl) +@mock.patch( + "google.auth._agent_identity_utils.get_and_parse_agent_identity_certificate", + return_value=None, +) @mock.patch( "google.auth.metrics.token_request_access_token_mds", return_value=ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE, ) @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) -def test_get_service_account_token_with_scopes_list(utcnow, mock_metrics_header_value): +def test_get_service_account_token_with_scopes_list( + utcnow, mock_metrics_header_value, mock_get_agent_cert +): ttl = 500 request = make_request( json.dumps({"access_token": "token", "expires_in": ttl}), @@ -695,13 +707,17 @@ def test_get_service_account_token_with_scopes_list(utcnow, mock_metrics_header_ assert expiry == utcnow() + datetime.timedelta(seconds=ttl) +@mock.patch( + "google.auth._agent_identity_utils.get_and_parse_agent_identity_certificate", + return_value=None, +) @mock.patch( "google.auth.metrics.token_request_access_token_mds", return_value=ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE, ) @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) def test_get_service_account_token_with_scopes_string( - utcnow, mock_metrics_header_value + utcnow, mock_metrics_header_value, mock_get_agent_cert ): ttl = 500 request = make_request( @@ -824,48 +840,81 @@ def test_get_service_account_info(): assert info[key] == value -def test__get_metadata_root_mtls(): - assert ( - _metadata._get_metadata_root(use_mtls=True) - == "https://metadata.google.internal/computeMetadata/v1/" - ) - - -def test__get_metadata_root_no_mtls(): - assert ( - _metadata._get_metadata_root(use_mtls=False) - == "http://metadata.google.internal/computeMetadata/v1/" - ) - - -def test__get_metadata_ip_root_mtls(): - assert _metadata._get_metadata_ip_root(use_mtls=True) == "https://169.254.169.254" - - -def test__get_metadata_ip_root_no_mtls(): - assert _metadata._get_metadata_ip_root(use_mtls=False) == "http://169.254.169.254" +def test__mount_mds_adapter_and_get_url_mtls(): + request = mock.Mock(spec=transport.Request) + # Mocking mds_mtls_certificates_exist to True so it returns https + with mock.patch( + "google.auth.compute_engine._mtls.mds_mtls_certificates_exist", + return_value=True, + ): + with mock.patch( + "google.auth.compute_engine._mtls.MdsMtlsConfig._parse_mds_mode", + return_value=_metadata._mtls.MdsMtlsMode.STRICT, + ): + assert ( + _metadata._mount_mds_adapter_and_get_url( + request, root=_metadata._GCE_DEFAULT_HOST + ) + == "https://metadata.google.internal/computeMetadata/v1/" + ) + + +def test__mount_mds_adapter_and_get_url_no_mtls(): + request = mock.Mock(spec=transport.Request) + with mock.patch( + "google.auth.compute_engine._mtls.MdsMtlsConfig._parse_mds_mode", + return_value=_metadata._mtls.MdsMtlsMode.NONE, + ): + assert ( + _metadata._mount_mds_adapter_and_get_url( + request, root=_metadata._GCE_DEFAULT_HOST + ) + == "http://metadata.google.internal/computeMetadata/v1/" + ) +@mock.patch( + "google.auth.compute_engine._mtls.mds_mtls_certificates_exist", return_value=True +) @mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") -def test__prepare_request_for_mds_mtls(mock_mds_mtls_adapter): +def test__try_mount_mds_mtls_adapter_mtls(mock_mds_mtls_adapter, mock_certs_exist): request = google_auth_requests.Request(mock.create_autospec(requests.Session)) - _metadata._prepare_request_for_mds(request, use_mtls=True) + config = _metadata._mtls.MdsMtlsConfig(mode=_metadata._mtls.MdsMtlsMode.STRICT) + assert _metadata._try_mount_mds_mtls_adapter(request, config) mock_mds_mtls_adapter.assert_called_once() assert request.session.mount.call_count == len(_metadata._GCE_DEFAULT_MDS_HOSTS) -def test__prepare_request_for_mds_no_mtls(): - request = mock.Mock() - _metadata._prepare_request_for_mds(request, use_mtls=False) - request.session.mount.assert_not_called() +@mock.patch("ssl.create_default_context") +def test__try_mount_mds_mtls_adapter_no_mtls(mock_ssl_context): + request = mock.create_autospec(transport.Request) + # Ensure it doesn't have a 'session' attribute + del request.session + assert not _metadata._try_mount_mds_mtls_adapter( + request, _metadata._mtls.MdsMtlsConfig() + ) @mock.patch("google.auth.metrics.mds_ping", return_value=MDS_PING_METRICS_HEADER_VALUE) -@mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") -@mock.patch("google.auth.compute_engine._mtls.should_use_mds_mtls", return_value=True) +@mock.patch( + "google.auth.compute_engine._metadata._try_mount_mds_mtls_adapter", + return_value=True, +) +@mock.patch( + "google.auth.compute_engine._mtls.MdsMtlsConfig._parse_mds_mode", + return_value=_metadata._mtls.MdsMtlsMode.STRICT, +) +@mock.patch( + "google.auth.compute_engine._mtls.mds_mtls_certificates_exist", + return_value=True, +) @mock.patch("google.auth.transport.requests.Request") def test_ping_mtls( - mock_request, mock_should_use_mtls, mock_mds_mtls_adapter, mock_metrics_header_value + mock_request, + mock_certs_exist, + mock_parse_mds_mode, + mock_try_mount_adapter, + mock_metrics_header_value, ): response = mock.create_autospec(transport.Response, instance=True) response.status = http_client.OK @@ -874,20 +923,32 @@ def test_ping_mtls( assert _metadata.ping(mock_request) - mock_should_use_mtls.assert_called_once() - mock_mds_mtls_adapter.assert_called_once() + mock_parse_mds_mode.assert_called() + mock_try_mount_adapter.assert_called_once() mock_request.assert_called_once_with( - url="https://169.254.169.254", + url="https://169.254.169.254/computeMetadata/v1/", method="GET", headers=MDS_PING_REQUEST_HEADER, timeout=_metadata._METADATA_DEFAULT_TIMEOUT, ) -@mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") -@mock.patch("google.auth.compute_engine._mtls.should_use_mds_mtls", return_value=True) +@mock.patch( + "google.auth.compute_engine._metadata._try_mount_mds_mtls_adapter", + return_value=True, +) +@mock.patch( + "google.auth.compute_engine._mtls.MdsMtlsConfig._parse_mds_mode", + return_value=_metadata._mtls.MdsMtlsMode.STRICT, +) +@mock.patch( + "google.auth.compute_engine._mtls.mds_mtls_certificates_exist", + return_value=True, +) @mock.patch("google.auth.transport.requests.Request") -def test_get_mtls(mock_request, mock_should_use_mtls, mock_mds_mtls_adapter): +def test_get_mtls( + mock_request, mock_certs_exist, mock_parse_mds_mode, mock_try_mount_adapter +): response = mock.create_autospec(transport.Response, instance=True) response.status = http_client.OK response.data = _helpers.to_bytes("{}") @@ -896,8 +957,8 @@ def test_get_mtls(mock_request, mock_should_use_mtls, mock_mds_mtls_adapter): _metadata.get(mock_request, "some/path") - mock_should_use_mtls.assert_called_once() - mock_mds_mtls_adapter.assert_called_once() + mock_parse_mds_mode.assert_called() + mock_try_mount_adapter.assert_called_once() mock_request.assert_called_once_with( url="https://metadata.google.internal/computeMetadata/v1/some/path", method="GET", @@ -907,58 +968,102 @@ def test_get_mtls(mock_request, mock_should_use_mtls, mock_mds_mtls_adapter): @pytest.mark.parametrize( - "mds_mode, metadata_host, expect_exception", + "mds_mode, metadata_host, certs_exist, expect_exception", [ - (_metadata._mtls.MdsMtlsMode.STRICT, _metadata._GCE_DEFAULT_HOST, False), - (_metadata._mtls.MdsMtlsMode.STRICT, _metadata._GCE_DEFAULT_MDS_IP, False), - (_metadata._mtls.MdsMtlsMode.STRICT, "custom.host", True), - (_metadata._mtls.MdsMtlsMode.NONE, "custom.host", False), - (_metadata._mtls.MdsMtlsMode.DEFAULT, _metadata._GCE_DEFAULT_HOST, False), - (_metadata._mtls.MdsMtlsMode.DEFAULT, _metadata._GCE_DEFAULT_MDS_IP, False), + ( + _metadata._mtls.MdsMtlsMode.STRICT, + _metadata._GCE_DEFAULT_HOST, + True, + False, + ), + ( + _metadata._mtls.MdsMtlsMode.STRICT, + _metadata._GCE_DEFAULT_MDS_IP, + True, + False, + ), + (_metadata._mtls.MdsMtlsMode.STRICT, "custom.host", True, True), + (_metadata._mtls.MdsMtlsMode.STRICT, _metadata._GCE_DEFAULT_HOST, False, True), + (_metadata._mtls.MdsMtlsMode.NONE, "custom.host", True, False), + ( + _metadata._mtls.MdsMtlsMode.DEFAULT, + _metadata._GCE_DEFAULT_HOST, + True, + False, + ), + ( + _metadata._mtls.MdsMtlsMode.DEFAULT, + _metadata._GCE_DEFAULT_MDS_IP, + True, + False, + ), ], ) -@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode") def test_validate_gce_mds_configured_environment( - mock_parse_mds_mode, mds_mode, metadata_host, expect_exception + mds_mode, metadata_host, certs_exist, expect_exception ): - mock_parse_mds_mode.return_value = mds_mode + # Validation now happens inside _mount_mds_adapter_and_get_url + request = mock.Mock(spec=transport.Request) with mock.patch( - "google.auth.compute_engine._metadata._GCE_METADATA_HOST", new=metadata_host + "google.auth.compute_engine._mtls.MdsMtlsConfig._parse_mds_mode", + return_value=mds_mode, ): - if expect_exception: - with pytest.raises(exceptions.MutualTLSChannelError): - _metadata._validate_gce_mds_configured_environment() - else: - _metadata._validate_gce_mds_configured_environment() - mock_parse_mds_mode.assert_called_once() + with mock.patch( + "google.auth.compute_engine._mtls.mds_mtls_certificates_exist", + return_value=certs_exist, + ): + if expect_exception: + with pytest.raises(exceptions.MutualTLSChannelError): + _metadata._mount_mds_adapter_and_get_url( + request, root=metadata_host + ) + else: + _metadata._mount_mds_adapter_and_get_url(request, root=metadata_host) +@mock.patch( + "google.auth.compute_engine._mtls.mds_mtls_certificates_exist", return_value=True +) @mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") -def test__prepare_request_for_mds_mtls_session_exists(mock_mds_mtls_adapter): +def test__try_mount_mds_mtls_adapter_mtls_session_exists( + mock_mds_mtls_adapter, mock_certs_exist +): mock_session = mock.create_autospec(requests.Session) request = google_auth_requests.Request(mock_session) - _metadata._prepare_request_for_mds(request, use_mtls=True) + config = _metadata._mtls.MdsMtlsConfig(mode=_metadata._mtls.MdsMtlsMode.STRICT) + assert _metadata._try_mount_mds_mtls_adapter(request, config) mock_mds_mtls_adapter.assert_called_once() assert mock_session.mount.call_count == len(_metadata._GCE_DEFAULT_MDS_HOSTS) +@mock.patch( + "google.auth.compute_engine._mtls.mds_mtls_certificates_exist", return_value=True +) @mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") -def test__prepare_request_for_mds_mtls_no_session(mock_mds_mtls_adapter): +def test__try_mount_mds_mtls_adapter_mtls_no_session( + mock_mds_mtls_adapter, mock_certs_exist +): request = google_auth_requests.Request(None) # Explicitly set session to None to avoid a session being created in the Request constructor. request.session = None + config = _metadata._mtls.MdsMtlsConfig(mode=_metadata._mtls.MdsMtlsMode.STRICT) with mock.patch("requests.Session") as mock_session_class: - _metadata._prepare_request_for_mds(request, use_mtls=True) + assert _metadata._try_mount_mds_mtls_adapter(request, config) mock_session_class.assert_called_once() mock_mds_mtls_adapter.assert_called_once() assert request.session.mount.call_count == len(_metadata._GCE_DEFAULT_MDS_HOSTS) +@mock.patch( + "google.auth.compute_engine._mtls.mds_mtls_certificates_exist", return_value=True +) @mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") -def test__prepare_request_for_mds_mtls_http_request(mock_mds_mtls_adapter): +def test__try_mount_mds_mtls_adapter_mtls_http_request( + mock_mds_mtls_adapter, mock_certs_exist +): """ http requests should be ignored. Regression test for https://github.com/googleapis/google-cloud-python/issues/16035 @@ -966,6 +1071,7 @@ def test__prepare_request_for_mds_mtls_http_request(mock_mds_mtls_adapter): from google.auth.transport import _http_client request = _http_client.Request() - _metadata._prepare_request_for_mds(request, use_mtls=True) + config = _metadata._mtls.MdsMtlsConfig(mode=_metadata._mtls.MdsMtlsMode.STRICT) + assert not _metadata._try_mount_mds_mtls_adapter(request, config) assert mock_mds_mtls_adapter.call_count == 0 diff --git a/packages/google-auth/tests/compute_engine/test__mtls.py b/packages/google-auth/tests/compute_engine/test__mtls.py index 6b40b6682869..a3495d16767f 100644 --- a/packages/google-auth/tests/compute_engine/test__mtls.py +++ b/packages/google-auth/tests/compute_engine/test__mtls.py @@ -21,7 +21,7 @@ import pytest # type: ignore import requests -from google.auth import environment_vars, exceptions +from google.auth import environment_vars from google.auth.compute_engine import _mtls @@ -55,7 +55,7 @@ def test__MdsMtlsConfig_non_windows_defaults(): def test__parse_mds_mode_default(monkeypatch): monkeypatch.delenv(environment_vars.GCE_METADATA_MTLS_MODE, raising=False) - assert _mtls._parse_mds_mode() == _mtls.MdsMtlsMode.DEFAULT + assert _mtls.MdsMtlsConfig()._parse_mds_mode() == _mtls.MdsMtlsMode.DEFAULT @pytest.mark.parametrize( @@ -69,50 +69,50 @@ def test__parse_mds_mode_default(monkeypatch): ) def test__parse_mds_mode_valid(monkeypatch, mode_str, expected_mode): monkeypatch.setenv(environment_vars.GCE_METADATA_MTLS_MODE, mode_str) - assert _mtls._parse_mds_mode() == expected_mode + assert _mtls.MdsMtlsConfig()._parse_mds_mode() == expected_mode def test__parse_mds_mode_invalid(monkeypatch): monkeypatch.setenv(environment_vars.GCE_METADATA_MTLS_MODE, "invalid_mode") with pytest.raises(ValueError): - _mtls._parse_mds_mode() + _mtls.MdsMtlsConfig()._parse_mds_mode() @mock.patch("os.path.exists") -def test__certs_exist_true(mock_exists, mock_mds_mtls_config): +def test_mds_mtls_certificates_exist_true(mock_exists, mock_mds_mtls_config): mock_exists.return_value = True - assert _mtls._certs_exist(mock_mds_mtls_config) is True + assert _mtls.mds_mtls_certificates_exist(mock_mds_mtls_config) is True @mock.patch("os.path.exists") -def test__certs_exist_false(mock_exists, mock_mds_mtls_config): +def test_mds_mtls_certificates_exist_false(mock_exists, mock_mds_mtls_config): mock_exists.return_value = False - assert _mtls._certs_exist(mock_mds_mtls_config) is False + assert _mtls.mds_mtls_certificates_exist(mock_mds_mtls_config) is False @pytest.mark.parametrize( "mtls_mode, certs_exist, expected_result", [ - ("strict", True, True), - ("strict", False, exceptions.MutualTLSChannelError), - ("none", True, False), - ("none", False, False), - ("default", True, True), - ("default", False, False), + (_mtls.MdsMtlsMode.STRICT, True, True), + (_mtls.MdsMtlsMode.STRICT, False, True), + (_mtls.MdsMtlsMode.NONE, True, False), + (_mtls.MdsMtlsMode.NONE, False, False), + (_mtls.MdsMtlsMode.DEFAULT, True, True), + (_mtls.MdsMtlsMode.DEFAULT, False, False), ], ) @mock.patch("os.path.exists") def test_should_use_mds_mtls( - mock_exists, monkeypatch, mtls_mode, certs_exist, expected_result + mock_exists, mtls_mode, certs_exist, expected_result, mock_mds_mtls_config ): - monkeypatch.setenv(environment_vars.GCE_METADATA_MTLS_MODE, mtls_mode) mock_exists.return_value = certs_exist + mock_mds_mtls_config.mode = mtls_mode if isinstance(expected_result, type) and issubclass(expected_result, Exception): with pytest.raises(expected_result): - _mtls.should_use_mds_mtls() + _mtls.should_use_mds_mtls(mock_mds_mtls_config) else: - assert _mtls.should_use_mds_mtls() is expected_result + assert _mtls.should_use_mds_mtls(mock_mds_mtls_config) is expected_result @mock.patch("ssl.create_default_context") @@ -171,14 +171,13 @@ def test_mds_mtls_adapter_session_request( mock_super_send.assert_called_once() -@mock.patch("requests.adapters.HTTPAdapter.send") -@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode") +@mock.patch("requests.adapters.HTTPAdapter.send") # Patch the PARENT class method @mock.patch("ssl.create_default_context") def test_mds_mtls_adapter_send_success( - mock_ssl_context, mock_parse_mds_mode, mock_super_send, mock_mds_mtls_config + mock_ssl_context, mock_super_send, mock_mds_mtls_config ): """Test the explicit 'happy path' where mTLS succeeds without error.""" - mock_parse_mds_mode.return_value = _mtls.MdsMtlsMode.DEFAULT + mock_mds_mtls_config.mode = _mtls.MdsMtlsMode.DEFAULT adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) # Setup the parent class send return value to be successful (200 OK) @@ -197,12 +196,11 @@ def test_mds_mtls_adapter_send_success( @mock.patch("google.auth.compute_engine._mtls.HTTPAdapter") -@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode") @mock.patch("ssl.create_default_context") def test_mds_mtls_adapter_send_fallback_default_mode( - mock_ssl_context, mock_parse_mds_mode, mock_http_adapter_class, mock_mds_mtls_config + mock_ssl_context, mock_http_adapter_class, mock_mds_mtls_config ): - mock_parse_mds_mode.return_value = _mtls.MdsMtlsMode.DEFAULT + mock_mds_mtls_config.mode = _mtls.MdsMtlsMode.DEFAULT adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) mock_fallback_send = mock.Mock() @@ -223,12 +221,11 @@ def test_mds_mtls_adapter_send_fallback_default_mode( @mock.patch("google.auth.compute_engine._mtls.HTTPAdapter") -@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode") @mock.patch("ssl.create_default_context") def test_mds_mtls_adapter_send_fallback_http_error( - mock_ssl_context, mock_parse_mds_mode, mock_http_adapter_class, mock_mds_mtls_config + mock_ssl_context, mock_http_adapter_class, mock_mds_mtls_config ): - mock_parse_mds_mode.return_value = _mtls.MdsMtlsMode.DEFAULT + mock_mds_mtls_config.mode = _mtls.MdsMtlsMode.DEFAULT adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) mock_fallback_send = mock.Mock() @@ -251,12 +248,11 @@ def test_mds_mtls_adapter_send_fallback_http_error( @mock.patch("requests.adapters.HTTPAdapter.send") -@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode") @mock.patch("ssl.create_default_context") def test_mds_mtls_adapter_send_no_fallback_other_exception( - mock_ssl_context, mock_parse_mds_mode, mock_http_adapter_send, mock_mds_mtls_config + mock_ssl_context, mock_http_adapter_send, mock_mds_mtls_config ): - mock_parse_mds_mode.return_value = _mtls.MdsMtlsMode.DEFAULT + mock_mds_mtls_config.mode = _mtls.MdsMtlsMode.DEFAULT adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) # Simulate HTTP exception @@ -271,12 +267,11 @@ def test_mds_mtls_adapter_send_no_fallback_other_exception( mock_http_adapter_send.assert_not_called() -@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode") @mock.patch("ssl.create_default_context") def test_mds_mtls_adapter_send_no_fallback_strict_mode( - mock_ssl_context, mock_parse_mds_mode, mock_mds_mtls_config + mock_ssl_context, mock_mds_mtls_config ): - mock_parse_mds_mode.return_value = _mtls.MdsMtlsMode.STRICT + mock_mds_mtls_config.mode = _mtls.MdsMtlsMode.STRICT adapter = _mtls.MdsMtlsAdapter(mock_mds_mtls_config) # Simulate SSLError on the super().send() call