blob: bb27cf06e0fd0d04f1e99b0add564fc4ed7f98cf [file] [log] [blame]
/*
* Copyright 2019 Google Inc.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "server_impl.h"
#include <algorithm>
#include "crypto/paillier.h"
#include "util/status.inc"
#include "crypto/ec_commutative_cipher.h"
#include "absl/memory/memory.h"
using ::private_join_and_compute::BigNum;
using ::private_join_and_compute::ECCommutativeCipher;
using ::private_join_and_compute::PublicPaillier;
namespace private_join_and_compute {
StatusOr<PrivateIntersectionSumServerMessage::ServerRoundOne>
PrivateIntersectionSumProtocolServerImpl::EncryptSet() {
if (ec_cipher_ != nullptr) {
return InvalidArgumentError("Attempted to call EncryptSet twice.");
}
StatusOr<std::unique_ptr<ECCommutativeCipher>> ec_cipher =
ECCommutativeCipher::CreateWithNewKey(
NID_secp224r1, ECCommutativeCipher::HashType::SHA512);
if (!ec_cipher.ok()) {
return ec_cipher.status();
}
ec_cipher_ = std::move(ec_cipher.ValueOrDie());
PrivateIntersectionSumServerMessage::ServerRoundOne result;
for (const std::string& input : inputs_) {
EncryptedElement* encrypted =
result.mutable_encrypted_set()->add_elements();
StatusOr<std::string> encrypted_element = ec_cipher_->Encrypt(input);
if (!encrypted_element.ok()) {
return encrypted_element.status();
}
*encrypted->mutable_element() = encrypted_element.ValueOrDie();
}
return result;
}
StatusOr<PrivateIntersectionSumServerMessage::ServerRoundTwo>
PrivateIntersectionSumProtocolServerImpl::ComputeIntersection(
const PrivateIntersectionSumClientMessage::ClientRoundOne& client_message) {
if (ec_cipher_ == nullptr) {
return InvalidArgumentError(
"Called ComputeIntersection before EncryptSet.");
}
PrivateIntersectionSumServerMessage::ServerRoundTwo result;
BigNum N = ctx_->CreateBigNum(client_message.public_key());
PublicPaillier public_paillier(ctx_, N, 2);
std::vector<EncryptedElement> server_set, client_set, intersection;
// First, we re-encrypt the client party's set, so that we can compare with
// the re-encrypted set received from the client.
for (const EncryptedElement& element :
client_message.encrypted_set().elements()) {
EncryptedElement reencrypted;
*reencrypted.mutable_associated_data() = element.associated_data();
StatusOr<std::string> reenc = ec_cipher_->ReEncrypt(element.element());
if (!reenc.ok()) {
return reenc.status();
}
*reencrypted.mutable_element() = reenc.ValueOrDie();
client_set.push_back(reencrypted);
}
for (const EncryptedElement& element :
client_message.reencrypted_set().elements()) {
server_set.push_back(element);
}
// std::set_intersection requires sorted inputs.
std::sort(client_set.begin(), client_set.end(),
[](const EncryptedElement& a, const EncryptedElement& b) {
return a.element() < b.element();
});
std::sort(server_set.begin(), server_set.end(),
[](const EncryptedElement& a, const EncryptedElement& b) {
return a.element() < b.element();
});
std::set_intersection(
client_set.begin(), client_set.end(), server_set.begin(),
server_set.end(), std::back_inserter(intersection),
[](const EncryptedElement& a, const EncryptedElement& b) {
return a.element() < b.element();
});
// From the intersection we compute the sum of the associated values, which is
// the result we return to the client.
StatusOr<BigNum> encrypted_zero =
public_paillier.Encrypt(ctx_->CreateBigNum(0));
if (!encrypted_zero.ok()) {
return encrypted_zero.status();
}
BigNum sum = encrypted_zero.ValueOrDie();
for (const EncryptedElement& element : intersection) {
sum =
public_paillier.Add(sum, ctx_->CreateBigNum(element.associated_data()));
}
*result.mutable_encrypted_sum() = sum.ToBytes();
result.set_intersection_size(intersection.size());
return result;
}
Status PrivateIntersectionSumProtocolServerImpl::Handle(
const ClientMessage& request,
MessageSink<ServerMessage>* server_message_sink) {
if (protocol_finished()) {
return InvalidArgumentError(
"PrivateIntersectionSumProtocolServerImpl: Protocol is already "
"complete.");
}
// Check that the message is a PrivateIntersectionSum protocol message.
if (!request.has_private_intersection_sum_client_message()) {
return InvalidArgumentError(
"PrivateIntersectionSumProtocolServerImpl: Received a message for the "
"wrong protocol type");
}
const PrivateIntersectionSumClientMessage& client_message =
request.private_intersection_sum_client_message();
ServerMessage server_message;
if (client_message.has_start_protocol_request()) {
// Handle a protocol start message.
auto maybe_server_round_one = EncryptSet();
if (!maybe_server_round_one.ok()) {
return maybe_server_round_one.status();
}
*(server_message.mutable_private_intersection_sum_server_message()
->mutable_server_round_one()) =
std::move(maybe_server_round_one.ValueOrDie());
} else if (client_message.has_client_round_one()) {
// Handle the client round 1 message.
auto maybe_server_round_two =
ComputeIntersection(client_message.client_round_one());
if (!maybe_server_round_two.ok()) {
return maybe_server_round_two.status();
}
*(server_message.mutable_private_intersection_sum_server_message()
->mutable_server_round_two()) =
std::move(maybe_server_round_two.ValueOrDie());
// Mark the protocol as finished here.
protocol_finished_ = true;
} else {
return InvalidArgumentError(
"PrivateIntersectionSumProtocolServerImpl: Received a client message "
"of an unknown type.");
}
return server_message_sink->Send(server_message);
}
} // namespace private_join_and_compute