Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 101 additions & 32 deletions packages/google-auth/google/auth/_regional_access_boundary_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(self):
)
self.refresh_manager = _RegionalAccessBoundaryRefreshManager()
self._update_lock = threading.Lock()
self._use_blocking_regional_access_boundary_lookup = False

def __getstate__(self):
"""Pickle helper that serializes the _update_lock attribute."""
Expand All @@ -109,6 +110,36 @@ def __setstate__(self, state):
self.__dict__.update(state)
self._update_lock = threading.Lock()

def __eq__(self, other):
"""Checks if two managers are equal."""
if not isinstance(other, _RegionalAccessBoundaryManager):
return NotImplemented
return (
self._data == other._data
and self._use_blocking_regional_access_boundary_lookup
== other._use_blocking_regional_access_boundary_lookup
)

def use_blocking_regional_access_boundary_lookup(self):
"""Enables blocking regional access boundary lookup to true"""
self._use_blocking_regional_access_boundary_lookup = True

def set_initial_regional_access_boundary(self, seed):
"""Manually sets the regional access boundary to the client provided seed

Args:
seed (Mapping[str, str]): The regional access boundary to use for the
credential. This should be a map with, at a minimum, an "encodedLocations"
key that maps to a hex string and an "expiry" key which maps to a
datetime.datetime.
"""
self._data = _RegionalAccessBoundaryData(
encoded_locations=seed.get("encodedLocations", None),
expiry=seed.get("expiry", None),
cooldown_expiry=None,
cooldown_duration=DEFAULT_REGIONAL_ACCESS_BOUNDARY_COOLDOWN,
)

def apply_headers(self, headers):
"""Applies the Regional Access Boundary header to the provided dictionary.

Expand Down Expand Up @@ -151,48 +182,47 @@ def maybe_start_refresh(self, credentials, request):
return

# If all checks pass, start the background refresh.
self.refresh_manager.start_refresh(credentials, request, self)


class _RegionalAccessBoundaryRefreshThread(threading.Thread):
"""Thread for background refreshing of the Regional Access Boundary."""

def __init__(self, credentials, request, rab_manager):
super().__init__()
self.daemon = True
self._credentials = credentials
self._request = request
self._rab_manager = rab_manager
if self._use_blocking_regional_access_boundary_lookup:
self.start_blocking_refresh(credentials, request)
else:
self.refresh_manager.start_refresh(credentials, request, self)

def run(self):
"""
Performs the Regional Access Boundary lookup and updates the state.
def start_blocking_refresh(self, credentials, request):
"""Initiates a blocking lookup of the Regional Access Boundary.

This method is run in a separate thread. It delegates the actual lookup
to the credentials object's `_lookup_regional_access_boundary` method.
Based on the lookup's outcome (success or complete failure after retries),
it updates the cached Regional Access Boundary information,
its expiry, its cooldown expiry, and its exponential cooldown duration.
Args:
credentials (google.auth.credentials.Credentials): The credentials to refresh.
request (google.auth.transport.Request): The object used to make HTTP requests.
"""
# Catch exceptions (e.g., from the underlying transport) to prevent the
# background thread from crashing. This ensures we can gracefully enter
# an exponential cooldown state on failure.
try:
# A blocking parameter is passed here to indicate this is a blocking lookup,
# which in turn will do two things: 1) set a timeout to 3s instead of the
# default 120s and 2) ensure we do not retry at all
blocking = True
regional_access_boundary_info = (
self._credentials._lookup_regional_access_boundary(self._request)
credentials._lookup_regional_access_boundary(request, blocking)
)
except Exception as e:
if _helpers.is_logging_enabled(_LOGGER):
_LOGGER.warning(
"Asynchronous Regional Access Boundary lookup raised an exception: %s",
"Blocking Regional Access Boundary lookup raised an exception: %s",
e,
exc_info=True,
)
regional_access_boundary_info = None

with self._rab_manager._update_lock:
self.process_regional_access_boundary_info(regional_access_boundary_info)

def process_regional_access_boundary_info(self, regional_access_boundary_info):
"""Processes the regional access boundary info and updates the state.

Args:
regional_access_boundary_info (Optional[Mapping[str, str]]): The regional access
boundary info to process.
"""
with self._update_lock:
# Capture the current state before calculating updates.
current_data = self._rab_manager._data
current_data = self._data

if regional_access_boundary_info:
# On success, update the boundary and its expiry, and clear any cooldown.
Expand All @@ -206,14 +236,12 @@ def run(self):
cooldown_duration=DEFAULT_REGIONAL_ACCESS_BOUNDARY_COOLDOWN,
)
if _helpers.is_logging_enabled(_LOGGER):
_LOGGER.debug(
"Asynchronous Regional Access Boundary lookup successful."
)
_LOGGER.debug("Regional Access Boundary lookup successful.")
else:
# On failure, calculate cooldown and update state.
if _helpers.is_logging_enabled(_LOGGER):
_LOGGER.warning(
"Asynchronous Regional Access Boundary lookup failed. Entering cooldown."
"Regional Access Boundary lookup failed. Entering cooldown."
)

