| # Copyright 2021 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS-IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """JWT testing service API implementations in Python.""" |
| |
| import datetime |
| import io |
| import json |
| |
| from typing import Tuple |
| |
| import grpc |
| import tink |
| from tink import cleartext_keyset_handle |
| |
| from tink import jwt |
| |
| from google.protobuf import duration_pb2 |
| from google.protobuf import timestamp_pb2 |
| |
| from protos import testing_api_pb2 |
| from protos import testing_api_pb2_grpc |
| |
| |
| def _to_timestamp_tuple(t: datetime.datetime) -> Tuple[int, int]: |
| if not t.tzinfo: |
| raise ValueError('datetime must have tzinfo') |
| seconds = int(t.timestamp()) |
| nanos = int((t.timestamp() - seconds) * 1e9) |
| return (seconds, nanos) |
| |
| |
| def _from_timestamp_proto( |
| timestamp: timestamp_pb2.Timestamp) -> datetime.datetime: |
| t = timestamp.seconds + (timestamp.nanos / 1e9) |
| return datetime.datetime.fromtimestamp(t, datetime.timezone.utc) |
| |
| |
| def _from_duration_proto( |
| duration: duration_pb2.Duration) -> datetime.timedelta: |
| return datetime.timedelta(seconds=duration.seconds) |
| |
| |
| def raw_jwt_from_proto(proto_raw_jwt: testing_api_pb2.JwtToken) -> jwt.RawJwt: |
| """Converts a proto JwtToken into a jwt.RawJwt.""" |
| type_header = None |
| if proto_raw_jwt.HasField('type_header'): |
| type_header = proto_raw_jwt.type_header.value |
| issuer = None |
| if proto_raw_jwt.HasField('issuer'): |
| issuer = proto_raw_jwt.issuer.value |
| subject = None |
| if proto_raw_jwt.HasField('subject'): |
| subject = proto_raw_jwt.subject.value |
| audiences = list(proto_raw_jwt.audiences) |
| if not audiences: |
| audiences = None |
| jwt_id = None |
| if proto_raw_jwt.HasField('jwt_id'): |
| jwt_id = proto_raw_jwt.jwt_id.value |
| custom_claims = {} |
| for name, claim in proto_raw_jwt.custom_claims.items(): |
| if claim.HasField('null_value'): |
| custom_claims[name] = None |
| elif claim.HasField('number_value'): |
| custom_claims[name] = claim.number_value |
| elif claim.HasField('string_value'): |
| custom_claims[name] = claim.string_value |
| elif claim.HasField('bool_value'): |
| custom_claims[name] = claim.bool_value |
| elif claim.HasField('json_object_value'): |
| custom_claims[name] = json.loads(claim.json_object_value) |
| elif claim.HasField('json_array_value'): |
| custom_claims[name] = json.loads(claim.json_array_value) |
| else: |
| raise ValueError('claim %s has unknown type' % name) |
| expiration = None |
| if proto_raw_jwt.HasField('expiration'): |
| expiration = _from_timestamp_proto(proto_raw_jwt.expiration) |
| not_before = None |
| if proto_raw_jwt.HasField('not_before'): |
| not_before = _from_timestamp_proto(proto_raw_jwt.not_before) |
| issued_at = None |
| if proto_raw_jwt.HasField('issued_at'): |
| issued_at = _from_timestamp_proto(proto_raw_jwt.issued_at) |
| without_expiration = not expiration |
| return jwt.new_raw_jwt( |
| type_header=type_header, |
| issuer=issuer, |
| subject=subject, |
| audiences=audiences, |
| jwt_id=jwt_id, |
| expiration=expiration, |
| without_expiration=without_expiration, |
| not_before=not_before, |
| issued_at=issued_at, |
| custom_claims=custom_claims) |
| |
| |
| def verifiedjwt_to_proto( |
| verified_jwt: jwt.VerifiedJwt) -> testing_api_pb2.JwtToken: |
| """Converts a jwt.VerifiedJwt into a proto JwtToken.""" |
| token = testing_api_pb2.JwtToken() |
| if verified_jwt.has_type_header(): |
| token.type_header.value = verified_jwt.type_header() |
| if verified_jwt.has_issuer(): |
| token.issuer.value = verified_jwt.issuer() |
| if verified_jwt.has_subject(): |
| token.subject.value = verified_jwt.subject() |
| if verified_jwt.has_audiences(): |
| token.audiences.extend(verified_jwt.audiences()) |
| if verified_jwt.has_jwt_id(): |
| token.jwt_id.value = verified_jwt.jwt_id() |
| if verified_jwt.has_expiration(): |
| seconds, nanos = _to_timestamp_tuple(verified_jwt.expiration()) |
| token.expiration.seconds = seconds |
| token.expiration.nanos = nanos |
| if verified_jwt.has_not_before(): |
| seconds, nanos = _to_timestamp_tuple(verified_jwt.not_before()) |
| token.not_before.seconds = seconds |
| token.not_before.nanos = nanos |
| if verified_jwt.has_issued_at(): |
| seconds, nanos = _to_timestamp_tuple(verified_jwt.issued_at()) |
| token.issued_at.seconds = seconds |
| token.issued_at.nanos = nanos |
| for name in verified_jwt.custom_claim_names(): |
| value = verified_jwt.custom_claim(name) |
| if value is None: |
| token.custom_claims[name].null_value = testing_api_pb2.NULL_VALUE |
| elif isinstance(value, bool): |
| token.custom_claims[name].bool_value = value |
| elif isinstance(value, (int, float)): |
| token.custom_claims[name].number_value = value |
| elif isinstance(value, str): |
| token.custom_claims[name].string_value = value |
| elif isinstance(value, dict): |
| token.custom_claims[name].json_object_value = json.dumps(value) |
| elif isinstance(value, list): |
| token.custom_claims[name].json_array_value = json.dumps(value) |
| else: |
| raise ValueError('claim %s has unknown type' % name) |
| return token |
| |
| |
| def validator_from_proto( |
| proto_validator: testing_api_pb2.JwtValidator) -> jwt.JwtValidator: |
| """Converts a proto JwtValidator into a JwtValidator.""" |
| expected_type_header = None |
| if proto_validator.HasField('expected_type_header'): |
| expected_type_header = proto_validator.expected_type_header.value |
| expected_issuer = None |
| if proto_validator.HasField('expected_issuer'): |
| expected_issuer = proto_validator.expected_issuer.value |
| expected_audience = None |
| if proto_validator.HasField('expected_audience'): |
| expected_audience = proto_validator.expected_audience.value |
| fixed_now = None |
| if proto_validator.HasField('now'): |
| fixed_now = _from_timestamp_proto(proto_validator.now) |
| clock_skew = None |
| if proto_validator.HasField('clock_skew'): |
| clock_skew = _from_duration_proto(proto_validator.clock_skew) |
| return jwt.new_validator( |
| expected_type_header=expected_type_header, |
| expected_issuer=expected_issuer, |
| expected_audience=expected_audience, |
| ignore_type_header=proto_validator.ignore_type_header, |
| ignore_issuer=proto_validator.ignore_issuer, |
| ignore_audiences=proto_validator.ignore_audience, |
| allow_missing_expiration=proto_validator.allow_missing_expiration, |
| expect_issued_in_the_past=proto_validator.expect_issued_in_the_past, |
| fixed_now=fixed_now, |
| clock_skew=clock_skew) |
| |
| |
| class JwtServicer(testing_api_pb2_grpc.JwtServicer): |
| """A service for signing and verifying JWTs.""" |
| |
| def CreateJwtMac( |
| self, request: testing_api_pb2.CreationRequest, |
| context: grpc.ServicerContext) -> testing_api_pb2.CreationResponse: |
| """Creates a JwtMac without using it.""" |
| try: |
| keyset_handle = cleartext_keyset_handle.read( |
| tink.BinaryKeysetReader(request.annotated_keyset.serialized_keyset)) |
| keyset_handle.primitive(jwt.JwtMac) |
| return testing_api_pb2.CreationResponse() |
| except tink.TinkError as e: |
| return testing_api_pb2.CreationResponse(err=str(e)) |
| |
| def CreateJwtPublicKeySign( |
| self, request: testing_api_pb2.CreationRequest, |
| context: grpc.ServicerContext) -> testing_api_pb2.CreationResponse: |
| """Creates a JwtPublicKeySign without using it.""" |
| try: |
| keyset_handle = cleartext_keyset_handle.read( |
| tink.BinaryKeysetReader(request.annotated_keyset.serialized_keyset)) |
| keyset_handle.primitive(jwt.JwtPublicKeySign) |
| return testing_api_pb2.CreationResponse() |
| except tink.TinkError as e: |
| return testing_api_pb2.CreationResponse(err=str(e)) |
| |
| def CreateJwtPublicKeyVerify( |
| self, request: testing_api_pb2.CreationRequest, |
| context: grpc.ServicerContext) -> testing_api_pb2.CreationResponse: |
| """Creates a JwtPublicKeyVerify without using it.""" |
| try: |
| keyset_handle = cleartext_keyset_handle.read( |
| tink.BinaryKeysetReader(request.annotated_keyset.serialized_keyset)) |
| keyset_handle.primitive(jwt.JwtPublicKeyVerify) |
| return testing_api_pb2.CreationResponse() |
| except tink.TinkError as e: |
| return testing_api_pb2.CreationResponse(err=str(e)) |
| |
| def ComputeMacAndEncode( |
| self, request: testing_api_pb2.JwtSignRequest, |
| context: grpc.ServicerContext) -> testing_api_pb2.JwtSignResponse: |
| """Computes a MACed compact JWT.""" |
| try: |
| keyset_handle = cleartext_keyset_handle.read( |
| tink.BinaryKeysetReader(request.annotated_keyset.serialized_keyset)) |
| p = keyset_handle.primitive(jwt.JwtMac) |
| raw_jwt = raw_jwt_from_proto(request.raw_jwt) |
| signed_compact_jwt = p.compute_mac_and_encode(raw_jwt) |
| return testing_api_pb2.JwtSignResponse( |
| signed_compact_jwt=signed_compact_jwt) |
| except tink.TinkError as e: |
| return testing_api_pb2.JwtSignResponse(err=str(e)) |
| |
| def VerifyMacAndDecode( |
| self, request: testing_api_pb2.JwtVerifyRequest, |
| context: grpc.ServicerContext) -> testing_api_pb2.JwtVerifyResponse: |
| """Verifies a MAC value.""" |
| try: |
| keyset_handle = cleartext_keyset_handle.read( |
| tink.BinaryKeysetReader(request.annotated_keyset.serialized_keyset)) |
| validator = validator_from_proto(request.validator) |
| p = keyset_handle.primitive(jwt.JwtMac) |
| verified_jwt = p.verify_mac_and_decode(request.signed_compact_jwt, |
| validator) |
| return testing_api_pb2.JwtVerifyResponse( |
| verified_jwt=verifiedjwt_to_proto(verified_jwt)) |
| except tink.TinkError as e: |
| return testing_api_pb2.JwtVerifyResponse(err=str(e)) |
| |
| def PublicKeySignAndEncode( |
| self, request: testing_api_pb2.JwtSignRequest, |
| context: grpc.ServicerContext) -> testing_api_pb2.JwtSignResponse: |
| """Computes a signed compact JWT token.""" |
| try: |
| keyset_handle = cleartext_keyset_handle.read( |
| tink.BinaryKeysetReader(request.annotated_keyset.serialized_keyset)) |
| p = keyset_handle.primitive(jwt.JwtPublicKeySign) |
| raw_jwt = raw_jwt_from_proto(request.raw_jwt) |
| signed_compact_jwt = p.sign_and_encode(raw_jwt) |
| return testing_api_pb2.JwtSignResponse( |
| signed_compact_jwt=signed_compact_jwt) |
| except tink.TinkError as e: |
| return testing_api_pb2.JwtSignResponse(err=str(e)) |
| |
| def PublicKeyVerifyAndDecode( |
| self, request: testing_api_pb2.JwtVerifyRequest, |
| context: grpc.ServicerContext) -> testing_api_pb2.JwtVerifyResponse: |
| """Verifies the validity of the signed compact JWT token.""" |
| try: |
| keyset_handle = cleartext_keyset_handle.read( |
| tink.BinaryKeysetReader(request.annotated_keyset.serialized_keyset)) |
| validator = validator_from_proto(request.validator) |
| p = keyset_handle.primitive(jwt.JwtPublicKeyVerify) |
| verified_jwt = p.verify_and_decode(request.signed_compact_jwt, validator) |
| return testing_api_pb2.JwtVerifyResponse( |
| verified_jwt=verifiedjwt_to_proto(verified_jwt)) |
| except tink.TinkError as e: |
| return testing_api_pb2.JwtVerifyResponse(err=str(e)) |
| |
| def ToJwkSet( |
| self, request: testing_api_pb2.JwtToJwkSetRequest, |
| context: grpc.ServicerContext) -> testing_api_pb2.JwtToJwkSetResponse: |
| """Converts a Tink Keyset with JWT keys into a JWK set.""" |
| try: |
| keyset_handle = cleartext_keyset_handle.read( |
| tink.BinaryKeysetReader(request.keyset)) |
| jwk_set = jwt.jwk_set_from_public_keyset_handle(keyset_handle) |
| return testing_api_pb2.JwtToJwkSetResponse(jwk_set=jwk_set) |
| except tink.TinkError as e: |
| return testing_api_pb2.JwtToJwkSetResponse(err=str(e)) |
| |
| def FromJwkSet( |
| self, request: testing_api_pb2.JwtFromJwkSetRequest, |
| context: grpc.ServicerContext) -> testing_api_pb2.JwtFromJwkSetResponse: |
| """Converts a JWK set into a Tink Keyset.""" |
| try: |
| keyset_handle = jwt.jwk_set_to_public_keyset_handle(request.jwk_set) |
| keyset = io.BytesIO() |
| cleartext_keyset_handle.write( |
| tink.BinaryKeysetWriter(keyset), keyset_handle) |
| return testing_api_pb2.JwtFromJwkSetResponse(keyset=keyset.getvalue()) |
| except tink.TinkError as e: |
| return testing_api_pb2.JwtFromJwkSetResponse(err=str(e)) |