blob: e0ef58344274e5c6e5324dc060ee0b8c0a23c16f [file] [log] [blame]
/*
* Copyright 2019 Google Inc.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "data_util.h"
#include <algorithm>
#include <fstream>
#include <limits>
#include <random>
#include <string>
#include "crypto/context.h"
#include "util/status.inc"
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_replace.h"
namespace private_join_and_compute {
namespace {
static const char kAlphaNumericCharacters[] =
"1234567890qwertyuiopasdfghjklzxcvbnmQWERTYUIOPASDFGHJKLZXCVBNM";
static const size_t kAlphaNumericSize = 62;
// Creates a string of the specified length consistin of random letters and
// numbers.
std::string GetRandomAlphaNumericString(size_t length) {
std::string output;
for (size_t i = 0; i < length; i++) {
std::string next_char(1,
kAlphaNumericCharacters[rand() % kAlphaNumericSize]);
absl::StrAppend(&output, next_char);
}
return output;
}
// Utility functions to convert a line to CSV format, and parse a CSV line into
// columns safely.
char* strndup_with_new(const char* the_string, size_t max_length) {
if (the_string == nullptr) return nullptr;
char* result = new char[max_length + 1];
result[max_length] = '\0'; // terminate the string because strncpy might not
return strncpy(result, the_string, max_length);
}
void SplitCSVLineWithDelimiter(char* line, char delimiter,
std::vector<char*>* cols) {
char* end_of_line = line + strlen(line);
char* end;
char* start;
for (; line < end_of_line; line++) {
// Skip leading whitespace, unless said whitespace is the delimiter.
while (std::isspace(*line) && *line != delimiter) ++line;
if (*line == '"' && delimiter == ',') { // Quoted value...
start = ++line;
end = start;
for (; *line; line++) {
if (*line == '"') {
line++;
if (*line != '"') // [""] is an escaped ["]
break; // but just ["] is end of value
}
*end++ = *line;
}
// All characters after the closing quote and before the comma
// are ignored.
line = strchr(line, delimiter);
if (!line) line = end_of_line;
} else {
start = line;
line = strchr(line, delimiter);
if (!line) line = end_of_line;
// Skip all trailing whitespace, unless said whitespace is the delimiter.
for (end = line; end > start; --end) {
if (!std::isspace(end[-1]) || end[-1] == delimiter) break;
}
}
const bool need_another_column =
(*line == delimiter) && (line == end_of_line - 1);
*end = '\0';
cols->push_back(start);
// If line was something like [paul,] (comma is the last character
// and is not proceeded by whitespace or quote) then we are about
// to eliminate the last column (which is empty). This would be
// incorrect.
if (need_another_column) cols->push_back(end);
assert(*line == '\0' || *line == delimiter);
}
}
void SplitCSVLineWithDelimiterForStrings(const std::string& line,
char delimiter,
std::vector<std::string>* cols) {
// Unfortunately, the interface requires char* instead of const char*
// which requires copying the string.
char* cline = strndup_with_new(line.c_str(), line.size());
std::vector<char*> v;
SplitCSVLineWithDelimiter(cline, delimiter, &v);
for (char* str : v) {
cols->push_back(str);
}
delete[] cline;
}
std::vector<std::string> SplitCsvLine(const std::string& line) {
std::vector<std::string> cols;
SplitCSVLineWithDelimiterForStrings(line, ',', &cols);
return cols;
}
// Escapes a string for CSV file writing. By default, this will surround each
// string with double quotes, and escape each occurrence of a double quote by
// replacing it with 2 double quotes.
std::string EscapeForCsv(const std::string& input) {
return absl::StrCat("\"", absl::StrReplaceAll(input, {{"\"", "\"\""}}), "\"");
}
} // namespace
util::StatusOr<std::tuple<
std::vector<std::string>,
std::pair<std::vector<std::string>, std::vector<int64_t>>, int64_t>>
GenerateRandomDatabases(int64_t server_data_size, int64_t client_data_size,
int64_t intersection_size,
int64_t max_associated_value) {
// Check parameters
if (intersection_size < 0 || server_data_size < 0 || client_data_size < 0 ||
max_associated_value < 0) {
return util::InvalidArgumentError(
"GenerateRandomDatabases: Sizes cannot be negative.");
}
if (intersection_size > server_data_size ||
intersection_size > client_data_size) {
return util::InvalidArgumentError(
"GenerateRandomDatabases: intersection_size is larger than "
"client/server data size.");
}
if (max_associated_value > 0 &&
intersection_size >
std::numeric_limits<int64_t>::max() / max_associated_value) {
return util::InvalidArgumentError(
"GenerateRandomDatabases: intersection_size * max_associated_value is "
"larger than int64_t::max.");
}
std::random_device rd;
std::mt19937 gen(rd());
// Generate the random identifiers that are going to be in the intersection.
std::vector<std::string> common_identifiers;
common_identifiers.reserve(intersection_size);
for (int64_t i = 0; i < intersection_size; i++) {
common_identifiers.push_back(
GetRandomAlphaNumericString(kRandomIdentifierLengthBytes));
}
// Generate remaining random identifiers for the server, and shuffle.
std::vector<std::string> server_identifiers = common_identifiers;
server_identifiers.reserve(server_data_size);
for (int64_t i = intersection_size; i < server_data_size; i++) {
server_identifiers.push_back(
GetRandomAlphaNumericString(kRandomIdentifierLengthBytes));
}
std::shuffle(server_identifiers.begin(), server_identifiers.end(), gen);
// Generate remaining random identifiers for the client.
std::vector<std::string> client_identifiers = common_identifiers;
client_identifiers.reserve(client_data_size);
for (int64_t i = intersection_size; i < client_data_size; i++) {
client_identifiers.push_back(
GetRandomAlphaNumericString(kRandomIdentifierLengthBytes));
}
std::shuffle(client_identifiers.begin(), client_identifiers.end(), gen);
std::set<std::string> server_identifiers_set(server_identifiers.begin(),
server_identifiers.end());
// Generate associated values for the client, adding them to the intersection
// sum if the identifier is in common.
std::vector<int64_t> client_associated_values;
Context context;
BigNum associated_values_bound = context.CreateBigNum(max_associated_value);
client_associated_values.reserve(client_data_size);
int64_t intersection_sum = 0;
for (int64_t i = 0; i < client_data_size; i++) {
// Converting the associated value from BigNum to int64_t should never fail
// because associated_values_bound is less than int64_t::max.
int64_t associated_value =
context.GenerateRandLessThan(associated_values_bound)
.ToIntValue()
.ValueOrDie();
client_associated_values.push_back(associated_value);
if (server_identifiers_set.count(client_identifiers[i]) > 0) {
intersection_sum += associated_value;
}
}
// Return the output.
return std::make_tuple(std::move(server_identifiers),
std::make_pair(std::move(client_identifiers),
std::move(client_associated_values)),
intersection_sum);
}
util::Status WriteServerDatasetToFile(
const std::vector<std::string>& server_data,
absl::string_view server_data_filename) {
// Open file.
std::ofstream server_data_file;
server_data_file.open(std::string(server_data_filename));
if (!server_data_file.is_open()) {
return util::InvalidArgumentError(absl::StrCat(
"WriteServerDatasetToFile: Couldn't open server data file: ",
server_data_filename));
}
// Write each (escaped) line to file.
for (const auto& identifier : server_data) {
server_data_file << EscapeForCsv(identifier) << "\n";
}
// Close file.
server_data_file.close();
if (server_data_file.fail()) {
return util::InternalError(
absl::StrCat("WriteServerDatasetToFile: Couldn't write to or close "
"server data file: ",
server_data_filename));
}
return util::OkStatus();
}
util::Status WriteClientDatasetToFile(
const std::vector<std::string>& client_identifiers,
const std::vector<int64_t>& client_associated_values,
absl::string_view client_data_filename) {
if (client_associated_values.size() != client_identifiers.size()) {
return util::InvalidArgumentError(
"WriteClientDatasetToFile: there should be the same number of client "
"identifiers and associated values.");
}
// Open file.
std::ofstream client_data_file;
client_data_file.open(std::string(client_data_filename));
if (!client_data_file.is_open()) {
return util::InvalidArgumentError(absl::StrCat(
"WriteClientDatasetToFile: Couldn't open client data file: ",
client_data_filename));
}
// Write each (escaped) line to file.
for (size_t i = 0; i < client_identifiers.size(); i++) {
client_data_file << absl::StrCat(EscapeForCsv(client_identifiers[i]), ",",
client_associated_values[i])
<< "\n";
}
// Close file.
client_data_file.close();
if (client_data_file.fail()) {
return util::InternalError(
absl::StrCat("WriteClientDatasetToFile: Couldn't write to or close "
"client data file: ",
client_data_filename));
}
return util::OkStatus();
}
util::StatusOr<std::vector<std::string>> ReadServerDatasetFromFile(
absl::string_view server_data_filename) {
// Open file.
std::ifstream server_data_file;
server_data_file.open(std::string(server_data_filename));
if (!server_data_file.is_open()) {
return util::InvalidArgumentError(absl::StrCat(
"ReadServerDatasetFromFile: Couldn't open server data file: ",
server_data_filename));
}
// Read each line from file (unescaping and splitting columns). Verify that
// each line contains a single column
std::vector<std::string> server_data;
std::string line;
int64_t line_number = 0;
while (getline(server_data_file, line)) {
std::vector<std::string> columns = SplitCsvLine(line);
if (columns.size() != 1) {
return util::InvalidArgumentError(absl::StrCat(
"ReadServerDatasetFromFile: Expected exactly 1 identifier per line, "
"but line ",
line_number, "has ", columns.size(),
" comma-separated items (file: ", server_data_filename, ")"));
}
server_data.push_back(columns[0]);
line_number++;
}
// Close file.
server_data_file.close();
if (server_data_file.is_open()) {
return util::InternalError(absl::StrCat(
"ReadServerDatasetFromFile: Couldn't close server data file: ",
server_data_filename));
}
return server_data;
}
util::StatusOr<std::pair<std::vector<std::string>, std::vector<BigNum>>>
ReadClientDatasetFromFile(absl::string_view client_data_filename,
Context* context) {
// Open file.
std::ifstream client_data_file;
client_data_file.open(std::string(client_data_filename));
if (!client_data_file.is_open()) {
return util::InvalidArgumentError(absl::StrCat(
"ReadClientDatasetFromFile: Couldn't open client data file: ",
client_data_filename));
}
// Read each line from file (unescaping and splitting columns). Verify that
// each line contains two columns, and parse the second column into an
// associated value.
std::vector<std::string> client_identifiers;
std::vector<BigNum> client_associated_values;
std::string line;
int64_t line_number = 0;
while (getline(client_data_file, line)) {
std::vector<std::string> columns = SplitCsvLine(line);
if (columns.size() != 2) {
return util::InvalidArgumentError(absl::StrCat(
"ReadClientDatasetFromFile: Expected exactly 2 items per line, "
"but line ",
line_number, "has ", columns.size(),
" comma-separated items (file: ", client_data_filename, ")"));
}
client_identifiers.push_back(columns[0]);
int64_t parsed_associated_value;
if (!absl::SimpleAtoi(columns[1], &parsed_associated_value) ||
parsed_associated_value < 0) {
return util::InvalidArgumentError(
absl::StrCat("ReadClientDatasetFromFile: could not parse a "
"nonnegative associated value at line number",
line_number));
}
client_associated_values.push_back(
context->CreateBigNum(parsed_associated_value));
line_number++;
}
// Close file.
client_data_file.close();
if (client_data_file.is_open()) {
return util::InternalError(absl::StrCat(
"ReadClientDatasetFromFile: Couldn't close client data file: ",
client_data_filename));
}
return std::make_pair(std::move(client_identifiers),
std::move(client_associated_values));
}
} // namespace private_join_and_compute