next_cooldown_expiry = (
Expand Down Expand Up @@ -241,7 +269,48 @@ def run(self):
)

# Perform the atomic swap of the state object.
self._rab_manager._data = updated_data
self._data = updated_data


class _RegionalAccessBoundaryRefreshThread(threading.Thread):
"""Thread for background refreshing of the Regional Access Boundary."""

def __init__(self, credentials, request, rab_manager):
super().__init__()
self.daemon = True
self._credentials = credentials
self._request = request
self._rab_manager = rab_manager

def run(self):
"""
Performs the Regional Access Boundary lookup and updates the state.

This method is run in a separate thread. It delegates the actual lookup
to the credentials object's `_lookup_regional_access_boundary` method.
Based on the lookup's outcome (success or complete failure after retries),
it updates the cached Regional Access Boundary information,
its expiry, its cooldown expiry, and its exponential cooldown duration.
"""
# Catch exceptions (e.g., from the underlying transport) to prevent the
# background thread from crashing. This ensures we can gracefully enter
# an exponential cooldown state on failure.
try:
regional_access_boundary_info = (
self._credentials._lookup_regional_access_boundary(self._request)
)
except Exception as e:
if _helpers.is_logging_enabled(_LOGGER):
_LOGGER.warning(
"Asynchronous Regional Access Boundary lookup raised an exception: %s",
e,
exc_info=True,
)
regional_access_boundary_info = None

self._rab_manager.process_regional_access_boundary_info(
regional_access_boundary_info
)


class _RegionalAccessBoundaryRefreshManager(object):
Expand Down
46 changes: 41 additions & 5 deletions packages/google-auth/google/auth/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,32 @@ def _copy_regional_access_boundary_manager(self, target):
new_manager._data = self._rab_manager._data
target._rab_manager = new_manager

def _with_regional_access_boundary(self, seed):
"""Returns a copy of these credentials with the the regional_access_boundary
set to the provided seed. This is intended for internal use only as invalid
seeds would produce unexpected results until automatic recovery is supported.
Currently this is used by the gcloud CLI and therefore changes to the
contract MUST be backwards compatible (e.g. the method signature must be
unchanged and a copy of the credenials with the RAB set must be returned).


Returns:
google.auth.credentials.Credentials: A new credentials instance.
"""
creds = self._make_copy()
creds._rab_manager.set_initial_regional_access_boundary(seed)
return creds

