blob: a3898bdca78b7d488714ec99d4b8ffe7d829d72e [file] [log] [blame]
/*
* Copyright 2019 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "private_join_and_compute/crypto/context.h"
#include <cmath>
#include <memory>
#include <string>
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "private_join_and_compute/crypto/openssl_init.h"
namespace private_join_and_compute {
std::string OpenSSLErrorString() {
char buf[256];
ERR_error_string_n(ERR_get_error(), buf, sizeof(buf));
return buf;
}
Context::Context()
: bn_ctx_(BN_CTX_new()),
evp_md_ctx_(EVP_MD_CTX_create()),
zero_bn_(CreateBigNum(0)),
one_bn_(CreateBigNum(1)),
two_bn_(CreateBigNum(2)),
three_bn_(CreateBigNum(3)) {
OpenSSLInit();
CHECK(RAND_status()) << "OpenSSL PRNG is not properly seeded.";
HMAC_CTX_init(&hmac_ctx_);
}
Context::~Context() { HMAC_CTX_cleanup(&hmac_ctx_); }
BN_CTX* Context::GetBnCtx() { return bn_ctx_.get(); }
BigNum Context::CreateBigNum(absl::string_view bytes) {
return BigNum(bn_ctx_.get(), bytes);
}
BigNum Context::CreateBigNum(uint64_t number) {
return BigNum(bn_ctx_.get(), number);
}
BigNum Context::CreateBigNum(BigNum::BignumPtr bn) {
return BigNum(bn_ctx_.get(), std::move(bn));
}
std::string Context::Sha256String(absl::string_view bytes) {
unsigned char hash[EVP_MAX_MD_SIZE];
CRYPTO_CHECK(1 ==
EVP_DigestInit_ex(evp_md_ctx_.get(), EVP_sha256(), nullptr));
CRYPTO_CHECK(
1 == EVP_DigestUpdate(evp_md_ctx_.get(), bytes.data(), bytes.length()));
unsigned int md_len;
CRYPTO_CHECK(1 == EVP_DigestFinal_ex(evp_md_ctx_.get(), hash, &md_len));
return std::string(reinterpret_cast<char*>(hash), md_len);
}
std::string Context::Sha384String(absl::string_view bytes) {
unsigned char hash[EVP_MAX_MD_SIZE];
CRYPTO_CHECK(1 ==
EVP_DigestInit_ex(evp_md_ctx_.get(), EVP_sha384(), nullptr));
CRYPTO_CHECK(
1 == EVP_DigestUpdate(evp_md_ctx_.get(), bytes.data(), bytes.length()));
unsigned int md_len;
CRYPTO_CHECK(1 == EVP_DigestFinal_ex(evp_md_ctx_.get(), hash, &md_len));
return std::string(reinterpret_cast<char*>(hash), md_len);
}
std::string Context::Sha512String(absl::string_view bytes) {
unsigned char hash[EVP_MAX_MD_SIZE];
CRYPTO_CHECK(1 ==
EVP_DigestInit_ex(evp_md_ctx_.get(), EVP_sha512(), nullptr));
CRYPTO_CHECK(
1 == EVP_DigestUpdate(evp_md_ctx_.get(), bytes.data(), bytes.length()));
unsigned int md_len;
CRYPTO_CHECK(1 == EVP_DigestFinal_ex(evp_md_ctx_.get(), hash, &md_len));
return std::string(reinterpret_cast<char*>(hash), md_len);
}
BigNum Context::RandomOracle(absl::string_view x, const BigNum& max_value,
RandomOracleHashType hash_type) {
int hash_output_length = 256;
if (hash_type == SHA512) {
hash_output_length = 512;
} else if (hash_type == SHA384) {
hash_output_length = 384;
}
int output_bit_length = max_value.BitLength() + hash_output_length;
int iter_count =
std::ceil(static_cast<float>(output_bit_length) / hash_output_length);
CHECK(iter_count * hash_output_length < 130048)
<< "The domain bit length must not be greater than "
"130048. Desired bit length: "
<< output_bit_length;
int excess_bit_count = (iter_count * hash_output_length) - output_bit_length;
BigNum hash_output = CreateBigNum(0);
for (int i = 1; i < iter_count + 1; i++) {
hash_output = hash_output.Lshift(hash_output_length);
std::string bignum_bytes = absl::StrCat(CreateBigNum(i).ToBytes(), x);
std::string hashed_string;
if (hash_type == SHA512) {
hashed_string = Sha512String(bignum_bytes);
} else if (hash_type == SHA384) {
hashed_string = Sha384String(bignum_bytes);
} else {
hashed_string = Sha256String(bignum_bytes);
}
hash_output = hash_output + CreateBigNum(hashed_string);
}
return hash_output.Rshift(excess_bit_count).Mod(max_value);
}
BigNum Context::RandomOracleSha512(absl::string_view x,
const BigNum& max_value) {
return RandomOracle(x, max_value, SHA512);
}
BigNum Context::RandomOracleSha384(absl::string_view x,
const BigNum& max_value) {
return RandomOracle(x, max_value, SHA384);
}
BigNum Context::RandomOracleSha256(absl::string_view x,
const BigNum& max_value) {
return RandomOracle(x, max_value, SHA256);
}
BigNum Context::PRF(absl::string_view key, absl::string_view data,
const BigNum& max_value) {
CHECK_GE(key.size() * 8, 80);
CHECK_LE(max_value.BitLength(), 512)
<< "The requested output length is not supported. The maximum "
"supported output length is 512. The requested output length is "
<< max_value.BitLength();
CRYPTO_CHECK(1 == HMAC_Init_ex(&hmac_ctx_, key.data(), key.size(),
EVP_sha512(), nullptr));
CRYPTO_CHECK(1 ==
HMAC_Update(&hmac_ctx_,
reinterpret_cast<const unsigned char*>(data.data()),
data.size()));
unsigned int md_len;
unsigned char hash[EVP_MAX_MD_SIZE];
CRYPTO_CHECK(1 == HMAC_Final(&hmac_ctx_, hash, &md_len));
BigNum hash_bn(bn_ctx_.get(), hash, md_len);
BigNum hash_bn_reduced = hash_bn.GetLastNBits(max_value.BitLength());
if (hash_bn_reduced < max_value) {
return hash_bn_reduced;
} else {
return Context::PRF(key, hash_bn.ToBytes(), max_value);
}
}
BigNum Context::GenerateSafePrime(int prime_length) {
BigNum r(bn_ctx_.get());
CRYPTO_CHECK(1 == BN_generate_prime_ex(r.bn_.get(), prime_length, 1, nullptr,
nullptr, nullptr));
return r;
}
BigNum Context::GeneratePrime(int prime_length) {
BigNum r(bn_ctx_.get());
CRYPTO_CHECK(1 == BN_generate_prime_ex(r.bn_.get(), prime_length, 0, nullptr,
nullptr, nullptr));
return r;
}
BigNum Context::GenerateRandLessThan(const BigNum& max_value) {
BigNum r(bn_ctx_.get());
CRYPTO_CHECK(1 == BN_rand_range(r.bn_.get(), max_value.bn_.get()));
return r;
}
BigNum Context::GenerateRandBetween(const BigNum& start, const BigNum& end) {
CHECK(start < end);
return GenerateRandLessThan(end - start) + start;
}
std::string Context::GenerateRandomBytes(int num_bytes) {
CHECK_GE(num_bytes, 0) << "num_bytes must be nonnegative, provided value was "
<< num_bytes << ".";
std::unique_ptr<unsigned char[]> bytes(new unsigned char[num_bytes]);
CRYPTO_CHECK(1 == RAND_bytes(bytes.get(), num_bytes));
return std::string(reinterpret_cast<char*>(bytes.get()), num_bytes);
}
BigNum Context::RelativelyPrimeRandomLessThan(const BigNum& num) {
BigNum rand_num = GenerateRandLessThan(num);
while (rand_num.Gcd(num) > One()) {
rand_num = GenerateRandLessThan(num);
}
return rand_num;
}
} // namespace private_join_and_compute