| // Copyright 2022 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // |
| //////////////////////////////////////////////////////////////////////////////// |
| |
| package signature_test |
| |
| import ( |
| "crypto/rand" |
| "crypto/rsa" |
| "fmt" |
| "math/big" |
| "testing" |
| |
| "github.com/google/tink/go/internal/signature" |
| "github.com/google/tink/go/subtle/random" |
| "github.com/google/tink/go/subtle" |
| "github.com/google/tink/go/testutil" |
| ) |
| |
| func TestRSASSAPSSSignVerify(t *testing.T) { |
| data := random.GetRandomBytes(20) |
| sigHash := "SHA256" |
| saltLength := 10 |
| privKey, err := rsa.GenerateKey(rand.Reader, 3072) |
| if err != nil { |
| t.Fatalf("rsa.GenerateKey(rand.Reader, 3072) err = %v, want nil", err) |
| } |
| signer, err := signature.New_RSA_SSA_PSS_Signer(sigHash, saltLength, privKey) |
| if err != nil { |
| t.Fatalf("New_RSA_SSA_PSS_Signer() error = %v, want nil", err) |
| } |
| verifier, err := signature.New_RSA_SSA_PSS_Verifier(sigHash, saltLength, &privKey.PublicKey) |
| if err != nil { |
| t.Fatalf("New_RSA_SSA_PSS_Verifier() error = %v, want nil", err) |
| } |
| s, err := signer.Sign(data) |
| if err != nil { |
| t.Fatalf("Sign() err = %v, want nil", err) |
| } |
| if err = verifier.Verify(s, data); err != nil { |
| t.Fatalf("Verify() err = %v, want nil", err) |
| } |
| } |
| |
| func TestRSASSAPSSSignVerifyInvalidFails(t *testing.T) { |
| data := random.GetRandomBytes(20) |
| sigHash := "SHA256" |
| saltLength := 10 |
| privKey, err := rsa.GenerateKey(rand.Reader, 3072) |
| if err != nil { |
| t.Fatalf("rsa.GenerateKey(rand.Reader, 3072) err = %v, want nil", err) |
| } |
| signer, err := signature.New_RSA_SSA_PSS_Signer(sigHash, saltLength, privKey) |
| if err != nil { |
| t.Fatalf("New_RSA_SSA_PSS_Signer() error = %v, want nil", err) |
| } |
| verifier, err := signature.New_RSA_SSA_PSS_Verifier(sigHash, saltLength, &privKey.PublicKey) |
| if err != nil { |
| t.Fatalf("New_RSA_SSA_PSS_Verifier() error = %v, want nil", err) |
| } |
| s, err := signer.Sign(data) |
| if err != nil { |
| t.Fatalf("Sign() err = %v, want nil", err) |
| } |
| if err = verifier.Verify(s, data); err != nil { |
| t.Fatalf("Verify() err = %v, want nil", err) |
| } |
| |
| modifiedSig := s[:] |
| // modify first byte in signature |
| modifiedSig[0] = byte(uint8(modifiedSig[0]) + 1) |
| if err := verifier.Verify(modifiedSig, data); err == nil { |
| t.Errorf("Verify(modifiedSig, data) err = nil, want error") |
| } |
| if err := verifier.Verify(s, []byte("invalid_data")); err == nil { |
| t.Errorf("Verify(s, invalid_data) err = nil, want error") |
| } |
| if err := verifier.Verify([]byte("invalid_signature"), data); err == nil { |
| t.Errorf("Verify(invalid_signature, data) err = nil, want error") |
| } |
| |
| diffPrivKey, err := rsa.GenerateKey(rand.Reader, 3072) |
| if err != nil { |
| t.Fatalf("rsa.GenerateKey(rand.Reader, 3072) err = %v, want nil", err) |
| } |
| diffVerifier, err := signature.New_RSA_SSA_PSS_Verifier(sigHash, saltLength, &diffPrivKey.PublicKey) |
| if err != nil { |
| t.Fatalf("New_RSA_SSA_PSS_Verifier() error = %v, want nil", err) |
| } |
| if err := diffVerifier.Verify(s, data); err == nil { |
| t.Errorf("Verify() err = nil, want error") |
| } |
| } |
| |
| func TestNewRSASSAPSSSignerVerifierFailWithInvalidInputs(t *testing.T) { |
| type testCase struct { |
| name string |
| hash string |
| salt int |
| privKey *rsa.PrivateKey |
| } |
| validPrivKey, err := rsa.GenerateKey(rand.Reader, 3072) |
| if err != nil { |
| t.Fatalf("rsa.GenerateKey(rand.Reader, 3072) err = %v, want nil", err) |
| } |
| for _, tc := range []testCase{ |
| { |
| name: "invalid hash function", |
| hash: "SHA1", |
| privKey: validPrivKey, |
| salt: 0, |
| }, |
| { |
| name: "invalid exponent", |
| hash: "SHA256", |
| salt: 0, |
| privKey: &rsa.PrivateKey{ |
| D: validPrivKey.D, |
| PublicKey: rsa.PublicKey{ |
| N: validPrivKey.N, |
| E: 8, |
| }, |
| Primes: validPrivKey.Primes, |
| Precomputed: validPrivKey.Precomputed, |
| }, |
| }, |
| { |
| name: "invalid modulus", |
| hash: "SHA256", |
| salt: 0, |
| privKey: &rsa.PrivateKey{ |
| D: validPrivKey.D, |
| PublicKey: rsa.PublicKey{ |
| N: big.NewInt(5), |
| E: validPrivKey.E, |
| }, |
| Primes: validPrivKey.Primes, |
| Precomputed: validPrivKey.Precomputed, |
| }, |
| }, |
| { |
| name: "invalid salt", |
| hash: "SHA256", |
| salt: -1, |
| privKey: validPrivKey, |
| }, |
| } { |
| t.Run(tc.name, func(t *testing.T) { |
| if _, err := signature.New_RSA_SSA_PSS_Signer(tc.hash, tc.salt, tc.privKey); err == nil { |
| t.Errorf("New_RSA_SSA_PSS_Signer() err = nil, want error") |
| } |
| if _, err := signature.New_RSA_SSA_PSS_Verifier(tc.hash, tc.salt, &tc.privKey.PublicKey); err == nil { |
| t.Errorf("New_RSA_SSA_PSS_Verifier() err = nil, want error") |
| } |
| }) |
| } |
| } |
| |
| type rsaSSAPSSSuite struct { |
| testutil.WycheproofSuite |
| TestGroups []*rsaSSAPSSGroup `json:"testGroups"` |
| } |
| |
| type rsaSSAPSSGroup struct { |
| testutil.WycheproofGroup |
| SHA string `json:"sha"` |
| MGFSHA string `json:"mgfSha"` |
| SaltLength int `json:"sLen"` |
| E testutil.HexBytes `json:"e"` |
| N testutil.HexBytes `json:"N"` |
| Tests []*rsaSSAPSSCase `json:"tests"` |
| } |
| |
| type rsaSSAPSSCase struct { |
| testutil.WycheproofCase |
| Message testutil.HexBytes `json:"msg"` |
| Signature testutil.HexBytes `json:"sig"` |
| } |
| |
| func TestRSASSAPSSWycheproofCases(t *testing.T) { |
| testutil.SkipTestIfTestSrcDirIsNotSet(t) |
| ranTestCount := 0 |
| vectorsFiles := []string{ |
| "rsa_pss_2048_sha512_256_mgf1_28_test.json", |
| "rsa_pss_2048_sha512_256_mgf1_32_test.json", |
| "rsa_pss_2048_sha256_mgf1_0_test.json", |
| "rsa_pss_2048_sha256_mgf1_32_test.json", |
| "rsa_pss_3072_sha256_mgf1_32_test.json", |
| "rsa_pss_4096_sha256_mgf1_32_test.json", |
| "rsa_pss_4096_sha512_mgf1_32_test.json", |
| } |
| for _, v := range vectorsFiles { |
| suite := &rsaSSAPSSSuite{} |
| if err := testutil.PopulateSuite(suite, v); err != nil { |
| t.Fatalf("failed populating suite: %s", err) |
| } |
| for _, group := range suite.TestGroups { |
| sigHash := subtle.ConvertHashName(group.SHA) |
| if sigHash == "" { |
| continue |
| } |
| pubKey := &rsa.PublicKey{ |
| E: int(new(big.Int).SetBytes(group.E).Uint64()), |
| N: new(big.Int).SetBytes(group.N), |
| } |
| verifier, err := signature.New_RSA_SSA_PSS_Verifier(sigHash, group.SaltLength, pubKey) |
| if err != nil { |
| t.Fatalf("New_RSA_SSA_PSS_Verifier() err = %v, want nil", err) |
| } |
| for _, test := range group.Tests { |
| if (test.CaseID == 67 || test.CaseID == 68) && v == "rsa_pss_2048_sha256_mgf1_0_test.json" { |
| // crypto/rsa will interpret zero length salt and parse the salt length from signature. |
| // Since this test cases use a zero salt length as a parameter, even if a different parameter |
| // is provided, Golang will interpret it and parse the salt directly from the signature. |
| continue |
| } |
| ranTestCount++ |
| caseName := fmt.Sprintf("%s: %s-%s-%s-%d:Case-%d", v, group.Type, group.SHA, group.MGFSHA, group.SaltLength, test.CaseID) |
| t.Run(caseName, func(t *testing.T) { |
| err := verifier.Verify(test.Signature, test.Message) |
| switch test.Result { |
| case "valid": |
| if err != nil { |
| t.Errorf("Verify() err = %, want nil", err) |
| } |
| case "invalid": |
| if err == nil { |
| t.Errorf("Verify() err = nil, want error") |
| } |
| case "acceptable": |
| // TODO(b/230489047): Inspect flags to appropriately handle acceptable test cases. |
| default: |
| t.Errorf("unsupported test result: %q", test.Result) |
| } |
| }) |
| } |
| } |
| } |
| if ranTestCount < 578 { |
| t.Errorf("ranTestCount > %d, want > %d", ranTestCount, 578) |
| } |
| } |