| #include "caffe2/onnx/backend.h" |
| #include "caffe2/core/logging.h" |
| #include "caffe2/core/operator.h" |
| #include "caffe2/onnx/device.h" |
| #include "caffe2/onnx/helper.h" |
| #include "caffe2/utils/map_utils.h" |
| #include "caffe2/utils/proto_utils.h" |
| |
| #ifndef C10_MOBILE |
| #include "onnx/checker.h" |
| #endif |
| |
| #include "google/protobuf/io/coded_stream.h" |
| #include "google/protobuf/io/zero_copy_stream_impl_lite.h" |
| |
| #include <cmath> |
| #include <iostream> |
| #include <limits> |
| #include <sstream> |
| #include <unordered_map> |
| #include <unordered_set> |
| |
| namespace caffe2 { |
| namespace onnx { |
| |
| namespace { |
| |
| bool AlmostEqual(double a, double b) { |
| constexpr static double kEps = 1e-15; |
| return (fabs(a - b) < kEps); |
| } |
| |
| template <class T> |
| bool TryConvertingTensorRawValues( |
| const TensorProto& onnx_tensor, |
| ::google::protobuf::RepeatedField<T>* field) { |
| if (!onnx_tensor.has_raw_data()) { |
| return false; |
| } |
| |
| size_t raw_size = onnx_tensor.raw_data().size(); |
| CAFFE_ENFORCE_EQ(raw_size % sizeof(T), 0); |
| |
| size_t num_elements = raw_size / sizeof(T); |
| const void* src_ptr = static_cast<const void*>(onnx_tensor.raw_data().data()); |
| field->Resize(num_elements, 0); |
| void* target_ptr = static_cast<void*>(field->mutable_data()); |
| memcpy(target_ptr, src_ptr, raw_size); |
| |
| return true; |
| } |
| |
| bool IsOperator(const std::string& op_type) { |
| // pull in all the operators upon first invocation |
| // Intentional leaky |
| static std::set<std::string>* ops_ = |
| new std::set<std::string>(caffe2::GetRegisteredOperators()); |
| return ops_->count(caffe2::OpRegistryKey(op_type, "DEFAULT")); |
| } |
| |
| caffe2::DeviceOption GetDeviceOption(const Device& onnx_device) { |
| static const std::unordered_map<DeviceType, caffe2::DeviceType> m = { |
| {DeviceType::CPU, caffe2::DeviceType::CPU}, |
| {DeviceType::CUDA, caffe2::DeviceType::CUDA}}; |
| caffe2::DeviceOption d; |
| d.set_device_type(static_cast<int32_t>(m.at(onnx_device.type))); |
| d.set_device_id(onnx_device.device_id); |
| return d; |
| } |
| |
| template <class T, class U> |
| U LookUpWithDefault( |
| const std::unordered_map<T, U>& map, |
| const T& key, |
| const U& default_value) { |
| const auto it = map.find(key); |
| if (it == map.end()) { |
| return default_value; |
| } else { |
| return it->second; |
| } |
| } |
| |
| void UpdateNames( |
| std::shared_ptr<DummyName> dummy, |
| const caffe2::OperatorDef& op) { |
| for (const auto& n : op.input()) { |
| dummy->AddName(n); |
| } |
| for (const auto& n : op.output()) { |
| dummy->AddName(n); |
| } |
| } |
| |
| void BuildOperator( |
| caffe2::OperatorDef* c2_op, |
| const std::string& op_type, |
| const std::vector<std::string>& inputs, |
| const std::vector<std::string>& outputs, |
| const std::vector<caffe2::Argument>& args) { |
| c2_op->set_name(""); |
| c2_op->set_type(op_type); |
| for (const auto& input : inputs) { |
| c2_op->add_input(input); |
| } |
| for (const auto& output : outputs) { |
| c2_op->add_output(output); |
| } |
| for (const auto& arg : args) { |
| auto* tmp = c2_op->add_arg(); |
| tmp->CopyFrom(arg); |
| } |
| } |
| |
| void BuildOperator( |
| caffe2::OperatorDef* c2_op, |
| const std::string& op_type, |
| const std::vector<std::string>& inputs, |
| const std::vector<std::string>& outputs) { |
| std::vector<caffe2::Argument> empty; |
| BuildOperator(c2_op, op_type, inputs, outputs, empty); |
| } |
| |
| void CopyOnnxAttrValueToCaffe2Arg( |
| caffe2::Argument* arg, |
| const AttributeProto& attr) { |
| if (attr.has_f()) { |
| arg->set_f(attr.f()); |
| } else if (attr.has_i()) { |
| arg->set_i(attr.i()); |
| } else if (attr.has_s()) { |
| arg->set_s(attr.s()); |
| } else if (attr.has_t()) { |
| // For proto, we convert it to serialized string |
| std::string buffer; |
| attr.t().SerializeToString(&buffer); |
| arg->set_s(buffer); |
| } else if (attr.floats_size()) { |
| arg->mutable_floats()->CopyFrom(attr.floats()); |
| } else if (attr.ints_size()) { |
| arg->mutable_ints()->CopyFrom(attr.ints()); |
| } else if (attr.strings_size()) { |
| arg->mutable_strings()->CopyFrom(attr.strings()); |
| } else { |
| CAFFE_THROW("Unsupported ONNX attribute: ", attr.name()); |
| } |
| } |
| } // namespace |
| |
| OnnxAttributes::OnnxAttributes(const NodeProto& node) { |
| for (const auto& attr : node.attribute()) { |
| onnx_attrs_.emplace(attr.name(), &attr); |
| } |
| } |
| |
| template <> |
| int64_t OnnxAttributes::get(const std::string& key) const { |
| int64_t value = 0; |
| const auto it = onnx_attrs_.find(key); |
| if (it != onnx_attrs_.end()) { |
| const AttributeProto& attr = *it->second; |
| value = attr.i(); |
| } |
| return value; |
| } |
| |
| template <> |
| float OnnxAttributes::get(const std::string& key) const { |
| float value = 0.0; |
| const auto it = onnx_attrs_.find(key); |
| if (it != onnx_attrs_.end()) { |
| const AttributeProto& attr = *it->second; |
| value = attr.f(); |
| } |
| return value; |
| } |
| |
| template <> |
| ::google::protobuf::RepeatedPtrField<std::string> OnnxAttributes::get( |
| const std::string& key) const { |
| ::google::protobuf::RepeatedPtrField<std::string> value; |
| const auto it = onnx_attrs_.find(key); |
| if (it != onnx_attrs_.end()) { |
| const AttributeProto& attr = *it->second; |
| value.CopyFrom(attr.strings()); |
| } |
| return value; |
| } |
| |
| template <> |
| ::google::protobuf::RepeatedField<::google::protobuf::int64> |
| OnnxAttributes::get(const std::string& key) const { |
| ::google::protobuf::RepeatedField<::google::protobuf::int64> value; |
| const auto it = onnx_attrs_.find(key); |
| if (it != onnx_attrs_.end()) { |
| const AttributeProto& attr = *it->second; |
| value.CopyFrom(attr.ints()); |
| } |
| return value; |
| } |
| |
| template <> |
| ::google::protobuf::RepeatedField<float> OnnxAttributes::get( |
| const std::string& key) const { |
| ::google::protobuf::RepeatedField<float> value; |
| const auto it = onnx_attrs_.find(key); |
| if (it != onnx_attrs_.end()) { |
| const AttributeProto& attr = *it->second; |
| value.CopyFrom(attr.floats()); |
| } |
| return value; |
| } |
| |
| template <> |
| const TensorProto* OnnxAttributes::get(const std::string& key) const { |
| const TensorProto* value = nullptr; |
| const auto it = onnx_attrs_.find(key); |
| if (it != onnx_attrs_.end()) { |
| const AttributeProto& attr = *it->second; |
| value = &attr.t(); |
| } |
| return value; |
| } |
| |
| ::google::protobuf::RepeatedPtrField<caffe2::Argument> |
| OnnxAttributes::OnnxAttrToCaffe2Arg( |
| std::function<std::string(const std::string&)> mapper) const { |
| ::google::protobuf::RepeatedPtrField<caffe2::Argument> args; |
| for (const auto& kv : onnx_attrs_) { |
| // If the attribute was rewritten, we use it instead. Note that the |
| // rewritten attribute still has the unmapped name |
| const auto& attr = rewritten_onnx_attrs_.count(kv.first) |
| ? rewritten_onnx_attrs_.at(kv.first) |
| : (*kv.second); |
| auto* arg = args.Add(); |
| arg->set_name(mapper(attr.name())); |
| CopyOnnxAttrValueToCaffe2Arg(arg, attr); |
| } |
| for (const auto& kv : rewritten_onnx_attrs_) { |
| // If rewritten attribute doesn't appear in the original attributes, this is |
| // a newlly added one and we need to add this to argument too |
| if (!onnx_attrs_.count(kv.first)) { |
| const auto& attr = kv.second; |
| auto* arg = args.Add(); |
| arg->set_name(mapper(attr.name())); |
| CopyOnnxAttrValueToCaffe2Arg(arg, attr); |
| } |
| } |
| |
| return args; |
| } |
| |
| const std::unordered_map<std::string, int>& |
| Caffe2Backend::get_broken_operators() const { |
| const static std::unordered_map<std::string, int> kBrokenOperators{}; |
| return kBrokenOperators; |
| } |
| |
| // Temporary hack for RNN related operators, as we don't have C++ interface in |
| // C2 to build those operators yet |
| const std::unordered_set<std::string>& Caffe2Backend::get_rnn_operators() |
| const { |
| const static std::unordered_set<std::string> kRNNOperators{ |
| "LSTM", "GRU", "RNN"}; |
| return kRNNOperators; |
| } |
| |
| // Operators that are different between Caffe2 and |
| // ONNX but only in their name. |
| // In most cases, this should be empty - as the effort of ONNX is |
| // to unify the operator definitions. |
| const std::unordered_map<std::string, std::string>& |
| Caffe2Backend::get_renamed_operators() const { |
| const static std::unordered_map<std::string, std::string> kRenamedOperators{ |
| {"Caffe2ConvTranspose", "ConvTranspose"}, |
| {"GlobalMaxPool", "MaxPool"}, |
| {"GlobalAveragePool", "AveragePool"}, |
| {"Pad", "PadImage"}, |
| {"Neg", "Negative"}, |
| {"BatchNormalization", "SpatialBN"}, |
| {"InstanceNormalization", "InstanceNorm"}, |
| {"MatMul", "BatchMatMul"}, |
| {"Upsample", "ResizeNearest"}, |
| {"Identity", "Copy"}, |
| {"InstanceNormalization", "InstanceNorm"}, |
| {"Equal", "EQ"}, |
| {"Less", "LT"}, |
| {"Greater", "GT"}, |
| {"Unsqueeze", "ExpandDims"}, |
| {"Tile", "NumpyTile"}, |
| {"DynamicSlice", "Slice"}, |
| {"ConstantOfShape", "ConstantFill"}, |
| {"RandomNormal", "GaussianFill"}, |
| {"RandomNormalLike", "GaussianFill"}}; |
| return kRenamedOperators; |
| } |
| |
| const std::unordered_map<std::string, std::string>& |
| Caffe2Backend::get_renamed_attrs() const { |
| const static std::unordered_map<std::string, std::string> kRenamedAttrs{ |
| {"kernel_shape", "kernels"}}; |
| return kRenamedAttrs; |
| } |
| |
| const std:: |
| unordered_map<std::string, std::unordered_map<std::string, std::string>>& |
| Caffe2Backend::get_per_op_renamed_attrs() const { |
| const static std:: |
| unordered_map<std::string, std::unordered_map<std::string, std::string>> |
| kPerOpRenamedAttrs = { |
| {"Squeeze", {{"axes", "dims"}}}, |
| {"Unsqueeze", {{"axes", "dims"}}}, |
| {"Transpose", {{"perm", "axes"}}}, |
| {"ConvTranspose", {{"output_padding", "adjs"}}}, |
| {"Selu", {{"gamma", "scale"}}}}; |
| |
| return kPerOpRenamedAttrs; |
| } |
| |
| // operators whose behavior is different beyond renaming |
| // the value is an attribute of this class that is a |
| // function from ToffeIR node_def to caffe2 op_def |
| const std::unordered_map<std::string, Caffe2Backend::SpecialOpConverter>& |
| Caffe2Backend::get_special_operators() const { |
| const static std:: |
| unordered_map<std::string, Caffe2Backend::SpecialOpConverter> |
| kSpecialOperators = { |
| {"ArgMax", &Caffe2Backend::CreateArgMaxMin}, |
| {"ArgMin", &Caffe2Backend::CreateArgMaxMin}, |
| {"Cast", &Caffe2Backend::CreateCast}, |
| {"Constant", &Caffe2Backend::CreateConstant}, |
| {"ConstantOfShape", &Caffe2Backend::CreateConstantOfShape}, |
| {"Conv", &Caffe2Backend::CreateConvPoolOpBase}, |
| {"AveragePool", &Caffe2Backend::CreateConvPoolOpBase}, |
| {"GlobalAveragePool", &Caffe2Backend::CreateConvPoolOpBase}, |
| {"GlobalMaxPool", &Caffe2Backend::CreateConvPoolOpBase}, |
| {"MaxPool", &Caffe2Backend::CreateConvPoolOpBase}, |
| {"Reshape", &Caffe2Backend::CreateReshape}, |
| {"Int8Reshape", &Caffe2Backend::CreateReshape}, |
| {"Gather", &Caffe2Backend::CreateGather}, |
| {"Gemm", &Caffe2Backend::CreateGemm}, |
| {"Pad", &Caffe2Backend::CreatePad}, |
| {"Concat", &Caffe2Backend::CreateConcat}, |
| {"Int8Concat", &Caffe2Backend::CreateConcat}, |
| {"LogSoftmax", &Caffe2Backend::CreateLogSoftmax}, |
| {"Slice", &Caffe2Backend::CreateSlice}, |
| {"Split", &Caffe2Backend::CreateSplit}, |
| {"Reciprocal", &Caffe2Backend::CreateReciprocal}, |
| {"BatchNormalization", &Caffe2Backend::CreateBatchNormalization}, |
| {"MatMul", &Caffe2Backend::CreateMatMul}, |
| {"Upsample", &Caffe2Backend::CreateUpsample}, |
| {"Dropout", &Caffe2Backend::CreateDropout}, |
| {"LRN", &Caffe2Backend::CreateLRN}, |
| {"DynamicSlice", &Caffe2Backend::CreateDynamicSlice}, |
| {"RandomNormal", &Caffe2Backend::CreateRandomNormal}, |
| {"RandomNormalLike", &Caffe2Backend::CreateRandomNormal}, |
| {"Where", &Caffe2Backend::CreateWhereOp}, |
| {"NonZero", &Caffe2Backend::CreateNonZeroOp}, |
| {"Multinomial", &Caffe2Backend::CreateMultinomialOp}}; |
| return kSpecialOperators; |
| } |
| |
| //============================ |
| // Special Operator Converters |
| //============================ |
| |
| Caffe2Ops Caffe2Backend::CreateArgMaxMin( |
| OnnxNode* onnx_node, |
| const ConversionContext& ctx) { |
| auto& attributes = onnx_node->attributes; |
| if (!attributes.HasAttribute("axis")) { |
| auto* attr = attributes.AddRewrittenAttribute("axis"); |
| attr->set_i(0); |
| } |
| return CommonOnnxNodeToCaffe2Ops(onnx_node, ctx); |
| } |
| |
| Caffe2Ops Caffe2Backend::CreateCast( |
| OnnxNode* onnx_node, |
| const ConversionContext& ctx) { |
| auto c2_op = CommonOnnxNodeToCaffe2Ops(onnx_node, ctx); |
| |
| auto onnx_dtype = |
| onnx_node->attributes.get<int64_t>("to", TensorProto::UNDEFINED); |
| auto c2_dtype = caffe2::TensorProto::UNDEFINED; |
| switch (onnx_dtype) { |
| case ::ONNX_NAMESPACE::TensorProto::FLOAT: |
| c2_dtype = caffe2::TensorProto::FLOAT; |
| break; |
| case ::ONNX_NAMESPACE::TensorProto::UINT8: |
| c2_dtype = caffe2::TensorProto::UINT8; |
| break; |
| case ::ONNX_NAMESPACE::TensorProto::INT8: |
| c2_dtype = caffe2::TensorProto::INT8; |
| break; |
| case ::ONNX_NAMESPACE::TensorProto::UINT16: |
| c2_dtype = caffe2::TensorProto::UINT16; |
| break; |
| case ::ONNX_NAMESPACE::TensorProto::INT16: |
| c2_dtype = caffe2::TensorProto::INT16; |
| break; |
| case ::ONNX_NAMESPACE::TensorProto::INT32: |
| c2_dtype = caffe2::TensorProto::INT32; |
| break; |
| case ::ONNX_NAMESPACE::TensorProto::INT64: |
| c2_dtype = caffe2::TensorProto::INT64; |
| break; |
| case ::ONNX_NAMESPACE::TensorProto::STRING: |
| c2_dtype = caffe2::TensorProto::STRING; |
| break; |
| case ::ONNX_NAMESPACE::TensorProto::BOOL: |
| c2_dtype = caffe2::TensorProto::BOOL; |
| break; |
| case ::ONNX_NAMESPACE::TensorProto::FLOAT16: |
| c2_dtype = caffe2::TensorProto::FLOAT16; |
| break; |
| case ::ONNX_NAMESPACE::TensorProto::DOUBLE: |
| c2_dtype = caffe2::TensorProto::DOUBLE; |
| break; |
| case ::ONNX_NAMESPACE::TensorProto::UINT32: |
| case ::ONNX_NAMESPACE::TensorProto::UINT64: |
| case ::ONNX_NAMESPACE::TensorProto::COMPLEX64: |
| case ::ONNX_NAMESPACE::TensorProto::COMPLEX128: |
| case ::ONNX_NAMESPACE::TensorProto::UNDEFINED: |
| c2_dtype = caffe2::TensorProto::UNDEFINED; |
| break; |
| }; |
| |
| CAFFE_ENFORCE_NE( |
| c2_dtype, |
| caffe2::TensorProto::UNDEFINED, |
| "Casting to '", |
| onnx_dtype, |
| "' dtype is not supported"); |
| |
| CAFFE_ENFORCE_EQ( |
| c2_op.ops.Get(0).arg().size(), |
| 1, |
| "Unexpected number of attributes in 'Cast'"); |
| c2_op.ops.Mutable(0)->mutable_arg(0)->set_i(c2_dtype); |
| |
| return c2_op; |
| } |
| |
| Caffe2Ops Caffe2Backend::CreateConstant( |
| OnnxNode* onnx_node, |
| const ConversionContext& ctx) { |
| CAFFE_ENFORCE_EQ(onnx_node->node.output_size(), 1); |
| |
| Caffe2Ops ret; |
| auto* c2_op = ret.ops.Add(); |
| const auto* value = onnx_node->attributes.get<const TensorProto*>("value"); |
| BuildTensorFillingOp(c2_op, *value, onnx_node->node.output(0)); |
| |
| return ret; |
| } |
| |
| Caffe2Ops Caffe2Backend::CreateConstantOfShape( |
| OnnxNode* onnx_node, |
| const ConversionContext& ctx) { |
| CAFFE_ENFORCE_EQ(onnx_node->node.input_size(), 1); |
| CAFFE_ENFORCE_EQ(onnx_node->node.output_size(), 1); |
| |
| Caffe2Ops ret; |
| auto* c2_op = ret.ops.Add(); |
| const auto* value = onnx_node->attributes.get<const TensorProto*>("value"); |
| if (value) { |
| BuildTensorFillingOp( |
| c2_op, *value, onnx_node->node.output(0), onnx_node->node.input(0)); |
| } else { |
| c2_op->set_type("ConstantFill"); |
| c2_op->add_input(onnx_node->node.input(0)); |
| c2_op->add_output(onnx_node->node.output(0)); |
| auto c2_input_as_shape = c2_op->add_arg(); |
| c2_input_as_shape->set_name("input_as_shape"); |
| c2_input_as_shape->set_i(1); |
| } |
| |
| return ret; |
| } |
| |
| // Note [Caffe2 ConvPoolOpBase] |
| // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| // To understand what is going on here, we have to talk a little bit about |
| // Caffe2's internals. |
| // |
| // First, it's important to know that all of Caffe2's pooling and convolution |
| // operators inherit from "ConvPoolOpBase", which is an abstract class that |
| // defines all of the attributes (kernels, dilations, strides, etc) which one |
| // sees on these operators. Unfortunately, Caffe2's documentation generator |
| // doesn't know how to handle cases like this, so for example, if you look at |
| // the docs for MaxPool at |
| // <https://caffe2.ai/docs/operators-catalogue.html#maxpool> you won't see any |
| // of the attributes. You have to go source diving to find the information; in |
| // particular, you want to look at: |
| // https://github.com/caffe2/caffe2/blob/master/caffe2/operators/conv_pool_op_base.h |
| // This class handles *global* pooling as well. |
| // |
| // Second, it's important to know what Caffe2 expects for padding, which can |
| // be somewhat difficult to understand from the code because Caffe2 handles |
| // both singular/pluralized spellings of padding, and there is also legacy |
| // padding business. The short version of the story is that, for NON-legacy |
| // padding (which is what we want to output), padding is expected to be |
| // *twice* the size of kernels. So if you have a 2D convolution, Caffe2 |
| // will accept two values in 'kernels', but FOUR values in 'pads'; |
| // furthermore, this is *mandatory.* |
| // |
| // Finally, ConvPoolOpBase is not the only class of it's kind; there is |
| // be tricked by the fact that Conv and ConvTranspose have similar |
| // parameters; they exercise different codepaths and need to be handled |
| // differently. |
| Caffe2Ops Caffe2Backend::CreateConvPoolOpBase( |
| OnnxNode* onnx_node, |
| const ConversionContext& ctx) { |
| const auto& node = onnx_node->node; |
| auto& attributes = onnx_node->attributes; |
| if (node.op_type().find("Global") == 0) { |
| auto* attr = attributes.AddRewrittenAttribute("global_pooling"); |
| attr->set_i(1); |
| } |
| |
| if (attributes.HasAttribute("kernel_shape") && |
| attributes.HasAttribute("pads")) { |
| auto kernel_shape = |
| attributes |
| .get<::google::protobuf::RepeatedField<::google::protobuf::int64>>( |
| "kernel_shape"); |
| auto pads = |
| attributes |
| .get<::google::protobuf::RepeatedField<::google::protobuf::int64>>( |
| "pads"); |
| if (kernel_shape.size() == pads.size()) { |
| // Caffe2 requires pads to be twice the size of kernels. |
| auto* attr = attributes.AddRewrittenAttribute("pads"); |
| attr->mutable_ints()->CopyFrom(pads); |
| attr->mutable_ints()->MergeFrom(pads); |
| } |
| } |
| |
| return CommonOnnxNodeToCaffe2Ops(onnx_node, ctx); |
| } |
| |
| Caffe2Ops Caffe2Backend::CreateReshape( |
| OnnxNode* onnx_node, |
| const ConversionContext& ctx) { |
| auto c2_op = CommonOnnxNodeToCaffe2Ops(onnx_node, ctx); |
| CAFFE_ENFORCE_EQ(c2_op.ops.size(), 1); |
| auto* op = c2_op.ops.Mutable(0); |
| op->add_output(dummy_->NewDummyName()); |
| |
| return c2_op; |
| } |
| |
| Caffe2Ops Caffe2Backend::CreateRandomNormal( |
| OnnxNode* onnx_node, |
| const ConversionContext& ctx) { |
| auto& attributes = onnx_node->attributes; |
| |
| if (attributes.HasAttribute("seed")) { |
| CAFFE_THROW("Caffe2 GaussianFill does not support random seed"); |
| } |
| |
| if (attributes.HasAttribute("dtype")) { |
| if (attributes.get<int64_t>("dtype") != TensorProto::FLOAT) { |
| CAFFE_THROW("Caffe2 GaussianFill only support FLOAT dtype"); |
| } |
| attributes.remove("dtype"); |
| } |
| if (attributes.HasAttribute("scale")) { |
| auto scale = attributes.get<float>("scale"); |
| auto* attr = attributes.AddRewrittenAttribute("std"); |
| attr->set_f(scale); |
| attributes.remove("scale"); |
| } |
| return CommonOnnxNodeToCaffe2Ops(onnx_node, ctx); |
| } |
| |
| Caffe2Ops Caffe2Backend::CreateWhereOp( |
| OnnxNode* onnx_node, |
| const ConversionContext& ctx) { |
| // The native Caffe2 op doesn't support broadcasting, so we defer the handling |
| // of this op to the ATen library that does. |
| onnx::NodeProto converted; |
| converted.CopyFrom(onnx_node->node); |
| converted.set_op_type("ATen"); |
| onnx::AttributeProto* attr = converted.add_attribute(); |
| attr->set_name("operator"); |
| attr->set_s("where"); |
| OnnxNode new_node(converted); |
| return CommonOnnxNodeToCaffe2Ops(&new_node, ctx); |
| } |
| |
| Caffe2Ops Caffe2Backend::CreateNonZeroOp( |
| OnnxNode* onnx_node, |
| const ConversionContext& ctx) { |
| // Native Caffe2 doesn't support NonZero, fallback to ATen. |
| // ATen nonzero is equivalent to Transpose(ONNX::NonZero). |
| onnx::NodeProto converted; |
| converted.CopyFrom(onnx_node->node); |
| |
| auto nonzero_output = dummy_->NewDummyName(); |
| converted.set_output(0, nonzero_output); |
| converted.set_op_type("ATen"); |
| onnx::AttributeProto* attr = converted.add_attribute(); |
| attr->set_name("operator"); |
| attr->set_s("nonzero"); |
| OnnxNode new_node(converted); |
| auto ret = CommonOnnxNodeToCaffe2Ops(&new_node, ctx); |
| |
| auto* c2_transpose = ret.ops.Add(); |
| BuildOperator( |
| c2_transpose, "Transpose", {nonzero_output}, {onnx_node->node.output(0)}); |
| return ret; |
| } |
| |
| Caffe2Ops Caffe2Backend::CreateMultinomialOp( |
| OnnxNode* onnx_node, |
| const ConversionContext& ctx) { |
| // Fallback to ATen. |
| // ATen::Multinomial takes probabilities as input, ONNX Multinomial expects |
| // input to be log probabilities. |
| Caffe2Ops ret; |
| auto c2_exp_output = dummy_->NewDummyName(); |
| auto* c2_exp = ret.ops.Add(); |
| BuildOperator(c2_exp, "Exp", {onnx_node->node.input(0)}, {c2_exp_output}); |
| |
| auto* c2_multinomial = ret.ops.Add(); |
| caffe2::Argument c2_arg_op; |
| c2_arg_op.set_name("operator"); |
| c2_arg_op.set_s("multinomial"); |
| // ONNX Multinomial only supports replacement=True. |
| caffe2::Argument c2_arg_rep; |
| c2_arg_rep.set_name("replacement"); |
| c2_arg_rep.set_i(1); |
| auto& onnx_attributes = onnx_node->attributes; |
| caffe2::Argument c2_arg_num; |
| c2_arg_num.set_name("num_samples"); |
| c2_arg_num.set_i(onnx_attributes.get<int64_t>("sample_size")); |
| |
| // ONNX Multinomial has attribute dtype in {int64, int32}, which specifies |
| // output datatype. ATen::Multinomial output dtype is always int64. |
| auto onnx_dtype = |
| onnx_attributes.get<int64_t>("dtype", TensorProto::UNDEFINED); |
| if (onnx_dtype == ::ONNX_NAMESPACE::TensorProto::INT64) { |
| BuildOperator( |
| c2_multinomial, |
| "ATen", |
| {c2_exp_output}, |
| {onnx_node->node.output(0)}, |
| {c2_arg_op, c2_arg_rep, c2_arg_num}); |
| } else if (onnx_dtype == ::ONNX_NAMESPACE::TensorProto::INT32) { |
| auto c2_multinomial_output = dummy_->NewDummyName(); |
| BuildOperator( |
| c2_multinomial, |
| "ATen", |
| {c2_exp_output}, |
| {c2_multinomial_output}, |
| {c2_arg_op, c2_arg_rep, c2_arg_num}); |
| |
| auto* c2_cast = ret.ops.Add(); |
| caffe2::Argument to; |
| to.set_name("to"); |
| to.set_i(caffe2::TensorProto::INT32); |
| BuildOperator( |
| c2_cast, |
| "Cast", |
| {c2_multinomial_output}, |
| {onnx_node->node.output(0)}, |
| {to}); |
| } else { |
| CAFFE_THROW( |
| "ONNX does not support dtype other than int32/int64 in Multinomial, but get ", |
| onnx_dtype); |
| } |
| |
| return ret; |
| } |
| |
| Caffe2Ops Caffe2Backend::CreateReciprocal( |
| OnnxNode* onnx_node, |
| const ConversionContext& ctx) { |
| const auto& node = onnx_node->node; |
| if (node.input_size() != 1 || node.output_size() != 1) { |
| CAFFE_THROW("Caffe2 Reciprocal should have 1 input and 1 output"); |
| } |
| |
| Caffe2Ops ret; |
| auto* c2_op = ret.ops.Add(); |
| |
| BuildOperator(c2_op, "Reciprocal", {node.input(0)}, {node.output(0)}, {}); |
| return ret; |
| } |
| |
| Caffe2Ops Caffe2Backend::CreateGather( |
| OnnxNode* onnx_node, |
| const ConversionContext& ctx) { |
| const auto& node = onnx_node->node; |
| if (node.input_size() < 2 || node.output_size() < 1) { |
| CAFFE_THROW("Caffe2 Gather should have 2 inputs and 1 output"); |
| } |
| |
| Caffe2Ops ret; |
| auto* c2_op = ret.ops.Add(); |
| |
| std::vector<std::string> inputs; |
| inputs.emplace_back(node.input(0)); |
| inputs.emplace_back(node.input(1)); |
| std::vector<std::string> outputs; |
| outputs.emplace_back(node.output(0)); |
| |
| auto axis = onnx_node->attributes.get<int64_t>("axis", 0L); |
| if (axis == 0) { |
| BuildOperator(c2_op, "Gather", inputs, outputs); |
| } else if (axis == 1) { |
| BuildOperator(c2_op, "BatchGather", inputs, outputs); |
| } else { |
| CAFFE_THROW( |
| "Caffe2 only supports Gather with axis being 0 or 1, ", |
| "whereas axis is ", |
| axis); |
| } |
| |
| return ret; |
| } |
| |
| Caffe2Ops Caffe2Backend::CreateGemm( |
| OnnxNode* onnx_node, |
| const ConversionContext& ctx) { |
| const auto& node = onnx_node->node; |
| if (node.input_size() < 3 || node.output_size() < 1) { |
| CAFFE_THROW("Caffe2 Gemm should have 3 inputs and 1 output"); |
| } |
| |
| Caffe2Ops ret; |
| auto input_a = node.input(0); |
| auto input_b = node.input(1); |
| auto input_c = node.input(2); |
| auto output = node.output(0); |
| |
| auto alpha = onnx_node->attributes.get<float>("alpha", 1.0); |
| auto beta = onnx_node->attributes.get<float>("beta", 1.0); |
| if (!AlmostEqual(alpha, 1)) { |
| auto scaled_a = dummy_->NewDummyName(); |
| caffe2::Argument scale; |
| scale.set_name("scale"); |
| scale.set_f(alpha); |
| |
| auto* c2_op = ret.ops.Add(); |
| BuildOperator(c2_op, "Scale", {input_a}, {scaled_a}, {scale}); |
| input_a = scaled_a; |
| } |
| if (!AlmostEqual(beta, 1)) { |
| auto scaled_c = dummy_->NewDummyName(); |
| caffe2::Argument scale; |
| scale.set_name("scale"); |
| scale.set_f(beta); |
| |
| auto* c2_op = ret.ops.Add(); |
| BuildOperator(c2_op, "Scale", {input_c}, {scaled_c}, {scale}); |
| input_c = scaled_c; |
| } |
| |
| auto trans_a = onnx_node->attributes.get<int64_t>("transA", 0L); |
| auto trans_b = onnx_node->attributes.get<int64_t>("transB", 0L); |
| // Support broadcast by default when opset_version > 6. |
| auto broadcast = onnx_node->attributes.get<int64_t>( |
| "broadcast", (ctx.opset_version() > 6) ? 1L : 0L); |
| |
| // If the c's shape information is available and c is a 1d tensor(except |
| // c is a scalar), use FC aggressively. |
| auto check_fc = [&]() -> bool { |
| const auto input_c_vi_iter = ctx.value_infos().find(node.input(2)); |
| |
| if (input_c_vi_iter == ctx.value_infos().end()) { |
| return false; |
| } |
| |
| const auto input_c_shape = |
| input_c_vi_iter->second.type().tensor_type().shape(); |
| |
| if (input_c_shape.dim_size() != 1) { |
| return false; |
| } |
| |
| // c is a scalar. |
| if (input_c_shape.dim(0).dim_value() == 1) { |
| const auto input_b_vi_iter = ctx.value_infos().find(node.input(1)); |
| |
| // If the b's shape is not available, skip FC. |
| if (input_b_vi_iter == ctx.value_infos().end()) { |
| return false; |
| } |
| const auto input_b_shape = |
| input_b_vi_iter->second.type().tensor_type().shape(); |
| int input_b_last_dim_index = (trans_b) ? 0 : 1; |
| // If b's last dim is not 1, skip FC. |
| if (input_b_shape.dim_size() <= input_b_last_dim_index || |
| input_b_shape.dim(input_b_last_dim_index).dim_value() != 1) { |
| return false; |
| } |
| } |
| |
| return true; |
| }; |
| |
| if (!trans_a && broadcast && check_fc()) { |
| auto* c2_op = ret.ops.Add(); |
| if (trans_b) { |
| BuildOperator(c2_op, "FC", {input_a, input_b, input_c}, {output}); |
| } else { |
| BuildOperator( |
| c2_op, "FCTransposed", {input_a, input_b, input_c}, {output}); |
| } |
| } else { |
| auto ab = dummy_->NewDummyName(); |
| caffe2::Argument arg_trans_a; |
| arg_trans_a.set_name("trans_a"); |
| arg_trans_a.set_i(trans_a); |
| caffe2::Argument arg_trans_b; |
| arg_trans_b.set_name("trans_b"); |
| arg_trans_b.set_i(trans_b); |
| |
| auto* c2_op = ret.ops.Add(); |
| BuildOperator( |
| c2_op, "MatMul", {input_a, input_b}, {ab}, {arg_trans_a, arg_trans_b}); |
| c2_op = ret.ops.Add(); |
| if (ctx.opset_version() >= 7) { |
| BuildOperator(c2_op, "Add", {ab, input_c}, {output}); |
| } else { |
| caffe2::Argument arg_broadcast; |
| arg_broadcast.set_name("broadcast"); |
| arg_broadcast.set_i(broadcast); |
| BuildOperator(c2_op, "Add", {ab, input_c}, {output}, {arg_broadcast}); |
| } |
| } |
| |
| return ret; |
| } |
| |
| Caffe2Ops Caffe2Backend::CreatePad( |
| OnnxNode* onnx_node, |
| const ConversionContext& ctx) { |
| auto& attributes = onnx_node->attributes; |
| ::google::protobuf::RepeatedField<::google::protobuf::int64> pads; |
| std::string pad_name = ctx.opset_version() < 2 ? "paddings" : "pads"; |
| pads = attributes |
| .get<::google::protobuf::RepeatedField<::google::protobuf::int64>>( |
| pad_name); |
| std::string str; |
| std::stringstream ss; |
| ss << "["; |
| for (const auto& i : pads) { |
| ss << i << ", "; |
| } |
| ss << "]"; |
| str = ss.str(); |
| |
| // Guard the invalid (negative) pads attribute. |
| for (const auto i : pads) { |
| if (i < 0) { |
| CAFFE_THROW("ONNX does not support negative pads in Pad, but get ", str); |
| } |
| } |
| |
| // first two dim is for batch and channel. Note that now all the values are |
| // non-negative |
| if (!(pads.size() == 8 && |
| (pads.Get(0) + pads.Get(1) + pads.Get(4) + pads.Get(5) == 0))) { |
| CAFFE_THROW( |
| "Caffe2 only supports padding 2D Tensor, whereas padding is ", str); |
| } |
| |
| // rewrite the padding info |
| auto* attr = attributes.AddRewrittenAttribute(pad_name); |
| attr->add_ints(pads.Get(2)); |
| attr->add_ints(pads.Get(3)); |
| attr->add_ints(pads.Get(6)); |
| attr->add_ints(pads.Get(7)); |
| |
| return CommonOnnxNodeToCaffe2Ops(onnx_node, ctx); |
| } |
| |
| // TODO: Caffe2 Concat has an extra output. It should be only |
| // used when doing training, so we should change Caffe2 to allow |
| // 1 output. |
| Caffe2Ops Caffe2Backend::CreateConcat( |
| OnnxNode* onnx_node, |
| const ConversionContext& ctx) { |
| auto c2_op = CommonOnnxNodeToCaffe2Ops(onnx_node, ctx); |
| CAFFE_ENFORCE_EQ(c2_op.ops.size(), 1); |
| auto* op = c2_op.ops.Mutable(0); |
| op->add_output(dummy_->NewDummyName()); |
| |
| return c2_op; |
| } |
| |
| Caffe2Ops Caffe2Backend::CreateLogSoftmax( |
| OnnxNode* onnx_node, |
| const ConversionContext& ctx) { |
| const auto& node = onnx_node->node; |
| if (node.input_size() < 1 || node.output_size() < 1) { |
| CAFFE_THROW("LogSoftmax should have 1 input and 1 output"); |
| } |
| auto axis = onnx_node->attributes.get<int64_t>("axis", 1L); |
| caffe2::Argument arg_axis; |
| arg_axis.set_name("axis"); |
| arg_axis.set_i(axis); |
| auto softmax_a = dummy_->NewDummyName(); |
| |
| Caffe2Ops ret; |
| auto* c2_op = ret.ops.Add(); |
| BuildOperator(c2_op, "Softmax", {node.input(0)}, {softmax_a}, {arg_axis}); |
| c2_op = ret.ops.Add(); |
| BuildOperator(c2_op, "Log", {softmax_a}, {node.output(0)}); |
| |
| return ret; |
| } |
| |
| Caffe2Ops Caffe2Backend::CreateSlice( |
| OnnxNode* onnx_node, |
| const ConversionContext& ctx) { |
| auto op_tmp = CommonOnnxNodeToCaffe2Ops(onnx_node, ctx); |
| CAFFE_ENFORCE_EQ(op_tmp.ops.size(), 1); |
| auto* op = op_tmp.ops.Mutable(0); |
| std::unordered_map<std::string, caffe2::Argument*> args; |
| for (auto& arg : *op->mutable_arg()) { |
| args.emplace(arg.name(), &arg); |
| } |
| |
| caffe2::Argument starts_vals; |
| starts_vals.set_name("values"); |
| auto pos = args.find("starts"); |
| if (pos != args.end()) { |
| for (auto i : pos->second->ints()) { |
| starts_vals.add_ints(i < 0 ? i - 1 : i); |
| } |
| args.erase(pos); |
| } |
| |
| caffe2::Argument ends_vals; |
| ends_vals.set_name("values"); |
| pos = args.find("ends"); |
| if (pos != args.end()) { |
| for (auto i : pos->second->ints()) { |
| if (i == std::numeric_limits<int64_t>::max()) { |
| ends_vals.add_ints(-1); |
| } else { |
| ends_vals.add_ints(i < 0 ? i - 1 : i); |
| } |
| } |
| args.erase(pos); |
| } |
| |
| caffe2::Argument axes_vals; |
| axes_vals.set_name("values"); |
| pos = args.find("axes"); |
| if (pos != args.end()) { |
| for (auto i : pos->second->ints()) { |
| axes_vals.add_ints(i); |
| } |
| args.erase(pos); |
| } else { |
| auto ndim = starts_vals.ints_size(); |
| for (int64_t i = 0; i < ndim; ++i) { |
| axes_vals.add_ints(i); |
| } |
| } |
| |
| CAFFE_ENFORCE_GE(op->input_size(), 1); |
| auto data = op->input(0); |
| auto shape_tensor = dummy_->NewDummyName(); |
| Caffe2Ops ret; |
| |
| auto* c2_op = ret.ops.Add(); |
| BuildOperator(c2_op, "Shape", {data}, {shape_tensor}); |
| |
| auto axes_tensor = dummy_->NewDummyName(); |
| c2_op = ret.ops.Add(); |
| { |
| caffe2::Argument shape; |
| shape.set_name("shape"); |
| shape.add_ints(axes_vals.ints_size()); |
| BuildOperator( |
| c2_op, "GivenTensorIntFill", {}, {axes_tensor}, {shape, axes_vals}); |
| } |
| |
| auto starts_vals_tensor = dummy_->NewDummyName(); |
| auto starts_tensor = dummy_->NewDummyName(); |
| c2_op = ret.ops.Add(); |
| { |
| caffe2::Argument shape_starts; |
| shape_starts.set_name("shape"); |
| shape_starts.add_ints(starts_vals.ints_size()); |
| BuildOperator( |
| c2_op, |
| "GivenTensorInt64Fill", |
| {}, |
| {starts_vals_tensor}, |
| {shape_starts, starts_vals}); |
| } |
| |
| caffe2::Argument dtype; |
| dtype.set_name("dtype"); |
| dtype.set_i(static_cast<int64_t>(caffe2::TensorProto::INT64)); |
| caffe2::Argument constant; |
| constant.set_name("value"); |
| constant.set_i(0); |
| c2_op = ret.ops.Add(); |
| BuildOperator( |
| c2_op, |
| "ConstantFill", |
| {shape_tensor}, |
| {starts_tensor}, |
| {dtype, constant}); |
| c2_op = ret.ops.Add(); |
| BuildOperator( |
| c2_op, |
| "ScatterAssign", |
| {starts_tensor, axes_tensor, starts_vals_tensor}, |
| {starts_tensor}); |
| // Slice only accepts starts as int |
| caffe2::Argument to; |
| to.set_name("to"); |
| to.set_i(static_cast<int64_t>(caffe2::TensorProto::INT32)); |
| |
| auto ends_vals_tensor = dummy_->NewDummyName(); |
| auto ends_tensor = dummy_->NewDummyName(); |
| c2_op = ret.ops.Add(); |
| { |
| caffe2::Argument shape_ends; |
| shape_ends.set_name("shape"); |
| shape_ends.add_ints(ends_vals.ints_size()); |
| BuildOperator( |
| c2_op, |
| "GivenTensorInt64Fill", |
| {}, |
| {ends_vals_tensor}, |
| {shape_ends, ends_vals}); |
| } |
| |
| constant.set_i(-1); |
| c2_op = ret.ops.Add(); |
| BuildOperator( |
| c2_op, "ConstantFill", {shape_tensor}, {ends_tensor}, {dtype, constant}); |
| c2_op = ret.ops.Add(); |
| BuildOperator( |
| c2_op, |
| "ScatterAssign", |
| {ends_tensor, axes_tensor, ends_vals_tensor}, |
| {ends_tensor}); |
| |
| // attach the original op at the end |
| c2_op = ret.ops.Add(); |
| c2_op->CopyFrom(*op); |
| c2_op->mutable_input()->Clear(); |
| c2_op->add_input(data); |
| c2_op->add_input(starts_tensor); |
| c2_op->add_input(ends_tensor); |
| c2_op->mutable_arg()->Clear(); |
| for (const auto& kv : args) { |
| c2_op->add_arg()->CopyFrom(*kv.second); |
| } |
| |
| return ret; |
| } |
| |
| // Do the following: |
| // for a given index tensor (i.e. `starts` or `ends`): |
| // 1) Hilariously subtract 1 from the value if it is negative. This due to |
| // the behavior of Caffe2's slice operator not matching that of ONNX's slice |
| // 2) Fully expand the index tensor out to the rank of the data tensor. |
| // pseudocode: indices_full = zeros(rank); indices_full[axes] = indices.int() |
| std::string Caffe2Backend::PreprocessSliceIndexTensor( |
| OnnxNode* onnx_node, |
| Caffe2Ops& ret, |
| std::string indices_tensor, |
| std::string axes_tensor, |
| std::string rank_tensor, |
| std::string zero_tensor, |
| std::string one_tensor, |
| int default_value) { |
| auto indices_tensor_full = dummy_->NewDummyName(); |
| |
| { |
| caffe2::Argument value; |
| value.set_name("value"); |
| value.set_i(default_value); |
| caffe2::Argument dtype; |
| dtype.set_name("dtype"); |
| dtype.set_i(static_cast<int64_t>(caffe2::TensorProto::INT64)); |
| caffe2::Argument input_as_shape; |
| input_as_shape.set_name("input_as_shape"); |
| input_as_shape.set_i(1); |
| auto c2_op = ret.ops.Add(); |
| BuildOperator( |
| c2_op, |
| "ConstantFill", |
| {rank_tensor}, |
| {indices_tensor_full}, |
| {value, dtype, input_as_shape}); |
| } |
| |
| // Subtract 1 from each element of the indices tensor that is negative |
| auto lt_tensor = dummy_->NewDummyName(); |
| { |
| caffe2::Argument broadcast; |
| broadcast.set_name("broadcast"); |
| broadcast.set_i(1); |
| auto c2_op = ret.ops.Add(); |
| BuildOperator( |
| c2_op, "LT", {indices_tensor, zero_tensor}, {lt_tensor}, {broadcast}); |
| } |
| |
| auto sub_one_tensor = dummy_->NewDummyName(); |
| { |
| caffe2::Argument broadcast; |
| broadcast.set_name("broadcast"); |
| broadcast.set_i(1); |
| auto c2_op = ret.ops.Add(); |
| BuildOperator( |
| c2_op, |
| "Sub", |
| {indices_tensor, one_tensor}, |
| {sub_one_tensor}, |
| {broadcast}); |
| } |
| |
| auto indices_tensor_adjusted = dummy_->NewDummyName(); |
| auto c2_op = ret.ops.Add(); |
| BuildOperator( |
| c2_op, |
| "Conditional", |
| {lt_tensor, sub_one_tensor, indices_tensor}, |
| {indices_tensor_adjusted}, |
| {}); |
| |
| // Fill in values specified from the partially-specified ONNX indices tensor |
| c2_op = ret.ops.Add(); |
| BuildOperator( |
| c2_op, |
| "ScatterAssign", |
| {indices_tensor_full, axes_tensor, indices_tensor_adjusted}, |
| {indices_tensor_full}); |
| |
| return indices_tensor_full; |
| } |
| |
| Caffe2Ops Caffe2Backend::CreateDynamicSlice( |
| OnnxNode* onnx_node, |
| const ConversionContext& ctx) { |
| auto op_tmp = CommonOnnxNodeToCaffe2Ops(onnx_node, ctx); |
| CAFFE_ENFORCE_EQ(op_tmp.ops.size(), 1); |
| auto* op = op_tmp.ops.Mutable(0); |
| std::unordered_map<std::string, caffe2::Argument*> args; |
| for (auto& arg : *op->mutable_arg()) { |
| args.emplace(arg.name(), &arg); |
| } |
| |
| CAFFE_ENFORCE_GE(op->input_size(), 1); |
| auto data = op->input(0); |
| Caffe2Ops ret; |
| |
| // First get the shape of the input tensor |
| auto* c2_op = ret.ops.Add(); |
| auto size_tensor = dummy_->NewDummyName(); |
| BuildOperator(c2_op, "Shape", {data}, {size_tensor}); |
| |
| // Now get the rank of the tensor by getting the shape of the shape of |
| // the input tensor |
| c2_op = ret.ops.Add(); |
| auto rank_tensor = dummy_->NewDummyName(); |
| BuildOperator(c2_op, "Shape", {size_tensor}, {rank_tensor}); |
| |
| // Axes tensor will be used to populate the fully-specified starts and ends |
| // arguments to the caffe2 Slice operator. |
| std::string axes_tensor; |
| if (onnx_node->node.input_size() > 3) { |
| axes_tensor = onnx_node->node.input(3); |
| } else { |
| axes_tensor = dummy_->NewDummyName(); |
| auto* c2_op = ret.ops.Add(); |
| BuildOperator(c2_op, "Range", {rank_tensor}, {axes_tensor}, {}); |
| } |
| |
| // Useful int tensors |
| auto define_integer_constant = [this, &ret](int val) { |
| caffe2::Argument value; |
| value.set_name("value"); |
| value.set_i(val); |
| caffe2::Argument dtype; |
| dtype.set_name("dtype"); |
| dtype.set_i(static_cast<int64_t>(caffe2::TensorProto::INT64)); |
| caffe2::Argument shape; |
| shape.set_name("shape"); |
| shape.add_ints(1); |
| auto c2_op = ret.ops.Add(); |
| auto name = dummy_->NewDummyName(); |
| BuildOperator(c2_op, "ConstantFill", {}, {name}, {value, dtype, shape}); |
| return name; |
| }; |
| |
| auto zero_tensor = define_integer_constant(0); |
| auto one_tensor = define_integer_constant(1); |
| |
| auto starts_tensor_full = PreprocessSliceIndexTensor( |
| onnx_node, |
| ret, |
| onnx_node->node.input(1), // starts |
| axes_tensor, |
| rank_tensor, |
| zero_tensor, |
| one_tensor, |
| 0); |
| |
| auto ends_tensor_full = PreprocessSliceIndexTensor( |
| onnx_node, |
| ret, |
| onnx_node->node.input(2), // ends |
| axes_tensor, |
| rank_tensor, |
| zero_tensor, |
| one_tensor, |
| -1); |
| |
| // attach the original op at the end |
| c2_op = ret.ops.Add(); |
| c2_op->CopyFrom(*op); |
| c2_op->mutable_input()->Clear(); |
| c2_op->add_input(data); |
| c2_op->add_input(starts_tensor_full); |
| c2_op->add_input(ends_tensor_full); |
| c2_op->mutable_arg()->Clear(); |
| for (const auto& kv : args) { |
| c2_op->add_arg()->CopyFrom(*kv.second); |
| } |
| |
| return ret; |
| } |
| |
| Caffe2Ops Caffe2Backend::CreateBatchNormalization( |
| OnnxNode* onnx_node, |
| const ConversionContext& ctx) { |
| auto& attributes = onnx_node->attributes; |
| |
| if (ctx.opset_version() < 6) { |
| attributes.remove("consumed_inputs"); |
| } |
| |
| if (ctx.opset_version() >= 7) { |
| auto* attr = attributes.AddRewrittenAttribute("is_test"); |
| attr->set_i(1); |
| } |
| |
| if (attributes.HasAttribute("spatial") && |
| attributes.get<int64_t>("spatial") == 1) { |
| attributes.remove("spatial"); |
| } |
| |
| return CommonOnnxNodeToCaffe2Ops(onnx_node, ctx); |
| } |
| |
| Caffe2Ops Caffe2Backend::CreateSplit( |
| OnnxNode* onnx_node, |
| const ConversionContext& ctx) { |
| auto& attributes = onnx_node->attributes; |
| if (!attributes.HasAttribute("axis")) { |
| auto* attr = attributes.AddRewrittenAttribute("axis"); |
| attr->set_i(0); |
| } |
| |
| return CommonOnnxNodeToCaffe2Ops(onnx_node, ctx); |
| } |
| |
| Caffe2Ops Caffe2Backend::CreateMatMul( |
| OnnxNode* onnx_node, |
| const ConversionContext& ctx) { |
| const auto& node = onnx_node->node; |
| if (node.input_size() != 2) { |
| CAFFE_THROW("MatMul should have 2 inputs"); |
| } |
| |
| auto c2_op = CommonOnnxNodeToCaffe2Ops(onnx_node, ctx); |
| CAFFE_ENFORCE_EQ(c2_op.ops.size(), 1); |
| auto* op = c2_op.ops.Mutable(0); |
| auto* broadcast_arg = op->add_arg(); |
| broadcast_arg->set_name("broadcast"); |
| broadcast_arg->set_i(1); |
| |
| return c2_op; |
| } |
| |
| Caffe2Ops Caffe2Backend::CreateUpsample( |
| OnnxNode* onnx_node, |
| const ConversionContext& ctx) { |
| auto& attributes = onnx_node->attributes; |
| attributes.remove("mode"); |
| |
| if (ctx.opset_version() >= 7 && ctx.opset_version() < 9) { |
| const auto& scales = |
| attributes.get<::google::protobuf::RepeatedField<float>>("scales"); |
| if (scales.size() != 4) { |
| CAFFE_THROW("The scales argument should have size 4"); |
| } else if ( |
| !AlmostEqual(scales.Get(0), 1) || !AlmostEqual(scales.Get(1), 1)) { |
| CAFFE_THROW("The first two elements in the scales argument must be 1"); |
| } |
| attributes.remove("scales"); |
| auto c2_op = CommonOnnxNodeToCaffe2Ops(onnx_node, ctx); |
| auto* op = c2_op.ops.Mutable(0); |
| auto* c2_height = op->add_arg(); |
| c2_height->set_name("height_scale"); |
| c2_height->set_f(scales.Get(2)); |
| auto* c2_width = op->add_arg(); |
| c2_width->set_name("width_scale"); |
| c2_width->set_f(scales.Get(3)); |
| return c2_op; |
| } else if (ctx.opset_version() >= 9) { |
| const auto& node = onnx_node->node; |
| if (node.input_size() != 2) { |
| CAFFE_THROW("Expects 2 input in upsample after onnx version 9"); |
| } |
| Caffe2Ops ret; |
| |
| // Slice the input {1, 1, height, width} -> {height, width} |
| auto* c2_op = ret.ops.Add(); |
| auto sliced_input = dummy_->NewDummyName(); |
| caffe2::Argument arg_starts, arg_ends; |
| arg_starts.set_name("starts"); |
| arg_starts.add_ints(2); |
| arg_ends.set_name("ends"); |
| arg_ends.add_ints(-1); |
| BuildOperator( |
| c2_op, |
| "Slice", |
| {node.input(1)}, |
| {sliced_input}, |
| {arg_starts, arg_ends}); |
| |
| // Upsample |
| c2_op = ret.ops.Add(); |
| BuildOperator( |
| c2_op, |
| "ResizeNearest", |
| {node.input(0), sliced_input}, |
| {node.output(0)}, |
| {}); |
| return ret; |
| } |
| return CommonOnnxNodeToCaffe2Ops(onnx_node, ctx); |
| } |
| |
| Caffe2Ops Caffe2Backend::CreateDropout( |
| OnnxNode* onnx_node, |
| const ConversionContext& ctx) { |
| if (ctx.opset_version() >= 7) { |
| auto& attributes = onnx_node->attributes; |
| auto* attr = attributes.AddRewrittenAttribute("is_test"); |
| attr->set_i(1); |
| } |
| |
| return CommonOnnxNodeToCaffe2Ops(onnx_node, ctx); |
| } |
| |
| Caffe2Ops Caffe2Backend::CreateLRN( |
| OnnxNode* onnx_node, |
| const ConversionContext& ctx) { |
| auto c2_op = CommonOnnxNodeToCaffe2Ops(onnx_node, ctx); |
| const auto& attributes = onnx_node->attributes; |
| if (!attributes.HasAttribute("alpha")) { |
| auto* arg = c2_op.ops.Mutable(0)->add_arg(); |
| arg->set_name("alpha"); |
| arg->set_f(1e-4); |
| } |
| if (!attributes.HasAttribute("beta")) { |
| auto* arg = c2_op.ops.Mutable(0)->add_arg(); |
| arg->set_name("beta"); |
| arg->set_f(0.75); |
| } |
| return c2_op; |
| } |
| |
| //============================================== |
| // Rest of the member functions for Caffe2Backend |
| //============================================== |
| std::unordered_set<std::string> Caffe2Backend::AllNamesInGraph( |
| const GraphProto& graph) { |
| std::unordered_set<std::string> names; |
| |
| for (const auto& input : graph.input()) { |
| names.emplace(input.name()); |
| } |
| for (const auto& output : graph.output()) { |
| names.emplace(output.name()); |
| } |
| for (const auto& node : graph.node()) { |
| for (const auto& n : node.input()) { |
| names.emplace(n); |
| } |
| for (const auto& n : node.output()) { |
| names.emplace(n); |
| } |
| } |
| |
| return names; |
| } |
| |
| // This translator performs the basic translation of ONNX nodes into |
| // Caffe2 operators. Besides doing a straightforward marshalling from |
| // one format to another, it also does these extra things: |
| // |
| // - Renames operators based on 'renamed_operators' |
| // - Renames attributes based on 'renamed_attrs' and |
| // 'get_per_op_renamed_attrs' |
| // |
| // If you're writing a custom translator, consider calling this first, |
| // and then fixing things up further. |
| Caffe2Ops Caffe2Backend::CommonOnnxNodeToCaffe2Ops( |
| OnnxNode* onnx_node, |
| const ConversionContext& ctx) { |
| Caffe2Ops ret; |
| auto* c2_op = ret.ops.Add(); |
| |
| const auto& node = onnx_node->node; |
| c2_op->mutable_input()->MergeFrom(node.input()); |
| c2_op->mutable_output()->MergeFrom(node.output()); |
| c2_op->set_name(node.name()); |
| |
| // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
| const auto onnx_op_type = node.op_type(); |
| auto broken_version = caffe2::get_default( |
| get_broken_operators(), onnx_op_type, std::numeric_limits<int>::max()); |
| if (broken_version <= ctx.opset_version()) { |
| CAFFE_THROW( |
| "Don't know how to translate op ", |
| onnx_op_type, |
| " in ONNX operator set v", |
| ctx.opset_version(), |
| " (I only support prior to v", |
| broken_version); |
| } |
| c2_op->set_type( |
| caffe2::get_default(get_renamed_operators(), onnx_op_type, onnx_op_type)); |
| if (!IsOperator(c2_op->type())) { |
| CAFFE_THROW("Don't know how to translate op ", onnx_op_type); |
| } |
| |
| auto mapper = [&, this](const std::string& k) { |
| const auto it = get_per_op_renamed_attrs().find(onnx_op_type); |
| if (it != get_per_op_renamed_attrs().end()) { |
| const auto it_op = it->second.find(k); |
| if (it_op != it->second.end()) { |
| return it_op->second; |
| } |
| } |
| const auto it_global = get_renamed_attrs().find(k); |
| if (it_global != get_renamed_attrs().end()) { |
| return it_global->second; |
| } |
| return k; |
| }; |
| c2_op->mutable_arg()->MergeFrom( |
| onnx_node->attributes.OnnxAttrToCaffe2Arg(mapper)); |
| |
| return ret; |
| } |
| |
| Caffe2Ops Caffe2Backend::ConvertNode( |
| const std::string& node_str, |
| const ConversionContext& ctx) { |
| ::google::protobuf::RepeatedPtrField<NodeProto> nodes; |
| auto* n = nodes.Add(); |
| ParseProtoFromLargeString(node_str, n); |
| ModelProto init_model; |
| ModelProto pred_model; |
| OnnxNode onnx_node = OnnxNode(nodes.Get(0)); |
| return OnnxNodeToCaffe2Ops(init_model, pred_model, ctx, &onnx_node); |
| } |
| |
| void Caffe2Backend::CheckOpSchemaArguments( |
| const caffe2::OpSchema& schema, |
| const caffe2::OperatorDef& op) { |
| const auto& schema_args = schema.args(); |
| if (schema_args.size() > 0) { |
| std::vector<std::string> argnames; |
| std::transform( |
| schema_args.begin(), |
| schema_args.end(), |
| std::back_inserter(argnames), |
| [](caffe2::OpSchema::Argument elem) { return elem.name(); }); |
| |
| for (const auto& arg : op.arg()) { |
| if (std::count(argnames.begin(), argnames.end(), arg.name()) == 0) { |
| CAFFE_THROW( |
| "Don't know how to map unexpected argument ", |
| arg.name(), |
| " (from operator ", |
| op.type(), |
| ")"); |
| } |
| } |
| } else { |
| // A number of C2 operators do not declare proper arguments. Let's log the |
| // error |
| VLOG(2) |
| << "Operator " << op.type() |
| << " does not declare arguments in its schema. Please file a Caffe2 issue."; |
| } |
| } |
| |
| Caffe2Ops Caffe2Backend::OnnxNodeToCaffe2Ops( |
| const ModelProto& init_model, |
| const ModelProto& pred_model, |
| const ConversionContext& ctx, |
| OnnxNode* onnx_node) { |
| Caffe2Ops res; |
| if (get_special_operators().count(onnx_node->node.op_type())) { |
| res = (this->*get_special_operators().at(onnx_node->node.op_type()))( |
| onnx_node, ctx); |
| } else { |
| res = CommonOnnxNodeToCaffe2Ops(onnx_node, ctx); |
| } |
| |
| for (const auto& result_op : res.ops) { |
| const auto* schema = OpSchemaRegistry::Schema(result_op.type()); |
| if (schema) { |
| CheckOpSchemaArguments(*schema, result_op); |
| } else { |
| CAFFE_THROW( |
| "Caffe2 has no such operator, could not find schema for ", |
| result_op.type()); |
| } |
| } |
| return res; |
| } |
| |
| void Caffe2Backend::OnnxToCaffe2( |
| caffe2::NetDef* init_net, |
| caffe2::NetDef* pred_net, |
| const ModelProto& onnx_model, |
| const std::string& device, |
| int opset_version, |
| bool include_initializers, |
| const std::vector<Caffe2Ops>& extras) { |
| auto device_option = GetDeviceOption(Device(device)); |
| |
| ModelProto init_model = ModelProto(); |
| ModelProto pred_model = onnx_model; |
| pred_model.mutable_graph()->mutable_initializer()->Clear(); |
| |
| init_net->set_name(onnx_model.graph().name() + "_init"); |
| pred_net->set_name(onnx_model.graph().name() + "_predict"); |
| |
| // Convert initializer if necessary |
| if (include_initializers) { |
| for (const auto& tp : onnx_model.graph().initializer()) { |
| auto* c2_op = init_net->add_op(); |
| BuildTensorFillingOp(c2_op, tp); |
| } |
| } |
| |
| auto name_set = AllNamesInGraph(init_model.graph()); |
| auto name_set_pred = AllNamesInGraph(pred_model.graph()); |
| name_set.insert(name_set_pred.begin(), name_set_pred.end()); |
| dummy_->Reset(name_set); |
| |
| ValueInfoMap graph_value_infos{}; |
| for (const auto& vi : pred_model.graph().input()) { |
| graph_value_infos[vi.name()].CopyFrom(vi); |
| } |
| for (const auto& vi : pred_model.graph().output()) { |
| graph_value_infos[vi.name()].CopyFrom(vi); |
| } |
| for (const auto& vi : pred_model.graph().value_info()) { |
| graph_value_infos[vi.name()].CopyFrom(vi); |
| } |
| |
| size_t idx_extra = 0; |
| auto converter = [&](const ModelProto& model, caffe2::NetDef* net) mutable { |
| net->mutable_device_option()->CopyFrom(device_option); |
| for (const auto& node : model.graph().node()) { |
| auto* init_net_tmp = include_initializers ? init_net : net; |
| // For RNN operators, we rely on Python code to convert them for us, and |
| // we simply deserilize the string. This is hack and eventually we want to |
| // get rid of this to have one flow. Note that we need to update the dummy |
| // name generator to avoid having duplicated names between Python and C++ |
| // generated dummies |
| if (get_rnn_operators().count(node.op_type())) { |
| if (idx_extra < extras.size()) { |
| const auto& c2ops = extras[idx_extra++]; |
| for (const auto& op : c2ops.init_ops) { |
| UpdateNames(dummy_, op); |
| } |
| init_net_tmp->mutable_op()->MergeFrom(c2ops.init_ops); |
| for (const auto& op : c2ops.ops) { |
| UpdateNames(dummy_, op); |
| } |
| net->mutable_op()->MergeFrom(c2ops.ops); |
| for (const auto& input : c2ops.interface_blobs) { |
| dummy_->AddName(input); |
| } |
| net->mutable_external_input()->MergeFrom(c2ops.interface_blobs); |
| } else { |
| CAFFE_THROW( |
| "Don't know how to convert ", |
| node.op_type(), |
| " without enough extra preconverted string"); |
| } |
| } else { |
| ValueInfoMap value_infos{}; |
| for (const auto& name : node.input()) { |
| auto iter = graph_value_infos.find(name); |
| if (iter != graph_value_infos.end()) { |
| value_infos[name].CopyFrom(iter->second); |
| } |
| } |
| auto onnx_node = OnnxNode(node); |
| auto c2ops = OnnxNodeToCaffe2Ops( |
| init_model, pred_model, {value_infos, opset_version}, &onnx_node); |
| init_net_tmp->mutable_op()->MergeFrom(c2ops.init_ops); |
| net->mutable_op()->MergeFrom(c2ops.ops); |
| net->mutable_external_input()->MergeFrom(c2ops.interface_blobs); |
| } |
| } |
| |
| for (const auto& value : model.graph().output()) { |
| net->add_external_output(value.name()); |
| } |
| for (const auto& value : model.graph().input()) { |
| net->add_external_input(value.name()); |
| } |
| }; |
| |
| converter(init_model, init_net); |
| converter(pred_model, pred_net); |
| } |
| |
| Caffe2BackendRep* Caffe2Backend::Prepare( |
| const std::string& onnx_model_str, |
| const std::string& device, |
| const std::vector<Caffe2Ops>& extras) { |
| Caffe2BackendRep* rep = new Caffe2BackendRep(); |
| ModelProto onnx_model; |
| ParseProtoFromLargeString(onnx_model_str, &onnx_model); |
| |
| #ifndef C10_MOBILE |
| ::ONNX_NAMESPACE::checker::check_model(onnx_model); |
| #endif |
| |
| int opset_version = -1; |
| for (const auto& imp : onnx_model.opset_import()) { |
| if ((!imp.has_domain()) || imp.domain().empty()) { |
| opset_version = imp.version(); |
| if (opset_version > kKnownOpsetVersion) { |
| std::cout |
| << "This version of onnx-caffe2 targets ONNX operator set version " |
| << kKnownOpsetVersion |
| << ", but the model we are trying to import uses version " |
| << opset_version << ". We will try to import it anyway, " |
| << "but if the model uses operators which had BC-breaking changes " |
| "in the intervening versions, import will fail." |
| << std::endl; |
| } |
| } else { |
| std::cout << "Unrecognized operator set " << opset_version << std::endl; |
| } |
| } |
| if (opset_version < 0) { |
| if (onnx_model.ir_version() >= 0x00000003) { |
| CAFFE_THROW( |
| "Model with IR version >= 3 did not specify ONNX operator set " |
| "version (onnx-caffe2 requires it)"); |
| } else { |
| opset_version = 1; |
| } |
| } |
| |
| // TODO: avoid extra copy by directly feed initializers to backend blobs |
| OnnxToCaffe2( |
| &rep->init_net(), |
| &rep->pred_net(), |
| onnx_model, |
| device, |
| opset_version, |
| true, |
| extras); |
| |
| // Get a list of uninitialized inputs to help with the inference setup |
| auto& uninitialized_inputs = rep->uninitialized_inputs(); |
| std::unordered_set<std::string> initialized_inputs; |
| for (const auto& tp : onnx_model.graph().initializer()) { |
| initialized_inputs.emplace(tp.name()); |
| } |
| for (const auto& input : onnx_model.graph().input()) { |
| if (!initialized_inputs.count(input.name())) { |
| uninitialized_inputs.emplace_back(input.name()); |
| } |
| } |
| |
| return rep; |
| } |
| |
| template <typename T> |
| void ConvertIntegralValueToCaffe2( |
| caffe2::OperatorDef* c2_op, |
| caffe2::Argument* c2_values, |
| const TensorProto& onnx_tensor) { |
| c2_op->set_type( |
| onnx_tensor.data_type() == TensorProto::BOOL ? "GivenTensorBoolFill" |
| : "GivenTensorIntFill"); |
| ::google::protobuf::RepeatedField<T> tmp; |
| const ::google::protobuf::RepeatedField<T>* src = &tmp; |
| bool converted = TryConvertingTensorRawValues<T>(onnx_tensor, &tmp); |
| if (converted) { |
| for (const auto i : *src) { |
| c2_values->add_ints(i); |
| } |
| } else { |
| const ::google::protobuf::RepeatedField<::google::protobuf::int32>* |
| int32_src = &onnx_tensor.int32_data(); |
| for (const auto i : *int32_src) { |
| c2_values->add_ints(i); |
| } |
| } |
| } |
| |
| template <> |
| void ConvertIntegralValueToCaffe2<::google::protobuf::int64>( |
| caffe2::OperatorDef* c2_op, |
| caffe2::Argument* c2_values, |
| const TensorProto& onnx_tensor) { |
| c2_op->set_type("GivenTensorInt64Fill"); |
| auto* ints = c2_values->mutable_ints(); |
| if (!TryConvertingTensorRawValues<::google::protobuf::int64>( |
| onnx_tensor, ints)) { |
| ints->CopyFrom(onnx_tensor.int64_data()); |
| } |
| } |
| |
| template <> |
| void ConvertIntegralValueToCaffe2<::google::protobuf::uint64>( |
| caffe2::OperatorDef* c2_op, |
| caffe2::Argument* c2_values, |
| const TensorProto& onnx_tensor) { |
| c2_op->set_type("GivenTensorInt64Fill"); |
| ::google::protobuf::RepeatedField<::google::protobuf::uint64> tmp; |
| const ::google::protobuf::RepeatedField<::google::protobuf::uint64>* src = |
| &tmp; |
| if (!TryConvertingTensorRawValues<::google::protobuf::uint64>( |
| onnx_tensor, &tmp)) { |
| src = &onnx_tensor.uint64_data(); |
| } |
| for (const auto i : *src) { |
| c2_values->add_ints(i); |
| } |
| } |
| |
| void Caffe2Backend::BuildTensorFillingOp( |
| caffe2::OperatorDef* c2_op, |
| const TensorProto& onnx_tensor, |
| const std::string& output_name, |
| const std::string& shape_name) { |
| auto fill_name = output_name.empty() ? onnx_tensor.name() : output_name; |
| CAFFE_ENFORCE(!fill_name.empty()); |
| |
| if (onnx_tensor.has_segment()) { |
| CAFFE_THROW("Currently not supporting loading segments."); |
| } |
| |
| auto* c2_values = c2_op->add_arg(); |
| // if shape_name is empty, we generate GivenTensorFill |
| // otherwise, we generate ConstantFill, which accept shape as input |
| if (shape_name.empty()) { |
| // GivenTensor*Fill uses values |
| c2_values->set_name("values"); |
| if (onnx_tensor.data_type() == TensorProto::FLOAT) { |
| c2_op->set_type("GivenTensorFill"); |
| auto* floats = c2_values->mutable_floats(); |
| if (!TryConvertingTensorRawValues<float>(onnx_tensor, floats)) { |
| floats->CopyFrom(onnx_tensor.float_data()); |
| } |
| } else if (onnx_tensor.data_type() == TensorProto::DOUBLE) { |
| c2_op->set_type("GivenTensorDoubleFill"); |
| ::google::protobuf::RepeatedField<double> tmp; |
| const ::google::protobuf::RepeatedField<double>* src = &tmp; |
| if (!TryConvertingTensorRawValues<double>(onnx_tensor, &tmp)) { |
| src = &onnx_tensor.double_data(); |
| } |
| for (const auto i : *src) { |
| c2_values->add_floats(i); |
| } |
| } else if (onnx_tensor.data_type() == TensorProto::INT64) { |
| ConvertIntegralValueToCaffe2<::google::protobuf::int64>( |
| c2_op, c2_values, onnx_tensor); |
| } else if (onnx_tensor.data_type() == TensorProto::UINT32) { |
| ConvertIntegralValueToCaffe2<::google::protobuf::uint64>( |
| c2_op, c2_values, onnx_tensor); |
| // NOLINTNEXTLINE(bugprone-branch-clone) |
| } else if (onnx_tensor.data_type() == TensorProto::BOOL) { |
| ConvertIntegralValueToCaffe2<::google::protobuf::int8>( |
| c2_op, c2_values, onnx_tensor); |
| } else if (onnx_tensor.data_type() == TensorProto::UINT8) { |
| ConvertIntegralValueToCaffe2<::google::protobuf::uint8>( |
| c2_op, c2_values, onnx_tensor); |
| } else if (onnx_tensor.data_type() == TensorProto::INT8) { |
| ConvertIntegralValueToCaffe2<::google::protobuf::int8>( |
| c2_op, c2_values, onnx_tensor); |
| } else if (onnx_tensor.data_type() == TensorProto::UINT16) { |
| ConvertIntegralValueToCaffe2<::google::protobuf::uint16>( |
| c2_op, c2_values, onnx_tensor); |
| } else if (onnx_tensor.data_type() == TensorProto::INT16) { |
| ConvertIntegralValueToCaffe2<::google::protobuf::int16>( |
| c2_op, c2_values, onnx_tensor); |
| } else if (onnx_tensor.data_type() == TensorProto::INT32) { |
| ConvertIntegralValueToCaffe2<::google::protobuf::int32>( |
| c2_op, c2_values, onnx_tensor); |
| } else if (onnx_tensor.data_type() == TensorProto::STRING) { |
| c2_op->set_type("GivenTensorStringFill"); |
| auto* strings = c2_values->mutable_strings(); |
| strings->CopyFrom(onnx_tensor.string_data()); |
| } else { |
| CAFFE_THROW("unrecognized tensor type: ", onnx_tensor.data_type()); |
| } |
| auto* c2_shape = c2_op->add_arg(); |
| c2_shape->set_name("shape"); |
| for (const auto d : onnx_tensor.dims()) { |
| c2_shape->add_ints(d); |
| } |
| } else { |
| int value_size = 1; |
| for (const auto d : onnx_tensor.dims()) { |
| value_size *= d; |
| } |
| CAFFE_ENFORCE(value_size == 1); |
| auto c2_input_as_shape = c2_op->add_arg(); |
| c2_input_as_shape->set_name("input_as_shape"); |
| c2_input_as_shape->set_i(1); |
| c2_values->set_name("value"); |
| auto* c2_dtype = c2_op->add_arg(); |
| c2_dtype->set_name("dtype"); |
| if (onnx_tensor.data_type() == TensorProto::FLOAT) { |
| c2_dtype->set_i(caffe2::TensorProto::FLOAT); |
| if (onnx_tensor.float_data_size() > 0) { |
| c2_values->set_f(onnx_tensor.float_data(0)); |
| } else { |
| CAFFE_ENFORCE(onnx_tensor.raw_data().size() == sizeof(float)); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| float f; |
| memcpy(&f, onnx_tensor.raw_data().c_str(), sizeof(float)); |
| c2_values->set_f(f); |
| } |
| } else if (onnx_tensor.data_type() == TensorProto::DOUBLE) { |
| c2_dtype->set_i(caffe2::TensorProto::DOUBLE); |
| if (onnx_tensor.double_data_size() > 0) { |
| c2_values->set_f(static_cast<float>(onnx_tensor.double_data(0))); |
| } else { |
| CAFFE_ENFORCE(onnx_tensor.raw_data().size() == sizeof(double)); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| double d; |
| memcpy(&d, onnx_tensor.raw_data().c_str(), sizeof(double)); |
| c2_values->set_f(static_cast<float>(d)); |
| } |
| } else if (onnx_tensor.data_type() == TensorProto::INT64) { |
| c2_dtype->set_i(caffe2::TensorProto::INT64); |
| if (onnx_tensor.int64_data_size() > 0) { |
| c2_values->set_i(onnx_tensor.int64_data(0)); |
| } else { |
| CAFFE_ENFORCE(onnx_tensor.raw_data().size() == sizeof(int64_t)); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| int64_t i; |
| memcpy(&i, onnx_tensor.raw_data().c_str(), sizeof(int64_t)); |
| c2_values->set_i(i); |
| } |
| } else if (onnx_tensor.data_type() == TensorProto::INT32) { |
| c2_dtype->set_i(caffe2::TensorProto::INT32); |
| if (onnx_tensor.int32_data_size() > 0) { |
| c2_values->set_i(onnx_tensor.int32_data(0)); |
| } else { |
| CAFFE_ENFORCE(onnx_tensor.raw_data().size() == sizeof(int32_t)); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| int32_t i; |
| memcpy(&i, onnx_tensor.raw_data().c_str(), sizeof(int32_t)); |
| c2_values->set_i(i); |
| } |
| } else { |
| // TODO: to support more data type |
| std::stringstream oss; |
| oss << "Unsupported dtype: " << onnx_tensor.data_type(); |
| CAFFE_THROW(oss.str()); |
| } |
| // ConstantFill uses value |
| c2_op->set_type("ConstantFill"); |
| c2_op->add_input(shape_name); |
| } |
| |
| c2_op->add_output(fill_name); |
| } |
| |
| bool Caffe2Backend::SupportOp(const std::string type) const { |
| return get_special_operators().count(type); |
| } |
| |
| } // namespace onnx |
| } // namespace caffe2 |