feat: add asyncio based auth flow (#612)
* feat: asyncio http request logic and asynchronous credentials logic (#572)
Co-authored-by: Anirudh Baddepudi <[email protected]>
diff --git a/docs/reference/google.auth.credentials_async.rst b/docs/reference/google.auth.credentials_async.rst
new file mode 100644
index 0000000..683139a
--- /dev/null
+++ b/docs/reference/google.auth.credentials_async.rst
@@ -0,0 +1,7 @@
+google.auth.credentials\_async module
+=====================================
+
+.. automodule:: google.auth._credentials_async
+ :members:
+ :inherited-members:
+ :show-inheritance:
diff --git a/docs/reference/google.auth.jwt_async.rst b/docs/reference/google.auth.jwt_async.rst
new file mode 100644
index 0000000..4e56a6e
--- /dev/null
+++ b/docs/reference/google.auth.jwt_async.rst
@@ -0,0 +1,7 @@
+google.auth.jwt\_async module
+=============================
+
+.. automodule:: google.auth.jwt_async
+ :members:
+ :inherited-members:
+ :show-inheritance:
diff --git a/docs/reference/google.auth.rst b/docs/reference/google.auth.rst
index cfcf703..3acf7df 100644
--- a/docs/reference/google.auth.rst
+++ b/docs/reference/google.auth.rst
@@ -24,8 +24,10 @@
google.auth.app_engine
google.auth.credentials
+ google.auth._credentials_async
google.auth.environment_vars
google.auth.exceptions
google.auth.iam
google.auth.impersonated_credentials
google.auth.jwt
+ google.auth.jwt_async
diff --git a/docs/reference/google.auth.transport.aiohttp_requests.rst b/docs/reference/google.auth.transport.aiohttp_requests.rst
new file mode 100644
index 0000000..44fc4e5
--- /dev/null
+++ b/docs/reference/google.auth.transport.aiohttp_requests.rst
@@ -0,0 +1,7 @@
+google.auth.transport.aiohttp\_requests module
+==============================================
+
+.. automodule:: google.auth.transport._aiohttp_requests
+ :members:
+ :inherited-members:
+ :show-inheritance:
diff --git a/docs/reference/google.auth.transport.rst b/docs/reference/google.auth.transport.rst
index 8921863..f1d1988 100644
--- a/docs/reference/google.auth.transport.rst
+++ b/docs/reference/google.auth.transport.rst
@@ -12,6 +12,7 @@
.. toctree::
:maxdepth: 4
+ google.auth.transport._aiohttp_requests
google.auth.transport.grpc
google.auth.transport.mtls
google.auth.transport.requests
diff --git a/docs/reference/google.oauth2.credentials_async.rst b/docs/reference/google.oauth2.credentials_async.rst
new file mode 100644
index 0000000..d0df1e8
--- /dev/null
+++ b/docs/reference/google.oauth2.credentials_async.rst
@@ -0,0 +1,7 @@
+google.oauth2.credentials\_async module
+=======================================
+
+.. automodule:: google.oauth2._credentials_async
+ :members:
+ :inherited-members:
+ :show-inheritance:
diff --git a/docs/reference/google.oauth2.rst b/docs/reference/google.oauth2.rst
index 1ac9c73..6f3ba50 100644
--- a/docs/reference/google.oauth2.rst
+++ b/docs/reference/google.oauth2.rst
@@ -13,5 +13,7 @@
:maxdepth: 4
google.oauth2.credentials
+ google.oauth2._credentials_async
google.oauth2.id_token
google.oauth2.service_account
+ google.oauth2._service_account_async
diff --git a/docs/reference/google.oauth2.service_account_async.rst b/docs/reference/google.oauth2.service_account_async.rst
new file mode 100644
index 0000000..8aba0d8
--- /dev/null
+++ b/docs/reference/google.oauth2.service_account_async.rst
@@ -0,0 +1,7 @@
+google.oauth2.service\_account\_async module
+============================================
+
+.. automodule:: google.oauth2._service_account_async
+ :members:
+ :inherited-members:
+ :show-inheritance:
diff --git a/google/auth/__init__.py b/google/auth/__init__.py
index 5ca20a3..22d61c6 100644
--- a/google/auth/__init__.py
+++ b/google/auth/__init__.py
@@ -18,9 +18,7 @@
from google.auth._default import default, load_credentials_from_file
-
__all__ = ["default", "load_credentials_from_file"]
-
# Set default logging handler to avoid "No handler found" warnings.
logging.getLogger(__name__).addHandler(logging.NullHandler())
diff --git a/google/auth/_credentials_async.py b/google/auth/_credentials_async.py
new file mode 100644
index 0000000..d4d4e2c
--- /dev/null
+++ b/google/auth/_credentials_async.py
@@ -0,0 +1,176 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Interfaces for credentials."""
+
+import abc
+import inspect
+
+import six
+
+from google.auth import credentials
+
+
[email protected]_metaclass(abc.ABCMeta)
+class Credentials(credentials.Credentials):
+ """Async inherited credentials class from google.auth.credentials.
+ The added functionality is the before_request call which requires
+ async/await syntax.
+ All credentials have a :attr:`token` that is used for authentication and
+ may also optionally set an :attr:`expiry` to indicate when the token will
+ no longer be valid.
+
+ Most credentials will be :attr:`invalid` until :meth:`refresh` is called.
+ Credentials can do this automatically before the first HTTP request in
+ :meth:`before_request`.
+
+ Although the token and expiration will change as the credentials are
+ :meth:`refreshed <refresh>` and used, credentials should be considered
+ immutable. Various credentials will accept configuration such as private
+ keys, scopes, and other options. These options are not changeable after
+ construction. Some classes will provide mechanisms to copy the credentials
+ with modifications such as :meth:`ScopedCredentials.with_scopes`.
+ """
+
+ async def before_request(self, request, method, url, headers):
+ """Performs credential-specific before request logic.
+
+ Refreshes the credentials if necessary, then calls :meth:`apply` to
+ apply the token to the authentication header.
+
+ Args:
+ request (google.auth.transport.Request): The object used to make
+ HTTP requests.
+ method (str): The request's HTTP method or the RPC method being
+ invoked.
+ url (str): The request's URI or the RPC service's URI.
+ headers (Mapping): The request's headers.
+ """
+ # pylint: disable=unused-argument
+ # (Subclasses may use these arguments to ascertain information about
+ # the http request.)
+
+ if not self.valid:
+ if inspect.iscoroutinefunction(self.refresh):
+ await self.refresh(request)
+ else:
+ self.refresh(request)
+ self.apply(headers)
+
+
+class CredentialsWithQuotaProject(credentials.CredentialsWithQuotaProject):
+ """Abstract base for credentials supporting ``with_quota_project`` factory"""
+
+
+class AnonymousCredentials(credentials.AnonymousCredentials, Credentials):
+ """Credentials that do not provide any authentication information.
+
+ These are useful in the case of services that support anonymous access or
+ local service emulators that do not use credentials. This class inherits
+ from the sync anonymous credentials file, but is kept if async credentials
+ is initialized and we would like anonymous credentials.
+ """
+
+
[email protected]_metaclass(abc.ABCMeta)
+class ReadOnlyScoped(credentials.ReadOnlyScoped):
+ """Interface for credentials whose scopes can be queried.
+
+ OAuth 2.0-based credentials allow limiting access using scopes as described
+ in `RFC6749 Section 3.3`_.
+ If a credential class implements this interface then the credentials either
+ use scopes in their implementation.
+
+ Some credentials require scopes in order to obtain a token. You can check
+ if scoping is necessary with :attr:`requires_scopes`::
+
+ if credentials.requires_scopes:
+ # Scoping is required.
+ credentials = _credentials_async.with_scopes(scopes=['one', 'two'])
+
+ Credentials that require scopes must either be constructed with scopes::
+
+ credentials = SomeScopedCredentials(scopes=['one', 'two'])
+
+ Or must copy an existing instance using :meth:`with_scopes`::
+
+ scoped_credentials = _credentials_async.with_scopes(scopes=['one', 'two'])
+
+ Some credentials have scopes but do not allow or require scopes to be set,
+ these credentials can be used as-is.
+
+ .. _RFC6749 Section 3.3: https://tools.ietf.org/html/rfc6749#section-3.3
+ """
+
+
+class Scoped(credentials.Scoped):
+ """Interface for credentials whose scopes can be replaced while copying.
+
+ OAuth 2.0-based credentials allow limiting access using scopes as described
+ in `RFC6749 Section 3.3`_.
+ If a credential class implements this interface then the credentials either
+ use scopes in their implementation.
+
+ Some credentials require scopes in order to obtain a token. You can check
+ if scoping is necessary with :attr:`requires_scopes`::
+
+ if credentials.requires_scopes:
+ # Scoping is required.
+ credentials = _credentials_async.create_scoped(['one', 'two'])
+
+ Credentials that require scopes must either be constructed with scopes::
+
+ credentials = SomeScopedCredentials(scopes=['one', 'two'])
+
+ Or must copy an existing instance using :meth:`with_scopes`::
+
+ scoped_credentials = credentials.with_scopes(scopes=['one', 'two'])
+
+ Some credentials have scopes but do not allow or require scopes to be set,
+ these credentials can be used as-is.
+
+ .. _RFC6749 Section 3.3: https://tools.ietf.org/html/rfc6749#section-3.3
+ """
+
+
+def with_scopes_if_required(credentials, scopes):
+ """Creates a copy of the credentials with scopes if scoping is required.
+
+ This helper function is useful when you do not know (or care to know) the
+ specific type of credentials you are using (such as when you use
+ :func:`google.auth.default`). This function will call
+ :meth:`Scoped.with_scopes` if the credentials are scoped credentials and if
+ the credentials require scoping. Otherwise, it will return the credentials
+ as-is.
+
+ Args:
+ credentials (google.auth.credentials.Credentials): The credentials to
+ scope if necessary.
+ scopes (Sequence[str]): The list of scopes to use.
+
+ Returns:
+ google.auth._credentials_async.Credentials: Either a new set of scoped
+ credentials, or the passed in credentials instance if no scoping
+ was required.
+ """
+ if isinstance(credentials, Scoped) and credentials.requires_scopes:
+ return credentials.with_scopes(scopes)
+ else:
+ return credentials
+
+
[email protected]_metaclass(abc.ABCMeta)
+class Signing(credentials.Signing):
+ """Interface for credentials that can cryptographically sign messages."""
diff --git a/google/auth/_default_async.py b/google/auth/_default_async.py
new file mode 100644
index 0000000..3347fbf
--- /dev/null
+++ b/google/auth/_default_async.py
@@ -0,0 +1,266 @@
+# Copyright 2020 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Application default credentials.
+
+Implements application default credentials and project ID detection.
+"""
+
+import io
+import json
+import os
+
+import six
+
+from google.auth import _default
+from google.auth import environment_vars
+from google.auth import exceptions
+
+
+def load_credentials_from_file(filename, scopes=None, quota_project_id=None):
+ """Loads Google credentials from a file.
+
+ The credentials file must be a service account key or stored authorized
+ user credentials.
+
+ Args:
+ filename (str): The full path to the credentials file.
+ scopes (Optional[Sequence[str]]): The list of scopes for the credentials. If
+ specified, the credentials will automatically be scoped if
+ necessary
+ quota_project_id (Optional[str]): The project ID used for
+ quota and billing.
+
+ Returns:
+ Tuple[google.auth.credentials.Credentials, Optional[str]]: Loaded
+ credentials and the project ID. Authorized user credentials do not
+ have the project ID information.
+
+ Raises:
+ google.auth.exceptions.DefaultCredentialsError: if the file is in the
+ wrong format or is missing.
+ """
+ if not os.path.exists(filename):
+ raise exceptions.DefaultCredentialsError(
+ "File {} was not found.".format(filename)
+ )
+
+ with io.open(filename, "r") as file_obj:
+ try:
+ info = json.load(file_obj)
+ except ValueError as caught_exc:
+ new_exc = exceptions.DefaultCredentialsError(
+ "File {} is not a valid json file.".format(filename), caught_exc
+ )
+ six.raise_from(new_exc, caught_exc)
+
+ # The type key should indicate that the file is either a service account
+ # credentials file or an authorized user credentials file.
+ credential_type = info.get("type")
+
+ if credential_type == _default._AUTHORIZED_USER_TYPE:
+ from google.oauth2 import _credentials_async as credentials
+
+ try:
+ credentials = credentials.Credentials.from_authorized_user_info(
+ info, scopes=scopes
+ ).with_quota_project(quota_project_id)
+ except ValueError as caught_exc:
+ msg = "Failed to load authorized user credentials from {}".format(filename)
+ new_exc = exceptions.DefaultCredentialsError(msg, caught_exc)
+ six.raise_from(new_exc, caught_exc)
+ if not credentials.quota_project_id:
+ _default._warn_about_problematic_credentials(credentials)
+ return credentials, None
+
+ elif credential_type == _default._SERVICE_ACCOUNT_TYPE:
+ from google.oauth2 import _service_account_async as service_account
+
+ try:
+ credentials = service_account.Credentials.from_service_account_info(
+ info, scopes=scopes
+ ).with_quota_project(quota_project_id)
+ except ValueError as caught_exc:
+ msg = "Failed to load service account credentials from {}".format(filename)
+ new_exc = exceptions.DefaultCredentialsError(msg, caught_exc)
+ six.raise_from(new_exc, caught_exc)
+ return credentials, info.get("project_id")
+
+ else:
+ raise exceptions.DefaultCredentialsError(
+ "The file {file} does not have a valid type. "
+ "Type is {type}, expected one of {valid_types}.".format(
+ file=filename, type=credential_type, valid_types=_default._VALID_TYPES
+ )
+ )
+
+
+def _get_gcloud_sdk_credentials():
+ """Gets the credentials and project ID from the Cloud SDK."""
+ from google.auth import _cloud_sdk
+
+ # Check if application default credentials exist.
+ credentials_filename = _cloud_sdk.get_application_default_credentials_path()
+
+ if not os.path.isfile(credentials_filename):
+ return None, None
+
+ credentials, project_id = load_credentials_from_file(credentials_filename)
+
+ if not project_id:
+ project_id = _cloud_sdk.get_project_id()
+
+ return credentials, project_id
+
+
+def _get_explicit_environ_credentials():
+ """Gets credentials from the GOOGLE_APPLICATION_CREDENTIALS environment
+ variable."""
+ explicit_file = os.environ.get(environment_vars.CREDENTIALS)
+
+ if explicit_file is not None:
+ credentials, project_id = load_credentials_from_file(
+ os.environ[environment_vars.CREDENTIALS]
+ )
+
+ return credentials, project_id
+
+ else:
+ return None, None
+
+
+def _get_gae_credentials():
+ """Gets Google App Engine App Identity credentials and project ID."""
+ # While this library is normally bundled with app_engine, there are
+ # some cases where it's not available, so we tolerate ImportError.
+
+ return _default._get_gae_credentials()
+
+
+def _get_gce_credentials(request=None):
+ """Gets credentials and project ID from the GCE Metadata Service."""
+ # Ping requires a transport, but we want application default credentials
+ # to require no arguments. So, we'll use the _http_client transport which
+ # uses http.client. This is only acceptable because the metadata server
+ # doesn't do SSL and never requires proxies.
+
+ # While this library is normally bundled with compute_engine, there are
+ # some cases where it's not available, so we tolerate ImportError.
+
+ return _default._get_gce_credentials(request)
+
+
+def default_async(scopes=None, request=None, quota_project_id=None):
+ """Gets the default credentials for the current environment.
+
+ `Application Default Credentials`_ provides an easy way to obtain
+ credentials to call Google APIs for server-to-server or local applications.
+ This function acquires credentials from the environment in the following
+ order:
+
+ 1. If the environment variable ``GOOGLE_APPLICATION_CREDENTIALS`` is set
+ to the path of a valid service account JSON private key file, then it is
+ loaded and returned. The project ID returned is the project ID defined
+ in the service account file if available (some older files do not
+ contain project ID information).
+ 2. If the `Google Cloud SDK`_ is installed and has application default
+ credentials set they are loaded and returned.
+
+ To enable application default credentials with the Cloud SDK run::
+
+ gcloud auth application-default login
+
+ If the Cloud SDK has an active project, the project ID is returned. The
+ active project can be set using::
+
+ gcloud config set project
+
+ 3. If the application is running in the `App Engine standard environment`_
+ then the credentials and project ID from the `App Identity Service`_
+ are used.
+ 4. If the application is running in `Compute Engine`_ or the
+ `App Engine flexible environment`_ then the credentials and project ID
+ are obtained from the `Metadata Service`_.
+ 5. If no credentials are found,
+ :class:`~google.auth.exceptions.DefaultCredentialsError` will be raised.
+
+ .. _Application Default Credentials: https://developers.google.com\
+ /identity/protocols/application-default-credentials
+ .. _Google Cloud SDK: https://cloud.google.com/sdk
+ .. _App Engine standard environment: https://cloud.google.com/appengine
+ .. _App Identity Service: https://cloud.google.com/appengine/docs/python\
+ /appidentity/
+ .. _Compute Engine: https://cloud.google.com/compute
+ .. _App Engine flexible environment: https://cloud.google.com\
+ /appengine/flexible
+ .. _Metadata Service: https://cloud.google.com/compute/docs\
+ /storing-retrieving-metadata
+
+ Example::
+
+ import google.auth
+
+ credentials, project_id = google.auth.default()
+
+ Args:
+ scopes (Sequence[str]): The list of scopes for the credentials. If
+ specified, the credentials will automatically be scoped if
+ necessary.
+ request (google.auth.transport.Request): An object used to make
+ HTTP requests. This is used to detect whether the application
+ is running on Compute Engine. If not specified, then it will
+ use the standard library http client to make requests.
+ quota_project_id (Optional[str]): The project ID used for
+ quota and billing.
+ Returns:
+ Tuple[~google.auth.credentials.Credentials, Optional[str]]:
+ the current environment's credentials and project ID. Project ID
+ may be None, which indicates that the Project ID could not be
+ ascertained from the environment.
+
+ Raises:
+ ~google.auth.exceptions.DefaultCredentialsError:
+ If no credentials were found, or if the credentials found were
+ invalid.
+ """
+ from google.auth._credentials_async import with_scopes_if_required
+
+ explicit_project_id = os.environ.get(
+ environment_vars.PROJECT, os.environ.get(environment_vars.LEGACY_PROJECT)
+ )
+
+ checkers = (
+ _get_explicit_environ_credentials,
+ _get_gcloud_sdk_credentials,
+ _get_gae_credentials,
+ lambda: _get_gce_credentials(request),
+ )
+
+ for checker in checkers:
+ credentials, project_id = checker()
+ if credentials is not None:
+ credentials = with_scopes_if_required(
+ credentials, scopes
+ ).with_quota_project(quota_project_id)
+ effective_project_id = explicit_project_id or project_id
+ if not effective_project_id:
+ _default._LOGGER.warning(
+ "No project ID could be determined. Consider running "
+ "`gcloud config set project` or setting the %s "
+ "environment variable",
+ environment_vars.PROJECT,
+ )
+ return credentials, effective_project_id
+
+ raise exceptions.DefaultCredentialsError(_default._HELP_MESSAGE)
diff --git a/google/auth/_jwt_async.py b/google/auth/_jwt_async.py
new file mode 100644
index 0000000..49e3026
--- /dev/null
+++ b/google/auth/_jwt_async.py
@@ -0,0 +1,168 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""JSON Web Tokens
+
+Provides support for creating (encoding) and verifying (decoding) JWTs,
+especially JWTs generated and consumed by Google infrastructure.
+
+See `rfc7519`_ for more details on JWTs.
+
+To encode a JWT use :func:`encode`::
+
+ from google.auth import crypt
+ from google.auth import jwt_async
+
+ signer = crypt.Signer(private_key)
+ payload = {'some': 'payload'}
+ encoded = jwt_async.encode(signer, payload)
+
+To decode a JWT and verify claims use :func:`decode`::
+
+ claims = jwt_async.decode(encoded, certs=public_certs)
+
+You can also skip verification::
+
+ claims = jwt_async.decode(encoded, verify=False)
+
+.. _rfc7519: https://tools.ietf.org/html/rfc7519
+
+
+NOTE: This async support is experimental and marked internal. This surface may
+change in minor releases.
+"""
+
+import google.auth
+from google.auth import jwt
+
+
+def encode(signer, payload, header=None, key_id=None):
+ """Make a signed JWT.
+
+ Args:
+ signer (google.auth.crypt.Signer): The signer used to sign the JWT.
+ payload (Mapping[str, str]): The JWT payload.
+ header (Mapping[str, str]): Additional JWT header payload.
+ key_id (str): The key id to add to the JWT header. If the
+ signer has a key id it will be used as the default. If this is
+ specified it will override the signer's key id.
+
+ Returns:
+ bytes: The encoded JWT.
+ """
+ return jwt.encode(signer, payload, header, key_id)
+
+
+def decode(token, certs=None, verify=True, audience=None):
+ """Decode and verify a JWT.
+
+ Args:
+ token (str): The encoded JWT.
+ certs (Union[str, bytes, Mapping[str, Union[str, bytes]]]): The
+ certificate used to validate the JWT signature. If bytes or string,
+ it must the the public key certificate in PEM format. If a mapping,
+ it must be a mapping of key IDs to public key certificates in PEM
+ format. The mapping must contain the same key ID that's specified
+ in the token's header.
+ verify (bool): Whether to perform signature and claim validation.
+ Verification is done by default.
+ audience (str): The audience claim, 'aud', that this JWT should
+ contain. If None then the JWT's 'aud' parameter is not verified.
+
+ Returns:
+ Mapping[str, str]: The deserialized JSON payload in the JWT.
+
+ Raises:
+ ValueError: if any verification checks failed.
+ """
+
+ return jwt.decode(token, certs, verify, audience)
+
+
+class Credentials(
+ jwt.Credentials,
+ google.auth._credentials_async.Signing,
+ google.auth._credentials_async.Credentials,
+):
+ """Credentials that use a JWT as the bearer token.
+
+ These credentials require an "audience" claim. This claim identifies the
+ intended recipient of the bearer token.
+
+ The constructor arguments determine the claims for the JWT that is
+ sent with requests. Usually, you'll construct these credentials with
+ one of the helper constructors as shown in the next section.
+
+ To create JWT credentials using a Google service account private key
+ JSON file::
+
+ audience = 'https://pubsub.googleapis.com/google.pubsub.v1.Publisher'
+ credentials = jwt_async.Credentials.from_service_account_file(
+ 'service-account.json',
+ audience=audience)
+
+ If you already have the service account file loaded and parsed::
+
+ service_account_info = json.load(open('service_account.json'))
+ credentials = jwt_async.Credentials.from_service_account_info(
+ service_account_info,
+ audience=audience)
+
+ Both helper methods pass on arguments to the constructor, so you can
+ specify the JWT claims::
+
+ credentials = jwt_async.Credentials.from_service_account_file(
+ 'service-account.json',
+ audience=audience,
+ additional_claims={'meta': 'data'})
+
+ You can also construct the credentials directly if you have a
+ :class:`~google.auth.crypt.Signer` instance::
+
+ credentials = jwt_async.Credentials(
+ signer,
+ issuer='your-issuer',
+ subject='your-subject',
+ audience=audience)
+
+ The claims are considered immutable. If you want to modify the claims,
+ you can easily create another instance using :meth:`with_claims`::
+
+ new_audience = (
+ 'https://pubsub.googleapis.com/google.pubsub.v1.Subscriber')
+ new_credentials = credentials.with_claims(audience=new_audience)
+ """
+
+
+class OnDemandCredentials(
+ jwt.OnDemandCredentials,
+ google.auth._credentials_async.Signing,
+ google.auth._credentials_async.Credentials,
+):
+ """On-demand JWT credentials.
+
+ Like :class:`Credentials`, this class uses a JWT as the bearer token for
+ authentication. However, this class does not require the audience at
+ construction time. Instead, it will generate a new token on-demand for
+ each request using the request URI as the audience. It caches tokens
+ so that multiple requests to the same URI do not incur the overhead
+ of generating a new token every time.
+
+ This behavior is especially useful for `gRPC`_ clients. A gRPC service may
+ have multiple audience and gRPC clients may not know all of the audiences
+ required for accessing a particular service. With these credentials,
+ no knowledge of the audiences is required ahead of time.
+
+ .. _grpc: http://www.grpc.io/
+ """
diff --git a/google/auth/transport/_aiohttp_requests.py b/google/auth/transport/_aiohttp_requests.py
new file mode 100644
index 0000000..aaf4e2c
--- /dev/null
+++ b/google/auth/transport/_aiohttp_requests.py
@@ -0,0 +1,384 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Transport adapter for Async HTTP (aiohttp).
+
+NOTE: This async support is experimental and marked internal. This surface may
+change in minor releases.
+"""
+
+from __future__ import absolute_import
+
+import asyncio
+import functools
+
+import aiohttp
+import six
+import urllib3
+
+from google.auth import exceptions
+from google.auth import transport
+from google.auth.transport import requests
+
+# Timeout can be re-defined depending on async requirement. Currently made 60s more than
+# sync timeout.
+_DEFAULT_TIMEOUT = 180 # in seconds
+
+
+class _CombinedResponse(transport.Response):
+ """
+ In order to more closely resemble the `requests` interface, where a raw
+ and deflated content could be accessed at once, this class lazily reads the
+ stream in `transport.Response` so both return forms can be used.
+
+ The gzip and deflate transfer-encodings are automatically decoded for you
+ because the default parameter for autodecompress into the ClientSession is set
+ to False, and therefore we add this class to act as a wrapper for a user to be
+ able to access both the raw and decoded response bodies - mirroring the sync
+ implementation.
+ """
+
+ def __init__(self, response):
+ self._response = response
+ self._raw_content = None
+
+ def _is_compressed(self):
+ headers = self._response.headers
+ return "Content-Encoding" in headers and (
+ headers["Content-Encoding"] == "gzip"
+ or headers["Content-Encoding"] == "deflate"
+ )
+
+ @property
+ def status(self):
+ return self._response.status
+
+ @property
+ def headers(self):
+ return self._response.headers
+
+ @property
+ def data(self):
+ return self._response.content
+
+ async def raw_content(self):
+ if self._raw_content is None:
+ self._raw_content = await self._response.content.read()
+ return self._raw_content
+
+ async def content(self):
+ # Load raw_content if necessary
+ await self.raw_content()
+ if self._is_compressed():
+ decoder = urllib3.response.MultiDecoder(
+ self._response.headers["Content-Encoding"]
+ )
+ decompressed = decoder.decompress(self._raw_content)
+ return decompressed
+
+ return self._raw_content
+
+
+class _Response(transport.Response):
+ """
+ Requests transport response adapter.
+
+ Args:
+ response (requests.Response): The raw Requests response.
+ """
+
+ def __init__(self, response):
+ self._response = response
+
+ @property
+ def status(self):
+ return self._response.status
+
+ @property
+ def headers(self):
+ return self._response.headers
+
+ @property
+ def data(self):
+ return self._response.content
+
+
+class Request(transport.Request):
+ """Requests request adapter.
+
+ This class is used internally for making requests using asyncio transports
+ in a consistent way. If you use :class:`AuthorizedSession` you do not need
+ to construct or use this class directly.
+
+ This class can be useful if you want to manually refresh a
+ :class:`~google.auth.credentials.Credentials` instance::
+
+ import google.auth.transport.aiohttp_requests
+
+ request = google.auth.transport.aiohttp_requests.Request()
+
+ credentials.refresh(request)
+
+ Args:
+ session (aiohttp.ClientSession): An instance :class: aiohttp.ClientSession used
+ to make HTTP requests. If not specified, a session will be created.
+
+ .. automethod:: __call__
+ """
+
+ def __init__(self, session=None):
+ self.session = None
+
+ async def __call__(
+ self,
+ url,
+ method="GET",
+ body=None,
+ headers=None,
+ timeout=_DEFAULT_TIMEOUT,
+ **kwargs,
+ ):
+ """
+ Make an HTTP request using aiohttp.
+
+ Args:
+ url (str): The URL to be requested.
+ method (str): The HTTP method to use for the request. Defaults
+ to 'GET'.
+ body (bytes): The payload / body in HTTP request.
+ headers (Mapping[str, str]): Request headers.
+ timeout (Optional[int]): The number of seconds to wait for a
+ response from the server. If not specified or if None, the
+ requests default timeout will be used.
+ kwargs: Additional arguments passed through to the underlying
+ requests :meth:`~requests.Session.request` method.
+
+ Returns:
+ google.auth.transport.Response: The HTTP response.
+
+ Raises:
+ google.auth.exceptions.TransportError: If any exception occurred.
+ """
+
+ try:
+ if self.session is None: # pragma: NO COVER
+ self.session = aiohttp.ClientSession(
+ auto_decompress=False
+ ) # pragma: NO COVER
+ requests._LOGGER.debug("Making request: %s %s", method, url)
+ response = await self.session.request(
+ method, url, data=body, headers=headers, timeout=timeout, **kwargs
+ )
+ return _CombinedResponse(response)
+
+ except aiohttp.ClientError as caught_exc:
+ new_exc = exceptions.TransportError(caught_exc)
+ six.raise_from(new_exc, caught_exc)
+
+ except asyncio.TimeoutError as caught_exc:
+ new_exc = exceptions.TransportError(caught_exc)
+ six.raise_from(new_exc, caught_exc)
+
+
+class AuthorizedSession(aiohttp.ClientSession):
+ """This is an async implementation of the Authorized Session class. We utilize an
+ aiohttp transport instance, and the interface mirrors the google.auth.transport.requests
+ Authorized Session class, except for the change in the transport used in the async use case.
+
+ A Requests Session class with credentials.
+
+ This class is used to perform requests to API endpoints that require
+ authorization::
+
+ from google.auth.transport import aiohttp_requests
+
+ async with aiohttp_requests.AuthorizedSession(credentials) as authed_session:
+ response = await authed_session.request(
+ 'GET', 'https://www.googleapis.com/storage/v1/b')
+
+ The underlying :meth:`request` implementation handles adding the
+ credentials' headers to the request and refreshing credentials as needed.
+
+ Args:
+ credentials (google.auth._credentials_async.Credentials): The credentials to
+ add to the request.
+ refresh_status_codes (Sequence[int]): Which HTTP status codes indicate
+ that credentials should be refreshed and the request should be
+ retried.
+ max_refresh_attempts (int): The maximum number of times to attempt to
+ refresh the credentials and retry the request.
+ refresh_timeout (Optional[int]): The timeout value in seconds for
+ credential refresh HTTP requests.
+ auth_request (google.auth.transport.aiohttp_requests.Request):
+ (Optional) An instance of
+ :class:`~google.auth.transport.aiohttp_requests.Request` used when
+ refreshing credentials. If not passed,
+ an instance of :class:`~google.auth.transport.aiohttp_requests.Request`
+ is created.
+ """
+
+ def __init__(
+ self,
+ credentials,
+ refresh_status_codes=transport.DEFAULT_REFRESH_STATUS_CODES,
+ max_refresh_attempts=transport.DEFAULT_MAX_REFRESH_ATTEMPTS,
+ refresh_timeout=None,
+ auth_request=None,
+ auto_decompress=False,
+ ):
+ super(AuthorizedSession, self).__init__()
+ self.credentials = credentials
+ self._refresh_status_codes = refresh_status_codes
+ self._max_refresh_attempts = max_refresh_attempts
+ self._refresh_timeout = refresh_timeout
+ self._is_mtls = False
+ self._auth_request = auth_request
+ self._auth_request_session = None
+ self._loop = asyncio.get_event_loop()
+ self._refresh_lock = asyncio.Lock()
+ self._auto_decompress = auto_decompress
+
+ async def request(
+ self,
+ method,
+ url,
+ data=None,
+ headers=None,
+ max_allowed_time=None,
+ timeout=_DEFAULT_TIMEOUT,
+ auto_decompress=False,
+ **kwargs,
+ ):
+
+ """Implementation of Authorized Session aiohttp request.
+
+ Args:
+ method: The http request method used (e.g. GET, PUT, DELETE)
+
+ url: The url at which the http request is sent.
+
+ data, headers: These fields parallel the associated data and headers
+ fields of a regular http request. Using the aiohttp client session to
+ send the http request allows us to use this parallel corresponding structure
+ in our Authorized Session class.
+
+ timeout (Optional[Union[float, aiohttp.ClientTimeout]]):
+ The amount of time in seconds to wait for the server response
+ with each individual request.
+
+ Can also be passed as an `aiohttp.ClientTimeout` object.
+
+ max_allowed_time (Optional[float]):
+ If the method runs longer than this, a ``Timeout`` exception is
+ automatically raised. Unlike the ``timeout` parameter, this
+ value applies to the total method execution time, even if
+ multiple requests are made under the hood.
+
+ Mind that it is not guaranteed that the timeout error is raised
+ at ``max_allowed_time`. It might take longer, for example, if
+ an underlying request takes a lot of time, but the request
+ itself does not timeout, e.g. if a large file is being
+ transmitted. The timout error will be raised after such
+ request completes.
+ """
+ # Headers come in as bytes which isn't expected behavior, the resumable
+ # media libraries in some cases expect a str type for the header values,
+ # but sometimes the operations return these in bytes types.
+ if headers:
+ for key in headers.keys():
+ if type(headers[key]) is bytes:
+ headers[key] = headers[key].decode("utf-8")
+
+ async with aiohttp.ClientSession(
+ auto_decompress=self._auto_decompress
+ ) as self._auth_request_session:
+ auth_request = Request(self._auth_request_session)
+ self._auth_request = auth_request
+
+ # Use a kwarg for this instead of an attribute to maintain
+ # thread-safety.
+ _credential_refresh_attempt = kwargs.pop("_credential_refresh_attempt", 0)
+ # Make a copy of the headers. They will be modified by the credentials
+ # and we want to pass the original headers if we recurse.
+ request_headers = headers.copy() if headers is not None else {}
+
+ # Do not apply the timeout unconditionally in order to not override the
+ # _auth_request's default timeout.
+ auth_request = (
+ self._auth_request
+ if timeout is None
+ else functools.partial(self._auth_request, timeout=timeout)
+ )
+
+ remaining_time = max_allowed_time
+
+ with requests.TimeoutGuard(remaining_time, asyncio.TimeoutError) as guard:
+ await self.credentials.before_request(
+ auth_request, method, url, request_headers
+ )
+
+ with requests.TimeoutGuard(remaining_time, asyncio.TimeoutError) as guard:
+ response = await super(AuthorizedSession, self).request(
+ method,
+ url,
+ data=data,
+ headers=request_headers,
+ timeout=timeout,
+ **kwargs,
+ )
+
+ remaining_time = guard.remaining_timeout
+
+ if (
+ response.status in self._refresh_status_codes
+ and _credential_refresh_attempt < self._max_refresh_attempts
+ ):
+
+ requests._LOGGER.info(
+ "Refreshing credentials due to a %s response. Attempt %s/%s.",
+ response.status,
+ _credential_refresh_attempt + 1,
+ self._max_refresh_attempts,
+ )
+
+ # Do not apply the timeout unconditionally in order to not override the
+ # _auth_request's default timeout.
+ auth_request = (
+ self._auth_request
+ if timeout is None
+ else functools.partial(self._auth_request, timeout=timeout)
+ )
+
+ with requests.TimeoutGuard(
+ remaining_time, asyncio.TimeoutError
+ ) as guard:
+ async with self._refresh_lock:
+ await self._loop.run_in_executor(
+ None, self.credentials.refresh, auth_request
+ )
+
+ remaining_time = guard.remaining_timeout
+
+ return await self.request(
+ method,
+ url,
+ data=data,
+ headers=headers,
+ max_allowed_time=remaining_time,
+ timeout=timeout,
+ _credential_refresh_attempt=_credential_refresh_attempt + 1,
+ **kwargs,
+ )
+
+ return response
diff --git a/google/auth/transport/mtls.py b/google/auth/transport/mtls.py
index 5b74230..b40bfbe 100644
--- a/google/auth/transport/mtls.py
+++ b/google/auth/transport/mtls.py
@@ -86,9 +86,12 @@
def callback():
try:
- _, cert_bytes, key_bytes, passphrase_bytes = _mtls_helper.get_client_ssl_credentials(
- generate_encrypted_key=True
- )
+ (
+ _,
+ cert_bytes,
+ key_bytes,
+ passphrase_bytes,
+ ) = _mtls_helper.get_client_ssl_credentials(generate_encrypted_key=True)
with open(cert_path, "wb") as cert_file:
cert_file.write(cert_bytes)
with open(key_path, "wb") as key_file:
diff --git a/google/oauth2/_client_async.py b/google/oauth2/_client_async.py
new file mode 100644
index 0000000..4817ea4
--- /dev/null
+++ b/google/oauth2/_client_async.py
@@ -0,0 +1,264 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""OAuth 2.0 async client.
+
+This is a client for interacting with an OAuth 2.0 authorization server's
+token endpoint.
+
+For more information about the token endpoint, see
+`Section 3.1 of rfc6749`_
+
+.. _Section 3.1 of rfc6749: https://tools.ietf.org/html/rfc6749#section-3.2
+"""
+
+import datetime
+import json
+
+import six
+from six.moves import http_client
+from six.moves import urllib
+
+from google.auth import _helpers
+from google.auth import exceptions
+from google.auth import jwt
+from google.oauth2 import _client as client
+
+
+def _handle_error_response(response_body):
+ """"Translates an error response into an exception.
+
+ Args:
+ response_body (str): The decoded response data.
+
+ Raises:
+ google.auth.exceptions.RefreshError
+ """
+ try:
+ error_data = json.loads(response_body)
+ error_details = "{}: {}".format(
+ error_data["error"], error_data.get("error_description")
+ )
+ # If no details could be extracted, use the response data.
+ except (KeyError, ValueError):
+ error_details = response_body
+
+ raise exceptions.RefreshError(error_details, response_body)
+
+
+def _parse_expiry(response_data):
+ """Parses the expiry field from a response into a datetime.
+
+ Args:
+ response_data (Mapping): The JSON-parsed response data.
+
+ Returns:
+ Optional[datetime]: The expiration or ``None`` if no expiration was
+ specified.
+ """
+ expires_in = response_data.get("expires_in", None)
+
+ if expires_in is not None:
+ return _helpers.utcnow() + datetime.timedelta(seconds=expires_in)
+ else:
+ return None
+
+
+async def _token_endpoint_request(request, token_uri, body):
+ """Makes a request to the OAuth 2.0 authorization server's token endpoint.
+
+ Args:
+ request (google.auth.transport.Request): A callable used to make
+ HTTP requests.
+ token_uri (str): The OAuth 2.0 authorizations server's token endpoint
+ URI.
+ body (Mapping[str, str]): The parameters to send in the request body.
+
+ Returns:
+ Mapping[str, str]: The JSON-decoded response data.
+
+ Raises:
+ google.auth.exceptions.RefreshError: If the token endpoint returned
+ an error.
+ """
+ body = urllib.parse.urlencode(body).encode("utf-8")
+ headers = {"content-type": client._URLENCODED_CONTENT_TYPE}
+
+ retry = 0
+ # retry to fetch token for maximum of two times if any internal failure
+ # occurs.
+ while True:
+
+ response = await request(
+ method="POST", url=token_uri, headers=headers, body=body
+ )
+
+ # Using data.read() resulted in zlib decompression errors. This may require future investigation.
+ response_body1 = await response.content()
+
+ response_body = (
+ response_body1.decode("utf-8")
+ if hasattr(response_body1, "decode")
+ else response_body1
+ )
+
+ response_data = json.loads(response_body)
+
+ if response.status == http_client.OK:
+ break
+ else:
+ error_desc = response_data.get("error_description") or ""
+ error_code = response_data.get("error") or ""
+ if (
+ any(e == "internal_failure" for e in (error_code, error_desc))
+ and retry < 1
+ ):
+ retry += 1
+ continue
+ _handle_error_response(response_body)
+
+ return response_data
+
+
+async def jwt_grant(request, token_uri, assertion):
+ """Implements the JWT Profile for OAuth 2.0 Authorization Grants.
+
+ For more details, see `rfc7523 section 4`_.
+
+ Args:
+ request (google.auth.transport.Request): A callable used to make
+ HTTP requests.
+ token_uri (str): The OAuth 2.0 authorizations server's token endpoint
+ URI.
+ assertion (str): The OAuth 2.0 assertion.
+
+ Returns:
+ Tuple[str, Optional[datetime], Mapping[str, str]]: The access token,
+ expiration, and additional data returned by the token endpoint.
+
+ Raises:
+ google.auth.exceptions.RefreshError: If the token endpoint returned
+ an error.
+
+ .. _rfc7523 section 4: https://tools.ietf.org/html/rfc7523#section-4
+ """
+ body = {"assertion": assertion, "grant_type": client._JWT_GRANT_TYPE}
+
+ response_data = await _token_endpoint_request(request, token_uri, body)
+
+ try:
+ access_token = response_data["access_token"]
+ except KeyError as caught_exc:
+ new_exc = exceptions.RefreshError("No access token in response.", response_data)
+ six.raise_from(new_exc, caught_exc)
+
+ expiry = _parse_expiry(response_data)
+
+ return access_token, expiry, response_data
+
+
+async def id_token_jwt_grant(request, token_uri, assertion):
+ """Implements the JWT Profile for OAuth 2.0 Authorization Grants, but
+ requests an OpenID Connect ID Token instead of an access token.
+
+ This is a variant on the standard JWT Profile that is currently unique
+ to Google. This was added for the benefit of authenticating to services
+ that require ID Tokens instead of access tokens or JWT bearer tokens.
+
+ Args:
+ request (google.auth.transport.Request): A callable used to make
+ HTTP requests.
+ token_uri (str): The OAuth 2.0 authorization server's token endpoint
+ URI.
+ assertion (str): JWT token signed by a service account. The token's
+ payload must include a ``target_audience`` claim.
+
+ Returns:
+ Tuple[str, Optional[datetime], Mapping[str, str]]:
+ The (encoded) Open ID Connect ID Token, expiration, and additional
+ data returned by the endpoint.
+
+ Raises:
+ google.auth.exceptions.RefreshError: If the token endpoint returned
+ an error.
+ """
+ body = {"assertion": assertion, "grant_type": client._JWT_GRANT_TYPE}
+
+ response_data = await _token_endpoint_request(request, token_uri, body)
+
+ try:
+ id_token = response_data["id_token"]
+ except KeyError as caught_exc:
+ new_exc = exceptions.RefreshError("No ID token in response.", response_data)
+ six.raise_from(new_exc, caught_exc)
+
+ payload = jwt.decode(id_token, verify=False)
+ expiry = datetime.datetime.utcfromtimestamp(payload["exp"])
+
+ return id_token, expiry, response_data
+
+
+async def refresh_grant(
+ request, token_uri, refresh_token, client_id, client_secret, scopes=None
+):
+ """Implements the OAuth 2.0 refresh token grant.
+
+ For more details, see `rfc678 section 6`_.
+
+ Args:
+ request (google.auth.transport.Request): A callable used to make
+ HTTP requests.
+ token_uri (str): The OAuth 2.0 authorizations server's token endpoint
+ URI.
+ refresh_token (str): The refresh token to use to get a new access
+ token.
+ client_id (str): The OAuth 2.0 application's client ID.
+ client_secret (str): The Oauth 2.0 appliaction's client secret.
+ scopes (Optional(Sequence[str])): Scopes to request. If present, all
+ scopes must be authorized for the refresh token. Useful if refresh
+ token has a wild card scope (e.g.
+ 'https://www.googleapis.com/auth/any-api').
+
+ Returns:
+ Tuple[str, Optional[str], Optional[datetime], Mapping[str, str]]: The
+ access token, new refresh token, expiration, and additional data
+ returned by the token endpoint.
+
+ Raises:
+ google.auth.exceptions.RefreshError: If the token endpoint returned
+ an error.
+
+ .. _rfc6748 section 6: https://tools.ietf.org/html/rfc6749#section-6
+ """
+ body = {
+ "grant_type": client._REFRESH_GRANT_TYPE,
+ "client_id": client_id,
+ "client_secret": client_secret,
+ "refresh_token": refresh_token,
+ }
+ if scopes:
+ body["scope"] = " ".join(scopes)
+
+ response_data = await _token_endpoint_request(request, token_uri, body)
+
+ try:
+ access_token = response_data["access_token"]
+ except KeyError as caught_exc:
+ new_exc = exceptions.RefreshError("No access token in response.", response_data)
+ six.raise_from(new_exc, caught_exc)
+
+ refresh_token = response_data.get("refresh_token", refresh_token)
+ expiry = _parse_expiry(response_data)
+
+ return access_token, refresh_token, expiry, response_data
diff --git a/google/oauth2/_credentials_async.py b/google/oauth2/_credentials_async.py
new file mode 100644
index 0000000..eb3e97c
--- /dev/null
+++ b/google/oauth2/_credentials_async.py
@@ -0,0 +1,108 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""OAuth 2.0 Async Credentials.
+
+This module provides credentials based on OAuth 2.0 access and refresh tokens.
+These credentials usually access resources on behalf of a user (resource
+owner).
+
+Specifically, this is intended to use access tokens acquired using the
+`Authorization Code grant`_ and can refresh those tokens using a
+optional `refresh token`_.
+
+Obtaining the initial access and refresh token is outside of the scope of this
+module. Consult `rfc6749 section 4.1`_ for complete details on the
+Authorization Code grant flow.
+
+.. _Authorization Code grant: https://tools.ietf.org/html/rfc6749#section-1.3.1
+.. _refresh token: https://tools.ietf.org/html/rfc6749#section-6
+.. _rfc6749 section 4.1: https://tools.ietf.org/html/rfc6749#section-4.1
+"""
+
+from google.auth import _credentials_async as credentials
+from google.auth import _helpers
+from google.auth import exceptions
+from google.oauth2 import _client_async as _client
+from google.oauth2 import credentials as oauth2_credentials
+
+
+class Credentials(oauth2_credentials.Credentials):
+ """Credentials using OAuth 2.0 access and refresh tokens.
+
+ The credentials are considered immutable. If you want to modify the
+ quota project, use :meth:`with_quota_project` or ::
+
+ credentials = credentials.with_quota_project('myproject-123)
+ """
+
+ @_helpers.copy_docstring(credentials.Credentials)
+ async def refresh(self, request):
+ if (
+ self._refresh_token is None
+ or self._token_uri is None
+ or self._client_id is None
+ or self._client_secret is None
+ ):
+ raise exceptions.RefreshError(
+ "The credentials do not contain the necessary fields need to "
+ "refresh the access token. You must specify refresh_token, "
+ "token_uri, client_id, and client_secret."
+ )
+
+ (
+ access_token,
+ refresh_token,
+ expiry,
+ grant_response,
+ ) = await _client.refresh_grant(
+ request,
+ self._token_uri,
+ self._refresh_token,
+ self._client_id,
+ self._client_secret,
+ self._scopes,
+ )
+
+ self.token = access_token
+ self.expiry = expiry
+ self._refresh_token = refresh_token
+ self._id_token = grant_response.get("id_token")
+
+ if self._scopes and "scopes" in grant_response:
+ requested_scopes = frozenset(self._scopes)
+ granted_scopes = frozenset(grant_response["scopes"].split())
+ scopes_requested_but_not_granted = requested_scopes - granted_scopes
+ if scopes_requested_but_not_granted:
+ raise exceptions.RefreshError(
+ "Not all requested scopes were granted by the "
+ "authorization server, missing scopes {}.".format(
+ ", ".join(scopes_requested_but_not_granted)
+ )
+ )
+
+
+class UserAccessTokenCredentials(oauth2_credentials.UserAccessTokenCredentials):
+ """Access token credentials for user account.
+
+ Obtain the access token for a given user account or the current active
+ user account with the ``gcloud auth print-access-token`` command.
+
+ Args:
+ account (Optional[str]): Account to get the access token for. If not
+ specified, the current active account will be used.
+ quota_project_id (Optional[str]): The project ID used for quota
+ and billing.
+
+ """
diff --git a/google/oauth2/_id_token_async.py b/google/oauth2/_id_token_async.py
new file mode 100644
index 0000000..f5ef8ba
--- /dev/null
+++ b/google/oauth2/_id_token_async.py
@@ -0,0 +1,267 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Google ID Token helpers.
+
+Provides support for verifying `OpenID Connect ID Tokens`_, especially ones
+generated by Google infrastructure.
+
+To parse and verify an ID Token issued by Google's OAuth 2.0 authorization
+server use :func:`verify_oauth2_token`. To verify an ID Token issued by
+Firebase, use :func:`verify_firebase_token`.
+
+A general purpose ID Token verifier is available as :func:`verify_token`.
+
+Example::
+
+ from google.oauth2 import _id_token_async
+ from google.auth.transport import aiohttp_requests
+
+ request = aiohttp_requests.Request()
+
+ id_info = await _id_token_async.verify_oauth2_token(
+ token, request, 'my-client-id.example.com')
+
+ if id_info['iss'] != 'https://accounts.google.com':
+ raise ValueError('Wrong issuer.')
+
+ userid = id_info['sub']
+
+By default, this will re-fetch certificates for each verification. Because
+Google's public keys are only changed infrequently (on the order of once per
+day), you may wish to take advantage of caching to reduce latency and the
+potential for network errors. This can be accomplished using an external
+library like `CacheControl`_ to create a cache-aware
+:class:`google.auth.transport.Request`::
+
+ import cachecontrol
+ import google.auth.transport.requests
+ import requests
+
+ session = requests.session()
+ cached_session = cachecontrol.CacheControl(session)
+ request = google.auth.transport.requests.Request(session=cached_session)
+
+.. _OpenID Connect ID Token:
+ http://openid.net/specs/openid-connect-core-1_0.html#IDToken
+.. _CacheControl: https://cachecontrol.readthedocs.io
+"""
+
+import json
+import os
+
+import six
+from six.moves import http_client
+
+from google.auth import environment_vars
+from google.auth import exceptions
+from google.auth import jwt
+from google.auth.transport import requests
+from google.oauth2 import id_token as sync_id_token
+
+
+async def _fetch_certs(request, certs_url):
+ """Fetches certificates.
+
+ Google-style cerificate endpoints return JSON in the format of
+ ``{'key id': 'x509 certificate'}``.
+
+ Args:
+ request (google.auth.transport.Request): The object used to make
+ HTTP requests. This must be an aiohttp request.
+ certs_url (str): The certificate endpoint URL.
+
+ Returns:
+ Mapping[str, str]: A mapping of public key ID to x.509 certificate
+ data.
+ """
+ response = await request(certs_url, method="GET")
+
+ if response.status != http_client.OK:
+ raise exceptions.TransportError(
+ "Could not fetch certificates at {}".format(certs_url)
+ )
+
+ data = await response.data.read()
+
+ return json.loads(json.dumps(data))
+
+
+async def verify_token(
+ id_token, request, audience=None, certs_url=sync_id_token._GOOGLE_OAUTH2_CERTS_URL
+):
+ """Verifies an ID token and returns the decoded token.
+
+ Args:
+ id_token (Union[str, bytes]): The encoded token.
+ request (google.auth.transport.Request): The object used to make
+ HTTP requests. This must be an aiohttp request.
+ audience (str): The audience that this token is intended for. If None
+ then the audience is not verified.
+ certs_url (str): The URL that specifies the certificates to use to
+ verify the token. This URL should return JSON in the format of
+ ``{'key id': 'x509 certificate'}``.
+
+ Returns:
+ Mapping[str, Any]: The decoded token.
+ """
+ certs = await _fetch_certs(request, certs_url)
+
+ return jwt.decode(id_token, certs=certs, audience=audience)
+
+
+async def verify_oauth2_token(id_token, request, audience=None):
+ """Verifies an ID Token issued by Google's OAuth 2.0 authorization server.
+
+ Args:
+ id_token (Union[str, bytes]): The encoded token.
+ request (google.auth.transport.Request): The object used to make
+ HTTP requests. This must be an aiohttp request.
+ audience (str): The audience that this token is intended for. This is
+ typically your application's OAuth 2.0 client ID. If None then the
+ audience is not verified.
+
+ Returns:
+ Mapping[str, Any]: The decoded token.
+
+ Raises:
+ exceptions.GoogleAuthError: If the issuer is invalid.
+ """
+ idinfo = await verify_token(
+ id_token,
+ request,
+ audience=audience,
+ certs_url=sync_id_token._GOOGLE_OAUTH2_CERTS_URL,
+ )
+
+ if idinfo["iss"] not in sync_id_token._GOOGLE_ISSUERS:
+ raise exceptions.GoogleAuthError(
+ "Wrong issuer. 'iss' should be one of the following: {}".format(
+ sync_id_token._GOOGLE_ISSUERS
+ )
+ )
+
+ return idinfo
+
+
+async def verify_firebase_token(id_token, request, audience=None):
+ """Verifies an ID Token issued by Firebase Authentication.
+
+ Args:
+ id_token (Union[str, bytes]): The encoded token.
+ request (google.auth.transport.Request): The object used to make
+ HTTP requests. This must be an aiohttp request.
+ audience (str): The audience that this token is intended for. This is
+ typically your Firebase application ID. If None then the audience
+ is not verified.
+
+ Returns:
+ Mapping[str, Any]: The decoded token.
+ """
+ return await verify_token(
+ id_token,
+ request,
+ audience=audience,
+ certs_url=sync_id_token._GOOGLE_APIS_CERTS_URL,
+ )
+
+
+async def fetch_id_token(request, audience):
+ """Fetch the ID Token from the current environment.
+
+ This function acquires ID token from the environment in the following order:
+
+ 1. If the application is running in Compute Engine, App Engine or Cloud Run,
+ then the ID token are obtained from the metadata server.
+ 2. If the environment variable ``GOOGLE_APPLICATION_CREDENTIALS`` is set
+ to the path of a valid service account JSON file, then ID token is
+ acquired using this service account credentials.
+ 3. If metadata server doesn't exist and no valid service account credentials
+ are found, :class:`~google.auth.exceptions.DefaultCredentialsError` will
+ be raised.
+
+ Example::
+
+ import google.oauth2._id_token_async
+ import google.auth.transport.aiohttp_requests
+
+ request = google.auth.transport.aiohttp_requests.Request()
+ target_audience = "https://pubsub.googleapis.com"
+
+ id_token = await google.oauth2._id_token_async.fetch_id_token(request, target_audience)
+
+ Args:
+ request (google.auth.transport.aiohttp_requests.Request): A callable used to make
+ HTTP requests.
+ audience (str): The audience that this ID token is intended for.
+
+ Returns:
+ str: The ID token.
+
+ Raises:
+ ~google.auth.exceptions.DefaultCredentialsError:
+ If metadata server doesn't exist and no valid service account
+ credentials are found.
+ """
+ # 1. First try to fetch ID token from metadata server if it exists. The code
+ # works for GAE and Cloud Run metadata server as well.
+ try:
+ from google.auth import compute_engine
+
+ request_new = requests.Request()
+ credentials = compute_engine.IDTokenCredentials(
+ request_new, audience, use_metadata_identity_endpoint=True
+ )
+ credentials.refresh(request_new)
+
+ return credentials.token
+
+ except (ImportError, exceptions.TransportError, exceptions.RefreshError):
+ pass
+
+ # 2. Try to use service account credentials to get ID token.
+
+ # Try to get credentials from the GOOGLE_APPLICATION_CREDENTIALS environment
+ # variable.
+ credentials_filename = os.environ.get(environment_vars.CREDENTIALS)
+ if not (
+ credentials_filename
+ and os.path.exists(credentials_filename)
+ and os.path.isfile(credentials_filename)
+ ):
+ raise exceptions.DefaultCredentialsError(
+ "Neither metadata server or valid service account credentials are found."
+ )
+
+ try:
+ with open(credentials_filename, "r") as f:
+ info = json.load(f)
+ credentials_content = (
+ (info.get("type") == "service_account") and info or None
+ )
+
+ from google.oauth2 import _service_account_async as service_account
+
+ credentials = service_account.IDTokenCredentials.from_service_account_info(
+ credentials_content, target_audience=audience
+ )
+ except ValueError as caught_exc:
+ new_exc = exceptions.DefaultCredentialsError(
+ "Neither metadata server or valid service account credentials are found.",
+ caught_exc,
+ )
+ six.raise_from(new_exc, caught_exc)
+
+ await credentials.refresh(request)
+ return credentials.token
diff --git a/google/oauth2/_service_account_async.py b/google/oauth2/_service_account_async.py
new file mode 100644
index 0000000..0a4e724
--- /dev/null
+++ b/google/oauth2/_service_account_async.py
@@ -0,0 +1,132 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Service Accounts: JSON Web Token (JWT) Profile for OAuth 2.0
+
+NOTE: This file adds asynchronous refresh methods to both credentials
+classes, and therefore async/await syntax is required when calling this
+method when using service account credentials with asynchronous functionality.
+Otherwise, all other methods are inherited from the regular service account
+credentials file google.oauth2.service_account
+
+"""
+
+from google.auth import _credentials_async as credentials_async
+from google.auth import _helpers
+from google.oauth2 import _client_async
+from google.oauth2 import service_account
+
+
+class Credentials(
+ service_account.Credentials, credentials_async.Scoped, credentials_async.Credentials
+):
+ """Service account credentials
+
+ Usually, you'll create these credentials with one of the helper
+ constructors. To create credentials using a Google service account
+ private key JSON file::
+
+ credentials = _service_account_async.Credentials.from_service_account_file(
+ 'service-account.json')
+
+ Or if you already have the service account file loaded::
+
+ service_account_info = json.load(open('service_account.json'))
+ credentials = _service_account_async.Credentials.from_service_account_info(
+ service_account_info)
+
+ Both helper methods pass on arguments to the constructor, so you can
+ specify additional scopes and a subject if necessary::
+
+ credentials = _service_account_async.Credentials.from_service_account_file(
+ 'service-account.json',
+ scopes=['email'],
+ subject='[email protected]')
+
+ The credentials are considered immutable. If you want to modify the scopes
+ or the subject used for delegation, use :meth:`with_scopes` or
+ :meth:`with_subject`::
+
+ scoped_credentials = credentials.with_scopes(['email'])
+ delegated_credentials = credentials.with_subject(subject)
+
+ To add a quota project, use :meth:`with_quota_project`::
+
+ credentials = credentials.with_quota_project('myproject-123')
+ """
+
+ @_helpers.copy_docstring(credentials_async.Credentials)
+ async def refresh(self, request):
+ assertion = self._make_authorization_grant_assertion()
+ access_token, expiry, _ = await _client_async.jwt_grant(
+ request, self._token_uri, assertion
+ )
+ self.token = access_token
+ self.expiry = expiry
+
+
+class IDTokenCredentials(
+ service_account.IDTokenCredentials,
+ credentials_async.Signing,
+ credentials_async.Credentials,
+):
+ """Open ID Connect ID Token-based service account credentials.
+
+ These credentials are largely similar to :class:`.Credentials`, but instead
+ of using an OAuth 2.0 Access Token as the bearer token, they use an Open
+ ID Connect ID Token as the bearer token. These credentials are useful when
+ communicating to services that require ID Tokens and can not accept access
+ tokens.
+
+ Usually, you'll create these credentials with one of the helper
+ constructors. To create credentials using a Google service account
+ private key JSON file::
+
+ credentials = (
+ _service_account_async.IDTokenCredentials.from_service_account_file(
+ 'service-account.json'))
+
+ Or if you already have the service account file loaded::
+
+ service_account_info = json.load(open('service_account.json'))
+ credentials = (
+ _service_account_async.IDTokenCredentials.from_service_account_info(
+ service_account_info))
+
+ Both helper methods pass on arguments to the constructor, so you can
+ specify additional scopes and a subject if necessary::
+
+ credentials = (
+ _service_account_async.IDTokenCredentials.from_service_account_file(
+ 'service-account.json',
+ scopes=['email'],
+ subject='[email protected]'))
+`
+ The credentials are considered immutable. If you want to modify the scopes
+ or the subject used for delegation, use :meth:`with_scopes` or
+ :meth:`with_subject`::
+
+ scoped_credentials = credentials.with_scopes(['email'])
+ delegated_credentials = credentials.with_subject(subject)
+
+ """
+
+ @_helpers.copy_docstring(credentials_async.Credentials)
+ async def refresh(self, request):
+ assertion = self._make_authorization_grant_assertion()
+ access_token, expiry, _ = await _client_async.id_token_jwt_grant(
+ request, self._token_uri, assertion
+ )
+ self.token = access_token
+ self.expiry = expiry
diff --git a/noxfile.py b/noxfile.py
index c39f27c..d497f53 100644
--- a/noxfile.py
+++ b/noxfile.py
@@ -29,8 +29,18 @@
"responses",
"grpcio",
]
+
+ASYNC_DEPENDENCIES = ["pytest-asyncio", "aioresponses"]
+
BLACK_VERSION = "black==19.3b0"
-BLACK_PATHS = ["google", "tests", "noxfile.py", "setup.py", "docs/conf.py"]
+BLACK_PATHS = [
+ "google",
+ "tests",
+ "tests_async",
+ "noxfile.py",
+ "setup.py",
+ "docs/conf.py",
+]
@nox.session(python="3.7")
@@ -44,6 +54,7 @@
"--application-import-names=google,tests,system_tests",
"google",
"tests",
+ "tests_async",
)
session.run(
"python", "setup.py", "check", "--metadata", "--restructuredtext", "--strict"
@@ -64,9 +75,24 @@
session.run("black", *BLACK_PATHS)
[email protected](python=["2.7", "3.5", "3.6", "3.7", "3.8"])
[email protected](python=["3.6", "3.7", "3.8"])
def unit(session):
session.install(*TEST_DEPENDENCIES)
+ session.install(*(ASYNC_DEPENDENCIES))
+ session.install(".")
+ session.run(
+ "pytest",
+ "--cov=google.auth",
+ "--cov=google.oauth2",
+ "--cov=tests",
+ "tests",
+ "tests_async",
+ )
+
+
[email protected](python=["2.7", "3.5"])
+def unit_prev_versions(session):
+ session.install(*TEST_DEPENDENCIES)
session.install(".")
session.run(
"pytest", "--cov=google.auth", "--cov=google.oauth2", "--cov=tests", "tests"
@@ -76,14 +102,17 @@
@nox.session(python="3.7")
def cover(session):
session.install(*TEST_DEPENDENCIES)
+ session.install(*(ASYNC_DEPENDENCIES))
session.install(".")
session.run(
"pytest",
"--cov=google.auth",
"--cov=google.oauth2",
"--cov=tests",
+ "--cov=tests_async",
"--cov-report=",
"tests",
+ "tests_async",
)
session.run("coverage", "report", "--show-missing", "--fail-under=100")
@@ -117,5 +146,10 @@
session.install(*TEST_DEPENDENCIES)
session.install(".")
session.run(
- "pytest", "--cov=google.auth", "--cov=google.oauth2", "--cov=tests", "tests"
+ "pytest",
+ "--cov=google.auth",
+ "--cov=google.oauth2",
+ "--cov=tests",
+ "tests",
+ "tests_async",
)
diff --git a/setup.py b/setup.py
index 1c3578f..dd58f30 100644
--- a/setup.py
+++ b/setup.py
@@ -27,6 +27,7 @@
'rsa>=3.1.4,<5; python_version >= "3.5"',
"setuptools>=40.3.0",
"six>=1.9.0",
+ 'aiohttp >= 3.6.2, < 4.0.0dev; python_version>="3.6"',
)
diff --git a/system_tests/noxfile.py b/system_tests/noxfile.py
index 14cd3db..a039228 100644
--- a/system_tests/noxfile.py
+++ b/system_tests/noxfile.py
@@ -1,4 +1,4 @@
-# Copyright 2016 Google LLC
+# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -29,7 +29,6 @@
import nox
import py.path
-
HERE = os.path.abspath(os.path.dirname(__file__))
LIBRARY_DIR = os.path.join(HERE, "..")
DATA_DIR = os.path.join(HERE, "data")
@@ -169,92 +168,79 @@
# Test sesssions
-TEST_DEPENDENCIES = ["pytest", "requests"]
-PYTHON_VERSIONS = ["2.7", "3.7"]
+TEST_DEPENDENCIES_ASYNC = ["aiohttp", "pytest-asyncio", "nest-asyncio"]
+TEST_DEPENDENCIES_SYNC = ["pytest", "requests"]
+PYTHON_VERSIONS_ASYNC = ["3.7"]
+PYTHON_VERSIONS_SYNC = ["2.7", "3.7"]
[email protected](python=PYTHON_VERSIONS)
-def service_account(session):
- session.install(*TEST_DEPENDENCIES)
[email protected](python=PYTHON_VERSIONS_SYNC)
+def service_account_sync(session):
+ session.install(*TEST_DEPENDENCIES_SYNC)
session.install(LIBRARY_DIR)
- session.run("pytest", "test_service_account.py")
+ session.run("pytest", "system_tests_sync/test_service_account.py")
[email protected](python=PYTHON_VERSIONS)
-def oauth2_credentials(session):
- session.install(*TEST_DEPENDENCIES)
- session.install(LIBRARY_DIR)
- session.run("pytest", "test_oauth2_credentials.py")
-
-
[email protected](python=PYTHON_VERSIONS)
-def impersonated_credentials(session):
- session.install(*TEST_DEPENDENCIES)
- session.install(LIBRARY_DIR)
- session.run("pytest", "test_impersonated_credentials.py")
-
-
[email protected](python=PYTHON_VERSIONS)
[email protected](python=PYTHON_VERSIONS_SYNC)
def default_explicit_service_account(session):
session.env[EXPLICIT_CREDENTIALS_ENV] = SERVICE_ACCOUNT_FILE
session.env[EXPECT_PROJECT_ENV] = "1"
- session.install(*TEST_DEPENDENCIES)
+ session.install(*TEST_DEPENDENCIES_SYNC)
session.install(LIBRARY_DIR)
- session.run("pytest", "test_default.py", "test_id_token.py")
+ session.run("pytest", "system_tests_sync/test_default.py", "system_tests_sync/test_id_token.py")
[email protected](python=PYTHON_VERSIONS)
[email protected](python=PYTHON_VERSIONS_SYNC)
def default_explicit_authorized_user(session):
session.env[EXPLICIT_CREDENTIALS_ENV] = AUTHORIZED_USER_FILE
- session.install(*TEST_DEPENDENCIES)
+ session.install(*TEST_DEPENDENCIES_SYNC)
session.install(LIBRARY_DIR)
- session.run("pytest", "test_default.py")
+ session.run("pytest", "system_tests_sync/test_default.py")
[email protected](python=PYTHON_VERSIONS)
[email protected](python=PYTHON_VERSIONS_SYNC)
def default_explicit_authorized_user_explicit_project(session):
session.env[EXPLICIT_CREDENTIALS_ENV] = AUTHORIZED_USER_FILE
session.env[EXPLICIT_PROJECT_ENV] = "example-project"
session.env[EXPECT_PROJECT_ENV] = "1"
- session.install(*TEST_DEPENDENCIES)
+ session.install(*TEST_DEPENDENCIES_SYNC)
session.install(LIBRARY_DIR)
- session.run("pytest", "test_default.py")
+ session.run("pytest", "system_tests_sync/test_default.py")
[email protected](python=PYTHON_VERSIONS)
[email protected](python=PYTHON_VERSIONS_SYNC)
def default_cloud_sdk_service_account(session):
configure_cloud_sdk(session, SERVICE_ACCOUNT_FILE)
session.env[EXPECT_PROJECT_ENV] = "1"
- session.install(*TEST_DEPENDENCIES)
+ session.install(*TEST_DEPENDENCIES_SYNC)
session.install(LIBRARY_DIR)
- session.run("pytest", "test_default.py")
+ session.run("pytest", "system_tests_sync/test_default.py")
[email protected](python=PYTHON_VERSIONS)
[email protected](python=PYTHON_VERSIONS_SYNC)
def default_cloud_sdk_authorized_user(session):
configure_cloud_sdk(session, AUTHORIZED_USER_FILE)
- session.install(*TEST_DEPENDENCIES)
+ session.install(*TEST_DEPENDENCIES_SYNC)
session.install(LIBRARY_DIR)
- session.run("pytest", "test_default.py")
+ session.run("pytest", "system_tests_sync/test_default.py")
[email protected](python=PYTHON_VERSIONS)
[email protected](python=PYTHON_VERSIONS_SYNC)
def default_cloud_sdk_authorized_user_configured_project(session):
configure_cloud_sdk(session, AUTHORIZED_USER_FILE, project=True)
session.env[EXPECT_PROJECT_ENV] = "1"
- session.install(*TEST_DEPENDENCIES)
+ session.install(*TEST_DEPENDENCIES_SYNC)
session.install(LIBRARY_DIR)
- session.run("pytest", "test_default.py")
+ session.run("pytest", "system_tests_sync/test_default.py")
-
[email protected](python=PYTHON_VERSIONS)
[email protected](python=PYTHON_VERSIONS_SYNC)
def compute_engine(session):
- session.install(*TEST_DEPENDENCIES)
+ session.install(*TEST_DEPENDENCIES_SYNC)
# unset Application Default Credentials so
# credentials are detected from environment
del session.virtualenv.env["GOOGLE_APPLICATION_CREDENTIALS"]
session.install(LIBRARY_DIR)
- session.run("pytest", "test_compute_engine.py")
+ session.run("pytest", "system_tests_sync/test_compute_engine.py")
@nox.session(python=["2.7"])
@@ -283,8 +269,8 @@
application_url = GAE_APP_URL_TMPL.format(GAE_TEST_APP_SERVICE, project_id)
# Vendor in the test application's dependencies
- session.chdir(os.path.join(HERE, "app_engine_test_app"))
- session.install(*TEST_DEPENDENCIES)
+ session.chdir(os.path.join(HERE, "../app_engine_test_app"))
+ session.install(*TEST_DEPENDENCIES_SYNC)
session.install(LIBRARY_DIR)
session.run(
"pip", "install", "--target", "lib", "-r", "requirements.txt", silent=True
@@ -296,20 +282,82 @@
# Run the tests
session.env["TEST_APP_URL"] = application_url
session.chdir(HERE)
- session.run("pytest", "test_app_engine.py")
+ session.run("pytest", "system_tests_sync/test_app_engine.py")
[email protected](python=PYTHON_VERSIONS)
[email protected](python=PYTHON_VERSIONS_SYNC)
def grpc(session):
session.install(LIBRARY_DIR)
- session.install(*TEST_DEPENDENCIES, "google-cloud-pubsub==1.0.0")
+ session.install(*TEST_DEPENDENCIES_SYNC, "google-cloud-pubsub==1.0.0")
session.env[EXPLICIT_CREDENTIALS_ENV] = SERVICE_ACCOUNT_FILE
- session.run("pytest", "test_grpc.py")
+ session.run("pytest", "system_tests_sync/test_grpc.py")
[email protected](python=PYTHON_VERSIONS)
[email protected](python=PYTHON_VERSIONS_SYNC)
def mtls_http(session):
session.install(LIBRARY_DIR)
- session.install(*TEST_DEPENDENCIES, "pyopenssl")
+ session.install(*TEST_DEPENDENCIES_SYNC, "pyopenssl")
session.env[EXPLICIT_CREDENTIALS_ENV] = SERVICE_ACCOUNT_FILE
- session.run("pytest", "test_mtls_http.py")
+ session.run("pytest", "system_tests_sync/test_mtls_http.py")
+
+#ASYNC SYSTEM TESTS
+
[email protected](python=PYTHON_VERSIONS_ASYNC)
+def service_account_async(session):
+ session.install(*(TEST_DEPENDENCIES_SYNC+TEST_DEPENDENCIES_ASYNC))
+ session.install(LIBRARY_DIR)
+ session.run("pytest", "system_tests_async/test_service_account.py")
+
+
[email protected](python=PYTHON_VERSIONS_ASYNC)
+def default_explicit_service_account_async(session):
+ session.env[EXPLICIT_CREDENTIALS_ENV] = SERVICE_ACCOUNT_FILE
+ session.env[EXPECT_PROJECT_ENV] = "1"
+ session.install(*(TEST_DEPENDENCIES_SYNC + TEST_DEPENDENCIES_ASYNC))
+ session.install(LIBRARY_DIR)
+ session.run("pytest", "system_tests_async/test_default.py",
+ "system_tests_async/test_id_token.py")
+
+
[email protected](python=PYTHON_VERSIONS_ASYNC)
+def default_explicit_authorized_user_async(session):
+ session.env[EXPLICIT_CREDENTIALS_ENV] = AUTHORIZED_USER_FILE
+ session.install(*(TEST_DEPENDENCIES_SYNC + TEST_DEPENDENCIES_ASYNC))
+ session.install(LIBRARY_DIR)
+ session.run("pytest", "system_tests_async/test_default.py")
+
+
[email protected](python=PYTHON_VERSIONS_ASYNC)
+def default_explicit_authorized_user_explicit_project_async(session):
+ session.env[EXPLICIT_CREDENTIALS_ENV] = AUTHORIZED_USER_FILE
+ session.env[EXPLICIT_PROJECT_ENV] = "example-project"
+ session.env[EXPECT_PROJECT_ENV] = "1"
+ session.install(*(TEST_DEPENDENCIES_SYNC + TEST_DEPENDENCIES_ASYNC))
+ session.install(LIBRARY_DIR)
+ session.run("pytest", "system_tests_async/test_default.py")
+
+
[email protected](python=PYTHON_VERSIONS_ASYNC)
+def default_cloud_sdk_service_account_async(session):
+ configure_cloud_sdk(session, SERVICE_ACCOUNT_FILE)
+ session.env[EXPECT_PROJECT_ENV] = "1"
+ session.install(*(TEST_DEPENDENCIES_SYNC + TEST_DEPENDENCIES_ASYNC))
+ session.install(LIBRARY_DIR)
+ session.run("pytest", "system_tests_async/test_default.py")
+
+
[email protected](python=PYTHON_VERSIONS_ASYNC)
+def default_cloud_sdk_authorized_user_async(session):
+ configure_cloud_sdk(session, AUTHORIZED_USER_FILE)
+ session.install(*(TEST_DEPENDENCIES_SYNC + TEST_DEPENDENCIES_ASYNC))
+ session.install(LIBRARY_DIR)
+ session.run("pytest", "system_tests_async/test_default.py")
+
+
[email protected](python=PYTHON_VERSIONS_ASYNC)
+def default_cloud_sdk_authorized_user_configured_project_async(session):
+ configure_cloud_sdk(session, AUTHORIZED_USER_FILE, project=True)
+ session.env[EXPECT_PROJECT_ENV] = "1"
+ session.install(*(TEST_DEPENDENCIES_SYNC + TEST_DEPENDENCIES_ASYNC))
+ session.install(LIBRARY_DIR)
+ session.run("pytest", "system_tests_async/test_default.py")
diff --git a/system_tests/conftest.py b/system_tests/system_tests_async/conftest.py
similarity index 63%
copy from system_tests/conftest.py
copy to system_tests/system_tests_async/conftest.py
index 02de846..ecff74c 100644
--- a/system_tests/conftest.py
+++ b/system_tests/system_tests_async/conftest.py
@@ -1,4 +1,4 @@
-# Copyright 2016 Google LLC
+# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,52 +22,43 @@
import requests
import urllib3
+import aiohttp
+from google.auth.transport import _aiohttp_requests as aiohttp_requests
+from system_tests.system_tests_sync import conftest as sync_conftest
-HERE = os.path.dirname(__file__)
-DATA_DIR = os.path.join(HERE, "data")
-IMPERSONATED_SERVICE_ACCOUNT_FILE = os.path.join(
- DATA_DIR, "impersonated_service_account.json"
-)
-SERVICE_ACCOUNT_FILE = os.path.join(DATA_DIR, "service_account.json")
-AUTHORIZED_USER_FILE = os.path.join(DATA_DIR, "authorized_user.json")
-URLLIB3_HTTP = urllib3.PoolManager(retries=False)
-REQUESTS_SESSION = requests.Session()
-REQUESTS_SESSION.verify = False
+ASYNC_REQUESTS_SESSION = aiohttp.ClientSession()
+
+ASYNC_REQUESTS_SESSION.verify = False
TOKEN_INFO_URL = "https://www.googleapis.com/oauth2/v3/tokeninfo"
@pytest.fixture
def service_account_file():
"""The full path to a valid service account key file."""
- yield SERVICE_ACCOUNT_FILE
+ yield sync_conftest.SERVICE_ACCOUNT_FILE
@pytest.fixture
def impersonated_service_account_file():
"""The full path to a valid service account key file."""
- yield IMPERSONATED_SERVICE_ACCOUNT_FILE
+ yield sync_conftest.IMPERSONATED_SERVICE_ACCOUNT_FILE
@pytest.fixture
def authorized_user_file():
"""The full path to a valid authorized user file."""
- yield AUTHORIZED_USER_FILE
+ yield sync_conftest.AUTHORIZED_USER_FILE
-
[email protected](params=["urllib3", "requests"])
-def http_request(request):
[email protected](params=["aiohttp"])
+async def http_request(request):
"""A transport.request object."""
- if request.param == "urllib3":
- yield google.auth.transport.urllib3.Request(URLLIB3_HTTP)
- elif request.param == "requests":
- yield google.auth.transport.requests.Request(REQUESTS_SESSION)
-
+ yield aiohttp_requests.Request(ASYNC_REQUESTS_SESSION)
@pytest.fixture
-def token_info(http_request):
+async def token_info(http_request):
"""Returns a function that obtains OAuth2 token info."""
- def _token_info(access_token=None, id_token=None):
+ async def _token_info(access_token=None, id_token=None):
query_params = {}
if access_token is not None:
@@ -77,24 +68,25 @@
else:
raise ValueError("No token specified.")
- url = _helpers.update_query(TOKEN_INFO_URL, query_params)
+ url = _helpers.update_query(sync_conftest.TOKEN_INFO_URL, query_params)
- response = http_request(url=url, method="GET")
+ response = await http_request(url=url, method="GET")
+ data = await response.data.read()
- return json.loads(response.data.decode("utf-8"))
+ return json.loads(data.decode("utf-8"))
yield _token_info
@pytest.fixture
-def verify_refresh(http_request):
+async def verify_refresh(http_request):
"""Returns a function that verifies that credentials can be refreshed."""
- def _verify_refresh(credentials):
+ async def _verify_refresh(credentials):
if credentials.requires_scopes:
credentials = credentials.with_scopes(["email", "profile"])
- credentials.refresh(http_request)
+ await credentials.refresh(http_request)
assert credentials.token
assert credentials.valid
@@ -104,7 +96,7 @@
def verify_environment():
"""Checks to make sure that requisite data files are available."""
- if not os.path.isdir(DATA_DIR):
+ if not os.path.isdir(sync_conftest.DATA_DIR):
raise EnvironmentError(
"In order to run system tests, test data must exist in "
"system_tests/data. See CONTRIBUTING.rst for details."
diff --git a/system_tests/system_tests_async/test_default.py b/system_tests/system_tests_async/test_default.py
new file mode 100644
index 0000000..383cbff
--- /dev/null
+++ b/system_tests/system_tests_async/test_default.py
@@ -0,0 +1,30 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import pytest
+
+import google.auth
+
+EXPECT_PROJECT_ID = os.environ.get("EXPECT_PROJECT_ID")
+
[email protected]
+async def test_application_default_credentials(verify_refresh):
+ credentials, project_id = google.auth.default_async()
+ #breakpoint()
+
+ if EXPECT_PROJECT_ID is not None:
+ assert project_id is not None
+
+ await verify_refresh(credentials)
diff --git a/system_tests/system_tests_async/test_id_token.py b/system_tests/system_tests_async/test_id_token.py
new file mode 100644
index 0000000..a21b137
--- /dev/null
+++ b/system_tests/system_tests_async/test_id_token.py
@@ -0,0 +1,25 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pytest
+
+from google.auth import jwt
+import google.oauth2._id_token_async
+
[email protected]
+async def test_fetch_id_token(http_request):
+ audience = "https://pubsub.googleapis.com"
+ token = await google.oauth2._id_token_async.fetch_id_token(http_request, audience)
+
+ _, payload, _, _ = jwt._unverified_decode(token)
+ assert payload["aud"] == audience
diff --git a/system_tests/system_tests_async/test_service_account.py b/system_tests/system_tests_async/test_service_account.py
new file mode 100644
index 0000000..c1c16cc
--- /dev/null
+++ b/system_tests/system_tests_async/test_service_account.py
@@ -0,0 +1,53 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+from google.auth import _helpers
+from google.auth import exceptions
+from google.auth import iam
+from google.oauth2 import _service_account_async
+
+
[email protected]
+def credentials(service_account_file):
+ yield _service_account_async.Credentials.from_service_account_file(service_account_file)
+
+
[email protected]
+async def test_refresh_no_scopes(http_request, credentials):
+ """
+ We expect the http request to refresh credentials
+ without scopes provided to throw an error.
+ """
+ with pytest.raises(exceptions.RefreshError):
+ await credentials.refresh(http_request)
+
[email protected]
+async def test_refresh_success(http_request, credentials, token_info):
+ credentials = credentials.with_scopes(["email", "profile"])
+ await credentials.refresh(http_request)
+
+ assert credentials.token
+
+ info = await token_info(credentials.token)
+
+ assert info["email"] == credentials.service_account_email
+ info_scopes = _helpers.string_to_scopes(info["scope"])
+ assert set(info_scopes) == set(
+ [
+ "https://www.googleapis.com/auth/userinfo.email",
+ "https://www.googleapis.com/auth/userinfo.profile",
+ ]
+ )
diff --git a/system_tests/.gitignore b/system_tests/system_tests_sync/.gitignore
similarity index 100%
rename from system_tests/.gitignore
rename to system_tests/system_tests_sync/.gitignore
diff --git a/system_tests/__init__.py b/system_tests/system_tests_sync/__init__.py
similarity index 100%
rename from system_tests/__init__.py
rename to system_tests/system_tests_sync/__init__.py
diff --git a/system_tests/app_engine_test_app/.gitignore b/system_tests/system_tests_sync/app_engine_test_app/.gitignore
similarity index 100%
rename from system_tests/app_engine_test_app/.gitignore
rename to system_tests/system_tests_sync/app_engine_test_app/.gitignore
diff --git a/system_tests/app_engine_test_app/app.yaml b/system_tests/system_tests_sync/app_engine_test_app/app.yaml
similarity index 100%
rename from system_tests/app_engine_test_app/app.yaml
rename to system_tests/system_tests_sync/app_engine_test_app/app.yaml
diff --git a/system_tests/app_engine_test_app/appengine_config.py b/system_tests/system_tests_sync/app_engine_test_app/appengine_config.py
similarity index 100%
rename from system_tests/app_engine_test_app/appengine_config.py
rename to system_tests/system_tests_sync/app_engine_test_app/appengine_config.py
diff --git a/system_tests/app_engine_test_app/main.py b/system_tests/system_tests_sync/app_engine_test_app/main.py
similarity index 100%
rename from system_tests/app_engine_test_app/main.py
rename to system_tests/system_tests_sync/app_engine_test_app/main.py
diff --git a/system_tests/app_engine_test_app/requirements.txt b/system_tests/system_tests_sync/app_engine_test_app/requirements.txt
similarity index 100%
rename from system_tests/app_engine_test_app/requirements.txt
rename to system_tests/system_tests_sync/app_engine_test_app/requirements.txt
diff --git a/system_tests/conftest.py b/system_tests/system_tests_sync/conftest.py
similarity index 96%
rename from system_tests/conftest.py
rename to system_tests/system_tests_sync/conftest.py
index 02de846..37a6fd3 100644
--- a/system_tests/conftest.py
+++ b/system_tests/system_tests_sync/conftest.py
@@ -24,12 +24,11 @@
HERE = os.path.dirname(__file__)
-DATA_DIR = os.path.join(HERE, "data")
+DATA_DIR = os.path.join(HERE, "../data")
IMPERSONATED_SERVICE_ACCOUNT_FILE = os.path.join(
DATA_DIR, "impersonated_service_account.json"
)
SERVICE_ACCOUNT_FILE = os.path.join(DATA_DIR, "service_account.json")
-AUTHORIZED_USER_FILE = os.path.join(DATA_DIR, "authorized_user.json")
URLLIB3_HTTP = urllib3.PoolManager(retries=False)
REQUESTS_SESSION = requests.Session()
REQUESTS_SESSION.verify = False
diff --git a/system_tests/secrets.tar.enc b/system_tests/system_tests_sync/secrets.tar.enc
similarity index 100%
rename from system_tests/secrets.tar.enc
rename to system_tests/system_tests_sync/secrets.tar.enc
Binary files differ
diff --git a/system_tests/test_app_engine.py b/system_tests/system_tests_sync/test_app_engine.py
similarity index 100%
rename from system_tests/test_app_engine.py
rename to system_tests/system_tests_sync/test_app_engine.py
diff --git a/system_tests/test_compute_engine.py b/system_tests/system_tests_sync/test_compute_engine.py
similarity index 100%
rename from system_tests/test_compute_engine.py
rename to system_tests/system_tests_sync/test_compute_engine.py
diff --git a/system_tests/test_default.py b/system_tests/system_tests_sync/test_default.py
similarity index 100%
rename from system_tests/test_default.py
rename to system_tests/system_tests_sync/test_default.py
diff --git a/system_tests/test_grpc.py b/system_tests/system_tests_sync/test_grpc.py
similarity index 100%
rename from system_tests/test_grpc.py
rename to system_tests/system_tests_sync/test_grpc.py
diff --git a/system_tests/test_id_token.py b/system_tests/system_tests_sync/test_id_token.py
similarity index 100%
rename from system_tests/test_id_token.py
rename to system_tests/system_tests_sync/test_id_token.py
diff --git a/system_tests/test_impersonated_credentials.py b/system_tests/system_tests_sync/test_impersonated_credentials.py
similarity index 100%
rename from system_tests/test_impersonated_credentials.py
rename to system_tests/system_tests_sync/test_impersonated_credentials.py
diff --git a/system_tests/test_mtls_http.py b/system_tests/system_tests_sync/test_mtls_http.py
similarity index 100%
rename from system_tests/test_mtls_http.py
rename to system_tests/system_tests_sync/test_mtls_http.py
diff --git a/system_tests/test_oauth2_credentials.py b/system_tests/system_tests_sync/test_oauth2_credentials.py
similarity index 100%
rename from system_tests/test_oauth2_credentials.py
rename to system_tests/system_tests_sync/test_oauth2_credentials.py
diff --git a/system_tests/test_service_account.py b/system_tests/system_tests_sync/test_service_account.py
similarity index 100%
rename from system_tests/test_service_account.py
rename to system_tests/system_tests_sync/test_service_account.py
diff --git a/system_tests/__init__.py b/tests_async/__init__.py
similarity index 100%
copy from system_tests/__init__.py
copy to tests_async/__init__.py
diff --git a/tests_async/conftest.py b/tests_async/conftest.py
new file mode 100644
index 0000000..b4e90f0
--- /dev/null
+++ b/tests_async/conftest.py
@@ -0,0 +1,51 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+
+import mock
+import pytest
+
+
+def pytest_configure():
+ """Load public certificate and private key."""
+ pytest.data_dir = os.path.join(
+ os.path.abspath(os.path.join(__file__, "../..")), "tests/data"
+ )
+
+ with open(os.path.join(pytest.data_dir, "privatekey.pem"), "rb") as fh:
+ pytest.private_key_bytes = fh.read()
+
+ with open(os.path.join(pytest.data_dir, "public_cert.pem"), "rb") as fh:
+ pytest.public_cert_bytes = fh.read()
+
+
[email protected]
+def mock_non_existent_module(monkeypatch):
+ """Mocks a non-existing module in sys.modules.
+
+ Additionally mocks any non-existing modules specified in the dotted path.
+ """
+
+ def _mock_non_existent_module(path):
+ parts = path.split(".")
+ partial = []
+ for part in parts:
+ partial.append(part)
+ current_module = ".".join(partial)
+ if current_module not in sys.modules:
+ monkeypatch.setitem(sys.modules, current_module, mock.MagicMock())
+
+ return _mock_non_existent_module
diff --git a/tests_async/oauth2/test__client_async.py b/tests_async/oauth2/test__client_async.py
new file mode 100644
index 0000000..458937a
--- /dev/null
+++ b/tests_async/oauth2/test__client_async.py
@@ -0,0 +1,297 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+import json
+
+import mock
+import pytest
+import six
+from six.moves import http_client
+from six.moves import urllib
+
+from google.auth import _helpers
+from google.auth import _jwt_async as jwt
+from google.auth import exceptions
+from google.oauth2 import _client as sync_client
+from google.oauth2 import _client_async as _client
+from tests.oauth2 import test__client as test_client
+
+
+def test__handle_error_response():
+ response_data = json.dumps({"error": "help", "error_description": "I'm alive"})
+
+ with pytest.raises(exceptions.RefreshError) as excinfo:
+ _client._handle_error_response(response_data)
+
+ assert excinfo.match(r"help: I\'m alive")
+
+
+def test__handle_error_response_non_json():
+ response_data = "Help, I'm alive"
+
+ with pytest.raises(exceptions.RefreshError) as excinfo:
+ _client._handle_error_response(response_data)
+
+ assert excinfo.match(r"Help, I\'m alive")
+
+
[email protected]("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
+def test__parse_expiry(unused_utcnow):
+ result = _client._parse_expiry({"expires_in": 500})
+ assert result == datetime.datetime.min + datetime.timedelta(seconds=500)
+
+
+def test__parse_expiry_none():
+ assert _client._parse_expiry({}) is None
+
+
+def make_request(response_data, status=http_client.OK):
+ response = mock.AsyncMock(spec=["transport.Response"])
+ response.status = status
+ data = json.dumps(response_data).encode("utf-8")
+ response.data = mock.AsyncMock(spec=["__call__", "read"])
+ response.data.read = mock.AsyncMock(spec=["__call__"], return_value=data)
+ response.content = mock.AsyncMock(spec=["__call__"], return_value=data)
+ request = mock.AsyncMock(spec=["transport.Request"])
+ request.return_value = response
+ return request
+
+
[email protected]
+async def test__token_endpoint_request():
+
+ request = make_request({"test": "response"})
+
+ result = await _client._token_endpoint_request(
+ request, "http://example.com", {"test": "params"}
+ )
+
+ # Check request call
+ request.assert_called_with(
+ method="POST",
+ url="http://example.com",
+ headers={"content-type": "application/x-www-form-urlencoded"},
+ body="test=params".encode("utf-8"),
+ )
+
+ # Check result
+ assert result == {"test": "response"}
+
+
[email protected]
+async def test__token_endpoint_request_error():
+ request = make_request({}, status=http_client.BAD_REQUEST)
+
+ with pytest.raises(exceptions.RefreshError):
+ await _client._token_endpoint_request(request, "http://example.com", {})
+
+
[email protected]
+async def test__token_endpoint_request_internal_failure_error():
+ request = make_request(
+ {"error_description": "internal_failure"}, status=http_client.BAD_REQUEST
+ )
+
+ with pytest.raises(exceptions.RefreshError):
+ await _client._token_endpoint_request(
+ request, "http://example.com", {"error_description": "internal_failure"}
+ )
+
+ request = make_request(
+ {"error": "internal_failure"}, status=http_client.BAD_REQUEST
+ )
+
+ with pytest.raises(exceptions.RefreshError):
+ await _client._token_endpoint_request(
+ request, "http://example.com", {"error": "internal_failure"}
+ )
+
+
+def verify_request_params(request, params):
+ request_body = request.call_args[1]["body"].decode("utf-8")
+ request_params = urllib.parse.parse_qs(request_body)
+
+ for key, value in six.iteritems(params):
+ assert request_params[key][0] == value
+
+
[email protected]("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
[email protected]
+async def test_jwt_grant(utcnow):
+ request = make_request(
+ {"access_token": "token", "expires_in": 500, "extra": "data"}
+ )
+
+ token, expiry, extra_data = await _client.jwt_grant(
+ request, "http://example.com", "assertion_value"
+ )
+
+ # Check request call
+ verify_request_params(
+ request,
+ {"grant_type": sync_client._JWT_GRANT_TYPE, "assertion": "assertion_value"},
+ )
+
+ # Check result
+ assert token == "token"
+ assert expiry == utcnow() + datetime.timedelta(seconds=500)
+ assert extra_data["extra"] == "data"
+
+
[email protected]
+async def test_jwt_grant_no_access_token():
+ request = make_request(
+ {
+ # No access token.
+ "expires_in": 500,
+ "extra": "data",
+ }
+ )
+
+ with pytest.raises(exceptions.RefreshError):
+ await _client.jwt_grant(request, "http://example.com", "assertion_value")
+
+
[email protected]
+async def test_id_token_jwt_grant():
+ now = _helpers.utcnow()
+ id_token_expiry = _helpers.datetime_to_secs(now)
+ id_token = jwt.encode(test_client.SIGNER, {"exp": id_token_expiry}).decode("utf-8")
+ request = make_request({"id_token": id_token, "extra": "data"})
+
+ token, expiry, extra_data = await _client.id_token_jwt_grant(
+ request, "http://example.com", "assertion_value"
+ )
+
+ # Check request call
+ verify_request_params(
+ request,
+ {"grant_type": sync_client._JWT_GRANT_TYPE, "assertion": "assertion_value"},
+ )
+
+ # Check result
+ assert token == id_token
+ # JWT does not store microseconds
+ now = now.replace(microsecond=0)
+ assert expiry == now
+ assert extra_data["extra"] == "data"
+
+
[email protected]
+async def test_id_token_jwt_grant_no_access_token():
+ request = make_request(
+ {
+ # No access token.
+ "expires_in": 500,
+ "extra": "data",
+ }
+ )
+
+ with pytest.raises(exceptions.RefreshError):
+ await _client.id_token_jwt_grant(
+ request, "http://example.com", "assertion_value"
+ )
+
+
[email protected]("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
[email protected]
+async def test_refresh_grant(unused_utcnow):
+ request = make_request(
+ {
+ "access_token": "token",
+ "refresh_token": "new_refresh_token",
+ "expires_in": 500,
+ "extra": "data",
+ }
+ )
+
+ token, refresh_token, expiry, extra_data = await _client.refresh_grant(
+ request, "http://example.com", "refresh_token", "client_id", "client_secret"
+ )
+
+ # Check request call
+ verify_request_params(
+ request,
+ {
+ "grant_type": sync_client._REFRESH_GRANT_TYPE,
+ "refresh_token": "refresh_token",
+ "client_id": "client_id",
+ "client_secret": "client_secret",
+ },
+ )
+
+ # Check result
+ assert token == "token"
+ assert refresh_token == "new_refresh_token"
+ assert expiry == datetime.datetime.min + datetime.timedelta(seconds=500)
+ assert extra_data["extra"] == "data"
+
+
[email protected]("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
[email protected]
+async def test_refresh_grant_with_scopes(unused_utcnow):
+ request = make_request(
+ {
+ "access_token": "token",
+ "refresh_token": "new_refresh_token",
+ "expires_in": 500,
+ "extra": "data",
+ "scope": test_client.SCOPES_AS_STRING,
+ }
+ )
+
+ token, refresh_token, expiry, extra_data = await _client.refresh_grant(
+ request,
+ "http://example.com",
+ "refresh_token",
+ "client_id",
+ "client_secret",
+ test_client.SCOPES_AS_LIST,
+ )
+
+ # Check request call.
+ verify_request_params(
+ request,
+ {
+ "grant_type": sync_client._REFRESH_GRANT_TYPE,
+ "refresh_token": "refresh_token",
+ "client_id": "client_id",
+ "client_secret": "client_secret",
+ "scope": test_client.SCOPES_AS_STRING,
+ },
+ )
+
+ # Check result.
+ assert token == "token"
+ assert refresh_token == "new_refresh_token"
+ assert expiry == datetime.datetime.min + datetime.timedelta(seconds=500)
+ assert extra_data["extra"] == "data"
+
+
[email protected]
+async def test_refresh_grant_no_access_token():
+ request = make_request(
+ {
+ # No access token.
+ "refresh_token": "new_refresh_token",
+ "expires_in": 500,
+ "extra": "data",
+ }
+ )
+
+ with pytest.raises(exceptions.RefreshError):
+ await _client.refresh_grant(
+ request, "http://example.com", "refresh_token", "client_id", "client_secret"
+ )
diff --git a/tests_async/oauth2/test_credentials_async.py b/tests_async/oauth2/test_credentials_async.py
new file mode 100644
index 0000000..5c883d6
--- /dev/null
+++ b/tests_async/oauth2/test_credentials_async.py
@@ -0,0 +1,478 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+import json
+import os
+import pickle
+import sys
+
+import mock
+import pytest
+
+from google.auth import _helpers
+from google.auth import exceptions
+from google.oauth2 import _credentials_async as _credentials_async
+from google.oauth2 import credentials
+from tests.oauth2 import test_credentials
+
+
+class TestCredentials:
+
+ TOKEN_URI = "https://example.com/oauth2/token"
+ REFRESH_TOKEN = "refresh_token"
+ CLIENT_ID = "client_id"
+ CLIENT_SECRET = "client_secret"
+
+ @classmethod
+ def make_credentials(cls):
+ return _credentials_async.Credentials(
+ token=None,
+ refresh_token=cls.REFRESH_TOKEN,
+ token_uri=cls.TOKEN_URI,
+ client_id=cls.CLIENT_ID,
+ client_secret=cls.CLIENT_SECRET,
+ )
+
+ def test_default_state(self):
+ credentials = self.make_credentials()
+ assert not credentials.valid
+ # Expiration hasn't been set yet
+ assert not credentials.expired
+ # Scopes aren't required for these credentials
+ assert not credentials.requires_scopes
+ # Test properties
+ assert credentials.refresh_token == self.REFRESH_TOKEN
+ assert credentials.token_uri == self.TOKEN_URI
+ assert credentials.client_id == self.CLIENT_ID
+ assert credentials.client_secret == self.CLIENT_SECRET
+
+ @mock.patch("google.oauth2._client_async.refresh_grant", autospec=True)
+ @mock.patch(
+ "google.auth._helpers.utcnow",
+ return_value=datetime.datetime.min + _helpers.CLOCK_SKEW,
+ )
+ @pytest.mark.asyncio
+ async def test_refresh_success(self, unused_utcnow, refresh_grant):
+ token = "token"
+ expiry = _helpers.utcnow() + datetime.timedelta(seconds=500)
+ grant_response = {"id_token": mock.sentinel.id_token}
+ refresh_grant.return_value = (
+ # Access token
+ token,
+ # New refresh token
+ None,
+ # Expiry,
+ expiry,
+ # Extra data
+ grant_response,
+ )
+
+ request = mock.AsyncMock(spec=["transport.Request"])
+ creds = self.make_credentials()
+
+ # Refresh credentials
+ await creds.refresh(request)
+
+ # Check jwt grant call.
+ refresh_grant.assert_called_with(
+ request,
+ self.TOKEN_URI,
+ self.REFRESH_TOKEN,
+ self.CLIENT_ID,
+ self.CLIENT_SECRET,
+ None,
+ )
+
+ # Check that the credentials have the token and expiry
+ assert creds.token == token
+ assert creds.expiry == expiry
+ assert creds.id_token == mock.sentinel.id_token
+
+ # Check that the credentials are valid (have a token and are not
+ # expired)
+ assert creds.valid
+
+ @pytest.mark.asyncio
+ async def test_refresh_no_refresh_token(self):
+ request = mock.AsyncMock(spec=["transport.Request"])
+ credentials_ = _credentials_async.Credentials(token=None, refresh_token=None)
+
+ with pytest.raises(exceptions.RefreshError, match="necessary fields"):
+ await credentials_.refresh(request)
+
+ request.assert_not_called()
+
+ @mock.patch("google.oauth2._client_async.refresh_grant", autospec=True)
+ @mock.patch(
+ "google.auth._helpers.utcnow",
+ return_value=datetime.datetime.min + _helpers.CLOCK_SKEW,
+ )
+ @pytest.mark.asyncio
+ async def test_credentials_with_scopes_requested_refresh_success(
+ self, unused_utcnow, refresh_grant
+ ):
+ scopes = ["email", "profile"]
+ token = "token"
+ expiry = _helpers.utcnow() + datetime.timedelta(seconds=500)
+ grant_response = {"id_token": mock.sentinel.id_token}
+ refresh_grant.return_value = (
+ # Access token
+ token,
+ # New refresh token
+ None,
+ # Expiry,
+ expiry,
+ # Extra data
+ grant_response,
+ )
+
+ request = mock.AsyncMock(spec=["transport.Request"])
+ creds = _credentials_async.Credentials(
+ token=None,
+ refresh_token=self.REFRESH_TOKEN,
+ token_uri=self.TOKEN_URI,
+ client_id=self.CLIENT_ID,
+ client_secret=self.CLIENT_SECRET,
+ scopes=scopes,
+ )
+
+ # Refresh credentials
+ await creds.refresh(request)
+
+ # Check jwt grant call.
+ refresh_grant.assert_called_with(
+ request,
+ self.TOKEN_URI,
+ self.REFRESH_TOKEN,
+ self.CLIENT_ID,
+ self.CLIENT_SECRET,
+ scopes,
+ )
+
+ # Check that the credentials have the token and expiry
+ assert creds.token == token
+ assert creds.expiry == expiry
+ assert creds.id_token == mock.sentinel.id_token
+ assert creds.has_scopes(scopes)
+
+ # Check that the credentials are valid (have a token and are not
+ # expired.)
+ assert creds.valid
+
+ @mock.patch("google.oauth2._client_async.refresh_grant", autospec=True)
+ @mock.patch(
+ "google.auth._helpers.utcnow",
+ return_value=datetime.datetime.min + _helpers.CLOCK_SKEW,
+ )
+ @pytest.mark.asyncio
+ async def test_credentials_with_scopes_returned_refresh_success(
+ self, unused_utcnow, refresh_grant
+ ):
+ scopes = ["email", "profile"]
+ token = "token"
+ expiry = _helpers.utcnow() + datetime.timedelta(seconds=500)
+ grant_response = {
+ "id_token": mock.sentinel.id_token,
+ "scopes": " ".join(scopes),
+ }
+ refresh_grant.return_value = (
+ # Access token
+ token,
+ # New refresh token
+ None,
+ # Expiry,
+ expiry,
+ # Extra data
+ grant_response,
+ )
+
+ request = mock.AsyncMock(spec=["transport.Request"])
+ creds = _credentials_async.Credentials(
+ token=None,
+ refresh_token=self.REFRESH_TOKEN,
+ token_uri=self.TOKEN_URI,
+ client_id=self.CLIENT_ID,
+ client_secret=self.CLIENT_SECRET,
+ scopes=scopes,
+ )
+
+ # Refresh credentials
+ await creds.refresh(request)
+
+ # Check jwt grant call.
+ refresh_grant.assert_called_with(
+ request,
+ self.TOKEN_URI,
+ self.REFRESH_TOKEN,
+ self.CLIENT_ID,
+ self.CLIENT_SECRET,
+ scopes,
+ )
+
+ # Check that the credentials have the token and expiry
+ assert creds.token == token
+ assert creds.expiry == expiry
+ assert creds.id_token == mock.sentinel.id_token
+ assert creds.has_scopes(scopes)
+
+ # Check that the credentials are valid (have a token and are not
+ # expired.)
+ assert creds.valid
+
+ @mock.patch("google.oauth2._client_async.refresh_grant", autospec=True)
+ @mock.patch(
+ "google.auth._helpers.utcnow",
+ return_value=datetime.datetime.min + _helpers.CLOCK_SKEW,
+ )
+ @pytest.mark.asyncio
+ async def test_credentials_with_scopes_refresh_failure_raises_refresh_error(
+ self, unused_utcnow, refresh_grant
+ ):
+ scopes = ["email", "profile"]
+ scopes_returned = ["email"]
+ token = "token"
+ expiry = _helpers.utcnow() + datetime.timedelta(seconds=500)
+ grant_response = {
+ "id_token": mock.sentinel.id_token,
+ "scopes": " ".join(scopes_returned),
+ }
+ refresh_grant.return_value = (
+ # Access token
+ token,
+ # New refresh token
+ None,
+ # Expiry,
+ expiry,
+ # Extra data
+ grant_response,
+ )
+
+ request = mock.AsyncMock(spec=["transport.Request"])
+ creds = _credentials_async.Credentials(
+ token=None,
+ refresh_token=self.REFRESH_TOKEN,
+ token_uri=self.TOKEN_URI,
+ client_id=self.CLIENT_ID,
+ client_secret=self.CLIENT_SECRET,
+ scopes=scopes,
+ )
+
+ # Refresh credentials
+ with pytest.raises(
+ exceptions.RefreshError, match="Not all requested scopes were granted"
+ ):
+ await creds.refresh(request)
+
+ # Check jwt grant call.
+ refresh_grant.assert_called_with(
+ request,
+ self.TOKEN_URI,
+ self.REFRESH_TOKEN,
+ self.CLIENT_ID,
+ self.CLIENT_SECRET,
+ scopes,
+ )
+
+ # Check that the credentials have the token and expiry
+ assert creds.token == token
+ assert creds.expiry == expiry
+ assert creds.id_token == mock.sentinel.id_token
+ assert creds.has_scopes(scopes)
+
+ # Check that the credentials are valid (have a token and are not
+ # expired.)
+ assert creds.valid
+
+ def test_apply_with_quota_project_id(self):
+ creds = _credentials_async.Credentials(
+ token="token",
+ refresh_token=self.REFRESH_TOKEN,
+ token_uri=self.TOKEN_URI,
+ client_id=self.CLIENT_ID,
+ client_secret=self.CLIENT_SECRET,
+ quota_project_id="quota-project-123",
+ )
+
+ headers = {}
+ creds.apply(headers)
+ assert headers["x-goog-user-project"] == "quota-project-123"
+
+ def test_apply_with_no_quota_project_id(self):
+ creds = _credentials_async.Credentials(
+ token="token",
+ refresh_token=self.REFRESH_TOKEN,
+ token_uri=self.TOKEN_URI,
+ client_id=self.CLIENT_ID,
+ client_secret=self.CLIENT_SECRET,
+ )
+
+ headers = {}
+ creds.apply(headers)
+ assert "x-goog-user-project" not in headers
+
+ def test_with_quota_project(self):
+ creds = _credentials_async.Credentials(
+ token="token",
+ refresh_token=self.REFRESH_TOKEN,
+ token_uri=self.TOKEN_URI,
+ client_id=self.CLIENT_ID,
+ client_secret=self.CLIENT_SECRET,
+ quota_project_id="quota-project-123",
+ )
+
+ new_creds = creds.with_quota_project("new-project-456")
+ assert new_creds.quota_project_id == "new-project-456"
+ headers = {}
+ creds.apply(headers)
+ assert "x-goog-user-project" in headers
+
+ def test_from_authorized_user_info(self):
+ info = test_credentials.AUTH_USER_INFO.copy()
+
+ creds = _credentials_async.Credentials.from_authorized_user_info(info)
+ assert creds.client_secret == info["client_secret"]
+ assert creds.client_id == info["client_id"]
+ assert creds.refresh_token == info["refresh_token"]
+ assert creds.token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT
+ assert creds.scopes is None
+
+ scopes = ["email", "profile"]
+ creds = _credentials_async.Credentials.from_authorized_user_info(info, scopes)
+ assert creds.client_secret == info["client_secret"]
+ assert creds.client_id == info["client_id"]
+ assert creds.refresh_token == info["refresh_token"]
+ assert creds.token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT
+ assert creds.scopes == scopes
+
+ def test_from_authorized_user_file(self):
+ info = test_credentials.AUTH_USER_INFO.copy()
+
+ creds = _credentials_async.Credentials.from_authorized_user_file(
+ test_credentials.AUTH_USER_JSON_FILE
+ )
+ assert creds.client_secret == info["client_secret"]
+ assert creds.client_id == info["client_id"]
+ assert creds.refresh_token == info["refresh_token"]
+ assert creds.token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT
+ assert creds.scopes is None
+
+ scopes = ["email", "profile"]
+ creds = _credentials_async.Credentials.from_authorized_user_file(
+ test_credentials.AUTH_USER_JSON_FILE, scopes
+ )
+ assert creds.client_secret == info["client_secret"]
+ assert creds.client_id == info["client_id"]
+ assert creds.refresh_token == info["refresh_token"]
+ assert creds.token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT
+ assert creds.scopes == scopes
+
+ def test_to_json(self):
+ info = test_credentials.AUTH_USER_INFO.copy()
+ creds = _credentials_async.Credentials.from_authorized_user_info(info)
+
+ # Test with no `strip` arg
+ json_output = creds.to_json()
+ json_asdict = json.loads(json_output)
+ assert json_asdict.get("token") == creds.token
+ assert json_asdict.get("refresh_token") == creds.refresh_token
+ assert json_asdict.get("token_uri") == creds.token_uri
+ assert json_asdict.get("client_id") == creds.client_id
+ assert json_asdict.get("scopes") == creds.scopes
+ assert json_asdict.get("client_secret") == creds.client_secret
+
+ # Test with a `strip` arg
+ json_output = creds.to_json(strip=["client_secret"])
+ json_asdict = json.loads(json_output)
+ assert json_asdict.get("token") == creds.token
+ assert json_asdict.get("refresh_token") == creds.refresh_token
+ assert json_asdict.get("token_uri") == creds.token_uri
+ assert json_asdict.get("client_id") == creds.client_id
+ assert json_asdict.get("scopes") == creds.scopes
+ assert json_asdict.get("client_secret") is None
+
+ def test_pickle_and_unpickle(self):
+ creds = self.make_credentials()
+ unpickled = pickle.loads(pickle.dumps(creds))
+
+ # make sure attributes aren't lost during pickling
+ assert list(creds.__dict__).sort() == list(unpickled.__dict__).sort()
+
+ for attr in list(creds.__dict__):
+ assert getattr(creds, attr) == getattr(unpickled, attr)
+
+ def test_pickle_with_missing_attribute(self):
+ creds = self.make_credentials()
+
+ # remove an optional attribute before pickling
+ # this mimics a pickle created with a previous class definition with
+ # fewer attributes
+ del creds.__dict__["_quota_project_id"]
+
+ unpickled = pickle.loads(pickle.dumps(creds))
+
+ # Attribute should be initialized by `__setstate__`
+ assert unpickled.quota_project_id is None
+
+ # pickles are not compatible across versions
+ @pytest.mark.skipif(
+ sys.version_info < (3, 5),
+ reason="pickle file can only be loaded with Python >= 3.5",
+ )
+ def test_unpickle_old_credentials_pickle(self):
+ # make sure a credentials file pickled with an older
+ # library version (google-auth==1.5.1) can be unpickled
+ with open(
+ os.path.join(test_credentials.DATA_DIR, "old_oauth_credentials_py3.pickle"),
+ "rb",
+ ) as f:
+ credentials = pickle.load(f)
+ assert credentials.quota_project_id is None
+
+
+class TestUserAccessTokenCredentials(object):
+ def test_instance(self):
+ cred = _credentials_async.UserAccessTokenCredentials()
+ assert cred._account is None
+
+ cred = cred.with_account("account")
+ assert cred._account == "account"
+
+ @mock.patch("google.auth._cloud_sdk.get_auth_access_token", autospec=True)
+ def test_refresh(self, get_auth_access_token):
+ get_auth_access_token.return_value = "access_token"
+ cred = _credentials_async.UserAccessTokenCredentials()
+ cred.refresh(None)
+ assert cred.token == "access_token"
+
+ def test_with_quota_project(self):
+ cred = _credentials_async.UserAccessTokenCredentials()
+ quota_project_cred = cred.with_quota_project("project-foo")
+
+ assert quota_project_cred._quota_project_id == "project-foo"
+ assert quota_project_cred._account == cred._account
+
+ @mock.patch(
+ "google.oauth2._credentials_async.UserAccessTokenCredentials.apply",
+ autospec=True,
+ )
+ @mock.patch(
+ "google.oauth2._credentials_async.UserAccessTokenCredentials.refresh",
+ autospec=True,
+ )
+ def test_before_request(self, refresh, apply):
+ cred = _credentials_async.UserAccessTokenCredentials()
+ cred.before_request(mock.Mock(), "GET", "https://example.com", {})
+ refresh.assert_called()
+ apply.assert_called()
diff --git a/tests_async/oauth2/test_id_token.py b/tests_async/oauth2/test_id_token.py
new file mode 100644
index 0000000..a46bd61
--- /dev/null
+++ b/tests_async/oauth2/test_id_token.py
@@ -0,0 +1,205 @@
+# Copyright 2020 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+import mock
+import pytest
+
+from google.auth import environment_vars
+from google.auth import exceptions
+import google.auth.compute_engine._metadata
+from google.oauth2 import _id_token_async as id_token
+from google.oauth2 import id_token as sync_id_token
+from tests.oauth2 import test_id_token
+
+
+def make_request(status, data=None):
+ response = mock.AsyncMock(spec=["transport.Response"])
+ response.status = status
+
+ if data is not None:
+ response.data = mock.AsyncMock(spec=["__call__", "read"])
+ response.data.read = mock.AsyncMock(spec=["__call__"], return_value=data)
+
+ request = mock.AsyncMock(spec=["transport.Request"])
+ request.return_value = response
+ return request
+
+
[email protected]
+async def test__fetch_certs_success():
+ certs = {"1": "cert"}
+ request = make_request(200, certs)
+
+ returned_certs = await id_token._fetch_certs(request, mock.sentinel.cert_url)
+
+ request.assert_called_once_with(mock.sentinel.cert_url, method="GET")
+ assert returned_certs == certs
+
+
[email protected]
+async def test__fetch_certs_failure():
+ request = make_request(404)
+
+ with pytest.raises(exceptions.TransportError):
+ await id_token._fetch_certs(request, mock.sentinel.cert_url)
+
+ request.assert_called_once_with(mock.sentinel.cert_url, method="GET")
+
+
[email protected]("google.auth.jwt.decode", autospec=True)
[email protected]("google.oauth2._id_token_async._fetch_certs", autospec=True)
[email protected]
+async def test_verify_token(_fetch_certs, decode):
+ result = await id_token.verify_token(mock.sentinel.token, mock.sentinel.request)
+
+ assert result == decode.return_value
+ _fetch_certs.assert_called_once_with(
+ mock.sentinel.request, sync_id_token._GOOGLE_OAUTH2_CERTS_URL
+ )
+ decode.assert_called_once_with(
+ mock.sentinel.token, certs=_fetch_certs.return_value, audience=None
+ )
+
+
[email protected]("google.auth.jwt.decode", autospec=True)
[email protected]("google.oauth2._id_token_async._fetch_certs", autospec=True)
[email protected]
+async def test_verify_token_args(_fetch_certs, decode):
+ result = await id_token.verify_token(
+ mock.sentinel.token,
+ mock.sentinel.request,
+ audience=mock.sentinel.audience,
+ certs_url=mock.sentinel.certs_url,
+ )
+
+ assert result == decode.return_value
+ _fetch_certs.assert_called_once_with(mock.sentinel.request, mock.sentinel.certs_url)
+ decode.assert_called_once_with(
+ mock.sentinel.token,
+ certs=_fetch_certs.return_value,
+ audience=mock.sentinel.audience,
+ )
+
+
[email protected]("google.oauth2._id_token_async.verify_token", autospec=True)
[email protected]
+async def test_verify_oauth2_token(verify_token):
+ verify_token.return_value = {"iss": "accounts.google.com"}
+ result = await id_token.verify_oauth2_token(
+ mock.sentinel.token, mock.sentinel.request, audience=mock.sentinel.audience
+ )
+
+ assert result == verify_token.return_value
+ verify_token.assert_called_once_with(
+ mock.sentinel.token,
+ mock.sentinel.request,
+ audience=mock.sentinel.audience,
+ certs_url=sync_id_token._GOOGLE_OAUTH2_CERTS_URL,
+ )
+
+
[email protected]("google.oauth2._id_token_async.verify_token", autospec=True)
[email protected]
+async def test_verify_oauth2_token_invalid_iss(verify_token):
+ verify_token.return_value = {"iss": "invalid_issuer"}
+
+ with pytest.raises(exceptions.GoogleAuthError):
+ await id_token.verify_oauth2_token(
+ mock.sentinel.token, mock.sentinel.request, audience=mock.sentinel.audience
+ )
+
+
[email protected]("google.oauth2._id_token_async.verify_token", autospec=True)
[email protected]
+async def test_verify_firebase_token(verify_token):
+ result = await id_token.verify_firebase_token(
+ mock.sentinel.token, mock.sentinel.request, audience=mock.sentinel.audience
+ )
+
+ assert result == verify_token.return_value
+ verify_token.assert_called_once_with(
+ mock.sentinel.token,
+ mock.sentinel.request,
+ audience=mock.sentinel.audience,
+ certs_url=sync_id_token._GOOGLE_APIS_CERTS_URL,
+ )
+
+
[email protected]
+async def test_fetch_id_token_from_metadata_server():
+ def mock_init(self, request, audience, use_metadata_identity_endpoint):
+ assert use_metadata_identity_endpoint
+ self.token = "id_token"
+
+ with mock.patch.multiple(
+ google.auth.compute_engine.IDTokenCredentials,
+ __init__=mock_init,
+ refresh=mock.Mock(),
+ ):
+ request = mock.AsyncMock()
+ token = await id_token.fetch_id_token(request, "https://pubsub.googleapis.com")
+ assert token == "id_token"
+
+
[email protected](
+ google.auth.compute_engine.IDTokenCredentials,
+ "__init__",
+ side_effect=exceptions.TransportError(),
+)
[email protected]
+async def test_fetch_id_token_from_explicit_cred_json_file(mock_init, monkeypatch):
+ monkeypatch.setenv(environment_vars.CREDENTIALS, test_id_token.SERVICE_ACCOUNT_FILE)
+
+ async def mock_refresh(self, request):
+ self.token = "id_token"
+
+ with mock.patch.object(
+ google.oauth2._service_account_async.IDTokenCredentials, "refresh", mock_refresh
+ ):
+ request = mock.AsyncMock()
+ token = await id_token.fetch_id_token(request, "https://pubsub.googleapis.com")
+ assert token == "id_token"
+
+
[email protected](
+ google.auth.compute_engine.IDTokenCredentials,
+ "__init__",
+ side_effect=exceptions.TransportError(),
+)
[email protected]
+async def test_fetch_id_token_no_cred_json_file(mock_init, monkeypatch):
+ monkeypatch.delenv(environment_vars.CREDENTIALS, raising=False)
+
+ with pytest.raises(exceptions.DefaultCredentialsError):
+ request = mock.AsyncMock()
+ await id_token.fetch_id_token(request, "https://pubsub.googleapis.com")
+
+
[email protected](
+ google.auth.compute_engine.IDTokenCredentials,
+ "__init__",
+ side_effect=exceptions.TransportError(),
+)
[email protected]
+async def test_fetch_id_token_invalid_cred_file(mock_init, monkeypatch):
+ not_json_file = os.path.join(
+ os.path.dirname(__file__), "../../tests/data/public_cert.pem"
+ )
+ monkeypatch.setenv(environment_vars.CREDENTIALS, not_json_file)
+
+ with pytest.raises(exceptions.DefaultCredentialsError):
+ request = mock.AsyncMock()
+ await id_token.fetch_id_token(request, "https://pubsub.googleapis.com")
diff --git a/tests_async/oauth2/test_service_account_async.py b/tests_async/oauth2/test_service_account_async.py
new file mode 100644
index 0000000..4079453
--- /dev/null
+++ b/tests_async/oauth2/test_service_account_async.py
@@ -0,0 +1,372 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+
+import mock
+import pytest
+
+from google.auth import _helpers
+from google.auth import crypt
+from google.auth import jwt
+from google.auth import transport
+from google.oauth2 import _service_account_async as service_account
+from tests.oauth2 import test_service_account
+
+
+class TestCredentials(object):
+ SERVICE_ACCOUNT_EMAIL = "[email protected]"
+ TOKEN_URI = "https://example.com/oauth2/token"
+
+ @classmethod
+ def make_credentials(cls):
+ return service_account.Credentials(
+ test_service_account.SIGNER, cls.SERVICE_ACCOUNT_EMAIL, cls.TOKEN_URI
+ )
+
+ def test_from_service_account_info(self):
+ credentials = service_account.Credentials.from_service_account_info(
+ test_service_account.SERVICE_ACCOUNT_INFO
+ )
+
+ assert (
+ credentials._signer.key_id
+ == test_service_account.SERVICE_ACCOUNT_INFO["private_key_id"]
+ )
+ assert (
+ credentials.service_account_email
+ == test_service_account.SERVICE_ACCOUNT_INFO["client_email"]
+ )
+ assert (
+ credentials._token_uri
+ == test_service_account.SERVICE_ACCOUNT_INFO["token_uri"]
+ )
+
+ def test_from_service_account_info_args(self):
+ info = test_service_account.SERVICE_ACCOUNT_INFO.copy()
+ scopes = ["email", "profile"]
+ subject = "subject"
+ additional_claims = {"meta": "data"}
+
+ credentials = service_account.Credentials.from_service_account_info(
+ info, scopes=scopes, subject=subject, additional_claims=additional_claims
+ )
+
+ assert credentials.service_account_email == info["client_email"]
+ assert credentials.project_id == info["project_id"]
+ assert credentials._signer.key_id == info["private_key_id"]
+ assert credentials._token_uri == info["token_uri"]
+ assert credentials._scopes == scopes
+ assert credentials._subject == subject
+ assert credentials._additional_claims == additional_claims
+
+ def test_from_service_account_file(self):
+ info = test_service_account.SERVICE_ACCOUNT_INFO.copy()
+
+ credentials = service_account.Credentials.from_service_account_file(
+ test_service_account.SERVICE_ACCOUNT_JSON_FILE
+ )
+
+ assert credentials.service_account_email == info["client_email"]
+ assert credentials.project_id == info["project_id"]
+ assert credentials._signer.key_id == info["private_key_id"]
+ assert credentials._token_uri == info["token_uri"]
+
+ def test_from_service_account_file_args(self):
+ info = test_service_account.SERVICE_ACCOUNT_INFO.copy()
+ scopes = ["email", "profile"]
+ subject = "subject"
+ additional_claims = {"meta": "data"}
+
+ credentials = service_account.Credentials.from_service_account_file(
+ test_service_account.SERVICE_ACCOUNT_JSON_FILE,
+ subject=subject,
+ scopes=scopes,
+ additional_claims=additional_claims,
+ )
+
+ assert credentials.service_account_email == info["client_email"]
+ assert credentials.project_id == info["project_id"]
+ assert credentials._signer.key_id == info["private_key_id"]
+ assert credentials._token_uri == info["token_uri"]
+ assert credentials._scopes == scopes
+ assert credentials._subject == subject
+ assert credentials._additional_claims == additional_claims
+
+ def test_default_state(self):
+ credentials = self.make_credentials()
+ assert not credentials.valid
+ # Expiration hasn't been set yet
+ assert not credentials.expired
+ # Scopes haven't been specified yet
+ assert credentials.requires_scopes
+
+ def test_sign_bytes(self):
+ credentials = self.make_credentials()
+ to_sign = b"123"
+ signature = credentials.sign_bytes(to_sign)
+ assert crypt.verify_signature(
+ to_sign, signature, test_service_account.PUBLIC_CERT_BYTES
+ )
+
+ def test_signer(self):
+ credentials = self.make_credentials()
+ assert isinstance(credentials.signer, crypt.Signer)
+
+ def test_signer_email(self):
+ credentials = self.make_credentials()
+ assert credentials.signer_email == self.SERVICE_ACCOUNT_EMAIL
+
+ def test_create_scoped(self):
+ credentials = self.make_credentials()
+ scopes = ["email", "profile"]
+ credentials = credentials.with_scopes(scopes)
+ assert credentials._scopes == scopes
+
+ def test_with_claims(self):
+ credentials = self.make_credentials()
+ new_credentials = credentials.with_claims({"meep": "moop"})
+ assert new_credentials._additional_claims == {"meep": "moop"}
+
+ def test_with_quota_project(self):
+ credentials = self.make_credentials()
+ new_credentials = credentials.with_quota_project("new-project-456")
+ assert new_credentials.quota_project_id == "new-project-456"
+ hdrs = {}
+ new_credentials.apply(hdrs, token="tok")
+ assert "x-goog-user-project" in hdrs
+
+ def test__make_authorization_grant_assertion(self):
+ credentials = self.make_credentials()
+ token = credentials._make_authorization_grant_assertion()
+ payload = jwt.decode(token, test_service_account.PUBLIC_CERT_BYTES)
+ assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL
+ assert payload["aud"] == self.TOKEN_URI
+
+ def test__make_authorization_grant_assertion_scoped(self):
+ credentials = self.make_credentials()
+ scopes = ["email", "profile"]
+ credentials = credentials.with_scopes(scopes)
+ token = credentials._make_authorization_grant_assertion()
+ payload = jwt.decode(token, test_service_account.PUBLIC_CERT_BYTES)
+ assert payload["scope"] == "email profile"
+
+ def test__make_authorization_grant_assertion_subject(self):
+ credentials = self.make_credentials()
+ subject = "[email protected]"
+ credentials = credentials.with_subject(subject)
+ token = credentials._make_authorization_grant_assertion()
+ payload = jwt.decode(token, test_service_account.PUBLIC_CERT_BYTES)
+ assert payload["sub"] == subject
+
+ @mock.patch("google.oauth2._client_async.jwt_grant", autospec=True)
+ @pytest.mark.asyncio
+ async def test_refresh_success(self, jwt_grant):
+ credentials = self.make_credentials()
+ token = "token"
+ jwt_grant.return_value = (
+ token,
+ _helpers.utcnow() + datetime.timedelta(seconds=500),
+ {},
+ )
+ request = mock.create_autospec(transport.Request, instance=True)
+
+ # Refresh credentials
+ await credentials.refresh(request)
+
+ # Check jwt grant call.
+ assert jwt_grant.called
+
+ called_request, token_uri, assertion = jwt_grant.call_args[0]
+ assert called_request == request
+ assert token_uri == credentials._token_uri
+ assert jwt.decode(assertion, test_service_account.PUBLIC_CERT_BYTES)
+ # No further assertion done on the token, as there are separate tests
+ # for checking the authorization grant assertion.
+
+ # Check that the credentials have the token.
+ assert credentials.token == token
+
+ # Check that the credentials are valid (have a token and are not
+ # expired)
+ assert credentials.valid
+
+ @mock.patch("google.oauth2._client_async.jwt_grant", autospec=True)
+ @pytest.mark.asyncio
+ async def test_before_request_refreshes(self, jwt_grant):
+ credentials = self.make_credentials()
+ token = "token"
+ jwt_grant.return_value = (
+ token,
+ _helpers.utcnow() + datetime.timedelta(seconds=500),
+ None,
+ )
+ request = mock.create_autospec(transport.Request, instance=True)
+
+ # Credentials should start as invalid
+ assert not credentials.valid
+
+ # before_request should cause a refresh
+ await credentials.before_request(request, "GET", "http://example.com?a=1#3", {})
+
+ # The refresh endpoint should've been called.
+ assert jwt_grant.called
+
+ # Credentials should now be valid.
+ assert credentials.valid
+
+
+class TestIDTokenCredentials(object):
+ SERVICE_ACCOUNT_EMAIL = "[email protected]"
+ TOKEN_URI = "https://example.com/oauth2/token"
+ TARGET_AUDIENCE = "https://example.com"
+
+ @classmethod
+ def make_credentials(cls):
+ return service_account.IDTokenCredentials(
+ test_service_account.SIGNER,
+ cls.SERVICE_ACCOUNT_EMAIL,
+ cls.TOKEN_URI,
+ cls.TARGET_AUDIENCE,
+ )
+
+ def test_from_service_account_info(self):
+ credentials = service_account.IDTokenCredentials.from_service_account_info(
+ test_service_account.SERVICE_ACCOUNT_INFO,
+ target_audience=self.TARGET_AUDIENCE,
+ )
+
+ assert (
+ credentials._signer.key_id
+ == test_service_account.SERVICE_ACCOUNT_INFO["private_key_id"]
+ )
+ assert (
+ credentials.service_account_email
+ == test_service_account.SERVICE_ACCOUNT_INFO["client_email"]
+ )
+ assert (
+ credentials._token_uri
+ == test_service_account.SERVICE_ACCOUNT_INFO["token_uri"]
+ )
+ assert credentials._target_audience == self.TARGET_AUDIENCE
+
+ def test_from_service_account_file(self):
+ info = test_service_account.SERVICE_ACCOUNT_INFO.copy()
+
+ credentials = service_account.IDTokenCredentials.from_service_account_file(
+ test_service_account.SERVICE_ACCOUNT_JSON_FILE,
+ target_audience=self.TARGET_AUDIENCE,
+ )
+
+ assert credentials.service_account_email == info["client_email"]
+ assert credentials._signer.key_id == info["private_key_id"]
+ assert credentials._token_uri == info["token_uri"]
+ assert credentials._target_audience == self.TARGET_AUDIENCE
+
+ def test_default_state(self):
+ credentials = self.make_credentials()
+ assert not credentials.valid
+ # Expiration hasn't been set yet
+ assert not credentials.expired
+
+ def test_sign_bytes(self):
+ credentials = self.make_credentials()
+ to_sign = b"123"
+ signature = credentials.sign_bytes(to_sign)
+ assert crypt.verify_signature(
+ to_sign, signature, test_service_account.PUBLIC_CERT_BYTES
+ )
+
+ def test_signer(self):
+ credentials = self.make_credentials()
+ assert isinstance(credentials.signer, crypt.Signer)
+
+ def test_signer_email(self):
+ credentials = self.make_credentials()
+ assert credentials.signer_email == self.SERVICE_ACCOUNT_EMAIL
+
+ def test_with_target_audience(self):
+ credentials = self.make_credentials()
+ new_credentials = credentials.with_target_audience("https://new.example.com")
+ assert new_credentials._target_audience == "https://new.example.com"
+
+ def test_with_quota_project(self):
+ credentials = self.make_credentials()
+ new_credentials = credentials.with_quota_project("project-foo")
+ assert new_credentials._quota_project_id == "project-foo"
+
+ def test__make_authorization_grant_assertion(self):
+ credentials = self.make_credentials()
+ token = credentials._make_authorization_grant_assertion()
+ payload = jwt.decode(token, test_service_account.PUBLIC_CERT_BYTES)
+ assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL
+ assert payload["aud"] == self.TOKEN_URI
+ assert payload["target_audience"] == self.TARGET_AUDIENCE
+
+ @mock.patch("google.oauth2._client_async.id_token_jwt_grant", autospec=True)
+ @pytest.mark.asyncio
+ async def test_refresh_success(self, id_token_jwt_grant):
+ credentials = self.make_credentials()
+ token = "token"
+ id_token_jwt_grant.return_value = (
+ token,
+ _helpers.utcnow() + datetime.timedelta(seconds=500),
+ {},
+ )
+
+ request = mock.AsyncMock(spec=["transport.Request"])
+
+ # Refresh credentials
+ await credentials.refresh(request)
+
+ # Check jwt grant call.
+ assert id_token_jwt_grant.called
+
+ called_request, token_uri, assertion = id_token_jwt_grant.call_args[0]
+ assert called_request == request
+ assert token_uri == credentials._token_uri
+ assert jwt.decode(assertion, test_service_account.PUBLIC_CERT_BYTES)
+ # No further assertion done on the token, as there are separate tests
+ # for checking the authorization grant assertion.
+
+ # Check that the credentials have the token.
+ assert credentials.token == token
+
+ # Check that the credentials are valid (have a token and are not
+ # expired)
+ assert credentials.valid
+
+ @mock.patch("google.oauth2._client_async.id_token_jwt_grant", autospec=True)
+ @pytest.mark.asyncio
+ async def test_before_request_refreshes(self, id_token_jwt_grant):
+ credentials = self.make_credentials()
+ token = "token"
+ id_token_jwt_grant.return_value = (
+ token,
+ _helpers.utcnow() + datetime.timedelta(seconds=500),
+ None,
+ )
+ request = mock.AsyncMock(spec=["transport.Request"])
+
+ # Credentials should start as invalid
+ assert not credentials.valid
+
+ # before_request should cause a refresh
+ await credentials.before_request(request, "GET", "http://example.com?a=1#3", {})
+
+ # The refresh endpoint should've been called.
+ assert id_token_jwt_grant.called
+
+ # Credentials should now be valid.
+ assert credentials.valid
diff --git a/tests_async/test__default_async.py b/tests_async/test__default_async.py
new file mode 100644
index 0000000..bca396a
--- /dev/null
+++ b/tests_async/test__default_async.py
@@ -0,0 +1,468 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+
+import mock
+import pytest
+
+from google.auth import _credentials_async as credentials
+from google.auth import _default_async as _default
+from google.auth import app_engine
+from google.auth import compute_engine
+from google.auth import environment_vars
+from google.auth import exceptions
+from google.oauth2 import _service_account_async as service_account
+import google.oauth2.credentials
+from tests import test__default as test_default
+
+MOCK_CREDENTIALS = mock.Mock(spec=credentials.CredentialsWithQuotaProject)
+MOCK_CREDENTIALS.with_quota_project.return_value = MOCK_CREDENTIALS
+
+LOAD_FILE_PATCH = mock.patch(
+ "google.auth._default_async.load_credentials_from_file",
+ return_value=(MOCK_CREDENTIALS, mock.sentinel.project_id),
+ autospec=True,
+)
+
+
+def test_load_credentials_from_missing_file():
+ with pytest.raises(exceptions.DefaultCredentialsError) as excinfo:
+ _default.load_credentials_from_file("")
+
+ assert excinfo.match(r"not found")
+
+
+def test_load_credentials_from_file_invalid_json(tmpdir):
+ jsonfile = tmpdir.join("invalid.json")
+ jsonfile.write("{")
+
+ with pytest.raises(exceptions.DefaultCredentialsError) as excinfo:
+ _default.load_credentials_from_file(str(jsonfile))
+
+ assert excinfo.match(r"not a valid json file")
+
+
+def test_load_credentials_from_file_invalid_type(tmpdir):
+ jsonfile = tmpdir.join("invalid.json")
+ jsonfile.write(json.dumps({"type": "not-a-real-type"}))
+
+ with pytest.raises(exceptions.DefaultCredentialsError) as excinfo:
+ _default.load_credentials_from_file(str(jsonfile))
+
+ assert excinfo.match(r"does not have a valid type")
+
+
+def test_load_credentials_from_file_authorized_user():
+ credentials, project_id = _default.load_credentials_from_file(
+ test_default.AUTHORIZED_USER_FILE
+ )
+ assert isinstance(credentials, google.oauth2._credentials_async.Credentials)
+ assert project_id is None
+
+
+def test_load_credentials_from_file_no_type(tmpdir):
+ # use the client_secrets.json, which is valid json but not a
+ # loadable credentials type
+ with pytest.raises(exceptions.DefaultCredentialsError) as excinfo:
+ _default.load_credentials_from_file(test_default.CLIENT_SECRETS_FILE)
+
+ assert excinfo.match(r"does not have a valid type")
+ assert excinfo.match(r"Type is None")
+
+
+def test_load_credentials_from_file_authorized_user_bad_format(tmpdir):
+ filename = tmpdir.join("authorized_user_bad.json")
+ filename.write(json.dumps({"type": "authorized_user"}))
+
+ with pytest.raises(exceptions.DefaultCredentialsError) as excinfo:
+ _default.load_credentials_from_file(str(filename))
+
+ assert excinfo.match(r"Failed to load authorized user")
+ assert excinfo.match(r"missing fields")
+
+
+def test_load_credentials_from_file_authorized_user_cloud_sdk():
+ with pytest.warns(UserWarning, match="Cloud SDK"):
+ credentials, project_id = _default.load_credentials_from_file(
+ test_default.AUTHORIZED_USER_CLOUD_SDK_FILE
+ )
+ assert isinstance(credentials, google.oauth2._credentials_async.Credentials)
+ assert project_id is None
+
+ # No warning if the json file has quota project id.
+ credentials, project_id = _default.load_credentials_from_file(
+ test_default.AUTHORIZED_USER_CLOUD_SDK_WITH_QUOTA_PROJECT_ID_FILE
+ )
+ assert isinstance(credentials, google.oauth2._credentials_async.Credentials)
+ assert project_id is None
+
+
+def test_load_credentials_from_file_authorized_user_cloud_sdk_with_scopes():
+ with pytest.warns(UserWarning, match="Cloud SDK"):
+ credentials, project_id = _default.load_credentials_from_file(
+ test_default.AUTHORIZED_USER_CLOUD_SDK_FILE,
+ scopes=["https://www.google.com/calendar/feeds"],
+ )
+ assert isinstance(credentials, google.oauth2._credentials_async.Credentials)
+ assert project_id is None
+ assert credentials.scopes == ["https://www.google.com/calendar/feeds"]
+
+
+def test_load_credentials_from_file_authorized_user_cloud_sdk_with_quota_project():
+ credentials, project_id = _default.load_credentials_from_file(
+ test_default.AUTHORIZED_USER_CLOUD_SDK_FILE, quota_project_id="project-foo"
+ )
+
+ assert isinstance(credentials, google.oauth2._credentials_async.Credentials)
+ assert project_id is None
+ assert credentials.quota_project_id == "project-foo"
+
+
+def test_load_credentials_from_file_service_account():
+ credentials, project_id = _default.load_credentials_from_file(
+ test_default.SERVICE_ACCOUNT_FILE
+ )
+ assert isinstance(credentials, service_account.Credentials)
+ assert project_id == test_default.SERVICE_ACCOUNT_FILE_DATA["project_id"]
+
+
+def test_load_credentials_from_file_service_account_with_scopes():
+ credentials, project_id = _default.load_credentials_from_file(
+ test_default.SERVICE_ACCOUNT_FILE,
+ scopes=["https://www.google.com/calendar/feeds"],
+ )
+ assert isinstance(credentials, service_account.Credentials)
+ assert project_id == test_default.SERVICE_ACCOUNT_FILE_DATA["project_id"]
+ assert credentials.scopes == ["https://www.google.com/calendar/feeds"]
+
+
+def test_load_credentials_from_file_service_account_bad_format(tmpdir):
+ filename = tmpdir.join("serivce_account_bad.json")
+ filename.write(json.dumps({"type": "service_account"}))
+
+ with pytest.raises(exceptions.DefaultCredentialsError) as excinfo:
+ _default.load_credentials_from_file(str(filename))
+
+ assert excinfo.match(r"Failed to load service account")
+ assert excinfo.match(r"missing fields")
+
+
[email protected](os.environ, {}, clear=True)
+def test__get_explicit_environ_credentials_no_env():
+ assert _default._get_explicit_environ_credentials() == (None, None)
+
+
+@LOAD_FILE_PATCH
+def test__get_explicit_environ_credentials(load, monkeypatch):
+ monkeypatch.setenv(environment_vars.CREDENTIALS, "filename")
+
+ credentials, project_id = _default._get_explicit_environ_credentials()
+
+ assert credentials is MOCK_CREDENTIALS
+ assert project_id is mock.sentinel.project_id
+ load.assert_called_with("filename")
+
+
+@LOAD_FILE_PATCH
+def test__get_explicit_environ_credentials_no_project_id(load, monkeypatch):
+ load.return_value = MOCK_CREDENTIALS, None
+ monkeypatch.setenv(environment_vars.CREDENTIALS, "filename")
+
+ credentials, project_id = _default._get_explicit_environ_credentials()
+
+ assert credentials is MOCK_CREDENTIALS
+ assert project_id is None
+
+
+@LOAD_FILE_PATCH
[email protected](
+ "google.auth._cloud_sdk.get_application_default_credentials_path", autospec=True
+)
+def test__get_gcloud_sdk_credentials(get_adc_path, load):
+ get_adc_path.return_value = test_default.SERVICE_ACCOUNT_FILE
+
+ credentials, project_id = _default._get_gcloud_sdk_credentials()
+
+ assert credentials is MOCK_CREDENTIALS
+ assert project_id is mock.sentinel.project_id
+ load.assert_called_with(test_default.SERVICE_ACCOUNT_FILE)
+
+
[email protected](
+ "google.auth._cloud_sdk.get_application_default_credentials_path", autospec=True
+)
+def test__get_gcloud_sdk_credentials_non_existent(get_adc_path, tmpdir):
+ non_existent = tmpdir.join("non-existent")
+ get_adc_path.return_value = str(non_existent)
+
+ credentials, project_id = _default._get_gcloud_sdk_credentials()
+
+ assert credentials is None
+ assert project_id is None
+
+
[email protected](
+ "google.auth._cloud_sdk.get_project_id",
+ return_value=mock.sentinel.project_id,
+ autospec=True,
+)
[email protected]("os.path.isfile", return_value=True, autospec=True)
+@LOAD_FILE_PATCH
+def test__get_gcloud_sdk_credentials_project_id(load, unused_isfile, get_project_id):
+ # Don't return a project ID from load file, make the function check
+ # the Cloud SDK project.
+ load.return_value = MOCK_CREDENTIALS, None
+
+ credentials, project_id = _default._get_gcloud_sdk_credentials()
+
+ assert credentials == MOCK_CREDENTIALS
+ assert project_id == mock.sentinel.project_id
+ assert get_project_id.called
+
+
[email protected]("google.auth._cloud_sdk.get_project_id", return_value=None, autospec=True)
[email protected]("os.path.isfile", return_value=True)
+@LOAD_FILE_PATCH
+def test__get_gcloud_sdk_credentials_no_project_id(load, unused_isfile, get_project_id):
+ # Don't return a project ID from load file, make the function check
+ # the Cloud SDK project.
+ load.return_value = MOCK_CREDENTIALS, None
+
+ credentials, project_id = _default._get_gcloud_sdk_credentials()
+
+ assert credentials == MOCK_CREDENTIALS
+ assert project_id is None
+ assert get_project_id.called
+
+
+class _AppIdentityModule(object):
+ """The interface of the App Idenity app engine module.
+ See https://cloud.google.com/appengine/docs/standard/python/refdocs\
+ /google.appengine.api.app_identity.app_identity
+ """
+
+ def get_application_id(self):
+ raise NotImplementedError()
+
+
[email protected]
+def app_identity(monkeypatch):
+ """Mocks the app_identity module for google.auth.app_engine."""
+ app_identity_module = mock.create_autospec(_AppIdentityModule, instance=True)
+ monkeypatch.setattr(app_engine, "app_identity", app_identity_module)
+ yield app_identity_module
+
+
+def test__get_gae_credentials(app_identity):
+ app_identity.get_application_id.return_value = mock.sentinel.project
+
+ credentials, project_id = _default._get_gae_credentials()
+
+ assert isinstance(credentials, app_engine.Credentials)
+ assert project_id == mock.sentinel.project
+
+
+def test__get_gae_credentials_no_app_engine():
+ import sys
+
+ with mock.patch.dict("sys.modules"):
+ sys.modules["google.auth.app_engine"] = None
+ credentials, project_id = _default._get_gae_credentials()
+ assert credentials is None
+ assert project_id is None
+
+
+def test__get_gae_credentials_no_apis():
+ assert _default._get_gae_credentials() == (None, None)
+
+
[email protected](
+ "google.auth.compute_engine._metadata.ping", return_value=True, autospec=True
+)
[email protected](
+ "google.auth.compute_engine._metadata.get_project_id",
+ return_value="example-project",
+ autospec=True,
+)
+def test__get_gce_credentials(unused_get, unused_ping):
+ credentials, project_id = _default._get_gce_credentials()
+
+ assert isinstance(credentials, compute_engine.Credentials)
+ assert project_id == "example-project"
+
+
[email protected](
+ "google.auth.compute_engine._metadata.ping", return_value=False, autospec=True
+)
+def test__get_gce_credentials_no_ping(unused_ping):
+ credentials, project_id = _default._get_gce_credentials()
+
+ assert credentials is None
+ assert project_id is None
+
+
[email protected](
+ "google.auth.compute_engine._metadata.ping", return_value=True, autospec=True
+)
[email protected](
+ "google.auth.compute_engine._metadata.get_project_id",
+ side_effect=exceptions.TransportError(),
+ autospec=True,
+)
+def test__get_gce_credentials_no_project_id(unused_get, unused_ping):
+ credentials, project_id = _default._get_gce_credentials()
+
+ assert isinstance(credentials, compute_engine.Credentials)
+ assert project_id is None
+
+
+def test__get_gce_credentials_no_compute_engine():
+ import sys
+
+ with mock.patch.dict("sys.modules"):
+ sys.modules["google.auth.compute_engine"] = None
+ credentials, project_id = _default._get_gce_credentials()
+ assert credentials is None
+ assert project_id is None
+
+
[email protected](
+ "google.auth.compute_engine._metadata.ping", return_value=False, autospec=True
+)
+def test__get_gce_credentials_explicit_request(ping):
+ _default._get_gce_credentials(mock.sentinel.request)
+ ping.assert_called_with(request=mock.sentinel.request)
+
+
[email protected](
+ "google.auth._default_async._get_explicit_environ_credentials",
+ return_value=(MOCK_CREDENTIALS, mock.sentinel.project_id),
+ autospec=True,
+)
+def test_default_early_out(unused_get):
+ assert _default.default_async() == (MOCK_CREDENTIALS, mock.sentinel.project_id)
+
+
[email protected](
+ "google.auth._default_async._get_explicit_environ_credentials",
+ return_value=(MOCK_CREDENTIALS, mock.sentinel.project_id),
+ autospec=True,
+)
+def test_default_explict_project_id(unused_get, monkeypatch):
+ monkeypatch.setenv(environment_vars.PROJECT, "explicit-env")
+ assert _default.default_async() == (MOCK_CREDENTIALS, "explicit-env")
+
+
[email protected](
+ "google.auth._default_async._get_explicit_environ_credentials",
+ return_value=(MOCK_CREDENTIALS, mock.sentinel.project_id),
+ autospec=True,
+)
+def test_default_explict_legacy_project_id(unused_get, monkeypatch):
+ monkeypatch.setenv(environment_vars.LEGACY_PROJECT, "explicit-env")
+ assert _default.default_async() == (MOCK_CREDENTIALS, "explicit-env")
+
+
[email protected]("logging.Logger.warning", autospec=True)
[email protected](
+ "google.auth._default_async._get_explicit_environ_credentials",
+ return_value=(MOCK_CREDENTIALS, None),
+ autospec=True,
+)
[email protected](
+ "google.auth._default_async._get_gcloud_sdk_credentials",
+ return_value=(MOCK_CREDENTIALS, None),
+ autospec=True,
+)
[email protected](
+ "google.auth._default_async._get_gae_credentials",
+ return_value=(MOCK_CREDENTIALS, None),
+ autospec=True,
+)
[email protected](
+ "google.auth._default_async._get_gce_credentials",
+ return_value=(MOCK_CREDENTIALS, None),
+ autospec=True,
+)
+def test_default_without_project_id(
+ unused_gce, unused_gae, unused_sdk, unused_explicit, logger_warning
+):
+ assert _default.default_async() == (MOCK_CREDENTIALS, None)
+ logger_warning.assert_called_with(mock.ANY, mock.ANY, mock.ANY)
+
+
[email protected](
+ "google.auth._default_async._get_explicit_environ_credentials",
+ return_value=(None, None),
+ autospec=True,
+)
[email protected](
+ "google.auth._default_async._get_gcloud_sdk_credentials",
+ return_value=(None, None),
+ autospec=True,
+)
[email protected](
+ "google.auth._default_async._get_gae_credentials",
+ return_value=(None, None),
+ autospec=True,
+)
[email protected](
+ "google.auth._default_async._get_gce_credentials",
+ return_value=(None, None),
+ autospec=True,
+)
+def test_default_fail(unused_gce, unused_gae, unused_sdk, unused_explicit):
+ with pytest.raises(exceptions.DefaultCredentialsError):
+ assert _default.default_async()
+
+
[email protected](
+ "google.auth._default_async._get_explicit_environ_credentials",
+ return_value=(MOCK_CREDENTIALS, mock.sentinel.project_id),
+ autospec=True,
+)
[email protected](
+ "google.auth._credentials_async.with_scopes_if_required",
+ return_value=MOCK_CREDENTIALS,
+ autospec=True,
+)
+def test_default_scoped(with_scopes, unused_get):
+ scopes = ["one", "two"]
+
+ credentials, project_id = _default.default_async(scopes=scopes)
+
+ assert credentials == with_scopes.return_value
+ assert project_id == mock.sentinel.project_id
+ with_scopes.assert_called_once_with(MOCK_CREDENTIALS, scopes)
+
+
[email protected](
+ "google.auth._default_async._get_explicit_environ_credentials",
+ return_value=(MOCK_CREDENTIALS, mock.sentinel.project_id),
+ autospec=True,
+)
+def test_default_no_app_engine_compute_engine_module(unused_get):
+ """
+ google.auth.compute_engine and google.auth.app_engine are both optional
+ to allow not including them when using this package. This verifies
+ that default fails gracefully if these modules are absent
+ """
+ import sys
+
+ with mock.patch.dict("sys.modules"):
+ sys.modules["google.auth.compute_engine"] = None
+ sys.modules["google.auth.app_engine"] = None
+ assert _default.default_async() == (MOCK_CREDENTIALS, mock.sentinel.project_id)
diff --git a/tests_async/test_credentials_async.py b/tests_async/test_credentials_async.py
new file mode 100644
index 0000000..0a48908
--- /dev/null
+++ b/tests_async/test_credentials_async.py
@@ -0,0 +1,177 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+
+import pytest
+
+from google.auth import _credentials_async as credentials
+from google.auth import _helpers
+
+
+class CredentialsImpl(credentials.Credentials):
+ def refresh(self, request):
+ self.token = request
+
+ def with_quota_project(self, quota_project_id):
+ raise NotImplementedError()
+
+
+def test_credentials_constructor():
+ credentials = CredentialsImpl()
+ assert not credentials.token
+ assert not credentials.expiry
+ assert not credentials.expired
+ assert not credentials.valid
+
+
+def test_expired_and_valid():
+ credentials = CredentialsImpl()
+ credentials.token = "token"
+
+ assert credentials.valid
+ assert not credentials.expired
+
+ # Set the expiration to one second more than now plus the clock skew
+ # accomodation. These credentials should be valid.
+ credentials.expiry = (
+ datetime.datetime.utcnow() + _helpers.CLOCK_SKEW + datetime.timedelta(seconds=1)
+ )
+
+ assert credentials.valid
+ assert not credentials.expired
+
+ # Set the credentials expiration to now. Because of the clock skew
+ # accomodation, these credentials should report as expired.
+ credentials.expiry = datetime.datetime.utcnow()
+
+ assert not credentials.valid
+ assert credentials.expired
+
+
[email protected]
+async def test_before_request():
+ credentials = CredentialsImpl()
+ request = "token"
+ headers = {}
+
+ # First call should call refresh, setting the token.
+ await credentials.before_request(request, "http://example.com", "GET", headers)
+ assert credentials.valid
+ assert credentials.token == "token"
+ assert headers["authorization"] == "Bearer token"
+
+ request = "token2"
+ headers = {}
+
+ # Second call shouldn't call refresh.
+ credentials.before_request(request, "http://example.com", "GET", headers)
+
+ assert credentials.valid
+ assert credentials.token == "token"
+
+
+def test_anonymous_credentials_ctor():
+ anon = credentials.AnonymousCredentials()
+
+ assert anon.token is None
+ assert anon.expiry is None
+ assert not anon.expired
+ assert anon.valid
+
+
+def test_anonymous_credentials_refresh():
+ anon = credentials.AnonymousCredentials()
+
+ request = object()
+ with pytest.raises(ValueError):
+ anon.refresh(request)
+
+
+def test_anonymous_credentials_apply_default():
+ anon = credentials.AnonymousCredentials()
+ headers = {}
+ anon.apply(headers)
+ assert headers == {}
+ with pytest.raises(ValueError):
+ anon.apply(headers, token="TOKEN")
+
+
+def test_anonymous_credentials_before_request():
+ anon = credentials.AnonymousCredentials()
+ request = object()
+ method = "GET"
+ url = "https://example.com/api/endpoint"
+ headers = {}
+ anon.before_request(request, method, url, headers)
+ assert headers == {}
+
+
+class ReadOnlyScopedCredentialsImpl(credentials.ReadOnlyScoped, CredentialsImpl):
+ @property
+ def requires_scopes(self):
+ return super(ReadOnlyScopedCredentialsImpl, self).requires_scopes
+
+
+def test_readonly_scoped_credentials_constructor():
+ credentials = ReadOnlyScopedCredentialsImpl()
+ assert credentials._scopes is None
+
+
+def test_readonly_scoped_credentials_scopes():
+ credentials = ReadOnlyScopedCredentialsImpl()
+ credentials._scopes = ["one", "two"]
+ assert credentials.scopes == ["one", "two"]
+ assert credentials.has_scopes(["one"])
+ assert credentials.has_scopes(["two"])
+ assert credentials.has_scopes(["one", "two"])
+ assert not credentials.has_scopes(["three"])
+
+
+def test_readonly_scoped_credentials_requires_scopes():
+ credentials = ReadOnlyScopedCredentialsImpl()
+ assert not credentials.requires_scopes
+
+
+class RequiresScopedCredentialsImpl(credentials.Scoped, CredentialsImpl):
+ def __init__(self, scopes=None):
+ super(RequiresScopedCredentialsImpl, self).__init__()
+ self._scopes = scopes
+
+ @property
+ def requires_scopes(self):
+ return not self.scopes
+
+ def with_scopes(self, scopes):
+ return RequiresScopedCredentialsImpl(scopes=scopes)
+
+
+def test_create_scoped_if_required_scoped():
+ unscoped_credentials = RequiresScopedCredentialsImpl()
+ scoped_credentials = credentials.with_scopes_if_required(
+ unscoped_credentials, ["one", "two"]
+ )
+
+ assert scoped_credentials is not unscoped_credentials
+ assert not scoped_credentials.requires_scopes
+ assert scoped_credentials.has_scopes(["one", "two"])
+
+
+def test_create_scoped_if_required_not_scopes():
+ unscoped_credentials = CredentialsImpl()
+ scoped_credentials = credentials.with_scopes_if_required(
+ unscoped_credentials, ["one", "two"]
+ )
+
+ assert scoped_credentials is unscoped_credentials
diff --git a/tests_async/test_jwt_async.py b/tests_async/test_jwt_async.py
new file mode 100644
index 0000000..a35b837
--- /dev/null
+++ b/tests_async/test_jwt_async.py
@@ -0,0 +1,356 @@
+# Copyright 2020 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+import json
+
+import mock
+import pytest
+
+from google.auth import _jwt_async as jwt_async
+from google.auth import crypt
+from google.auth import exceptions
+from tests import test_jwt
+
+
[email protected]
+def signer():
+ return crypt.RSASigner.from_string(test_jwt.PRIVATE_KEY_BYTES, "1")
+
+
+class TestCredentials(object):
+ SERVICE_ACCOUNT_EMAIL = "[email protected]"
+ SUBJECT = "subject"
+ AUDIENCE = "audience"
+ ADDITIONAL_CLAIMS = {"meta": "data"}
+ credentials = None
+
+ @pytest.fixture(autouse=True)
+ def credentials_fixture(self, signer):
+ self.credentials = jwt_async.Credentials(
+ signer,
+ self.SERVICE_ACCOUNT_EMAIL,
+ self.SERVICE_ACCOUNT_EMAIL,
+ self.AUDIENCE,
+ )
+
+ def test_from_service_account_info(self):
+ with open(test_jwt.SERVICE_ACCOUNT_JSON_FILE, "r") as fh:
+ info = json.load(fh)
+
+ credentials = jwt_async.Credentials.from_service_account_info(
+ info, audience=self.AUDIENCE
+ )
+
+ assert credentials._signer.key_id == info["private_key_id"]
+ assert credentials._issuer == info["client_email"]
+ assert credentials._subject == info["client_email"]
+ assert credentials._audience == self.AUDIENCE
+
+ def test_from_service_account_info_args(self):
+ info = test_jwt.SERVICE_ACCOUNT_INFO.copy()
+
+ credentials = jwt_async.Credentials.from_service_account_info(
+ info,
+ subject=self.SUBJECT,
+ audience=self.AUDIENCE,
+ additional_claims=self.ADDITIONAL_CLAIMS,
+ )
+
+ assert credentials._signer.key_id == info["private_key_id"]
+ assert credentials._issuer == info["client_email"]
+ assert credentials._subject == self.SUBJECT
+ assert credentials._audience == self.AUDIENCE
+ assert credentials._additional_claims == self.ADDITIONAL_CLAIMS
+
+ def test_from_service_account_file(self):
+ info = test_jwt.SERVICE_ACCOUNT_INFO.copy()
+
+ credentials = jwt_async.Credentials.from_service_account_file(
+ test_jwt.SERVICE_ACCOUNT_JSON_FILE, audience=self.AUDIENCE
+ )
+
+ assert credentials._signer.key_id == info["private_key_id"]
+ assert credentials._issuer == info["client_email"]
+ assert credentials._subject == info["client_email"]
+ assert credentials._audience == self.AUDIENCE
+
+ def test_from_service_account_file_args(self):
+ info = test_jwt.SERVICE_ACCOUNT_INFO.copy()
+
+ credentials = jwt_async.Credentials.from_service_account_file(
+ test_jwt.SERVICE_ACCOUNT_JSON_FILE,
+ subject=self.SUBJECT,
+ audience=self.AUDIENCE,
+ additional_claims=self.ADDITIONAL_CLAIMS,
+ )
+
+ assert credentials._signer.key_id == info["private_key_id"]
+ assert credentials._issuer == info["client_email"]
+ assert credentials._subject == self.SUBJECT
+ assert credentials._audience == self.AUDIENCE
+ assert credentials._additional_claims == self.ADDITIONAL_CLAIMS
+
+ def test_from_signing_credentials(self):
+ jwt_from_signing = self.credentials.from_signing_credentials(
+ self.credentials, audience=mock.sentinel.new_audience
+ )
+ jwt_from_info = jwt_async.Credentials.from_service_account_info(
+ test_jwt.SERVICE_ACCOUNT_INFO, audience=mock.sentinel.new_audience
+ )
+
+ assert isinstance(jwt_from_signing, jwt_async.Credentials)
+ assert jwt_from_signing._signer.key_id == jwt_from_info._signer.key_id
+ assert jwt_from_signing._issuer == jwt_from_info._issuer
+ assert jwt_from_signing._subject == jwt_from_info._subject
+ assert jwt_from_signing._audience == jwt_from_info._audience
+
+ def test_default_state(self):
+ assert not self.credentials.valid
+ # Expiration hasn't been set yet
+ assert not self.credentials.expired
+
+ def test_with_claims(self):
+ new_audience = "new_audience"
+ new_credentials = self.credentials.with_claims(audience=new_audience)
+
+ assert new_credentials._signer == self.credentials._signer
+ assert new_credentials._issuer == self.credentials._issuer
+ assert new_credentials._subject == self.credentials._subject
+ assert new_credentials._audience == new_audience
+ assert new_credentials._additional_claims == self.credentials._additional_claims
+ assert new_credentials._quota_project_id == self.credentials._quota_project_id
+
+ def test_with_quota_project(self):
+ quota_project_id = "project-foo"
+
+ new_credentials = self.credentials.with_quota_project(quota_project_id)
+ assert new_credentials._signer == self.credentials._signer
+ assert new_credentials._issuer == self.credentials._issuer
+ assert new_credentials._subject == self.credentials._subject
+ assert new_credentials._audience == self.credentials._audience
+ assert new_credentials._additional_claims == self.credentials._additional_claims
+ assert new_credentials._quota_project_id == quota_project_id
+
+ def test_sign_bytes(self):
+ to_sign = b"123"
+ signature = self.credentials.sign_bytes(to_sign)
+ assert crypt.verify_signature(to_sign, signature, test_jwt.PUBLIC_CERT_BYTES)
+
+ def test_signer(self):
+ assert isinstance(self.credentials.signer, crypt.RSASigner)
+
+ def test_signer_email(self):
+ assert (
+ self.credentials.signer_email
+ == test_jwt.SERVICE_ACCOUNT_INFO["client_email"]
+ )
+
+ def _verify_token(self, token):
+ payload = jwt_async.decode(token, test_jwt.PUBLIC_CERT_BYTES)
+ assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL
+ return payload
+
+ def test_refresh(self):
+ self.credentials.refresh(None)
+ assert self.credentials.valid
+ assert not self.credentials.expired
+
+ def test_expired(self):
+ assert not self.credentials.expired
+
+ self.credentials.refresh(None)
+ assert not self.credentials.expired
+
+ with mock.patch("google.auth._helpers.utcnow") as now:
+ one_day = datetime.timedelta(days=1)
+ now.return_value = self.credentials.expiry + one_day
+ assert self.credentials.expired
+
+ @pytest.mark.asyncio
+ async def test_before_request(self):
+ headers = {}
+
+ self.credentials.refresh(None)
+ await self.credentials.before_request(
+ None, "GET", "http://example.com?a=1#3", headers
+ )
+
+ header_value = headers["authorization"]
+ _, token = header_value.split(" ")
+
+ # Since the audience is set, it should use the existing token.
+ assert token.encode("utf-8") == self.credentials.token
+
+ payload = self._verify_token(token)
+ assert payload["aud"] == self.AUDIENCE
+
+ @pytest.mark.asyncio
+ async def test_before_request_refreshes(self):
+ assert not self.credentials.valid
+ await self.credentials.before_request(
+ None, "GET", "http://example.com?a=1#3", {}
+ )
+ assert self.credentials.valid
+
+
+class TestOnDemandCredentials(object):
+ SERVICE_ACCOUNT_EMAIL = "[email protected]"
+ SUBJECT = "subject"
+ ADDITIONAL_CLAIMS = {"meta": "data"}
+ credentials = None
+
+ @pytest.fixture(autouse=True)
+ def credentials_fixture(self, signer):
+ self.credentials = jwt_async.OnDemandCredentials(
+ signer,
+ self.SERVICE_ACCOUNT_EMAIL,
+ self.SERVICE_ACCOUNT_EMAIL,
+ max_cache_size=2,
+ )
+
+ def test_from_service_account_info(self):
+ with open(test_jwt.SERVICE_ACCOUNT_JSON_FILE, "r") as fh:
+ info = json.load(fh)
+
+ credentials = jwt_async.OnDemandCredentials.from_service_account_info(info)
+
+ assert credentials._signer.key_id == info["private_key_id"]
+ assert credentials._issuer == info["client_email"]
+ assert credentials._subject == info["client_email"]
+
+ def test_from_service_account_info_args(self):
+ info = test_jwt.SERVICE_ACCOUNT_INFO.copy()
+
+ credentials = jwt_async.OnDemandCredentials.from_service_account_info(
+ info, subject=self.SUBJECT, additional_claims=self.ADDITIONAL_CLAIMS
+ )
+
+ assert credentials._signer.key_id == info["private_key_id"]
+ assert credentials._issuer == info["client_email"]
+ assert credentials._subject == self.SUBJECT
+ assert credentials._additional_claims == self.ADDITIONAL_CLAIMS
+
+ def test_from_service_account_file(self):
+ info = test_jwt.SERVICE_ACCOUNT_INFO.copy()
+
+ credentials = jwt_async.OnDemandCredentials.from_service_account_file(
+ test_jwt.SERVICE_ACCOUNT_JSON_FILE
+ )
+
+ assert credentials._signer.key_id == info["private_key_id"]
+ assert credentials._issuer == info["client_email"]
+ assert credentials._subject == info["client_email"]
+
+ def test_from_service_account_file_args(self):
+ info = test_jwt.SERVICE_ACCOUNT_INFO.copy()
+
+ credentials = jwt_async.OnDemandCredentials.from_service_account_file(
+ test_jwt.SERVICE_ACCOUNT_JSON_FILE,
+ subject=self.SUBJECT,
+ additional_claims=self.ADDITIONAL_CLAIMS,
+ )
+
+ assert credentials._signer.key_id == info["private_key_id"]
+ assert credentials._issuer == info["client_email"]
+ assert credentials._subject == self.SUBJECT
+ assert credentials._additional_claims == self.ADDITIONAL_CLAIMS
+
+ def test_from_signing_credentials(self):
+ jwt_from_signing = self.credentials.from_signing_credentials(self.credentials)
+ jwt_from_info = jwt_async.OnDemandCredentials.from_service_account_info(
+ test_jwt.SERVICE_ACCOUNT_INFO
+ )
+
+ assert isinstance(jwt_from_signing, jwt_async.OnDemandCredentials)
+ assert jwt_from_signing._signer.key_id == jwt_from_info._signer.key_id
+ assert jwt_from_signing._issuer == jwt_from_info._issuer
+ assert jwt_from_signing._subject == jwt_from_info._subject
+
+ def test_default_state(self):
+ # Credentials are *always* valid.
+ assert self.credentials.valid
+ # Credentials *never* expire.
+ assert not self.credentials.expired
+
+ def test_with_claims(self):
+ new_claims = {"meep": "moop"}
+ new_credentials = self.credentials.with_claims(additional_claims=new_claims)
+
+ assert new_credentials._signer == self.credentials._signer
+ assert new_credentials._issuer == self.credentials._issuer
+ assert new_credentials._subject == self.credentials._subject
+ assert new_credentials._additional_claims == new_claims
+
+ def test_with_quota_project(self):
+ quota_project_id = "project-foo"
+ new_credentials = self.credentials.with_quota_project(quota_project_id)
+
+ assert new_credentials._signer == self.credentials._signer
+ assert new_credentials._issuer == self.credentials._issuer
+ assert new_credentials._subject == self.credentials._subject
+ assert new_credentials._additional_claims == self.credentials._additional_claims
+ assert new_credentials._quota_project_id == quota_project_id
+
+ def test_sign_bytes(self):
+ to_sign = b"123"
+ signature = self.credentials.sign_bytes(to_sign)
+ assert crypt.verify_signature(to_sign, signature, test_jwt.PUBLIC_CERT_BYTES)
+
+ def test_signer(self):
+ assert isinstance(self.credentials.signer, crypt.RSASigner)
+
+ def test_signer_email(self):
+ assert (
+ self.credentials.signer_email
+ == test_jwt.SERVICE_ACCOUNT_INFO["client_email"]
+ )
+
+ def _verify_token(self, token):
+ payload = jwt_async.decode(token, test_jwt.PUBLIC_CERT_BYTES)
+ assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL
+ return payload
+
+ def test_refresh(self):
+ with pytest.raises(exceptions.RefreshError):
+ self.credentials.refresh(None)
+
+ def test_before_request(self):
+ headers = {}
+
+ self.credentials.before_request(
+ None, "GET", "http://example.com?a=1#3", headers
+ )
+
+ _, token = headers["authorization"].split(" ")
+ payload = self._verify_token(token)
+
+ assert payload["aud"] == "http://example.com"
+
+ # Making another request should re-use the same token.
+ self.credentials.before_request(None, "GET", "http://example.com?b=2", headers)
+
+ _, new_token = headers["authorization"].split(" ")
+
+ assert new_token == token
+
+ def test_expired_token(self):
+ self.credentials._cache["audience"] = (
+ mock.sentinel.token,
+ datetime.datetime.min,
+ )
+
+ token = self.credentials._get_jwt_for_audience("audience")
+
+ assert token != mock.sentinel.token
diff --git a/system_tests/__init__.py b/tests_async/transport/__init__.py
similarity index 100%
copy from system_tests/__init__.py
copy to tests_async/transport/__init__.py
diff --git a/tests_async/transport/async_compliance.py b/tests_async/transport/async_compliance.py
new file mode 100644
index 0000000..9c4b173
--- /dev/null
+++ b/tests_async/transport/async_compliance.py
@@ -0,0 +1,133 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+
+import flask
+import pytest
+from pytest_localserver.http import WSGIServer
+from six.moves import http_client
+
+from google.auth import exceptions
+from tests.transport import compliance
+
+
+class RequestResponseTests(object):
+ @pytest.fixture(scope="module")
+ def server(self):
+ """Provides a test HTTP server.
+
+ The test server is automatically created before
+ a test and destroyed at the end. The server is serving a test
+ application that can be used to verify requests.
+ """
+ app = flask.Flask(__name__)
+ app.debug = True
+
+ # pylint: disable=unused-variable
+ # (pylint thinks the flask routes are unusued.)
+ @app.route("/basic")
+ def index():
+ header_value = flask.request.headers.get("x-test-header", "value")
+ headers = {"X-Test-Header": header_value}
+ return "Basic Content", http_client.OK, headers
+
+ @app.route("/server_error")
+ def server_error():
+ return "Error", http_client.INTERNAL_SERVER_ERROR
+
+ @app.route("/wait")
+ def wait():
+ time.sleep(3)
+ return "Waited"
+
+ # pylint: enable=unused-variable
+
+ server = WSGIServer(application=app.wsgi_app)
+ server.start()
+ yield server
+ server.stop()
+
+ @pytest.mark.asyncio
+ async def test_request_basic(self, server):
+ request = self.make_request()
+ response = await request(url=server.url + "/basic", method="GET")
+ assert response.status == http_client.OK
+ assert response.headers["x-test-header"] == "value"
+
+ # Use 13 as this is the length of the data written into the stream.
+
+ data = await response.data.read(13)
+ assert data == b"Basic Content"
+
+ @pytest.mark.asyncio
+ async def test_request_basic_with_http(self, server):
+ request = self.make_with_parameter_request()
+ response = await request(url=server.url + "/basic", method="GET")
+ assert response.status == http_client.OK
+ assert response.headers["x-test-header"] == "value"
+
+ # Use 13 as this is the length of the data written into the stream.
+
+ data = await response.data.read(13)
+ assert data == b"Basic Content"
+
+ @pytest.mark.asyncio
+ async def test_request_with_timeout_success(self, server):
+ request = self.make_request()
+ response = await request(url=server.url + "/basic", method="GET", timeout=2)
+
+ assert response.status == http_client.OK
+ assert response.headers["x-test-header"] == "value"
+
+ data = await response.data.read(13)
+ assert data == b"Basic Content"
+
+ @pytest.mark.asyncio
+ async def test_request_with_timeout_failure(self, server):
+ request = self.make_request()
+
+ with pytest.raises(exceptions.TransportError):
+ await request(url=server.url + "/wait", method="GET", timeout=1)
+
+ @pytest.mark.asyncio
+ async def test_request_headers(self, server):
+ request = self.make_request()
+ response = await request(
+ url=server.url + "/basic",
+ method="GET",
+ headers={"x-test-header": "hello world"},
+ )
+
+ assert response.status == http_client.OK
+ assert response.headers["x-test-header"] == "hello world"
+
+ data = await response.data.read(13)
+ assert data == b"Basic Content"
+
+ @pytest.mark.asyncio
+ async def test_request_error(self, server):
+ request = self.make_request()
+
+ response = await request(url=server.url + "/server_error", method="GET")
+ assert response.status == http_client.INTERNAL_SERVER_ERROR
+ data = await response.data.read(5)
+ assert data == b"Error"
+
+ @pytest.mark.asyncio
+ async def test_connection_error(self):
+ request = self.make_request()
+
+ with pytest.raises(exceptions.TransportError):
+ await request(url="http://{}".format(compliance.NXDOMAIN), method="GET")
diff --git a/tests_async/transport/test_aiohttp_requests.py b/tests_async/transport/test_aiohttp_requests.py
new file mode 100644
index 0000000..10c31db
--- /dev/null
+++ b/tests_async/transport/test_aiohttp_requests.py
@@ -0,0 +1,245 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import aiohttp
+from aioresponses import aioresponses, core
+import mock
+import pytest
+from tests_async.transport import async_compliance
+
+import google.auth._credentials_async
+from google.auth.transport import _aiohttp_requests as aiohttp_requests
+import google.auth.transport._mtls_helper
+
+
+class TestCombinedResponse:
+ @pytest.mark.asyncio
+ async def test__is_compressed(self):
+ response = core.CallbackResult(headers={"Content-Encoding": "gzip"})
+ combined_response = aiohttp_requests._CombinedResponse(response)
+ compressed = combined_response._is_compressed()
+ assert compressed
+
+ def test__is_compressed_not(self):
+ response = core.CallbackResult(headers={"Content-Encoding": "not"})
+ combined_response = aiohttp_requests._CombinedResponse(response)
+ compressed = combined_response._is_compressed()
+ assert not compressed
+
+ @pytest.mark.asyncio
+ async def test_raw_content(self):
+
+ mock_response = mock.AsyncMock()
+ mock_response.content.read.return_value = mock.sentinel.read
+ combined_response = aiohttp_requests._CombinedResponse(response=mock_response)
+ raw_content = await combined_response.raw_content()
+ assert raw_content == mock.sentinel.read
+
+ # Second call to validate the preconfigured path.
+ combined_response._raw_content = mock.sentinel.stored_raw
+ raw_content = await combined_response.raw_content()
+ assert raw_content == mock.sentinel.stored_raw
+
+ @pytest.mark.asyncio
+ async def test_content(self):
+ mock_response = mock.AsyncMock()
+ mock_response.content.read.return_value = mock.sentinel.read
+ combined_response = aiohttp_requests._CombinedResponse(response=mock_response)
+ content = await combined_response.content()
+ assert content == mock.sentinel.read
+
+ @mock.patch(
+ "google.auth.transport._aiohttp_requests.urllib3.response.MultiDecoder.decompress",
+ return_value="decompressed",
+ autospec=True,
+ )
+ @pytest.mark.asyncio
+ async def test_content_compressed(self, urllib3_mock):
+ rm = core.RequestMatch(
+ "url", headers={"Content-Encoding": "gzip"}, payload="compressed"
+ )
+ response = await rm.build_response(core.URL("url"))
+
+ combined_response = aiohttp_requests._CombinedResponse(response=response)
+ content = await combined_response.content()
+
+ urllib3_mock.assert_called_once()
+ assert content == "decompressed"
+
+
+class TestResponse:
+ def test_ctor(self):
+ response = aiohttp_requests._Response(mock.sentinel.response)
+ assert response._response == mock.sentinel.response
+
+ @pytest.mark.asyncio
+ async def test_headers_prop(self):
+ rm = core.RequestMatch("url", headers={"Content-Encoding": "header prop"})
+ mock_response = await rm.build_response(core.URL("url"))
+
+ response = aiohttp_requests._Response(mock_response)
+ assert response.headers["Content-Encoding"] == "header prop"
+
+ @pytest.mark.asyncio
+ async def test_status_prop(self):
+ rm = core.RequestMatch("url", status=123)
+ mock_response = await rm.build_response(core.URL("url"))
+ response = aiohttp_requests._Response(mock_response)
+ assert response.status == 123
+
+ @pytest.mark.asyncio
+ async def test_data_prop(self):
+ mock_response = mock.AsyncMock()
+ mock_response.content.read.return_value = mock.sentinel.read
+ response = aiohttp_requests._Response(mock_response)
+ data = await response.data.read()
+ assert data == mock.sentinel.read
+
+
+class TestRequestResponse(async_compliance.RequestResponseTests):
+ def make_request(self):
+ return aiohttp_requests.Request()
+
+ def make_with_parameter_request(self):
+ http = mock.create_autospec(aiohttp.ClientSession, instance=True)
+ return aiohttp_requests.Request(http)
+
+ def test_timeout(self):
+ http = mock.create_autospec(aiohttp.ClientSession, instance=True)
+ request = aiohttp_requests.Request(http)
+ request(url="http://example.com", method="GET", timeout=5)
+
+
+class CredentialsStub(google.auth._credentials_async.Credentials):
+ def __init__(self, token="token"):
+ super(CredentialsStub, self).__init__()
+ self.token = token
+
+ def apply(self, headers, token=None):
+ headers["authorization"] = self.token
+
+ def refresh(self, request):
+ self.token += "1"
+
+
+class TestAuthorizedSession(object):
+ TEST_URL = "http://example.com/"
+ method = "GET"
+
+ def test_constructor(self):
+ authed_session = aiohttp_requests.AuthorizedSession(mock.sentinel.credentials)
+ assert authed_session.credentials == mock.sentinel.credentials
+
+ def test_constructor_with_auth_request(self):
+ http = mock.create_autospec(aiohttp.ClientSession)
+ auth_request = aiohttp_requests.Request(http)
+
+ authed_session = aiohttp_requests.AuthorizedSession(
+ mock.sentinel.credentials, auth_request=auth_request
+ )
+
+ assert authed_session._auth_request == auth_request
+
+ @pytest.mark.asyncio
+ async def test_request(self):
+ with aioresponses() as mocked:
+ credentials = mock.Mock(wraps=CredentialsStub())
+
+ mocked.get(self.TEST_URL, status=200, body="test")
+ session = aiohttp_requests.AuthorizedSession(credentials)
+ resp = await session.request(
+ "GET",
+ "http://example.com/",
+ headers={"Keep-Alive": "timeout=5, max=1000", "fake": b"bytes"},
+ )
+
+ assert resp.status == 200
+ assert "test" == await resp.text()
+
+ await session.close()
+
+ @pytest.mark.asyncio
+ async def test_ctx(self):
+ with aioresponses() as mocked:
+ credentials = mock.Mock(wraps=CredentialsStub())
+ mocked.get("http://test.example.com", payload=dict(foo="bar"))
+ session = aiohttp_requests.AuthorizedSession(credentials)
+ resp = await session.request("GET", "http://test.example.com")
+ data = await resp.json()
+
+ assert dict(foo="bar") == data
+
+ await session.close()
+
+ @pytest.mark.asyncio
+ async def test_http_headers(self):
+ with aioresponses() as mocked:
+ credentials = mock.Mock(wraps=CredentialsStub())
+ mocked.post(
+ "http://example.com",
+ payload=dict(),
+ headers=dict(connection="keep-alive"),
+ )
+
+ session = aiohttp_requests.AuthorizedSession(credentials)
+ resp = await session.request("POST", "http://example.com")
+
+ assert resp.headers["Connection"] == "keep-alive"
+
+ await session.close()
+
+ @pytest.mark.asyncio
+ async def test_regexp_example(self):
+ with aioresponses() as mocked:
+ credentials = mock.Mock(wraps=CredentialsStub())
+ mocked.get("http://example.com", status=500)
+ mocked.get("http://example.com", status=200)
+
+ session1 = aiohttp_requests.AuthorizedSession(credentials)
+
+ resp1 = await session1.request("GET", "http://example.com")
+ session2 = aiohttp_requests.AuthorizedSession(credentials)
+ resp2 = await session2.request("GET", "http://example.com")
+
+ assert resp1.status == 500
+ assert resp2.status == 200
+
+ await session1.close()
+ await session2.close()
+
+ @pytest.mark.asyncio
+ async def test_request_no_refresh(self):
+ credentials = mock.Mock(wraps=CredentialsStub())
+ with aioresponses() as mocked:
+ mocked.get("http://example.com", status=200)
+ authed_session = aiohttp_requests.AuthorizedSession(credentials)
+ response = await authed_session.request("GET", "http://example.com")
+ assert response.status == 200
+ assert credentials.before_request.called
+ assert not credentials.refresh.called
+
+ await authed_session.close()
+
+ @pytest.mark.asyncio
+ async def test_request_refresh(self):
+ credentials = mock.Mock(wraps=CredentialsStub())
+ with aioresponses() as mocked:
+ mocked.get("http://example.com", status=401)
+ mocked.get("http://example.com", status=200)
+ authed_session = aiohttp_requests.AuthorizedSession(credentials)
+ response = await authed_session.request("GET", "http://example.com")
+ assert credentials.refresh.called
+ assert response.status == 200
+
+ await authed_session.close()