def with_blocking_regional_access_boundary_lookup(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does gcloud need this method?
If we have to have it, can we make it private? Since the blocking refresh was specifically added to support gcloud, I don't want users accidentally choosing a blocking lookup not knowing what it entails.

"""Returns a copy of these credentials with the blocking lookup mode enabled.

Returns:
google.auth.credentials.Credentials: A new credentials instance.
"""
creds = self._make_copy()
creds._rab_manager.use_blocking_regional_access_boundary_lookup()
return creds

def _maybe_start_regional_access_boundary_refresh(self, request, url):
"""
Starts a background thread to refresh the Regional Access Boundary if needed.
Expand Down Expand Up @@ -421,11 +447,16 @@ def before_request(self, request, method, url, headers):
"""Refreshes the access token and triggers the Regional Access Boundary
lookup if necessary.
"""
super(CredentialsWithRegionalAccessBoundary, self).before_request(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we removed the call to super().before_request in and duplicated the logic to control the execution order, it would be good to add a test in test_credentials.py to verify that before_request correctly triggers the RAB refresh flow in sequence with the token refresh.

a small, unrelated nit: I noticed that in test_before_request, the arguments for url and method are swapped in the call to before_request. It doesn't cause failures because they aren't used by the base class. If you end up adding the unit test to this file and want to do a quick cleanup, you could swap them back to the correct order!

request, method, url, headers
)
if self._use_non_blocking_refresh:
self._non_blocking_refresh(request)
else:
self._blocking_refresh(request)

self._maybe_start_regional_access_boundary_refresh(request, url)

metrics.add_metric_header(headers, self._metric_header_for_usage())
self.apply(headers)

def refresh(self, request):
"""Refreshes the access token.

Expand All @@ -435,13 +466,16 @@ def refresh(self, request):
self._perform_refresh_token(request)

def _lookup_regional_access_boundary(
self, request: "google.auth.transport.Request" # noqa: F821
self,
request: "google.auth.transport.Request", # noqa: F821
blocking: bool = False,
) -> "Optional[Dict[str, str]]":
"""Calls the Regional Access Boundary lookup API to retrieve the Regional Access Boundary information.

Args:
request (google.auth.transport.Request): The object used to make
HTTP requests.
blocking (bool): Whether the lookup should be blocking.

Returns:
Optional[Dict[str, str]]: The Regional Access Boundary information returned by the lookup API, or None if the lookup failed.
Expand All @@ -456,7 +490,9 @@ def _lookup_regional_access_boundary(
headers: Dict[str, str] = {}
self._apply(headers)
self._rab_manager.apply_headers(headers)
return _client._lookup_regional_access_boundary(request, url, headers=headers)
return _client._lookup_regional_access_boundary(
request, url, headers=headers, blocking=blocking
)

@abc.abstractmethod
def _build_regional_access_boundary_lookup_url(
Expand Down
21 changes: 14 additions & 7 deletions packages/google-auth/google/oauth2/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
_JSON_CONTENT_TYPE = "application/json"
_JWT_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"
_REFRESH_GRANT_TYPE = "refresh_token"
_BLOCKING_REGIONAL_ACCESS_BOUNDARY_LOOKUP_TIMEOUT = 3


def _handle_error_response(response_data, retryable_error):
Expand Down Expand Up @@ -517,7 +518,7 @@ def refresh_grant(
return _handle_refresh_grant_response(response_data, refresh_token)


def _lookup_regional_access_boundary(request, url, headers=None):
def _lookup_regional_access_boundary(request, url, headers=None, blocking=False):
"""Implements the global lookup of a credential Regional Access Boundary.
For the lookup, we send a request to the global lookup endpoint and then
parse the response. Service account credentials, workload identity
Expand All @@ -527,6 +528,7 @@ def _lookup_regional_access_boundary(request, url, headers=None):
HTTP requests.
url (str): The Regional Access Boundary lookup url.
headers (Optional[Mapping[str, str]]): The headers for the request.
blocking (bool): Whether the lookup should be blocking.
Returns:
Optional[Mapping[str,list|str]]: A dictionary containing
"locations" as a list of allowed locations as strings and
Expand All @@ -541,7 +543,7 @@ def _lookup_regional_access_boundary(request, url, headers=None):
"""

response_data = _lookup_regional_access_boundary_request(
request, url, headers=headers
request, url, headers=headers, blocking=blocking
)
if response_data is None:
# Error was already logged by _lookup_regional_access_boundary_request
Expand All @@ -557,7 +559,7 @@ def _lookup_regional_access_boundary(request, url, headers=None):


def _lookup_regional_access_boundary_request(
request, url, can_retry=True, headers=None
request, url, can_retry=True, headers=None, blocking=False
):
"""Makes a request to the Regional Access Boundary lookup endpoint.

Expand All @@ -567,6 +569,7 @@ def _lookup_regional_access_boundary_request(
url (str): The Regional Access Boundary lookup url.
can_retry (bool): Enable or disable request retry behavior. Defaults to true.
headers (Optional[Mapping[str, str]]): The headers for the request.
blocking (bool): Whether the lookup should be blocking.

Returns:
Optional[Mapping[str, str]]: The JSON-decoded response data on success, or None on failure.
Expand All @@ -576,7 +579,7 @@ def _lookup_regional_access_boundary_request(
response_data,
retryable_error,
) = _lookup_regional_access_boundary_request_no_throw(
request, url, can_retry, headers
request, url, can_retry, headers, blocking
)
if not response_status_ok:
_LOGGER.warning(
Expand All @@ -589,7 +592,7 @@ def _lookup_regional_access_boundary_request(


def _lookup_regional_access_boundary_request_no_throw(
request, url, can_retry=True, headers=None
request, url, can_retry=True, headers=None, blocking=False
):
"""Makes a request to the Regional Access Boundary lookup endpoint. This
function doesn't throw on response errors.
Expand All @@ -600,6 +603,7 @@ def _lookup_regional_access_boundary_request_no_throw(
url (str): The Regional Access Boundary lookup url.
can_retry (bool): Enable or disable request retry behavior. Defaults to true.
headers (Optional[Mapping[str, str]]): The headers for the request.
blocking (bool): Whether the lookup should be blocking.

Returns:
Tuple(bool, Mapping[str, str], Optional[bool]): A boolean indicating
Expand All @@ -611,9 +615,12 @@ def _lookup_regional_access_boundary_request_no_throw(
response_data = {}
retryable_error = False

retries = _exponential_backoff.ExponentialBackoff(total_attempts=6)
timeout = _BLOCKING_REGIONAL_ACCESS_BOUNDARY_LOOKUP_TIMEOUT if blocking else None
total_attempts = 1 if blocking else 6
retries = _exponential_backoff.ExponentialBackoff(total_attempts=total_attempts)

for _ in retries:
response = request(method="GET", url=url, headers=headers)
response = request(method="GET", url=url, headers=headers, timeout=timeout)
response_body = (
response.data.decode("utf-8")
if hasattr(response.data, "decode")
Expand Down
Loading
Loading