| # Copyright 2022 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| import copy |
| import datetime |
| import json |
| import os |
| from unittest import mock |
| |
| import pytest # type: ignore |
| import requests |
| |
| from google.auth import exceptions |
| from google.auth import jwt |
| import google.auth.transport.requests |
| from google.oauth2 import gdch_credentials |
| from google.oauth2.gdch_credentials import ServiceAccountCredentials |
| |
| |
| class TestServiceAccountCredentials(object): |
| AUDIENCE = "https://service-identity.<Domain>/authenticate" |
| PROJECT = "project_foo" |
| PRIVATE_KEY_ID = "key_foo" |
| NAME = "service_identity_name" |
| CA_CERT_PATH = "/path/to/ca/cert" |
| TOKEN_URI = "https://service-identity.<Domain>/authenticate" |
| |
| JSON_PATH = os.path.join( |
| os.path.dirname(__file__), "..", "data", "gdch_service_account.json" |
| ) |
| with open(JSON_PATH, "rb") as fh: |
| INFO = json.load(fh) |
| |
| def test_with_gdch_audience(self): |
| mock_signer = mock.Mock() |
| creds = ServiceAccountCredentials._from_signer_and_info(mock_signer, self.INFO) |
| assert creds._signer == mock_signer |
| assert creds._service_identity_name == self.NAME |
| assert creds._audience is None |
| assert creds._token_uri == self.TOKEN_URI |
| assert creds._ca_cert_path == self.CA_CERT_PATH |
| |
| new_creds = creds.with_gdch_audience(self.AUDIENCE) |
| assert new_creds._signer == mock_signer |
| assert new_creds._service_identity_name == self.NAME |
| assert new_creds._audience == self.AUDIENCE |
| assert new_creds._token_uri == self.TOKEN_URI |
| assert new_creds._ca_cert_path == self.CA_CERT_PATH |
| |
| def test__create_jwt(self): |
| creds = ServiceAccountCredentials.from_service_account_file(self.JSON_PATH) |
| with mock.patch("google.auth._helpers.utcnow") as utcnow: |
| utcnow.return_value = datetime.datetime.now() |
| jwt_token = creds._create_jwt() |
| header, payload, _, _ = jwt._unverified_decode(jwt_token) |
| |
| expected_iss_sub_value = ( |
| "system:serviceaccount:project_foo:service_identity_name" |
| ) |
| assert isinstance(jwt_token, str) |
| assert header["alg"] == "ES256" |
| assert header["kid"] == self.PRIVATE_KEY_ID |
| assert payload["iss"] == expected_iss_sub_value |
| assert payload["sub"] == expected_iss_sub_value |
| assert payload["aud"] == self.AUDIENCE |
| assert payload["exp"] == (payload["iat"] + 3600) |
| |
| @mock.patch( |
| "google.oauth2.gdch_credentials.ServiceAccountCredentials._create_jwt", |
| autospec=True, |
| ) |
| @mock.patch("google.oauth2._client._token_endpoint_request", autospec=True) |
| def test_refresh(self, token_endpoint_request, create_jwt): |
| creds = ServiceAccountCredentials.from_service_account_info(self.INFO) |
| creds = creds.with_gdch_audience(self.AUDIENCE) |
| req = google.auth.transport.requests.Request() |
| |
| mock_jwt_token = "jwt token" |
| create_jwt.return_value = mock_jwt_token |
| sts_token = "STS token" |
| token_endpoint_request.return_value = { |
| "access_token": sts_token, |
| "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", |
| "token_type": "Bearer", |
| "expires_in": 3600, |
| } |
| |
| creds.refresh(req) |
| |
| token_endpoint_request.assert_called_with( |
| req, |
| self.TOKEN_URI, |
| { |
| "grant_type": gdch_credentials.TOKEN_EXCHANGE_TYPE, |
| "audience": self.AUDIENCE, |
| "requested_token_type": gdch_credentials.ACCESS_TOKEN_TOKEN_TYPE, |
| "subject_token": mock_jwt_token, |
| "subject_token_type": gdch_credentials.SERVICE_ACCOUNT_TOKEN_TYPE, |
| }, |
| access_token=None, |
| use_json=True, |
| verify=self.CA_CERT_PATH, |
| ) |
| assert creds.token == sts_token |
| |
| def test_refresh_wrong_requests_object(self): |
| creds = ServiceAccountCredentials.from_service_account_info(self.INFO) |
| creds = creds.with_gdch_audience(self.AUDIENCE) |
| req = requests.Request() |
| |
| with pytest.raises(exceptions.RefreshError) as excinfo: |
| creds.refresh(req) |
| assert excinfo.match( |
| "request must be a google.auth.transport.requests.Request object" |
| ) |
| |
| def test__from_signer_and_info_wrong_format_version(self): |
| with pytest.raises(ValueError) as excinfo: |
| ServiceAccountCredentials._from_signer_and_info( |
| mock.Mock(), {"format_version": "2"} |
| ) |
| assert excinfo.match("Only format version 1 is supported") |
| |
| def test_from_service_account_info_miss_field(self): |
| for field in [ |
| "format_version", |
| "private_key_id", |
| "private_key", |
| "name", |
| "project", |
| "token_uri", |
| ]: |
| info_with_missing_field = copy.deepcopy(self.INFO) |
| del info_with_missing_field[field] |
| with pytest.raises(ValueError) as excinfo: |
| ServiceAccountCredentials.from_service_account_info( |
| info_with_missing_field |
| ) |
| assert excinfo.match("missing fields") |
| |
| @mock.patch("google.auth._service_account_info.from_filename") |
| def test_from_service_account_file(self, from_filename): |
| mock_signer = mock.Mock() |
| from_filename.return_value = (self.INFO, mock_signer) |
| creds = ServiceAccountCredentials.from_service_account_file(self.JSON_PATH) |
| from_filename.assert_called_with( |
| self.JSON_PATH, |
| require=[ |
| "format_version", |
| "private_key_id", |
| "private_key", |
| "name", |
| "project", |
| "token_uri", |
| ], |
| use_rsa_signer=False, |
| ) |
| assert creds._signer == mock_signer |
| assert creds._service_identity_name == self.NAME |
| assert creds._audience is None |
| assert creds._token_uri == self.TOKEN_URI |
| assert creds._ca_cert_path == self.CA_CERT_PATH |