blob: 8fc81586f3cae18c249be7ccd740e09b84c211c2 [file] [log] [blame]
#include "caffe2/utils/proto_utils.h"
#include <c10/core/DeviceType.h>
#include <fcntl.h>
#include <cerrno>
#include <fstream>
#include <unordered_set>
#if defined(_MSC_VER)
#include <io.h>
#else
#include <unistd.h>
#endif
#include <google/protobuf/io/coded_stream.h>
#ifndef CAFFE2_USE_LITE_PROTO
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/text_format.h>
#else
#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
#endif // !CAFFE2_USE_LITE_PROTO
#include <c10/util/Logging.h>
using ::google::protobuf::MessageLite;
namespace caffe2 {
C10_EXPORT std::string DeviceTypeName(const int32_t& d) {
return at::DeviceTypeName(static_cast<at::DeviceType>(d));
}
void setTotalBytesLimit(::google::protobuf::io::CodedInputStream& stream, int bytes_limit, int warning_threshold) {
#if GOOGLE_PROTOBUF_VERSION >= 3011000
// Only take one parameter since protobuf 3.11
stream.SetTotalBytesLimit(bytes_limit);
#else
stream.SetTotalBytesLimit(bytes_limit, warning_threshold);
#endif
}
C10_EXPORT int DeviceId(const DeviceOption& option) {
switch (option.device_type()) {
case PROTO_CPU:
return option.numa_node_id();
case PROTO_CUDA:
case PROTO_HIP:
return option.device_id();
case PROTO_MKLDNN:
return option.numa_node_id();
default:
CAFFE_THROW("Unknown device id for device type: ", option.device_type());
}
}
C10_EXPORT bool IsSameDevice(const DeviceOption& lhs, const DeviceOption& rhs) {
return (
lhs.device_type() == rhs.device_type() &&
lhs.device_id() == rhs.device_id() &&
lhs.node_name() == rhs.node_name() &&
lhs.numa_node_id() == rhs.numa_node_id());
}
C10_EXPORT bool IsCPUDeviceType(int device_type) {
static const std::unordered_set<int> cpu_types{
PROTO_CPU,
PROTO_MKLDNN,
PROTO_IDEEP,
};
return cpu_types.count(device_type);
}
C10_EXPORT bool IsGPUDeviceType(int device_type) {
static const std::unordered_set<int> gpu_types{
PROTO_CUDA,
PROTO_HIP,
};
return gpu_types.count(device_type);
}
C10_EXPORT bool ReadStringFromFile(const char* filename, string* str) {
std::ifstream ifs(filename, std::ios::in);
if (!ifs) {
VLOG(1) << "File cannot be opened: " << filename
<< " error: " << ifs.rdstate();
return false;
}
ifs.seekg(0, std::ios::end);
size_t n = ifs.tellg();
str->resize(n);
ifs.seekg(0);
ifs.read(&(*str)[0], n);
return true;
}
C10_EXPORT bool WriteStringToFile(const string& str, const char* filename) {
std::ofstream ofs(filename, std::ios::out | std::ios::trunc);
if (!ofs.is_open()) {
VLOG(1) << "File cannot be created: " << filename
<< " error: " << ofs.rdstate();
return false;
}
ofs << str;
return true;
}
// IO-specific proto functions: we will deal with the protocol buffer lite and
// full versions differently.
#ifdef CAFFE2_USE_LITE_PROTO
// Lite runtime.
namespace {
class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream {
public:
explicit IfstreamInputStream(const string& filename)
: ifs_(filename.c_str(), std::ios::in | std::ios::binary) {}
~IfstreamInputStream() {
ifs_.close();
}
int Read(void* buffer, int size) {
if (!ifs_) {
return -1;
}
ifs_.read(static_cast<char*>(buffer), size);
return ifs_.gcount();
}
private:
std::ifstream ifs_;
};
} // namespace
C10_EXPORT string ProtoDebugString(const MessageLite& proto) {
string serialized = proto.SerializeAsString();
for (char& c : serialized) {
if (c < 0x20 || c >= 0x7f) {
c = '?';
}
}
return serialized;
}
C10_EXPORT bool ParseProtoFromLargeString(
const string& str,
MessageLite* proto) {
::google::protobuf::io::ArrayInputStream input_stream(str.data(), str.size());
::google::protobuf::io::CodedInputStream coded_stream(&input_stream);
// Set PlanDef message size limit to 2G.
setTotalBytesLimit(coded_stream, 2147483647, 512LL << 20);
return proto->ParseFromCodedStream(&coded_stream);
}
C10_EXPORT bool ReadProtoFromBinaryFile(
const char* filename,
MessageLite* proto) {
::google::protobuf::io::CopyingInputStreamAdaptor stream(
new IfstreamInputStream(filename));
stream.SetOwnsCopyingStream(true);
// Total bytes hard limit / warning limit are set to 2GB and 512MB
// respectively.
::google::protobuf::io::CodedInputStream coded_stream(&stream);
setTotalBytesLimit(coded_stream, 2147483647, 512LL << 20);
return proto->ParseFromCodedStream(&coded_stream);
}
C10_EXPORT void WriteProtoToBinaryFile(
const MessageLite& /*proto*/,
const char* /*filename*/) {
LOG(FATAL) << "Not implemented yet.";
}
#else // CAFFE2_USE_LITE_PROTO
// Full protocol buffer.
using ::google::protobuf::Message;
using ::google::protobuf::io::CodedInputStream;
using ::google::protobuf::io::CodedOutputStream;
using ::google::protobuf::io::FileInputStream;
using ::google::protobuf::io::FileOutputStream;
using ::google::protobuf::io::ZeroCopyInputStream;
using ::google::protobuf::io::ZeroCopyOutputStream;
namespace TextFormat {
C10_EXPORT bool ParseFromString(const string& spec, Message* proto) {
string bc_spec = spec;
{
auto num_replaced = c10::ReplaceAll(bc_spec, "cuda_gpu_id", "device_id");
if (num_replaced) {
LOG(ERROR) << "Your model was serialized in Protobuf TextFormat and "
<< "it has " << num_replaced
<< " places using the deprecated field name 'cuda_gpu_id'!\n"
<< spec
<< "\nPlease re-export your model in Protobuf binary format "
<< "to make it backward compatible for field renaming.";
}
}
return ::google::protobuf::TextFormat::ParseFromString(
// NOLINTNEXTLINE(performance-move-const-arg)
std::move(bc_spec), proto);
}
} // namespace TextFormat
C10_EXPORT string ProtoDebugString(const Message& proto) {
return proto.ShortDebugString();
}
C10_EXPORT bool ParseProtoFromLargeString(const string& str, Message* proto) {
::google::protobuf::io::ArrayInputStream input_stream(str.data(), str.size());
::google::protobuf::io::CodedInputStream coded_stream(&input_stream);
// Set PlanDef message size limit to 2G.
setTotalBytesLimit(coded_stream, 2147483647, 512LL << 20);
return proto->ParseFromCodedStream(&coded_stream);
}
C10_EXPORT bool ReadProtoFromTextFile(const char* filename, Message* proto) {
int fd = open(filename, O_RDONLY);
CAFFE_ENFORCE_NE(fd, -1, "File not found: ", filename);
FileInputStream* input = new FileInputStream(fd);
bool success = google::protobuf::TextFormat::Parse(input, proto);
delete input;
close(fd);
return success;
}
C10_EXPORT void WriteProtoToTextFile(
const Message& proto,
const char* filename,
bool throwIfError) {
int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644);
FileOutputStream* output = new FileOutputStream(fd);
if(!google::protobuf::TextFormat::Print(proto, output)) {
if (throwIfError) {
CAFFE_THROW("Cannot write proto to text file: ", filename);
} else {
LOG(ERROR) << "Cannot write proto to text file: " << filename;
}
}
delete output;
close(fd);
}
C10_EXPORT bool ReadProtoFromBinaryFile(
const char* filename,
MessageLite* proto) {
#if defined(_MSC_VER) // for MSC compiler binary flag needs to be specified
int fd = open(filename, O_RDONLY | O_BINARY);
#else
int fd = open(filename, O_RDONLY);
#endif
CAFFE_ENFORCE_NE(fd, -1, "File not found: ", filename);
std::unique_ptr<ZeroCopyInputStream> raw_input(new FileInputStream(fd));
std::unique_ptr<CodedInputStream> coded_input(
new CodedInputStream(raw_input.get()));
// A hack to manually allow using very large protocol buffers.
#if GOOGLE_PROTOBUF_VERSION >= 3011000
// Only take one parameter since protobuf 3.11
coded_input->SetTotalBytesLimit(2147483647);
#else
// Total bytes hard limit / warning limit are set to 2GB and 512MB respectively.
coded_input->SetTotalBytesLimit(2147483647, 536870912);
#endif
bool success = proto->ParseFromCodedStream(coded_input.get());
coded_input.reset();
raw_input.reset();
close(fd);
return success;
}
C10_EXPORT void WriteProtoToBinaryFile(
const MessageLite& proto,
const char* filename) {
int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644);
CAFFE_ENFORCE_NE(
fd, -1, "File cannot be created: ", filename, " error number: ", errno);
std::unique_ptr<ZeroCopyOutputStream> raw_output(new FileOutputStream(fd));
std::unique_ptr<CodedOutputStream> coded_output(
new CodedOutputStream(raw_output.get()));
CAFFE_ENFORCE(proto.SerializeToCodedStream(coded_output.get()));
coded_output.reset();
raw_output.reset();
close(fd);
}
#endif // CAFFE2_USE_LITE_PROTO
C10_EXPORT ArgumentHelper::ArgumentHelper(const OperatorDef& def) {
for (auto& arg : def.arg()) {
if (arg_map_.count(arg.name())) {
if (arg.SerializeAsString() != arg_map_[arg.name()].SerializeAsString()) {
// If there are two arguments of the same name but different contents,
// we will throw an error.
CAFFE_THROW(
"Found argument of the same name ",
arg.name(),
"but with different contents.",
ProtoDebugString(def));
} else {
LOG(WARNING) << "Duplicated argument name [" << arg.name()
<< "] found in operator def: " << ProtoDebugString(def);
}
}
arg_map_[arg.name()] = arg;
}
}
C10_EXPORT ArgumentHelper::ArgumentHelper(const NetDef& netdef) {
for (auto& arg : netdef.arg()) {
CAFFE_ENFORCE(
arg_map_.count(arg.name()) == 0,
"Duplicated argument name [",
arg.name(),
"] found in net def: ",
ProtoDebugString(netdef));
arg_map_[arg.name()] = arg;
}
}
C10_EXPORT bool ArgumentHelper::HasArgument(c10::string_view name) const {
#ifdef CAFFE2_ENABLE_REDUCED_STRINGS_IN_ARGUMENT_LOOKUP
return arg_map_.count(name);
#else
return arg_map_.count(std::string(name));
#endif
}
namespace {
// Helper function to verify that conversion between types won't loose any
// significant bit.
template <typename InputType, typename TargetType>
bool SupportsLosslessConversion(const InputType& value) {
return static_cast<InputType>(static_cast<TargetType>(value)) == value;
}
} // namespace
bool operator==(const TensorProto& l, const TensorProto& r) {
return l.SerializeAsString() == r.SerializeAsString();
}
std::ostream& operator<<(std::ostream& output, const TensorProto& n) {
output << n.SerializeAsString();
return output;
}
bool operator==(const QTensorProto& l, const QTensorProto& r) {
return l.SerializeAsString() == r.SerializeAsString();
}
std::ostream& operator<<(std::ostream& output, const QTensorProto& n) {
output << n.SerializeAsString();
return output;
}
bool operator==(const NetDef& l, const NetDef& r) {
return l.SerializeAsString() == r.SerializeAsString();
}
std::ostream& operator<<(std::ostream& output, const NetDef& n) {
output << n.SerializeAsString();
return output;
}
#define INSTANTIATE_GET_SINGLE_ARGUMENT( \
T, fieldname, enforce_lossless_conversion) \
template <> \
C10_EXPORT T ArgumentHelper::GetSingleArgument<T>( \
c10::string_view name, const T& default_value) const { \
auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name); \
if (it == arg_map_.end()) { \
VLOG(1) << "Using default parameter value " << default_value \
<< " for parameter " << name; \
return default_value; \
} \
CAFFE_ENFORCE( \
it->second.has_##fieldname(), \
"Argument ", \
name, \
" does not have the right field: expected field " #fieldname); \
const auto& value = it->second.fieldname(); \
if (enforce_lossless_conversion) { \
auto supportsConversion = \
SupportsLosslessConversion<decltype(value), T>(value); \
CAFFE_ENFORCE( \
supportsConversion, \
"Value", \
value, \
" of argument ", \
name, \
"cannot be represented correctly in a target type"); \
} \
return static_cast<T>(value); \
} \
template <> \
C10_EXPORT bool ArgumentHelper::HasSingleArgumentOfType<T>( \
c10::string_view name) const { \
auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name); \
if (it == arg_map_.end()) { \
return false; \
} \
return it->second.has_##fieldname(); \
}
INSTANTIATE_GET_SINGLE_ARGUMENT(float, f, false)
INSTANTIATE_GET_SINGLE_ARGUMENT(double, f, false)
INSTANTIATE_GET_SINGLE_ARGUMENT(bool, i, false)
INSTANTIATE_GET_SINGLE_ARGUMENT(int8_t, i, true)
INSTANTIATE_GET_SINGLE_ARGUMENT(int16_t, i, true)
INSTANTIATE_GET_SINGLE_ARGUMENT(int, i, true)
INSTANTIATE_GET_SINGLE_ARGUMENT(int64_t, i, true)
INSTANTIATE_GET_SINGLE_ARGUMENT(uint8_t, i, true)
INSTANTIATE_GET_SINGLE_ARGUMENT(uint16_t, i, true)
INSTANTIATE_GET_SINGLE_ARGUMENT(size_t, i, true)
INSTANTIATE_GET_SINGLE_ARGUMENT(string, s, false)
INSTANTIATE_GET_SINGLE_ARGUMENT(NetDef, n, false)
#undef INSTANTIATE_GET_SINGLE_ARGUMENT
#define INSTANTIATE_GET_REPEATED_ARGUMENT( \
T, fieldname, enforce_lossless_conversion) \
template <> \
C10_EXPORT std::vector<T> ArgumentHelper::GetRepeatedArgument<T>( \
c10::string_view name, const std::vector<T>& default_value) const { \
auto it = CAFFE2_ARG_MAP_FIND(arg_map_, name); \
if (it == arg_map_.end()) { \
return default_value; \
} \
std::vector<T> values; \
for (const auto& v : it->second.fieldname()) { \
if (enforce_lossless_conversion) { \
auto supportsConversion = \
SupportsLosslessConversion<decltype(v), T>(v); \
CAFFE_ENFORCE( \
supportsConversion, \
"Value", \
v, \
" of argument ", \
name, \
"cannot be represented correctly in a target type"); \
} \
values.push_back(static_cast<T>(v)); \
} \
return values; \
}
INSTANTIATE_GET_REPEATED_ARGUMENT(float, floats, false)
INSTANTIATE_GET_REPEATED_ARGUMENT(double, floats, false)
INSTANTIATE_GET_REPEATED_ARGUMENT(bool, ints, false)
INSTANTIATE_GET_REPEATED_ARGUMENT(int8_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(int16_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(int, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(int64_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(uint8_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(uint16_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(size_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(string, strings, false)
INSTANTIATE_GET_REPEATED_ARGUMENT(NetDef, nets, false)
INSTANTIATE_GET_REPEATED_ARGUMENT(TensorProto, tensors, false)
INSTANTIATE_GET_REPEATED_ARGUMENT(QTensorProto, qtensors, false)
#undef INSTANTIATE_GET_REPEATED_ARGUMENT
#define CAFFE2_MAKE_SINGULAR_ARGUMENT(T, fieldname) \
template <> \
C10_EXPORT Argument MakeArgument(const string& name, const T& value) { \
Argument arg; \
arg.set_name(name); \
arg.set_##fieldname(value); \
return arg; \
}
CAFFE2_MAKE_SINGULAR_ARGUMENT(bool, i)
CAFFE2_MAKE_SINGULAR_ARGUMENT(float, f)
CAFFE2_MAKE_SINGULAR_ARGUMENT(int, i)
CAFFE2_MAKE_SINGULAR_ARGUMENT(int16_t, i)
CAFFE2_MAKE_SINGULAR_ARGUMENT(int64_t, i)
CAFFE2_MAKE_SINGULAR_ARGUMENT(string, s)
#undef CAFFE2_MAKE_SINGULAR_ARGUMENT
template <>
C10_EXPORT Argument MakeArgument(const string& name, const NetDef& value) {
Argument arg;
arg.set_name(name);
*arg.mutable_n() = value;
return arg;
}
template <>
C10_EXPORT bool ArgumentHelper::RemoveArgument(OperatorDef& def, int index);
template <>
bool ArgumentHelper::RemoveArgument(NetDef& def, int index);
template <>
C10_EXPORT Argument MakeArgument(const string& name, const MessageLite& value) {
Argument arg;
arg.set_name(name);
arg.set_s(value.SerializeAsString());
return arg;
}
#define CAFFE2_MAKE_REPEATED_ARGUMENT(T, fieldname) \
template <> \
C10_EXPORT Argument MakeArgument( \
const string& name, const std::vector<T>& value) { \
Argument arg; \
arg.set_name(name); \
for (const auto& v : value) { \
arg.add_##fieldname(v); \
} \
return arg; \
}
CAFFE2_MAKE_REPEATED_ARGUMENT(float, floats)
CAFFE2_MAKE_REPEATED_ARGUMENT(int, ints)
CAFFE2_MAKE_REPEATED_ARGUMENT(int64_t, ints)
CAFFE2_MAKE_REPEATED_ARGUMENT(string, strings)
#undef CAFFE2_MAKE_REPEATED_ARGUMENT
C10_EXPORT bool HasOutput(const OperatorDef& op, const std::string& output) {
for (const auto& outp : op.output()) {
if (outp == output) {
return true;
}
}
return false;
}
C10_EXPORT bool HasInput(const OperatorDef& op, const std::string& input) {
for (const auto& inp : op.input()) {
if (inp == input) {
return true;
}
}
return false;
}
// Return the argument index or -1 if it does not exist.
C10_EXPORT int GetArgumentIndex(
const google::protobuf::RepeatedPtrField<Argument>& args,
c10::string_view name) {
int index = 0;
for (const Argument& arg : args) {
if (arg.name() == name) {
return index;
}
index++;
}
return -1;
}
C10_EXPORT const Argument& GetArgument(
const OperatorDef& def,
c10::string_view name) {
int index = GetArgumentIndex(def.arg(), name);
if (index != -1) {
return def.arg(index);
} else {
CAFFE_THROW(
"Argument named ",
name,
" does not exist in operator ",
ProtoDebugString(def));
}
}
C10_EXPORT const Argument& GetArgument(const NetDef& def, c10::string_view name) {
int index = GetArgumentIndex(def.arg(), name);
if (index != -1) {
return def.arg(index);
} else {
CAFFE_THROW(
"Argument named ",
name,
" does not exist in net ",
ProtoDebugString(def));
}
}
C10_EXPORT const Argument* GetArgumentPtr(
const OperatorDef& def,
c10::string_view name) {
int index = GetArgumentIndex(def.arg(), name);
if (index != -1) {
return &def.arg(index);
} else {
return nullptr;
}
}
C10_EXPORT const Argument* GetArgumentPtr(
const NetDef& def,
c10::string_view name) {
int index = GetArgumentIndex(def.arg(), name);
if (index != -1) {
return &def.arg(index);
} else {
return nullptr;
}
}
C10_EXPORT bool GetFlagArgument(
const google::protobuf::RepeatedPtrField<Argument>& args,
c10::string_view name,
bool default_value) {
int index = GetArgumentIndex(args, name);
if (index != -1) {
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
auto arg = args.Get(index);
CAFFE_ENFORCE(
arg.has_i(), "Can't parse argument as bool: ", ProtoDebugString(arg));
return arg.i();
}
return default_value;
}
C10_EXPORT bool GetFlagArgument(
const OperatorDef& def,
c10::string_view name,
bool default_value) {
return GetFlagArgument(def.arg(), name, default_value);
}
C10_EXPORT bool
GetFlagArgument(const NetDef& def, c10::string_view name, bool default_value) {
return GetFlagArgument(def.arg(), name, default_value);
}
template <typename Def>
Argument* GetMutableArgumentImpl(
const string& name,
const bool create_if_missing,
Def* def) {
for (int i = 0; i < def->arg_size(); ++i) {
if (def->arg(i).name() == name) {
return def->mutable_arg(i);
}
}
// If no argument of the right name is found...
if (create_if_missing) {
Argument* arg = def->add_arg();
arg->set_name(name);
return arg;
} else {
return nullptr;
}
}
C10_EXPORT Argument* GetMutableArgument(
const string& name,
const bool create_if_missing,
OperatorDef* def) {
return GetMutableArgumentImpl(name, create_if_missing, def);
}
C10_EXPORT Argument* GetMutableArgument(
const string& name,
const bool create_if_missing,
NetDef* def) {
return GetMutableArgumentImpl(name, create_if_missing, def);
}
C10_EXPORT void cleanupExternalInputsAndOutputs(NetDef* net) {
std::vector<std::string> oldExternalInputs;
for (const auto& input : net->external_input()) {
oldExternalInputs.emplace_back(input);
}
std::vector<std::string> oldExternalOutputs;
for (const auto& output : net->external_output()) {
oldExternalOutputs.emplace_back(output);
}
net->clear_external_input();
net->clear_external_output();
std::set<std::string> inputSet;
for (const auto& input : oldExternalInputs) {
if (inputSet.count(input)) {
// Prevent duplicate external inputs.
continue;
}
inputSet.insert(input);
net->add_external_input(input);
}
// Set of blobs that are external inputs or outputs of some operators.
std::set<std::string> allOutputs(inputSet.begin(), inputSet.end());
for (const auto& op : net->op()) {
for (const auto& input : op.input()) {
if (inputSet.count(input) || allOutputs.count(input)) {
continue;
}
// Add missing external inputs.
inputSet.insert(input);
net->add_external_input(input);
}
for (const auto& output : op.output()) {
allOutputs.insert(output);
}
}
std::set<std::string> outputSet;
for (const auto& output : oldExternalOutputs) {
if (!allOutputs.count(output)) {
continue;
}
if (outputSet.count(output)) {
continue;
}
outputSet.insert(output);
net->add_external_output(output);
}
}
} // namespace caffe2