| // Copyright 2022 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // |
| //////////////////////////////////////////////////////////////////////////////// |
| |
| package jwt |
| |
| import ( |
| "encoding/base64" |
| "testing" |
| "time" |
| |
| "github.com/google/go-cmp/cmp" |
| "google.golang.org/protobuf/proto" |
| "github.com/google/tink/go/core/registry" |
| "github.com/google/tink/go/subtle/random" |
| jwtmacpb "github.com/google/tink/go/proto/jwt_hmac_go_proto" |
| tinkpb "github.com/google/tink/go/proto/tink_go_proto" |
| ) |
| |
| type jwtKeyManagerTestCase struct { |
| tag string |
| keyFormat *jwtmacpb.JwtHmacKeyFormat |
| key *jwtmacpb.JwtHmacKey |
| } |
| |
| const ( |
| typeURL = "type.googleapis.com/google.crypto.tink.JwtHmacKey" |
| ) |
| |
| func generateKeyFormat(keySize uint32, algorithm jwtmacpb.JwtHmacAlgorithm) *jwtmacpb.JwtHmacKeyFormat { |
| return &jwtmacpb.JwtHmacKeyFormat{ |
| KeySize: keySize, |
| Algorithm: algorithm, |
| } |
| } |
| |
| func TestDoesSupport(t *testing.T) { |
| km, err := registry.GetKeyManager(typeURL) |
| if err != nil { |
| t.Errorf("registry.GetKeyManager(%q) error = %v, want nil", typeURL, err) |
| } |
| if !km.DoesSupport(typeURL) { |
| t.Errorf("km.DoesSupport(%q) = false, want true", typeURL) |
| } |
| } |
| |
| func TestTypeURL(t *testing.T) { |
| km, err := registry.GetKeyManager(typeURL) |
| if err != nil { |
| t.Errorf("registry.GetKeyManager(%q) error = %v, want nil", typeURL, err) |
| } |
| if km.TypeURL() != typeURL { |
| t.Errorf("km.TypeURL() = %q, want %q", km.TypeURL(), typeURL) |
| } |
| } |
| |
| var invalidKeyFormatTestCases = []jwtKeyManagerTestCase{ |
| { |
| tag: "invalid hash algorithm", |
| keyFormat: generateKeyFormat(32, jwtmacpb.JwtHmacAlgorithm_HS_UNKNOWN), |
| }, |
| { |
| tag: "invalid HS256 key size", |
| keyFormat: generateKeyFormat(31, jwtmacpb.JwtHmacAlgorithm_HS256), |
| }, |
| { |
| tag: "invalid HS384 key size", |
| keyFormat: generateKeyFormat(47, jwtmacpb.JwtHmacAlgorithm_HS384), |
| }, |
| { |
| tag: "invalid HS512 key size", |
| keyFormat: generateKeyFormat(63, jwtmacpb.JwtHmacAlgorithm_HS512), |
| }, |
| { |
| tag: "empty key format", |
| keyFormat: &jwtmacpb.JwtHmacKeyFormat{}, |
| }, |
| { |
| tag: "nil key format", |
| keyFormat: nil, |
| }, |
| } |
| |
| func TestNewKeyInvalidFormatFails(t *testing.T) { |
| for _, tc := range invalidKeyFormatTestCases { |
| t.Run(tc.tag, func(t *testing.T) { |
| km, err := registry.GetKeyManager(typeURL) |
| if err != nil { |
| t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err) |
| } |
| serializedKeyFormat, err := proto.Marshal(tc.keyFormat) |
| if err != nil { |
| t.Errorf("serializing key format: %v", err) |
| } |
| if _, err := km.NewKey(serializedKeyFormat); err == nil { |
| t.Errorf("km.NewKey() err = nil, want error") |
| } |
| }) |
| } |
| } |
| |
| func TestNewDataInvalidFormatFails(t *testing.T) { |
| for _, tc := range invalidKeyFormatTestCases { |
| t.Run(tc.tag, func(t *testing.T) { |
| km, err := registry.GetKeyManager(typeURL) |
| if err != nil { |
| t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err) |
| } |
| serializedKeyFormat, err := proto.Marshal(tc.keyFormat) |
| if err != nil { |
| t.Errorf("serializing key format: %v", err) |
| } |
| if _, err := km.NewKeyData(serializedKeyFormat); err == nil { |
| t.Errorf("km.NewKey() err = nil, want error") |
| } |
| }) |
| } |
| } |
| |
| var validKeyFormatTestCases = []jwtKeyManagerTestCase{ |
| { |
| tag: "SHA256 hash algorithm", |
| keyFormat: generateKeyFormat(32, jwtmacpb.JwtHmacAlgorithm_HS256), |
| }, |
| { |
| tag: "SHA384 hash algorithm", |
| keyFormat: generateKeyFormat(48, jwtmacpb.JwtHmacAlgorithm_HS384), |
| }, |
| { |
| tag: "SHA512 hash algorithm", |
| keyFormat: generateKeyFormat(64, jwtmacpb.JwtHmacAlgorithm_HS512), |
| }, |
| } |
| |
| func TestNewKey(t *testing.T) { |
| for _, tc := range validKeyFormatTestCases { |
| t.Run(tc.tag, func(t *testing.T) { |
| km, err := registry.GetKeyManager(typeURL) |
| if err != nil { |
| t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err) |
| } |
| serializedKeyFormat, err := proto.Marshal(tc.keyFormat) |
| if err != nil { |
| t.Errorf("serializing key format: %v", err) |
| } |
| k, err := km.NewKey(serializedKeyFormat) |
| if err != nil { |
| t.Errorf("km.NewKey() err = %v, want nil", err) |
| } |
| key, ok := k.(*jwtmacpb.JwtHmacKey) |
| if !ok { |
| t.Errorf("key isn't of type JwtHmacKey") |
| } |
| if key.Algorithm != tc.keyFormat.Algorithm { |
| t.Errorf("k.Algorithm = %v, want %v", key.Algorithm, tc.keyFormat.Algorithm) |
| } |
| if len(key.KeyValue) != int(tc.keyFormat.KeySize) { |
| t.Errorf("len(key.KeyValue) = %d, want %d", len(key.KeyValue), tc.keyFormat.KeySize) |
| } |
| }) |
| } |
| } |
| |
| func TestNewKeyData(t *testing.T) { |
| for _, tc := range validKeyFormatTestCases { |
| t.Run(tc.tag, func(t *testing.T) { |
| km, err := registry.GetKeyManager(typeURL) |
| if err != nil { |
| t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err) |
| } |
| serializedKeyFormat, err := proto.Marshal(tc.keyFormat) |
| if err != nil { |
| t.Errorf("serializing key format: %v", err) |
| } |
| k, err := km.NewKeyData(serializedKeyFormat) |
| if err != nil { |
| t.Errorf("km.NewKeyData() err = %v, want nil", err) |
| } |
| if k.GetTypeUrl() != typeURL { |
| t.Errorf("k.GetTypeUrl() = %q, want %q", k.GetTypeUrl(), typeURL) |
| } |
| if k.GetKeyMaterialType() != tinkpb.KeyData_SYMMETRIC { |
| t.Errorf("k.GetKeyMaterialType() = %q, want %q", k.GetKeyMaterialType(), tinkpb.KeyData_SYMMETRIC) |
| } |
| }) |
| } |
| } |
| |
| func generateKey(keySize, version uint32, algorithm jwtmacpb.JwtHmacAlgorithm, kid *jwtmacpb.JwtHmacKey_CustomKid) *jwtmacpb.JwtHmacKey { |
| return &jwtmacpb.JwtHmacKey{ |
| KeyValue: random.GetRandomBytes(keySize), |
| Algorithm: algorithm, |
| CustomKid: kid, |
| Version: version, |
| } |
| } |
| |
| func TestGetPrimitiveWithValidKeys(t *testing.T) { |
| rawJWT, err := NewRawJWT(&RawJWTOptions{WithoutExpiration: true, Audiences: []string{"tink-aud"}}) |
| if err != nil { |
| t.Fatalf("NewRawJWT() err = %v, want nil", err) |
| } |
| validator, err := NewValidator(&ValidatorOpts{AllowMissingExpiration: true, ExpectedAudience: refString("tink-aud")}) |
| if err != nil { |
| t.Fatalf("NewValidator() err = %v, want nil", err) |
| } |
| for _, tc := range []jwtKeyManagerTestCase{ |
| { |
| tag: "SHA256 hash algorithm", |
| key: generateKey(32, 0, jwtmacpb.JwtHmacAlgorithm_HS256, nil), |
| }, |
| { |
| tag: "SHA384 hash algorithm", |
| key: generateKey(48, 0, jwtmacpb.JwtHmacAlgorithm_HS384, nil), |
| }, |
| { |
| tag: "SHA512 hash algorithm", |
| key: generateKey(64, 0, jwtmacpb.JwtHmacAlgorithm_HS512, nil), |
| }, |
| { |
| tag: "with custom kid", |
| key: generateKey(64, 0, jwtmacpb.JwtHmacAlgorithm_HS512, &jwtmacpb.JwtHmacKey_CustomKid{Value: "1235"}), |
| }, |
| } { |
| t.Run(tc.tag, func(t *testing.T) { |
| km, err := registry.GetKeyManager(typeURL) |
| if err != nil { |
| t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err) |
| } |
| serializedKey, err := proto.Marshal(tc.key) |
| if err != nil { |
| t.Errorf("serializing key format: %v", err) |
| } |
| p, err := km.Primitive(serializedKey) |
| if err != nil { |
| t.Errorf("km.Primitive() err = %v, want nil", err) |
| } |
| primitive, ok := p.(*macWithKID) |
| if !ok { |
| t.Errorf("primitive isn't of type: macWithKID") |
| } |
| compact, err := primitive.ComputeMACAndEncodeWithKID(rawJWT, nil) |
| if err != nil { |
| t.Errorf("ComputeMACAndEncodeWithKID() err = %v, want nil", err) |
| } |
| verifiedJWT, err := primitive.VerifyMACAndDecodeWithKID(compact, validator, nil) |
| if err != nil { |
| t.Errorf("VerifyMACAndDecodeWithKID() err = %v, want nil", err) |
| } |
| audiences, err := verifiedJWT.Audiences() |
| if err != nil { |
| t.Errorf("verifiedJWT.Audiences() err = %v, want nil", err) |
| } |
| if !cmp.Equal(audiences, []string{"tink-aud"}) { |
| t.Errorf("verifiedJWT.Audiences() = %q, want ['tink-aud']", audiences) |
| } |
| |
| }) |
| } |
| } |
| |
| func TestGetPrimitiveWithInvalidKeys(t *testing.T) { |
| for _, tc := range []jwtKeyManagerTestCase{ |
| { |
| tag: "HS256", |
| key: generateKey(31, 0, jwtmacpb.JwtHmacAlgorithm_HS256, nil), |
| }, |
| { |
| tag: "HS384", |
| key: generateKey(47, 0, jwtmacpb.JwtHmacAlgorithm_HS384, nil), |
| }, |
| { |
| tag: "HS512", |
| key: generateKey(63, 0, jwtmacpb.JwtHmacAlgorithm_HS512, nil), |
| }, |
| } { |
| t.Run(tc.tag, func(t *testing.T) { |
| km, err := registry.GetKeyManager(typeURL) |
| if err != nil { |
| t.Fatalf("registry.GetKeyManager(%q) err=%q, want nil", typeURL, err) |
| } |
| serializedKey, err := proto.Marshal(tc.key) |
| if err != nil { |
| t.Fatalf("proto.Marshal(tc.key) err =%q, want nil", err) |
| } |
| _, err = km.Primitive(serializedKey) |
| if err == nil { |
| t.Error("km.Primitive(serializedKey) err = nil, want error") |
| } |
| }) |
| } |
| } |
| |
| func TestSpecyfingCustomKIDAndTINKKIDFails(t *testing.T) { |
| // key and compact are examples from: https://datatracker.ietf.org/doc/html/rfc7515#appendix-A.1.1 |
| compact := "eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ.dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" |
| rawKey, err := base64.URLEncoding.WithPadding(base64.NoPadding).DecodeString("AyM1SysPpbyDfgZld3umj1qzKObwVMkoqQ-EstJQLr_T-1qS0gZH75aKtMN3Yj0iPS4hcgUuTwjAzZr1Z9CAow") |
| if err != nil { |
| t.Fatalf("failed decoding test key: %v", err) |
| } |
| key := &jwtmacpb.JwtHmacKey{ |
| KeyValue: rawKey, |
| Algorithm: jwtmacpb.JwtHmacAlgorithm_HS256, |
| CustomKid: &jwtmacpb.JwtHmacKey_CustomKid{Value: "1235"}, |
| Version: 0, |
| } |
| km, err := registry.GetKeyManager(typeURL) |
| if err != nil { |
| t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err) |
| } |
| serializedKey, err := proto.Marshal(key) |
| if err != nil { |
| t.Errorf("serializing key format: %v", err) |
| } |
| p, err := km.Primitive(serializedKey) |
| if err != nil { |
| t.Errorf("km.Primitive() err = %v, want nil", err) |
| } |
| primitive, ok := p.(*macWithKID) |
| if !ok { |
| t.Errorf("primitive isn't of type: macWithKID") |
| } |
| |
| rawJWT, err := NewRawJWT(&RawJWTOptions{WithoutExpiration: true}) |
| if err != nil { |
| t.Errorf("creating new RawJWT: %v", err) |
| } |
| opts := &ValidatorOpts{ |
| ExpectedTypeHeader: refString("JWT"), |
| ExpectedIssuer: refString("joe"), |
| FixedNow: time.Unix(12345, 0), |
| } |
| validator, err := NewValidator(opts) |
| if err != nil { |
| t.Errorf("creating new JWTValidator: %v", err) |
| } |
| if _, err := primitive.ComputeMACAndEncodeWithKID(rawJWT, refString("4566")); err == nil { |
| t.Errorf("primitive.ComputeMACAndEncodeWithKID() err = nil, want error") |
| } |
| if _, err := primitive.VerifyMACAndDecodeWithKID(compact, validator, refString("4566")); err == nil { |
| t.Errorf("primitive.VerifyMACAndDecodeWithKID(kid = 4566) err = nil, want error") |
| } |
| // Verify success without KID |
| if _, err := primitive.VerifyMACAndDecodeWithKID(compact, validator, nil); err != nil { |
| t.Errorf("primitive.VerifyMACAndDecodeWithKID(kid = nil) err = %v, want nil", err) |
| } |
| } |
| |
| func TestGetPrimitiveWithInvalidKeyFails(t *testing.T) { |
| for _, tc := range []jwtKeyManagerTestCase{ |
| { |
| tag: "empty key", |
| key: &jwtmacpb.JwtHmacKey{}, |
| }, |
| { |
| tag: "nil key", |
| key: nil, |
| }, |
| { |
| tag: "unsupported hash algorithm", |
| key: generateKey(32, 0, jwtmacpb.JwtHmacAlgorithm_HS_UNKNOWN, nil), |
| }, |
| { |
| tag: "short key length", |
| key: generateKey(20, 0, jwtmacpb.JwtHmacAlgorithm_HS384, nil), |
| }, |
| { |
| tag: "unsupported version", |
| key: generateKey(48, 1, jwtmacpb.JwtHmacAlgorithm_HS384, nil), |
| }, |
| } { |
| t.Run(tc.tag, func(t *testing.T) { |
| km, err := registry.GetKeyManager(typeURL) |
| if err != nil { |
| t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err) |
| } |
| serializedKey, err := proto.Marshal(tc.key) |
| if err != nil { |
| t.Errorf("serializing key format: %v", err) |
| } |
| if _, err := km.Primitive(serializedKey); err == nil { |
| t.Errorf("km.Primitive() err = nil, want error") |
| } |
| }) |
| } |
| } |
| |
| func TestGeneratesDifferentKeys(t *testing.T) { |
| km, err := registry.GetKeyManager(typeURL) |
| if err != nil { |
| t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err) |
| } |
| serializedKeyFormat, err := proto.Marshal(generateKeyFormat(32, jwtmacpb.JwtHmacAlgorithm_HS256)) |
| if err != nil { |
| t.Errorf("serializing key format: %v", err) |
| } |
| k1, err := km.NewKey(serializedKeyFormat) |
| if err != nil { |
| t.Errorf("km.NewKey() err = %v, want nil", err) |
| } |
| k2, err := km.NewKey(serializedKeyFormat) |
| if err != nil { |
| t.Errorf("km.NewKey() err = %v, want nil", err) |
| } |
| key1, ok := k1.(*jwtmacpb.JwtHmacKey) |
| if !ok { |
| t.Errorf("k1 isn't of type JwtHmacKey") |
| } |
| key2, ok := k2.(*jwtmacpb.JwtHmacKey) |
| if !ok { |
| t.Errorf("k2 isn't of type JwtHmacKey") |
| } |
| if cmp.Equal(key1.GetKeyValue(), key2.GetKeyValue()) { |
| t.Errorf("key material should differ") |
| } |
| } |