blob: 77454110cf41d130d2b5a0999670ecd62eeb7d69 [file] [log] [blame]
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
////////////////////////////////////////////////////////////////////////////////
package jwt
import (
"bytes"
"fmt"
"math/rand"
spb "google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/proto"
"github.com/google/tink/go/keyset"
jepb "github.com/google/tink/go/proto/jwt_ecdsa_go_proto"
jrsppb "github.com/google/tink/go/proto/jwt_rsa_ssa_pkcs1_go_proto"
jrpsspb "github.com/google/tink/go/proto/jwt_rsa_ssa_pss_go_proto"
tinkpb "github.com/google/tink/go/proto/tink_go_proto"
)
const (
jwtECDSAPublicKeyType = "type.googleapis.com/google.crypto.tink.JwtEcdsaPublicKey"
jwtRSPublicKeyType = "type.googleapis.com/google.crypto.tink.JwtRsaSsaPkcs1PublicKey"
jwtPSPublicKeyType = "type.googleapis.com/google.crypto.tink.JwtRsaSsaPssPublicKey"
)
func keysetHasID(ks *tinkpb.Keyset, keyID uint32) bool {
for _, k := range ks.GetKey() {
if k.GetKeyId() == keyID {
return true
}
}
return false
}
func generateUnusedID(ks *tinkpb.Keyset) uint32 {
for {
keyID := rand.Uint32()
if !keysetHasID(ks, keyID) {
return keyID
}
}
}
func hasItem(s *spb.Struct, name string) bool {
if s.GetFields() == nil {
return false
}
_, ok := s.Fields[name]
return ok
}
func stringItem(s *spb.Struct, name string) (string, error) {
fields := s.GetFields()
if fields == nil {
return "", fmt.Errorf("no fields")
}
val, ok := fields[name]
if !ok {
return "", fmt.Errorf("field %q not found", name)
}
r, ok := val.Kind.(*spb.Value_StringValue)
if !ok {
return "", fmt.Errorf("field %q is not a string", name)
}
return r.StringValue, nil
}
func listValue(s *spb.Struct, name string) (*spb.ListValue, error) {
fields := s.GetFields()
if fields == nil {
return nil, fmt.Errorf("empty set")
}
vals, ok := fields[name]
if !ok {
return nil, fmt.Errorf("%q not found", name)
}
list, ok := vals.Kind.(*spb.Value_ListValue)
if !ok {
return nil, fmt.Errorf("%q is not a list", name)
}
if list.ListValue == nil || len(list.ListValue.GetValues()) == 0 {
return nil, fmt.Errorf("%q list is empty", name)
}
return list.ListValue, nil
}
func expectStringItem(s *spb.Struct, name, value string) error {
item, err := stringItem(s, name)
if err != nil {
return err
}
if item != value {
return fmt.Errorf("unexpected value %q for %q", value, name)
}
return nil
}
func decodeItem(s *spb.Struct, name string) ([]byte, error) {
e, err := stringItem(s, name)
if err != nil {
return nil, err
}
return base64Decode(e)
}
func validateKeyOPSIsVerify(s *spb.Struct) error {
if !hasItem(s, "key_ops") {
return nil
}
keyOPSList, err := listValue(s, "key_ops")
if err != nil {
return err
}
if len(keyOPSList.GetValues()) != 1 {
return fmt.Errorf("key_ops size is not 1")
}
value, ok := keyOPSList.GetValues()[0].Kind.(*spb.Value_StringValue)
if !ok {
return fmt.Errorf("key_ops is not a string")
}
if value.StringValue != "verify" {
return fmt.Errorf("key_ops is not equal to [\"verify\"]")
}
return nil
}
func validateUseIsSig(s *spb.Struct) error {
if !hasItem(s, "use") {
return nil
}
return expectStringItem(s, "use", "sig")
}
func algorithmPrefix(s *spb.Struct) (string, error) {
alg, err := stringItem(s, "alg")
if err != nil {
return "", err
}
if len(alg) < 2 {
return "", fmt.Errorf("invalid algorithm")
}
return alg[0:2], nil
}
var psNameToAlg = map[string]jrpsspb.JwtRsaSsaPssAlgorithm{
"PS256": jrpsspb.JwtRsaSsaPssAlgorithm_PS256,
"PS384": jrpsspb.JwtRsaSsaPssAlgorithm_PS384,
"PS512": jrpsspb.JwtRsaSsaPssAlgorithm_PS512,
}
func psPublicKeyDataFromStruct(keyStruct *spb.Struct) (*tinkpb.KeyData, error) {
alg, err := stringItem(keyStruct, "alg")
if err != nil {
return nil, err
}
algorithm, ok := psNameToAlg[alg]
if !ok {
return nil, fmt.Errorf("invalid alg header: %q", alg)
}
rsaPubKey, err := rsaPubKeyFromStruct(keyStruct)
if err != nil {
return nil, err
}
jwtPubKey := &jrpsspb.JwtRsaSsaPssPublicKey{
Version: jwtECDSASignerKeyVersion,
Algorithm: algorithm,
E: rsaPubKey.exponent,
N: rsaPubKey.modulus,
}
if rsaPubKey.customKID != nil {
jwtPubKey.CustomKid = &jrpsspb.JwtRsaSsaPssPublicKey_CustomKid{
Value: *rsaPubKey.customKID,
}
}
serializedPubKey, err := proto.Marshal(jwtPubKey)
if err != nil {
return nil, err
}
return &tinkpb.KeyData{
TypeUrl: jwtPSPublicKeyType,
Value: serializedPubKey,
KeyMaterialType: tinkpb.KeyData_ASYMMETRIC_PUBLIC,
}, nil
}
var rsNameToAlg = map[string]jrsppb.JwtRsaSsaPkcs1Algorithm{
"RS256": jrsppb.JwtRsaSsaPkcs1Algorithm_RS256,
"RS384": jrsppb.JwtRsaSsaPkcs1Algorithm_RS384,
"RS512": jrsppb.JwtRsaSsaPkcs1Algorithm_RS512,
}
func rsPublicKeyDataFromStruct(keyStruct *spb.Struct) (*tinkpb.KeyData, error) {
alg, err := stringItem(keyStruct, "alg")
if err != nil {
return nil, err
}
algorithm, ok := rsNameToAlg[alg]
if !ok {
return nil, fmt.Errorf("invalid alg header: %q", alg)
}
rsaPubKey, err := rsaPubKeyFromStruct(keyStruct)
if err != nil {
return nil, err
}
jwtPubKey := &jrsppb.JwtRsaSsaPkcs1PublicKey{
Version: 0,
Algorithm: algorithm,
E: rsaPubKey.exponent,
N: rsaPubKey.modulus,
}
if rsaPubKey.customKID != nil {
jwtPubKey.CustomKid = &jrsppb.JwtRsaSsaPkcs1PublicKey_CustomKid{
Value: *rsaPubKey.customKID,
}
}
serializedPubKey, err := proto.Marshal(jwtPubKey)
if err != nil {
return nil, err
}
return &tinkpb.KeyData{
TypeUrl: jwtRSPublicKeyType,
Value: serializedPubKey,
KeyMaterialType: tinkpb.KeyData_ASYMMETRIC_PUBLIC,
}, nil
}
type rsaPubKey struct {
exponent []byte
modulus []byte
customKID *string
}
func rsaPubKeyFromStruct(keyStruct *spb.Struct) (*rsaPubKey, error) {
if hasItem(keyStruct, "p") ||
hasItem(keyStruct, "q") ||
hasItem(keyStruct, "dq") ||
hasItem(keyStruct, "dp") ||
hasItem(keyStruct, "d") ||
hasItem(keyStruct, "qi") {
return nil, fmt.Errorf("private key can't be converted")
}
if err := expectStringItem(keyStruct, "kty", "RSA"); err != nil {
return nil, err
}
if err := validateUseIsSig(keyStruct); err != nil {
return nil, err
}
if err := validateKeyOPSIsVerify(keyStruct); err != nil {
return nil, err
}
e, err := decodeItem(keyStruct, "e")
if err != nil {
return nil, err
}
n, err := decodeItem(keyStruct, "n")
if err != nil {
return nil, err
}
var customKID *string = nil
if hasItem(keyStruct, "kid") {
kid, err := stringItem(keyStruct, "kid")
if err != nil {
return nil, err
}
customKID = &kid
}
return &rsaPubKey{
exponent: e,
modulus: n,
customKID: customKID,
}, nil
}
func esPublicKeyDataFromStruct(keyStruct *spb.Struct) (*tinkpb.KeyData, error) {
alg, err := stringItem(keyStruct, "alg")
if err != nil {
return nil, err
}
curve, err := stringItem(keyStruct, "crv")
if err != nil {
return nil, err
}
var algorithm jepb.JwtEcdsaAlgorithm = jepb.JwtEcdsaAlgorithm_ES_UNKNOWN
if alg == "ES256" && curve == "P-256" {
algorithm = jepb.JwtEcdsaAlgorithm_ES256
}
if alg == "ES384" && curve == "P-384" {
algorithm = jepb.JwtEcdsaAlgorithm_ES384
}
if alg == "ES512" && curve == "P-521" {
algorithm = jepb.JwtEcdsaAlgorithm_ES512
}
if algorithm == jepb.JwtEcdsaAlgorithm_ES_UNKNOWN {
return nil, fmt.Errorf("invalid algorithm %q and curve %q", alg, curve)
}
if hasItem(keyStruct, "d") {
return nil, fmt.Errorf("private keys cannot be converted")
}
if err := expectStringItem(keyStruct, "kty", "EC"); err != nil {
return nil, err
}
if err := validateUseIsSig(keyStruct); err != nil {
return nil, err
}
if err := validateKeyOPSIsVerify(keyStruct); err != nil {
return nil, err
}
x, err := decodeItem(keyStruct, "x")
if err != nil {
return nil, fmt.Errorf("failed to decode x: %v", err)
}
y, err := decodeItem(keyStruct, "y")
if err != nil {
return nil, fmt.Errorf("failed to decode y: %v", err)
}
var customKID *jepb.JwtEcdsaPublicKey_CustomKid = nil
if hasItem(keyStruct, "kid") {
kid, err := stringItem(keyStruct, "kid")
if err != nil {
return nil, err
}
customKID = &jepb.JwtEcdsaPublicKey_CustomKid{Value: kid}
}
pubKey := &jepb.JwtEcdsaPublicKey{
Version: 0,
Algorithm: algorithm,
X: x,
Y: y,
CustomKid: customKID,
}
serializedPubKey, err := proto.Marshal(pubKey)
if err != nil {
return nil, err
}
return &tinkpb.KeyData{
TypeUrl: jwtECDSAPublicKeyType,
Value: serializedPubKey,
KeyMaterialType: tinkpb.KeyData_ASYMMETRIC_PUBLIC,
}, nil
}
func keysetKeyFromStruct(val *spb.Value, keyID uint32) (*tinkpb.Keyset_Key, error) {
keyStruct := val.GetStructValue()
if keyStruct == nil {
return nil, fmt.Errorf("key is not a JSON object")
}
algPrefix, err := algorithmPrefix(keyStruct)
if err != nil {
return nil, err
}
var keyData *tinkpb.KeyData
switch algPrefix {
case "ES":
keyData, err = esPublicKeyDataFromStruct(keyStruct)
case "RS":
keyData, err = rsPublicKeyDataFromStruct(keyStruct)
case "PS":
keyData, err = psPublicKeyDataFromStruct(keyStruct)
default:
return nil, fmt.Errorf("unsupported algorithm prefix: %v", algPrefix)
}
if err != nil {
return nil, err
}
return &tinkpb.Keyset_Key{
KeyData: keyData,
Status: tinkpb.KeyStatusType_ENABLED,
OutputPrefixType: tinkpb.OutputPrefixType_RAW,
KeyId: keyID,
}, nil
}
// JWKSetToPublicKeysetHandle converts a Json Web Key (JWK) set into a Tink KeysetHandle.
// It requires that all keys in the set have the "alg" field set. Currently, only
// public keys for algorithms ES256, ES384, ES512, RS256, RS384, and RS512 are supported.
// JWK is defined in https://www.rfc-editor.org/rfc/rfc7517.txt.
func JWKSetToPublicKeysetHandle(jwkSet []byte) (*keyset.Handle, error) {
jwk := &spb.Struct{}
if err := jwk.UnmarshalJSON(jwkSet); err != nil {
return nil, err
}
keyList, err := listValue(jwk, "keys")
if err != nil {
return nil, err
}
ks := &tinkpb.Keyset{}
for _, keyStruct := range keyList.GetValues() {
key, err := keysetKeyFromStruct(keyStruct, generateUnusedID(ks))
if err != nil {
return nil, err
}
ks.Key = append(ks.Key, key)
}
ks.PrimaryKeyId = ks.Key[len(ks.Key)-1].GetKeyId()
return keyset.NewHandleWithNoSecrets(ks)
}
func addKeyOPSVerify(s *spb.Struct) {
s.GetFields()["key_ops"] = spb.NewListValue(&spb.ListValue{Values: []*spb.Value{spb.NewStringValue("verify")}})
}
func addStringEntry(s *spb.Struct, key, val string) {
s.GetFields()[key] = spb.NewStringValue(val)
}
var psAlgToStr map[jrpsspb.JwtRsaSsaPssAlgorithm]string = map[jrpsspb.JwtRsaSsaPssAlgorithm]string{
jrpsspb.JwtRsaSsaPssAlgorithm_PS256: "PS256",
jrpsspb.JwtRsaSsaPssAlgorithm_PS384: "PS384",
jrpsspb.JwtRsaSsaPssAlgorithm_PS512: "PS512",
}
func psPublicKeyToStruct(key *tinkpb.Keyset_Key) (*spb.Struct, error) {
pubKey := &jrpsspb.JwtRsaSsaPssPublicKey{}
if err := proto.Unmarshal(key.GetKeyData().GetValue(), pubKey); err != nil {
return nil, err
}
alg, ok := psAlgToStr[pubKey.GetAlgorithm()]
if !ok {
return nil, fmt.Errorf("invalid algorithm")
}
outKey := &spb.Struct{
Fields: map[string]*spb.Value{},
}
addStringEntry(outKey, "alg", alg)
addStringEntry(outKey, "kty", "RSA")
addStringEntry(outKey, "e", base64Encode(pubKey.GetE()))
addStringEntry(outKey, "n", base64Encode(pubKey.GetN()))
addStringEntry(outKey, "use", "sig")
addKeyOPSVerify(outKey)
var customKID *string = nil
if pubKey.GetCustomKid() != nil {
ck := pubKey.GetCustomKid().GetValue()
customKID = &ck
}
if err := setKeyID(outKey, key, customKID); err != nil {
return nil, err
}
return outKey, nil
}
var rsAlgToStr map[jrsppb.JwtRsaSsaPkcs1Algorithm]string = map[jrsppb.JwtRsaSsaPkcs1Algorithm]string{
jrsppb.JwtRsaSsaPkcs1Algorithm_RS256: "RS256",
jrsppb.JwtRsaSsaPkcs1Algorithm_RS384: "RS384",
jrsppb.JwtRsaSsaPkcs1Algorithm_RS512: "RS512",
}
func rsPublicKeyToStruct(key *tinkpb.Keyset_Key) (*spb.Struct, error) {
pubKey := &jrsppb.JwtRsaSsaPkcs1PublicKey{}
if err := proto.Unmarshal(key.GetKeyData().GetValue(), pubKey); err != nil {
return nil, err
}
alg, ok := rsAlgToStr[pubKey.GetAlgorithm()]
if !ok {
return nil, fmt.Errorf("invalid algorithm")
}
outKey := &spb.Struct{
Fields: map[string]*spb.Value{},
}
addStringEntry(outKey, "alg", alg)
addStringEntry(outKey, "kty", "RSA")
addStringEntry(outKey, "e", base64Encode(pubKey.GetE()))
addStringEntry(outKey, "n", base64Encode(pubKey.GetN()))
addStringEntry(outKey, "use", "sig")
addKeyOPSVerify(outKey)
var customKID *string = nil
if pubKey.GetCustomKid() != nil {
ck := pubKey.GetCustomKid().GetValue()
customKID = &ck
}
if err := setKeyID(outKey, key, customKID); err != nil {
return nil, err
}
return outKey, nil
}
func esPublicKeyToStruct(key *tinkpb.Keyset_Key) (*spb.Struct, error) {
pubKey := &jepb.JwtEcdsaPublicKey{}
if err := proto.Unmarshal(key.GetKeyData().GetValue(), pubKey); err != nil {
return nil, err
}
outKey := &spb.Struct{
Fields: map[string]*spb.Value{},
}
var algorithm, curve string
switch pubKey.GetAlgorithm() {
case jepb.JwtEcdsaAlgorithm_ES256:
curve, algorithm = "P-256", "ES256"
case jepb.JwtEcdsaAlgorithm_ES384:
curve, algorithm = "P-384", "ES384"
case jepb.JwtEcdsaAlgorithm_ES512:
curve, algorithm = "P-521", "ES512"
default:
return nil, fmt.Errorf("invalid algorithm")
}
addStringEntry(outKey, "crv", curve)
addStringEntry(outKey, "alg", algorithm)
addStringEntry(outKey, "kty", "EC")
addStringEntry(outKey, "x", base64Encode(pubKey.GetX()))
addStringEntry(outKey, "y", base64Encode(pubKey.GetY()))
addStringEntry(outKey, "use", "sig")
addKeyOPSVerify(outKey)
var customKID *string = nil
if pubKey.GetCustomKid() != nil {
ck := pubKey.GetCustomKid().GetValue()
customKID = &ck
}
if err := setKeyID(outKey, key, customKID); err != nil {
return nil, err
}
return outKey, nil
}
func setKeyID(outKey *spb.Struct, key *tinkpb.Keyset_Key, customKID *string) error {
if key.GetOutputPrefixType() == tinkpb.OutputPrefixType_TINK {
if customKID != nil {
return fmt.Errorf("TINK keys shouldn't have custom KID")
}
kid := keyID(key.KeyId, key.GetOutputPrefixType())
if kid == nil {
return fmt.Errorf("tink KID shouldn't be nil")
}
addStringEntry(outKey, "kid", *kid)
} else if customKID != nil {
addStringEntry(outKey, "kid", *customKID)
}
return nil
}
// JWKSetFromPublicKeysetHandle converts a Tink KeysetHandle with JWT keys into a Json Web Key (JWK) set.
// Currently only public keys for algorithms ES256, ES384, ES512, RS256, RS384, and RS512 are supported.
// JWK is defined in https://www.rfc-editor.org/rfc/rfc7517.html.
func JWKSetFromPublicKeysetHandle(kh *keyset.Handle) ([]byte, error) {
b := &bytes.Buffer{}
if err := kh.WriteWithNoSecrets(keyset.NewBinaryWriter(b)); err != nil {
return nil, err
}
ks := &tinkpb.Keyset{}
if err := proto.Unmarshal(b.Bytes(), ks); err != nil {
return nil, err
}
keyValList := []*spb.Value{}
for _, k := range ks.Key {
if k.GetStatus() != tinkpb.KeyStatusType_ENABLED {
continue
}
if k.GetOutputPrefixType() != tinkpb.OutputPrefixType_TINK &&
k.GetOutputPrefixType() != tinkpb.OutputPrefixType_RAW {
return nil, fmt.Errorf("unsupported output prefix type")
}
keyData := k.GetKeyData()
if keyData == nil {
return nil, fmt.Errorf("invalid key data")
}
if keyData.GetKeyMaterialType() != tinkpb.KeyData_ASYMMETRIC_PUBLIC {
return nil, fmt.Errorf("only asymmetric public keys are supported")
}
keyStruct := &spb.Struct{}
var err error
switch keyData.GetTypeUrl() {
case jwtECDSAPublicKeyType:
keyStruct, err = esPublicKeyToStruct(k)
case jwtRSPublicKeyType:
keyStruct, err = rsPublicKeyToStruct(k)
case jwtPSPublicKeyType:
keyStruct, err = psPublicKeyToStruct(k)
default:
return nil, fmt.Errorf("unsupported key type url")
}
if err != nil {
return nil, err
}
keyValList = append(keyValList, spb.NewStructValue(keyStruct))
}
output := &spb.Struct{
Fields: map[string]*spb.Value{
"keys": spb.NewListValue(&spb.ListValue{Values: keyValList}),
},
}
return output.MarshalJSON()
}