| # Copyright 2014 Google Inc. All rights reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| """Oauth2client tests. |
| |
| Unit tests for service account credentials implemented using RSA. |
| """ |
| |
| import datetime |
| import json |
| import os |
| import tempfile |
| |
| import httplib2 |
| import mock |
| import rsa |
| from six import BytesIO |
| import unittest2 |
| |
| from oauth2client import client |
| from oauth2client import crypt |
| from oauth2client import service_account |
| from .http_mock import HttpMockSequence |
| |
| |
| def data_filename(filename): |
| return os.path.join(os.path.dirname(__file__), 'data', filename) |
| |
| |
| def datafile(filename): |
| with open(data_filename(filename), 'rb') as file_obj: |
| return file_obj.read() |
| |
| |
| class ServiceAccountCredentialsTests(unittest2.TestCase): |
| |
| def setUp(self): |
| self.client_id = '123' |
| self.service_account_email = '[email protected]' |
| self.private_key_id = 'ABCDEF' |
| self.private_key = datafile('pem_from_pkcs12.pem') |
| self.scopes = ['dummy_scope'] |
| self.signer = crypt.Signer.from_string(self.private_key) |
| self.credentials = service_account.ServiceAccountCredentials( |
| self.service_account_email, |
| self.signer, |
| private_key_id=self.private_key_id, |
| client_id=self.client_id, |
| ) |
| |
| def test__to_json_override(self): |
| signer = object() |
| creds = service_account.ServiceAccountCredentials( |
| '[email protected]', signer) |
| self.assertEqual(creds._signer, signer) |
| # Serialize over-ridden data (unrelated to ``creds``). |
| to_serialize = {'unrelated': 'data'} |
| serialized_str = creds._to_json([], to_serialize.copy()) |
| serialized_data = json.loads(serialized_str) |
| expected_serialized = { |
| '_class': 'ServiceAccountCredentials', |
| '_module': 'oauth2client.service_account', |
| 'token_expiry': None, |
| } |
| expected_serialized.update(to_serialize) |
| self.assertEqual(serialized_data, expected_serialized) |
| |
| def test_sign_blob(self): |
| private_key_id, signature = self.credentials.sign_blob('Google') |
| self.assertEqual(self.private_key_id, private_key_id) |
| |
| pub_key = rsa.PublicKey.load_pkcs1_openssl_pem( |
| datafile('publickey_openssl.pem')) |
| |
| self.assertTrue(rsa.pkcs1.verify(b'Google', signature, pub_key)) |
| |
| with self.assertRaises(rsa.pkcs1.VerificationError): |
| rsa.pkcs1.verify(b'Orest', signature, pub_key) |
| with self.assertRaises(rsa.pkcs1.VerificationError): |
| rsa.pkcs1.verify(b'Google', b'bad signature', pub_key) |
| |
| def test_service_account_email(self): |
| self.assertEqual(self.service_account_email, |
| self.credentials.service_account_email) |
| |
| @staticmethod |
| def _from_json_keyfile_name_helper(payload, scopes=None, |
| token_uri=None, revoke_uri=None): |
| filehandle, filename = tempfile.mkstemp() |
| os.close(filehandle) |
| try: |
| with open(filename, 'w') as file_obj: |
| json.dump(payload, file_obj) |
| return ( |
| service_account.ServiceAccountCredentials |
| .from_json_keyfile_name( |
| filename, scopes=scopes, token_uri=token_uri, |
| revoke_uri=revoke_uri)) |
| finally: |
| os.remove(filename) |
| |
| @mock.patch('oauth2client.crypt.Signer.from_string', |
| return_value=object()) |
| def test_from_json_keyfile_name_factory(self, signer_factory): |
| client_id = 'id123' |
| client_email = '[email protected]' |
| private_key_id = 'pkid456' |
| private_key = 's3kr3tz' |
| payload = { |
| 'type': client.SERVICE_ACCOUNT, |
| 'client_id': client_id, |
| 'client_email': client_email, |
| 'private_key_id': private_key_id, |
| 'private_key': private_key, |
| } |
| scopes = ['foo', 'bar'] |
| token_uri = 'baz' |
| revoke_uri = 'qux' |
| base_creds = self._from_json_keyfile_name_helper( |
| payload, scopes=scopes, token_uri=token_uri, revoke_uri=revoke_uri) |
| self.assertEqual(base_creds._signer, signer_factory.return_value) |
| signer_factory.assert_called_once_with(private_key) |
| |
| payload['token_uri'] = token_uri |
| payload['revoke_uri'] = revoke_uri |
| creds_with_uris_from_file = self._from_json_keyfile_name_helper( |
| payload, scopes=scopes) |
| for creds in (base_creds, creds_with_uris_from_file): |
| self.assertIsInstance( |
| creds, service_account.ServiceAccountCredentials) |
| self.assertEqual(creds.client_id, client_id) |
| self.assertEqual(creds._service_account_email, client_email) |
| self.assertEqual(creds._private_key_id, private_key_id) |
| self.assertEqual(creds._private_key_pkcs8_pem, private_key) |
| self.assertEqual(creds._scopes, ' '.join(scopes)) |
| self.assertEqual(creds.token_uri, token_uri) |
| self.assertEqual(creds.revoke_uri, revoke_uri) |
| |
| def test_from_json_keyfile_name_factory_bad_type(self): |
| type_ = 'bad-type' |
| self.assertNotEqual(type_, client.SERVICE_ACCOUNT) |
| payload = {'type': type_} |
| with self.assertRaises(ValueError): |
| self._from_json_keyfile_name_helper(payload) |
| |
| def test_from_json_keyfile_name_factory_missing_field(self): |
| payload = { |
| 'type': client.SERVICE_ACCOUNT, |
| 'client_id': 'my-client', |
| } |
| with self.assertRaises(KeyError): |
| self._from_json_keyfile_name_helper(payload) |
| |
| def _from_p12_keyfile_helper(self, private_key_password=None, scopes='', |
| token_uri=None, revoke_uri=None): |
| service_account_email = '[email protected]' |
| filename = data_filename('privatekey.p12') |
| with open(filename, 'rb') as file_obj: |
| key_contents = file_obj.read() |
| creds_from_filename = ( |
| service_account.ServiceAccountCredentials.from_p12_keyfile( |
| service_account_email, filename, |
| private_key_password=private_key_password, |
| scopes=scopes, token_uri=token_uri, revoke_uri=revoke_uri)) |
| creds_from_file_contents = ( |
| service_account.ServiceAccountCredentials.from_p12_keyfile_buffer( |
| service_account_email, BytesIO(key_contents), |
| private_key_password=private_key_password, |
| scopes=scopes, token_uri=token_uri, revoke_uri=revoke_uri)) |
| for creds in (creds_from_filename, creds_from_file_contents): |
| self.assertIsInstance( |
| creds, service_account.ServiceAccountCredentials) |
| self.assertIsNone(creds.client_id) |
| self.assertEqual(creds._service_account_email, |
| service_account_email) |
| self.assertIsNone(creds._private_key_id) |
| self.assertIsNone(creds._private_key_pkcs8_pem) |
| self.assertEqual(creds._private_key_pkcs12, key_contents) |
| if private_key_password is not None: |
| self.assertEqual(creds._private_key_password, |
| private_key_password) |
| self.assertEqual(creds._scopes, ' '.join(scopes)) |
| self.assertEqual(creds.token_uri, token_uri) |
| self.assertEqual(creds.revoke_uri, revoke_uri) |
| |
| def _p12_not_implemented_helper(self): |
| service_account_email = '[email protected]' |
| filename = data_filename('privatekey.p12') |
| with self.assertRaises(NotImplementedError): |
| service_account.ServiceAccountCredentials.from_p12_keyfile( |
| service_account_email, filename) |
| |
| @mock.patch('oauth2client.crypt.Signer', new=crypt.PyCryptoSigner) |
| def test_from_p12_keyfile_with_pycrypto(self): |
| self._p12_not_implemented_helper() |
| |
| @mock.patch('oauth2client.crypt.Signer', new=crypt.RsaSigner) |
| def test_from_p12_keyfile_with_rsa(self): |
| self._p12_not_implemented_helper() |
| |
| def test_from_p12_keyfile_defaults(self): |
| self._from_p12_keyfile_helper() |
| |
| def test_from_p12_keyfile_explicit(self): |
| password = 'notasecret' |
| self._from_p12_keyfile_helper(private_key_password=password, |
| scopes=['foo', 'bar'], |
| token_uri='baz', revoke_uri='qux') |
| |
| def test_create_scoped_required_without_scopes(self): |
| self.assertTrue(self.credentials.create_scoped_required()) |
| |
| def test_create_scoped_required_with_scopes(self): |
| signer = object() |
| self.credentials = service_account.ServiceAccountCredentials( |
| self.service_account_email, |
| signer, |
| scopes=self.scopes, |
| private_key_id=self.private_key_id, |
| client_id=self.client_id, |
| ) |
| self.assertFalse(self.credentials.create_scoped_required()) |
| |
| def test_create_scoped(self): |
| new_credentials = self.credentials.create_scoped(self.scopes) |
| self.assertNotEqual(self.credentials, new_credentials) |
| self.assertIsInstance(new_credentials, |
| service_account.ServiceAccountCredentials) |
| self.assertEqual('dummy_scope', new_credentials._scopes) |
| |
| def test_create_delegated(self): |
| signer = object() |
| sub = '[email protected]' |
| creds = service_account.ServiceAccountCredentials( |
| '[email protected]', signer) |
| self.assertNotIn('sub', creds._kwargs) |
| delegated_creds = creds.create_delegated(sub) |
| self.assertEqual(delegated_creds._kwargs['sub'], sub) |
| # Make sure the original is unchanged. |
| self.assertNotIn('sub', creds._kwargs) |
| |
| def test_create_delegated_existing_sub(self): |
| signer = object() |
| sub1 = '[email protected]' |
| sub2 = '[email protected]' |
| creds = service_account.ServiceAccountCredentials( |
| '[email protected]', signer, sub=sub1) |
| self.assertEqual(creds._kwargs['sub'], sub1) |
| delegated_creds = creds.create_delegated(sub2) |
| self.assertEqual(delegated_creds._kwargs['sub'], sub2) |
| # Make sure the original is unchanged. |
| self.assertEqual(creds._kwargs['sub'], sub1) |
| |
| @mock.patch('oauth2client.client._UTCNOW') |
| def test_access_token(self, utcnow): |
| # Configure the patch. |
| seconds = 11 |
| NOW = datetime.datetime(1992, 12, 31, second=seconds) |
| utcnow.return_value = NOW |
| |
| # Create a custom credentials with a mock signer. |
| signer = mock.MagicMock() |
| signed_value = b'signed-content' |
| signer.sign = mock.MagicMock(name='sign', |
| return_value=signed_value) |
| credentials = service_account.ServiceAccountCredentials( |
| self.service_account_email, |
| signer, |
| private_key_id=self.private_key_id, |
| client_id=self.client_id, |
| ) |
| |
| # Begin testing. |
| lifetime = 2 # number of seconds in which the token expires |
| EXPIRY_TIME = datetime.datetime(1992, 12, 31, |
| second=seconds + lifetime) |
| |
| token1 = u'first_token' |
| token_response_first = { |
| 'access_token': token1, |
| 'expires_in': lifetime, |
| } |
| token2 = u'second_token' |
| token_response_second = { |
| 'access_token': token2, |
| 'expires_in': lifetime, |
| } |
| http = HttpMockSequence([ |
| ({'status': '200'}, |
| json.dumps(token_response_first).encode('utf-8')), |
| ({'status': '200'}, |
| json.dumps(token_response_second).encode('utf-8')), |
| ]) |
| |
| # Get Access Token, First attempt. |
| self.assertIsNone(credentials.access_token) |
| self.assertFalse(credentials.access_token_expired) |
| self.assertIsNone(credentials.token_expiry) |
| token = credentials.get_access_token(http=http) |
| self.assertEqual(credentials.token_expiry, EXPIRY_TIME) |
| self.assertEqual(token1, token.access_token) |
| self.assertEqual(lifetime, token.expires_in) |
| self.assertEqual(token_response_first, |
| credentials.token_response) |
| # Two utcnow calls are expected: |
| # - get_access_token() -> _do_refresh_request (setting expires in) |
| # - get_access_token() -> _expires_in() |
| expected_utcnow_calls = [mock.call()] * 2 |
| self.assertEqual(expected_utcnow_calls, utcnow.mock_calls) |
| # One call to sign() expected: Actual refresh was needed. |
| self.assertEqual(len(signer.sign.mock_calls), 1) |
| |
| # Get Access Token, Second Attempt (not expired) |
| self.assertEqual(credentials.access_token, token1) |
| self.assertFalse(credentials.access_token_expired) |
| token = credentials.get_access_token(http=http) |
| # Make sure no refresh occurred since the token was not expired. |
| self.assertEqual(token1, token.access_token) |
| self.assertEqual(lifetime, token.expires_in) |
| self.assertEqual(token_response_first, credentials.token_response) |
| # Three more utcnow calls are expected: |
| # - access_token_expired |
| # - get_access_token() -> access_token_expired |
| # - get_access_token -> _expires_in |
| expected_utcnow_calls = [mock.call()] * (2 + 3) |
| self.assertEqual(expected_utcnow_calls, utcnow.mock_calls) |
| # No call to sign() expected: the token was not expired. |
| self.assertEqual(len(signer.sign.mock_calls), 1 + 0) |
| |
| # Get Access Token, Third Attempt (force expiration) |
| self.assertEqual(credentials.access_token, token1) |
| credentials.token_expiry = NOW # Manually force expiry. |
| self.assertTrue(credentials.access_token_expired) |
| token = credentials.get_access_token(http=http) |
| # Make sure refresh occurred since the token was not expired. |
| self.assertEqual(token2, token.access_token) |
| self.assertEqual(lifetime, token.expires_in) |
| self.assertFalse(credentials.access_token_expired) |
| self.assertEqual(token_response_second, |
| credentials.token_response) |
| # Five more utcnow calls are expected: |
| # - access_token_expired |
| # - get_access_token -> access_token_expired |
| # - get_access_token -> _do_refresh_request |
| # - get_access_token -> _expires_in |
| # - access_token_expired |
| expected_utcnow_calls = [mock.call()] * (2 + 3 + 5) |
| self.assertEqual(expected_utcnow_calls, utcnow.mock_calls) |
| # One more call to sign() expected: Actual refresh was needed. |
| self.assertEqual(len(signer.sign.mock_calls), 1 + 0 + 1) |
| |
| self.assertEqual(credentials.access_token, token2) |
| |
| TOKEN_LIFE = service_account._JWTAccessCredentials._MAX_TOKEN_LIFETIME_SECS |
| T1 = 42 |
| T1_DATE = datetime.datetime(1970, 1, 1, second=T1) |
| T1_EXPIRY = T1 + TOKEN_LIFE |
| T1_EXPIRY_DATE = T1_DATE + datetime.timedelta(seconds=TOKEN_LIFE) |
| |
| T2 = T1 + 100 |
| T2_DATE = T1_DATE + datetime.timedelta(seconds=100) |
| T2_EXPIRY = T2 + TOKEN_LIFE |
| T2_EXPIRY_DATE = T2_DATE + datetime.timedelta(seconds=TOKEN_LIFE) |
| |
| T3 = T1 + TOKEN_LIFE + 1 |
| T3_DATE = T1_DATE + datetime.timedelta(seconds=TOKEN_LIFE + 1) |
| T3_EXPIRY = T3 + TOKEN_LIFE |
| T3_EXPIRY_DATE = T3_DATE + datetime.timedelta(seconds=TOKEN_LIFE) |
| |
| |
| class JWTAccessCredentialsTests(unittest2.TestCase): |
| |
| def setUp(self): |
| self.client_id = '123' |
| self.service_account_email = '[email protected]' |
| self.private_key_id = 'ABCDEF' |
| self.private_key = datafile('pem_from_pkcs12.pem') |
| self.signer = crypt.Signer.from_string(self.private_key) |
| self.url = 'https://test.url.com' |
| self.jwt = service_account._JWTAccessCredentials( |
| self.service_account_email, self.signer, |
| private_key_id=self.private_key_id, client_id=self.client_id, |
| additional_claims={'aud': self.url}) |
| |
| @mock.patch('oauth2client.client._UTCNOW') |
| @mock.patch('time.time') |
| def test_get_access_token_no_claims(self, time, utcnow): |
| utcnow.return_value = T1_DATE |
| time.return_value = T1 |
| |
| token_info = self.jwt.get_access_token() |
| payload = crypt.verify_signed_jwt_with_certs( |
| token_info.access_token, |
| {'key': datafile('public_cert.pem')}, audience=self.url) |
| self.assertEqual(payload['iss'], self.service_account_email) |
| self.assertEqual(payload['sub'], self.service_account_email) |
| self.assertEqual(payload['iat'], T1) |
| self.assertEqual(payload['exp'], T1_EXPIRY) |
| self.assertEqual(token_info.expires_in, T1_EXPIRY - T1) |
| |
| # Verify that we vend the same token after 100 seconds |
| utcnow.return_value = T2_DATE |
| token_info = self.jwt.get_access_token() |
| payload = crypt.verify_signed_jwt_with_certs( |
| token_info.access_token, |
| {'key': datafile('public_cert.pem')}, audience=self.url) |
| self.assertEqual(payload['iat'], T1) |
| self.assertEqual(payload['exp'], T1_EXPIRY) |
| self.assertEqual(token_info.expires_in, T1_EXPIRY - T2) |
| |
| # Verify that we vend a new token after _MAX_TOKEN_LIFETIME_SECS |
| utcnow.return_value = T3_DATE |
| time.return_value = T3 |
| token_info = self.jwt.get_access_token() |
| payload = crypt.verify_signed_jwt_with_certs( |
| token_info.access_token, |
| {'key': datafile('public_cert.pem')}, audience=self.url) |
| expires_in = token_info.expires_in |
| self.assertEqual(payload['iat'], T3) |
| self.assertEqual(payload['exp'], T3_EXPIRY) |
| self.assertEqual(expires_in, T3_EXPIRY - T3) |
| |
| @mock.patch('oauth2client.client._UTCNOW') |
| @mock.patch('time.time') |
| def test_get_access_token_additional_claims(self, time, utcnow): |
| utcnow.return_value = T1_DATE |
| time.return_value = T1 |
| |
| token_info = self.jwt.get_access_token( |
| additional_claims={'aud': 'https://test2.url.com', |
| 'sub': '[email protected]' |
| }) |
| payload = crypt.verify_signed_jwt_with_certs( |
| token_info.access_token, |
| {'key': datafile('public_cert.pem')}, |
| audience='https://test2.url.com') |
| expires_in = token_info.expires_in |
| self.assertEqual(payload['iss'], self.service_account_email) |
| self.assertEqual(payload['sub'], '[email protected]') |
| self.assertEqual(payload['iat'], T1) |
| self.assertEqual(payload['exp'], T1_EXPIRY) |
| self.assertEqual(expires_in, T1_EXPIRY - T1) |
| |
| def test_revoke(self): |
| self.jwt.revoke(None) |
| |
| def test_create_scoped_required(self): |
| self.assertTrue(self.jwt.create_scoped_required()) |
| |
| def test_create_scoped(self): |
| self.jwt._private_key_pkcs12 = '' |
| self.jwt._private_key_password = '' |
| |
| new_credentials = self.jwt.create_scoped('dummy_scope') |
| self.assertNotEqual(self.jwt, new_credentials) |
| self.assertIsInstance( |
| new_credentials, service_account.ServiceAccountCredentials) |
| self.assertEqual('dummy_scope', new_credentials._scopes) |
| |
| @mock.patch('oauth2client.client._UTCNOW') |
| @mock.patch('time.time') |
| def test_authorize_success(self, time, utcnow): |
| utcnow.return_value = T1_DATE |
| time.return_value = T1 |
| |
| def mock_request(uri, method='GET', body=None, headers=None, |
| redirections=0, connection_type=None): |
| self.assertEqual(uri, self.url) |
| bearer, token = headers[b'Authorization'].split() |
| payload = crypt.verify_signed_jwt_with_certs( |
| token, |
| {'key': datafile('public_cert.pem')}, |
| audience=self.url) |
| self.assertEqual(payload['iss'], self.service_account_email) |
| self.assertEqual(payload['sub'], self.service_account_email) |
| self.assertEqual(payload['iat'], T1) |
| self.assertEqual(payload['exp'], T1_EXPIRY) |
| self.assertEqual(uri, self.url) |
| self.assertEqual(bearer, b'Bearer') |
| return (httplib2.Response({'status': '200'}), b'') |
| |
| h = httplib2.Http() |
| h.request = mock_request |
| self.jwt.authorize(h) |
| h.request(self.url) |
| |
| # Ensure we use the cached token |
| utcnow.return_value = T2_DATE |
| h.request(self.url) |
| |
| @mock.patch('oauth2client.client._UTCNOW') |
| @mock.patch('time.time') |
| def test_authorize_no_aud(self, time, utcnow): |
| utcnow.return_value = T1_DATE |
| time.return_value = T1 |
| |
| jwt = service_account._JWTAccessCredentials( |
| self.service_account_email, self.signer, |
| private_key_id=self.private_key_id, client_id=self.client_id) |
| |
| def mock_request(uri, method='GET', body=None, headers=None, |
| redirections=0, connection_type=None): |
| self.assertEqual(uri, self.url) |
| bearer, token = headers[b'Authorization'].split() |
| payload = crypt.verify_signed_jwt_with_certs( |
| token, |
| {'key': datafile('public_cert.pem')}, |
| audience=self.url) |
| self.assertEqual(payload['iss'], self.service_account_email) |
| self.assertEqual(payload['sub'], self.service_account_email) |
| self.assertEqual(payload['iat'], T1) |
| self.assertEqual(payload['exp'], T1_EXPIRY) |
| self.assertEqual(uri, self.url) |
| self.assertEqual(bearer, b'Bearer') |
| return httplib2.Response({'status': '200'}), b'' |
| |
| h = httplib2.Http() |
| h.request = mock_request |
| jwt.authorize(h) |
| h.request(self.url) |
| |
| # Ensure we do not cache the token |
| self.assertIsNone(jwt.access_token) |
| |
| @mock.patch('oauth2client.client._UTCNOW') |
| def test_authorize_stale_token(self, utcnow): |
| utcnow.return_value = T1_DATE |
| # Create an initial token |
| h = HttpMockSequence([({'status': '200'}, b''), |
| ({'status': '200'}, b'')]) |
| self.jwt.authorize(h) |
| h.request(self.url) |
| token_1 = self.jwt.access_token |
| |
| # Expire the token |
| utcnow.return_value = T3_DATE |
| h.request(self.url) |
| token_2 = self.jwt.access_token |
| self.assertEquals(self.jwt.token_expiry, T3_EXPIRY_DATE) |
| self.assertNotEqual(token_1, token_2) |
| |
| @mock.patch('oauth2client.client._UTCNOW') |
| def test_authorize_401(self, utcnow): |
| utcnow.return_value = T1_DATE |
| |
| h = HttpMockSequence([ |
| ({'status': '200'}, b''), |
| ({'status': '401'}, b''), |
| ({'status': '200'}, b'')]) |
| self.jwt.authorize(h) |
| h.request(self.url) |
| token_1 = self.jwt.access_token |
| |
| utcnow.return_value = T2_DATE |
| self.assertEquals(h.request(self.url)[0].status, 200) |
| token_2 = self.jwt.access_token |
| # Check the 401 forced a new token |
| self.assertNotEqual(token_1, token_2) |
| |
| @mock.patch('oauth2client.client._UTCNOW') |
| def test_refresh(self, utcnow): |
| utcnow.return_value = T1_DATE |
| token_1 = self.jwt.access_token |
| |
| utcnow.return_value = T2_DATE |
| self.jwt.refresh(None) |
| token_2 = self.jwt.access_token |
| self.assertEquals(self.jwt.token_expiry, T2_EXPIRY_DATE) |
| self.assertNotEqual(token_1, token_2) |