| // Copyright 2022 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // |
| //////////////////////////////////////////////////////////////////////////////// |
| |
| package jwt |
| |
| import ( |
| "bytes" |
| "fmt" |
| "math/rand" |
| |
| spb "google.golang.org/protobuf/types/known/structpb" |
| "google.golang.org/protobuf/proto" |
| "github.com/google/tink/go/keyset" |
| jepb "github.com/google/tink/go/proto/jwt_ecdsa_go_proto" |
| jrsppb "github.com/google/tink/go/proto/jwt_rsa_ssa_pkcs1_go_proto" |
| jrpsspb "github.com/google/tink/go/proto/jwt_rsa_ssa_pss_go_proto" |
| tinkpb "github.com/google/tink/go/proto/tink_go_proto" |
| ) |
| |
| const ( |
| jwtECDSAPublicKeyType = "type.googleapis.com/google.crypto.tink.JwtEcdsaPublicKey" |
| jwtRSPublicKeyType = "type.googleapis.com/google.crypto.tink.JwtRsaSsaPkcs1PublicKey" |
| jwtPSPublicKeyType = "type.googleapis.com/google.crypto.tink.JwtRsaSsaPssPublicKey" |
| ) |
| |
| func keysetHasID(ks *tinkpb.Keyset, keyID uint32) bool { |
| for _, k := range ks.GetKey() { |
| if k.GetKeyId() == keyID { |
| return true |
| } |
| } |
| return false |
| } |
| |
| func generateUnusedID(ks *tinkpb.Keyset) uint32 { |
| for { |
| keyID := rand.Uint32() |
| if !keysetHasID(ks, keyID) { |
| return keyID |
| } |
| } |
| } |
| |
| func hasItem(s *spb.Struct, name string) bool { |
| if s.GetFields() == nil { |
| return false |
| } |
| _, ok := s.Fields[name] |
| return ok |
| } |
| |
| func stringItem(s *spb.Struct, name string) (string, error) { |
| fields := s.GetFields() |
| if fields == nil { |
| return "", fmt.Errorf("no fields") |
| } |
| val, ok := fields[name] |
| if !ok { |
| return "", fmt.Errorf("field %q not found", name) |
| } |
| r, ok := val.Kind.(*spb.Value_StringValue) |
| if !ok { |
| return "", fmt.Errorf("field %q is not a string", name) |
| } |
| return r.StringValue, nil |
| } |
| |
| func listValue(s *spb.Struct, name string) (*spb.ListValue, error) { |
| fields := s.GetFields() |
| if fields == nil { |
| return nil, fmt.Errorf("empty set") |
| } |
| vals, ok := fields[name] |
| if !ok { |
| return nil, fmt.Errorf("%q not found", name) |
| } |
| list, ok := vals.Kind.(*spb.Value_ListValue) |
| if !ok { |
| return nil, fmt.Errorf("%q is not a list", name) |
| } |
| if list.ListValue == nil || len(list.ListValue.GetValues()) == 0 { |
| return nil, fmt.Errorf("%q list is empty", name) |
| } |
| return list.ListValue, nil |
| } |
| |
| func expectStringItem(s *spb.Struct, name, value string) error { |
| item, err := stringItem(s, name) |
| if err != nil { |
| return err |
| } |
| if item != value { |
| return fmt.Errorf("unexpected value %q for %q", value, name) |
| } |
| return nil |
| } |
| |
| func decodeItem(s *spb.Struct, name string) ([]byte, error) { |
| e, err := stringItem(s, name) |
| if err != nil { |
| return nil, err |
| } |
| return base64Decode(e) |
| } |
| |
| func validateKeyOPSIsVerify(s *spb.Struct) error { |
| if !hasItem(s, "key_ops") { |
| return nil |
| } |
| keyOPSList, err := listValue(s, "key_ops") |
| if err != nil { |
| return err |
| } |
| if len(keyOPSList.GetValues()) != 1 { |
| return fmt.Errorf("key_ops size is not 1") |
| } |
| value, ok := keyOPSList.GetValues()[0].Kind.(*spb.Value_StringValue) |
| if !ok { |
| return fmt.Errorf("key_ops is not a string") |
| } |
| if value.StringValue != "verify" { |
| return fmt.Errorf("key_ops is not equal to [\"verify\"]") |
| } |
| return nil |
| } |
| |
| func validateUseIsSig(s *spb.Struct) error { |
| if !hasItem(s, "use") { |
| return nil |
| } |
| return expectStringItem(s, "use", "sig") |
| } |
| |
| func algorithmPrefix(s *spb.Struct) (string, error) { |
| alg, err := stringItem(s, "alg") |
| if err != nil { |
| return "", err |
| } |
| if len(alg) < 2 { |
| return "", fmt.Errorf("invalid algorithm") |
| } |
| return alg[0:2], nil |
| } |
| |
| var psNameToAlg = map[string]jrpsspb.JwtRsaSsaPssAlgorithm{ |
| "PS256": jrpsspb.JwtRsaSsaPssAlgorithm_PS256, |
| "PS384": jrpsspb.JwtRsaSsaPssAlgorithm_PS384, |
| "PS512": jrpsspb.JwtRsaSsaPssAlgorithm_PS512, |
| } |
| |
| func psPublicKeyDataFromStruct(keyStruct *spb.Struct) (*tinkpb.KeyData, error) { |
| alg, err := stringItem(keyStruct, "alg") |
| if err != nil { |
| return nil, err |
| } |
| algorithm, ok := psNameToAlg[alg] |
| if !ok { |
| return nil, fmt.Errorf("invalid alg header: %q", alg) |
| } |
| rsaPubKey, err := rsaPubKeyFromStruct(keyStruct) |
| if err != nil { |
| return nil, err |
| } |
| jwtPubKey := &jrpsspb.JwtRsaSsaPssPublicKey{ |
| Version: jwtECDSASignerKeyVersion, |
| Algorithm: algorithm, |
| E: rsaPubKey.exponent, |
| N: rsaPubKey.modulus, |
| } |
| if rsaPubKey.customKID != nil { |
| jwtPubKey.CustomKid = &jrpsspb.JwtRsaSsaPssPublicKey_CustomKid{ |
| Value: *rsaPubKey.customKID, |
| } |
| } |
| serializedPubKey, err := proto.Marshal(jwtPubKey) |
| if err != nil { |
| return nil, err |
| } |
| return &tinkpb.KeyData{ |
| TypeUrl: jwtPSPublicKeyType, |
| Value: serializedPubKey, |
| KeyMaterialType: tinkpb.KeyData_ASYMMETRIC_PUBLIC, |
| }, nil |
| } |
| |
| var rsNameToAlg = map[string]jrsppb.JwtRsaSsaPkcs1Algorithm{ |
| "RS256": jrsppb.JwtRsaSsaPkcs1Algorithm_RS256, |
| "RS384": jrsppb.JwtRsaSsaPkcs1Algorithm_RS384, |
| "RS512": jrsppb.JwtRsaSsaPkcs1Algorithm_RS512, |
| } |
| |
| func rsPublicKeyDataFromStruct(keyStruct *spb.Struct) (*tinkpb.KeyData, error) { |
| alg, err := stringItem(keyStruct, "alg") |
| if err != nil { |
| return nil, err |
| } |
| algorithm, ok := rsNameToAlg[alg] |
| if !ok { |
| return nil, fmt.Errorf("invalid alg header: %q", alg) |
| } |
| rsaPubKey, err := rsaPubKeyFromStruct(keyStruct) |
| if err != nil { |
| return nil, err |
| } |
| jwtPubKey := &jrsppb.JwtRsaSsaPkcs1PublicKey{ |
| Version: 0, |
| Algorithm: algorithm, |
| E: rsaPubKey.exponent, |
| N: rsaPubKey.modulus, |
| } |
| if rsaPubKey.customKID != nil { |
| jwtPubKey.CustomKid = &jrsppb.JwtRsaSsaPkcs1PublicKey_CustomKid{ |
| Value: *rsaPubKey.customKID, |
| } |
| } |
| serializedPubKey, err := proto.Marshal(jwtPubKey) |
| if err != nil { |
| return nil, err |
| } |
| return &tinkpb.KeyData{ |
| TypeUrl: jwtRSPublicKeyType, |
| Value: serializedPubKey, |
| KeyMaterialType: tinkpb.KeyData_ASYMMETRIC_PUBLIC, |
| }, nil |
| } |
| |
| type rsaPubKey struct { |
| exponent []byte |
| modulus []byte |
| customKID *string |
| } |
| |
| func rsaPubKeyFromStruct(keyStruct *spb.Struct) (*rsaPubKey, error) { |
| if hasItem(keyStruct, "p") || |
| hasItem(keyStruct, "q") || |
| hasItem(keyStruct, "dq") || |
| hasItem(keyStruct, "dp") || |
| hasItem(keyStruct, "d") || |
| hasItem(keyStruct, "qi") { |
| return nil, fmt.Errorf("private key can't be converted") |
| } |
| if err := expectStringItem(keyStruct, "kty", "RSA"); err != nil { |
| return nil, err |
| } |
| if err := validateUseIsSig(keyStruct); err != nil { |
| return nil, err |
| } |
| if err := validateKeyOPSIsVerify(keyStruct); err != nil { |
| return nil, err |
| } |
| e, err := decodeItem(keyStruct, "e") |
| if err != nil { |
| return nil, err |
| } |
| n, err := decodeItem(keyStruct, "n") |
| if err != nil { |
| return nil, err |
| } |
| var customKID *string = nil |
| if hasItem(keyStruct, "kid") { |
| kid, err := stringItem(keyStruct, "kid") |
| if err != nil { |
| return nil, err |
| } |
| customKID = &kid |
| } |
| return &rsaPubKey{ |
| exponent: e, |
| modulus: n, |
| customKID: customKID, |
| }, nil |
| } |
| |
| func esPublicKeyDataFromStruct(keyStruct *spb.Struct) (*tinkpb.KeyData, error) { |
| alg, err := stringItem(keyStruct, "alg") |
| if err != nil { |
| return nil, err |
| } |
| curve, err := stringItem(keyStruct, "crv") |
| if err != nil { |
| return nil, err |
| } |
| var algorithm jepb.JwtEcdsaAlgorithm = jepb.JwtEcdsaAlgorithm_ES_UNKNOWN |
| if alg == "ES256" && curve == "P-256" { |
| algorithm = jepb.JwtEcdsaAlgorithm_ES256 |
| } |
| if alg == "ES384" && curve == "P-384" { |
| algorithm = jepb.JwtEcdsaAlgorithm_ES384 |
| } |
| if alg == "ES512" && curve == "P-521" { |
| algorithm = jepb.JwtEcdsaAlgorithm_ES512 |
| } |
| if algorithm == jepb.JwtEcdsaAlgorithm_ES_UNKNOWN { |
| return nil, fmt.Errorf("invalid algorithm %q and curve %q", alg, curve) |
| } |
| if hasItem(keyStruct, "d") { |
| return nil, fmt.Errorf("private keys cannot be converted") |
| } |
| if err := expectStringItem(keyStruct, "kty", "EC"); err != nil { |
| return nil, err |
| } |
| if err := validateUseIsSig(keyStruct); err != nil { |
| return nil, err |
| } |
| if err := validateKeyOPSIsVerify(keyStruct); err != nil { |
| return nil, err |
| } |
| x, err := decodeItem(keyStruct, "x") |
| if err != nil { |
| return nil, fmt.Errorf("failed to decode x: %v", err) |
| } |
| y, err := decodeItem(keyStruct, "y") |
| if err != nil { |
| return nil, fmt.Errorf("failed to decode y: %v", err) |
| } |
| var customKID *jepb.JwtEcdsaPublicKey_CustomKid = nil |
| if hasItem(keyStruct, "kid") { |
| kid, err := stringItem(keyStruct, "kid") |
| if err != nil { |
| return nil, err |
| } |
| customKID = &jepb.JwtEcdsaPublicKey_CustomKid{Value: kid} |
| } |
| pubKey := &jepb.JwtEcdsaPublicKey{ |
| Version: 0, |
| Algorithm: algorithm, |
| X: x, |
| Y: y, |
| CustomKid: customKID, |
| } |
| serializedPubKey, err := proto.Marshal(pubKey) |
| if err != nil { |
| return nil, err |
| } |
| return &tinkpb.KeyData{ |
| TypeUrl: jwtECDSAPublicKeyType, |
| Value: serializedPubKey, |
| KeyMaterialType: tinkpb.KeyData_ASYMMETRIC_PUBLIC, |
| }, nil |
| } |
| |
| func keysetKeyFromStruct(val *spb.Value, keyID uint32) (*tinkpb.Keyset_Key, error) { |
| keyStruct := val.GetStructValue() |
| if keyStruct == nil { |
| return nil, fmt.Errorf("key is not a JSON object") |
| } |
| algPrefix, err := algorithmPrefix(keyStruct) |
| if err != nil { |
| return nil, err |
| } |
| var keyData *tinkpb.KeyData |
| switch algPrefix { |
| case "ES": |
| keyData, err = esPublicKeyDataFromStruct(keyStruct) |
| case "RS": |
| keyData, err = rsPublicKeyDataFromStruct(keyStruct) |
| case "PS": |
| keyData, err = psPublicKeyDataFromStruct(keyStruct) |
| default: |
| return nil, fmt.Errorf("unsupported algorithm prefix: %v", algPrefix) |
| } |
| if err != nil { |
| return nil, err |
| } |
| return &tinkpb.Keyset_Key{ |
| KeyData: keyData, |
| Status: tinkpb.KeyStatusType_ENABLED, |
| OutputPrefixType: tinkpb.OutputPrefixType_RAW, |
| KeyId: keyID, |
| }, nil |
| } |
| |
| // JWKSetToPublicKeysetHandle converts a Json Web Key (JWK) set into a Tink KeysetHandle. |
| // It requires that all keys in the set have the "alg" field set. Currently, only |
| // public keys for algorithms ES256, ES384, ES512, RS256, RS384, and RS512 are supported. |
| // JWK is defined in https://www.rfc-editor.org/rfc/rfc7517.txt. |
| func JWKSetToPublicKeysetHandle(jwkSet []byte) (*keyset.Handle, error) { |
| jwk := &spb.Struct{} |
| if err := jwk.UnmarshalJSON(jwkSet); err != nil { |
| return nil, err |
| } |
| keyList, err := listValue(jwk, "keys") |
| if err != nil { |
| return nil, err |
| } |
| |
| ks := &tinkpb.Keyset{} |
| for _, keyStruct := range keyList.GetValues() { |
| key, err := keysetKeyFromStruct(keyStruct, generateUnusedID(ks)) |
| if err != nil { |
| return nil, err |
| } |
| ks.Key = append(ks.Key, key) |
| } |
| ks.PrimaryKeyId = ks.Key[len(ks.Key)-1].GetKeyId() |
| return keyset.NewHandleWithNoSecrets(ks) |
| } |
| |
| func addKeyOPSVerify(s *spb.Struct) { |
| s.GetFields()["key_ops"] = spb.NewListValue(&spb.ListValue{Values: []*spb.Value{spb.NewStringValue("verify")}}) |
| } |
| |
| func addStringEntry(s *spb.Struct, key, val string) { |
| s.GetFields()[key] = spb.NewStringValue(val) |
| } |
| |
| var psAlgToStr map[jrpsspb.JwtRsaSsaPssAlgorithm]string = map[jrpsspb.JwtRsaSsaPssAlgorithm]string{ |
| jrpsspb.JwtRsaSsaPssAlgorithm_PS256: "PS256", |
| jrpsspb.JwtRsaSsaPssAlgorithm_PS384: "PS384", |
| jrpsspb.JwtRsaSsaPssAlgorithm_PS512: "PS512", |
| } |
| |
| func psPublicKeyToStruct(key *tinkpb.Keyset_Key) (*spb.Struct, error) { |
| pubKey := &jrpsspb.JwtRsaSsaPssPublicKey{} |
| if err := proto.Unmarshal(key.GetKeyData().GetValue(), pubKey); err != nil { |
| return nil, err |
| } |
| alg, ok := psAlgToStr[pubKey.GetAlgorithm()] |
| if !ok { |
| return nil, fmt.Errorf("invalid algorithm") |
| } |
| outKey := &spb.Struct{ |
| Fields: map[string]*spb.Value{}, |
| } |
| addStringEntry(outKey, "alg", alg) |
| addStringEntry(outKey, "kty", "RSA") |
| addStringEntry(outKey, "e", base64Encode(pubKey.GetE())) |
| addStringEntry(outKey, "n", base64Encode(pubKey.GetN())) |
| addStringEntry(outKey, "use", "sig") |
| addKeyOPSVerify(outKey) |
| var customKID *string = nil |
| if pubKey.GetCustomKid() != nil { |
| ck := pubKey.GetCustomKid().GetValue() |
| customKID = &ck |
| } |
| if err := setKeyID(outKey, key, customKID); err != nil { |
| return nil, err |
| } |
| return outKey, nil |
| } |
| |
| var rsAlgToStr map[jrsppb.JwtRsaSsaPkcs1Algorithm]string = map[jrsppb.JwtRsaSsaPkcs1Algorithm]string{ |
| jrsppb.JwtRsaSsaPkcs1Algorithm_RS256: "RS256", |
| jrsppb.JwtRsaSsaPkcs1Algorithm_RS384: "RS384", |
| jrsppb.JwtRsaSsaPkcs1Algorithm_RS512: "RS512", |
| } |
| |
| func rsPublicKeyToStruct(key *tinkpb.Keyset_Key) (*spb.Struct, error) { |
| pubKey := &jrsppb.JwtRsaSsaPkcs1PublicKey{} |
| if err := proto.Unmarshal(key.GetKeyData().GetValue(), pubKey); err != nil { |
| return nil, err |
| } |
| alg, ok := rsAlgToStr[pubKey.GetAlgorithm()] |
| if !ok { |
| return nil, fmt.Errorf("invalid algorithm") |
| } |
| outKey := &spb.Struct{ |
| Fields: map[string]*spb.Value{}, |
| } |
| addStringEntry(outKey, "alg", alg) |
| addStringEntry(outKey, "kty", "RSA") |
| addStringEntry(outKey, "e", base64Encode(pubKey.GetE())) |
| addStringEntry(outKey, "n", base64Encode(pubKey.GetN())) |
| addStringEntry(outKey, "use", "sig") |
| addKeyOPSVerify(outKey) |
| |
| var customKID *string = nil |
| if pubKey.GetCustomKid() != nil { |
| ck := pubKey.GetCustomKid().GetValue() |
| customKID = &ck |
| } |
| if err := setKeyID(outKey, key, customKID); err != nil { |
| return nil, err |
| } |
| return outKey, nil |
| } |
| |
| func esPublicKeyToStruct(key *tinkpb.Keyset_Key) (*spb.Struct, error) { |
| pubKey := &jepb.JwtEcdsaPublicKey{} |
| if err := proto.Unmarshal(key.GetKeyData().GetValue(), pubKey); err != nil { |
| return nil, err |
| } |
| outKey := &spb.Struct{ |
| Fields: map[string]*spb.Value{}, |
| } |
| var algorithm, curve string |
| switch pubKey.GetAlgorithm() { |
| case jepb.JwtEcdsaAlgorithm_ES256: |
| curve, algorithm = "P-256", "ES256" |
| case jepb.JwtEcdsaAlgorithm_ES384: |
| curve, algorithm = "P-384", "ES384" |
| case jepb.JwtEcdsaAlgorithm_ES512: |
| curve, algorithm = "P-521", "ES512" |
| default: |
| return nil, fmt.Errorf("invalid algorithm") |
| } |
| addStringEntry(outKey, "crv", curve) |
| addStringEntry(outKey, "alg", algorithm) |
| addStringEntry(outKey, "kty", "EC") |
| addStringEntry(outKey, "x", base64Encode(pubKey.GetX())) |
| addStringEntry(outKey, "y", base64Encode(pubKey.GetY())) |
| addStringEntry(outKey, "use", "sig") |
| addKeyOPSVerify(outKey) |
| |
| var customKID *string = nil |
| if pubKey.GetCustomKid() != nil { |
| ck := pubKey.GetCustomKid().GetValue() |
| customKID = &ck |
| } |
| if err := setKeyID(outKey, key, customKID); err != nil { |
| return nil, err |
| } |
| return outKey, nil |
| } |
| |
| func setKeyID(outKey *spb.Struct, key *tinkpb.Keyset_Key, customKID *string) error { |
| if key.GetOutputPrefixType() == tinkpb.OutputPrefixType_TINK { |
| if customKID != nil { |
| return fmt.Errorf("TINK keys shouldn't have custom KID") |
| } |
| kid := keyID(key.KeyId, key.GetOutputPrefixType()) |
| if kid == nil { |
| return fmt.Errorf("tink KID shouldn't be nil") |
| } |
| addStringEntry(outKey, "kid", *kid) |
| } else if customKID != nil { |
| addStringEntry(outKey, "kid", *customKID) |
| } |
| return nil |
| } |
| |
| // JWKSetFromPublicKeysetHandle converts a Tink KeysetHandle with JWT keys into a Json Web Key (JWK) set. |
| // Currently only public keys for algorithms ES256, ES384, ES512, RS256, RS384, and RS512 are supported. |
| // JWK is defined in https://www.rfc-editor.org/rfc/rfc7517.html. |
| func JWKSetFromPublicKeysetHandle(kh *keyset.Handle) ([]byte, error) { |
| b := &bytes.Buffer{} |
| if err := kh.WriteWithNoSecrets(keyset.NewBinaryWriter(b)); err != nil { |
| return nil, err |
| } |
| ks := &tinkpb.Keyset{} |
| if err := proto.Unmarshal(b.Bytes(), ks); err != nil { |
| return nil, err |
| } |
| keyValList := []*spb.Value{} |
| for _, k := range ks.Key { |
| if k.GetStatus() != tinkpb.KeyStatusType_ENABLED { |
| continue |
| } |
| if k.GetOutputPrefixType() != tinkpb.OutputPrefixType_TINK && |
| k.GetOutputPrefixType() != tinkpb.OutputPrefixType_RAW { |
| return nil, fmt.Errorf("unsupported output prefix type") |
| } |
| keyData := k.GetKeyData() |
| if keyData == nil { |
| return nil, fmt.Errorf("invalid key data") |
| } |
| if keyData.GetKeyMaterialType() != tinkpb.KeyData_ASYMMETRIC_PUBLIC { |
| return nil, fmt.Errorf("only asymmetric public keys are supported") |
| } |
| keyStruct := &spb.Struct{} |
| var err error |
| switch keyData.GetTypeUrl() { |
| case jwtECDSAPublicKeyType: |
| keyStruct, err = esPublicKeyToStruct(k) |
| case jwtRSPublicKeyType: |
| keyStruct, err = rsPublicKeyToStruct(k) |
| case jwtPSPublicKeyType: |
| keyStruct, err = psPublicKeyToStruct(k) |
| default: |
| return nil, fmt.Errorf("unsupported key type url") |
| } |
| if err != nil { |
| return nil, err |
| } |
| keyValList = append(keyValList, spb.NewStructValue(keyStruct)) |
| } |
| output := &spb.Struct{ |
| Fields: map[string]*spb.Value{ |
| "keys": spb.NewListValue(&spb.ListValue{Values: keyValList}), |
| }, |
| } |
| return output.MarshalJSON() |
| } |