blob: a48d57cfecf3c7057e68be537d25fb40a79391ba [file] [log] [blame]
/*
* Copyright 2019 Google Inc.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include <string>
#include "gflags/gflags.h"
#include "include/grpc/grpc_security_constants.h"
#include "include/grpcpp/channel.h"
#include "include/grpcpp/client_context.h"
#include "include/grpcpp/create_channel.h"
#include "include/grpcpp/grpcpp.h"
#include "include/grpcpp/security/credentials.h"
#include "include/grpcpp/support/status.h"
#include "data_util.h"
#include "client_impl.h"
#include "private_join_and_compute.grpc.pb.h"
#include "private_join_and_compute.pb.h"
#include "protocol_client.h"
#include "util/status.inc"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
DEFINE_string(port, "0.0.0.0:10501", "Port on which to contact server");
DEFINE_string(client_data_file, "",
"The file from which to read the client database.");
DEFINE_int32(
paillier_modulus_size, 1536,
"The bit-length of the modulus to use for Paillier encryption. The modulus "
"will be the product of two safe primes, each of size "
"paillier_modulus_size/2.");
namespace private_join_and_compute {
namespace {
class InvokeServerHandleClientMessageSink : public MessageSink<ClientMessage> {
public:
explicit InvokeServerHandleClientMessageSink(
std::unique_ptr<PrivateJoinAndComputeRpc::Stub> stub)
: stub_(std::move(stub)) {}
~InvokeServerHandleClientMessageSink() override = default;
Status Send(const ClientMessage& message) override {
::grpc::ClientContext client_context;
::grpc::Status grpc_status =
stub_->Handle(&client_context, message, &last_server_response_);
if (grpc_status.ok()) {
return OkStatus();
} else {
return InternalError(absl::StrCat(
"GrpcClientMessageSink: Failed to send message, error code: ",
grpc_status.error_code(),
", error_message: ", grpc_status.error_message()));
}
}
const ServerMessage& last_server_response() { return last_server_response_; }
private:
std::unique_ptr<PrivateJoinAndComputeRpc::Stub> stub_;
ServerMessage last_server_response_;
};
int ExecuteProtocol() {
::private_join_and_compute::Context context;
std::cout << "Client: Loading data..." << std::endl;
auto maybe_client_identifiers_and_associated_values =
::private_join_and_compute::ReadClientDatasetFromFile(FLAGS_client_data_file, &context);
if (!maybe_client_identifiers_and_associated_values.ok()) {
std::cerr << "Client::ExecuteProtocol: failed "
<< maybe_client_identifiers_and_associated_values.status()
<< std::endl;
return 1;
}
auto client_identifiers_and_associated_values =
std::move(maybe_client_identifiers_and_associated_values.ValueOrDie());
std::cout << "Client: Generating keys..." << std::endl;
std::unique_ptr<::private_join_and_compute::ProtocolClient> client =
absl::make_unique<::private_join_and_compute::PrivateIntersectionSumProtocolClientImpl>(
&context, std::move(client_identifiers_and_associated_values.first),
std::move(client_identifiers_and_associated_values.second),
FLAGS_paillier_modulus_size);
// Consider grpc::SslServerCredentials if not running locally.
std::unique_ptr<PrivateJoinAndComputeRpc::Stub> stub =
PrivateJoinAndComputeRpc::NewStub(::grpc::CreateChannel(
FLAGS_port, ::grpc::experimental::LocalCredentials(
grpc_local_connect_type::LOCAL_TCP)));
InvokeServerHandleClientMessageSink invoke_server_handle_message_sink(
std::move(stub));
// Execute StartProtocol and wait for response from ServerRoundOne.
std::cout
<< "Client: Starting the protocol." << std::endl
<< "Client: Waiting for response and encrypted set from the server..."
<< std::endl;
auto start_protocol_status =
client->StartProtocol(&invoke_server_handle_message_sink);
if (!start_protocol_status.ok()) {
std::cerr << "Client::ExecuteProtocol: failed to StartProtocol: "
<< start_protocol_status << std::endl;
return 1;
}
ServerMessage server_round_one =
invoke_server_handle_message_sink.last_server_response();
// Execute ClientRoundOne, and wait for response from ServerRoundTwo.
std::cout
<< "Client: Received encrypted set from the server, double encrypting..."
<< std::endl;
std::cout << "Client: Sending double encrypted server data and "
"single-encrypted client data to the server."
<< std::endl
<< "Client: Waiting for encrypted intersection sum..." << std::endl;
auto client_round_one_status =
client->Handle(server_round_one, &invoke_server_handle_message_sink);
if (!client_round_one_status.ok()) {
std::cerr << "Client::ExecuteProtocol: failed to ReEncryptSet: "
<< client_round_one_status << std::endl;
return 1;
}
// Execute ServerRoundTwo.
std::cout << "Client: Sending double encrypted server data and "
"single-encrypted client data to the server."
<< std::endl
<< "Client: Waiting for encrypted intersection sum..." << std::endl;
ServerMessage server_round_two =
invoke_server_handle_message_sink.last_server_response();
// Compute the intersection size and sum.
std::cout << "Client: Received response from the server. Decrypting the "
"intersection-sum."
<< std::endl;
auto intersection_size_and_sum_status =
client->Handle(server_round_two, &invoke_server_handle_message_sink);
if (!intersection_size_and_sum_status.ok()) {
std::cerr << "Client::ExecuteProtocol: failed to DecryptSum: "
<< intersection_size_and_sum_status << std::endl;
return 1;
}
// Output the result.
auto client_print_output_status = client->PrintOutput();
if (!client_print_output_status.ok()) {
std::cerr << "Client::ExecuteProtocol: failed to PrintOutput: "
<< client_print_output_status << std::endl;
return 1;
}
return 0;
}
} // namespace
} // namespace private_join_and_compute
int main(int argc, char** argv) {
google::InitGoogleLogging(argv[0]);
gflags::ParseCommandLineFlags(&argc, &argv, true);
return private_join_and_compute::ExecuteProtocol();
}