| #include "caffe2/operators/index_ops.h" |
| #include <atomic> |
| #include <limits> |
| #include <mutex> |
| #include <sstream> |
| #include <unordered_map> |
| #include <vector> |
| #include "caffe2/core/blob_serialization.h" |
| #include "caffe2/core/operator.h" |
| #include "caffe2/core/tensor.h" |
| |
| namespace caffe2 { |
| |
| // TODO(azzolini): support sizes larger than int32 |
| template <class T> |
| // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
| class IndexCreateOp : public Operator<CPUContext> { |
| public: |
| template <class... Args> |
| explicit IndexCreateOp(Args&&... args) |
| : Operator(std::forward<Args>(args)...), |
| maxElements_(OperatorBase::GetSingleArgument<int>( |
| "max_elements", |
| std::numeric_limits<int>::max())) {} |
| |
| bool RunOnDevice() override { |
| *OperatorBase::Output<std::unique_ptr<IndexBase>>(0) = |
| std::unique_ptr<IndexBase>(new Index<T>(maxElements_)); |
| return true; |
| } |
| |
| private: |
| int64_tValue maxElements_; |
| }; |
| |
| class IndexGetOp : public Operator<CPUContext> { |
| public: |
| template <class... Args> |
| explicit IndexGetOp(Args&&... args) : Operator(std::forward<Args>(args)...) {} |
| |
| bool RunOnDevice() override { |
| return DispatchHelper<IndexKeyTypes>::call(this, Input(1)); |
| } |
| template <typename T> |
| bool DoRunWithType() { |
| auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0); |
| auto* dict = dynamic_cast_if_rtti<Index<T>*>(base.get()); |
| CAFFE_ENFORCE(dict, "Wrong dictionary type given input keys."); |
| const auto& keys = Input(1); |
| |
| auto* values = Output(0, keys.sizes(), at::dtype<int64_tValue>()); |
| dict->Get( |
| keys.data<T>(), |
| values->template mutable_data<int64_tValue>(), |
| keys.numel()); |
| return true; |
| } |
| }; |
| |
| class IndexLoadOp : public Operator<CPUContext> { |
| public: |
| template <class... Args> |
| explicit IndexLoadOp(Args&&... args) |
| : Operator(std::forward<Args>(args)...), |
| skipFirstEntry_( |
| OperatorBase::GetSingleArgument<int>("skip_first_entry", 0)) {} |
| |
| bool RunOnDevice() override { |
| return DispatchHelper<IndexKeyTypes>::call(this, Input(1)); |
| } |
| template <typename T> |
| bool DoRunWithType() { |
| auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0); |
| auto* dict = dynamic_cast_if_rtti<Index<T>*>(base.get()); |
| CAFFE_ENFORCE(dict, "Wrong dictionary type given input keys."); |
| const auto& keys = Input(1); |
| const auto* keys_data = keys.data<T>(); |
| auto keys_size = keys.numel(); |
| if (skipFirstEntry_) { |
| CAFFE_ENFORCE(keys.numel() > 0); |
| ++keys_data; |
| --keys_size; |
| } |
| return dict->Load(keys_data, keys_size); |
| } |
| |
| private: |
| bool skipFirstEntry_; |
| }; |
| |
| class IndexStoreOp : public Operator<CPUContext> { |
| public: |
| template <class... Args> |
| explicit IndexStoreOp(Args&&... args) |
| : Operator(std::forward<Args>(args)...) {} |
| |
| bool RunOnDevice() override { |
| auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0); |
| return DispatchHelper<IndexKeyTypes>::call(this, base->Type()); |
| } |
| |
| template <typename T> |
| bool DoRunWithType() { |
| auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0); |
| auto* dict = dynamic_cast_if_rtti<Index<T>*>(base.get()); |
| CAFFE_ENFORCE(dict); |
| return dict->Store(Output(0)); |
| } |
| }; |
| |
| class IndexFreezeOp : public Operator<CPUContext> { |
| public: |
| template <class... Args> |
| explicit IndexFreezeOp(Args&&... args) |
| : Operator(std::forward<Args>(args)...) {} |
| |
| bool RunOnDevice() override { |
| auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0); |
| base->Freeze(); |
| return true; |
| } |
| }; |
| |
| class IndexSizeOp : public Operator<CPUContext> { |
| public: |
| template <class... Args> |
| explicit IndexSizeOp(Args&&... args) |
| : Operator(std::forward<Args>(args)...) {} |
| |
| bool RunOnDevice() override { |
| auto& base = OperatorBase::Input<std::unique_ptr<IndexBase>>(0); |
| |
| auto* out = Output(0, std::vector<int64_t>{}, at::dtype<int64_tValue>()); |
| *out->template mutable_data<int64_tValue>() = base->Size(); |
| return true; |
| } |
| }; |
| |
| REGISTER_CPU_OPERATOR(IntIndexCreate, IndexCreateOp<int32_t>); |
| REGISTER_CPU_OPERATOR(LongIndexCreate, IndexCreateOp<int64_t>); |
| REGISTER_CPU_OPERATOR(StringIndexCreate, IndexCreateOp<std::string>); |
| |
| REGISTER_CPU_OPERATOR(IndexGet, IndexGetOp); |
| REGISTER_CPU_OPERATOR(IndexLoad, IndexLoadOp); |
| REGISTER_CPU_OPERATOR(IndexStore, IndexStoreOp); |
| REGISTER_CPU_OPERATOR(IndexFreeze, IndexFreezeOp); |
| REGISTER_CPU_OPERATOR(IndexSize, IndexSizeOp); |
| |
| OPERATOR_SCHEMA(IntIndexCreate) |
| .NumInputs(0) |
| .NumOutputs(1) |
| .SetDoc(R"DOC( |
| Creates a dictionary that maps int32 keys to consecutive integers |
| from 1 to max_elements. Zero is reserved for unknown keys. |
| )DOC") |
| .Arg("max_elements", "Max number of elements, including the zero entry.") |
| .Output(0, "handler", "Pointer to an Index instance.") |
| .ScalarType(TensorProto_DataType_UNDEFINED); |
| |
| OPERATOR_SCHEMA(LongIndexCreate) |
| .NumInputs(0) |
| .NumOutputs(1) |
| .SetDoc(R"DOC( |
| Creates a dictionary that maps int64 keys to consecutive integers |
| from 1 to max_elements. Zero is reserved for unknown keys. |
| )DOC") |
| .Arg("max_elements", "Max number of elements, including the zero entry.") |
| .Output(0, "handler", "Pointer to an Index instance.") |
| .ScalarType(TensorProto_DataType_UNDEFINED); |
| |
| OPERATOR_SCHEMA(StringIndexCreate) |
| .NumInputs(0) |
| .NumOutputs(1) |
| .SetDoc(R"DOC( |
| Creates a dictionary that maps string keys to consecutive integers |
| from 1 to max_elements. Zero is reserved for unknown keys. |
| )DOC") |
| .Arg("max_elements", "Max number of elements, including the zero entry.") |
| .Output(0, "handle", "Pointer to an Index instance.") |
| .ScalarType(TensorProto_DataType_UNDEFINED); |
| |
| OPERATOR_SCHEMA(IndexGet) |
| .NumInputs(2) |
| .NumOutputs(1) |
| .SetDoc(R"DOC( |
| Given an index handle and a tensor of keys, return an Int tensor of same shape |
| containing the indices for each of the keys. If the index is frozen, unknown |
| entries are given index 0. Otherwise, new entries are added into the index. |
| If an insert is necessary but max_elements has been reached, fail. |
| )DOC") |
| .Input(0, "handle", "Pointer to an Index instance.") |
| .Input(1, "keys", "Tensor of keys to be looked up.") |
| .Output(0, "indices", "Indices for each of the keys.") |
| .ScalarType(TensorProto::INT64); |
| |
| OPERATOR_SCHEMA(IndexFreeze) |
| .NumInputs(1) |
| .NumOutputs(1) |
| .SetDoc(R"DOC( |
| Freezes the given index, disallowing creation of new index entries. |
| Should not be called concurrently with IndexGet. |
| )DOC") |
| .Input(0, "handle", "Pointer to an Index instance.") |
| .Output(0, "handle", "The input handle.") |
| .EnforceInplace({{0, 0}}) |
| .ScalarType(TensorProto_DataType_UNDEFINED); |
| |
| OPERATOR_SCHEMA(IndexLoad) |
| .NumInputs(2) |
| .NumOutputs(1) |
| .SetDoc(R"DOC( |
| Loads the index from the given 1-D tensor. Elements in the tensor will be given |
| consecutive indexes starting at 1. Fails if tensor contains repeated elements. |
| )DOC") |
| .Input(0, "handle", "Pointer to an Index instance.") |
| .Input(1, "items", "1-D tensor with elements starting with index 1.") |
| .Output(0, "handle", "The input handle.") |
| .EnforceInplace({{0, 0}}) |
| .Arg( |
| "skip_first_entry", |
| "If set, skips the first entry of the tensor. This allows " |
| "to load tensors that are aligned with an embedding, where the first " |
| "entry corresponds to the default 0 index entry.") |
| .ScalarType(TensorProto_DataType_UNDEFINED); |
| |
| OPERATOR_SCHEMA(IndexStore) |
| .NumInputs(1) |
| .NumOutputs(1) |
| .SetDoc(R"DOC( |
| Stores the keys of this index in a 1-D tensor. Since element 0 is reserved |
| for unknowns, the first element of the output tensor will be element of index 1. |
| )DOC") |
| .Input(0, "handle", "Pointer to an Index instance.") |
| .Output(0, "items", "1-D tensor with elements starting with index 1."); |
| |
| OPERATOR_SCHEMA(IndexSize) |
| .NumInputs(1) |
| .NumOutputs(1) |
| .SetDoc(R"DOC( |
| Returns the number of entries currently present in the index. |
| )DOC") |
| .Input(0, "handle", "Pointer to an Index instance.") |
| .Output(0, "items", "Scalar int64 tensor with number of entries."); |
| |
| NO_GRADIENT(IndexGetOp); |
| NO_GRADIENT(IntIndexCreate); |
| NO_GRADIENT(LongIndexCreate); |
| NO_GRADIENT(StringIndexCreate); |
| SHOULD_NOT_DO_GRADIENT(IndexFreeze); |
| SHOULD_NOT_DO_GRADIENT(IndexLoad); |
| SHOULD_NOT_DO_GRADIENT(IndexStore); |
| SHOULD_NOT_DO_GRADIENT(IndexSize); |
| |
| class IndexSerializer : public BlobSerializerBase { |
| public: |
| // NOLINTNEXTLINE(modernize-use-equals-default) |
| IndexSerializer() {} |
| // NOLINTNEXTLINE(modernize-use-equals-default) |
| ~IndexSerializer() override {} |
| |
| void Serialize( |
| const void* pointer, |
| TypeMeta typeMeta, |
| const string& name, |
| SerializationAcceptor acceptor) override { |
| CAFFE_ENFORCE(typeMeta.Match<std::unique_ptr<IndexBase>>()); |
| const auto& base = *static_cast<const std::unique_ptr<IndexBase>*>(pointer); |
| Blob tensor_blob; |
| auto* tensor_out = BlobGetMutableTensor(&tensor_blob, CPU); |
| |
| if (base->Type().Match<std::string>()) { |
| doStore<std::string>(base, tensor_out); |
| } else if (base->Type().Match<int32_t>()) { |
| doStore<int32_t>(base, tensor_out); |
| } else if (base->Type().Match<int64_t>()) { |
| doStore<int64_t>(base, tensor_out); |
| } else { |
| CAFFE_THROW("Index of this type can't be serialized."); |
| } |
| |
| CAFFE_ENFORCE( |
| tensor_out->numel() <= std::numeric_limits<int32_t>::max(), |
| "Index too large to be serialized."); |
| BlobProto blob_proto; |
| TensorSerializer ser; |
| ser.Serialize( |
| *tensor_out, name, blob_proto.mutable_tensor(), 0, tensor_out->numel()); |
| blob_proto.set_name(name); |
| blob_proto.set_type("std::unique_ptr<caffe2::IndexBase>"); |
| |
| std::ostringstream os; |
| os << base->maxElements() << " " << base->isFrozen(); |
| blob_proto.set_content(os.str()); |
| |
| acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto)); |
| } |
| |
| private: |
| template <typename T> |
| void doStore(const std::unique_ptr<IndexBase>& base, Tensor* tensor_out) { |
| auto* dict = dynamic_cast_if_rtti<Index<T>*>(base.get()); |
| CAFFE_ENFORCE(dict, "Wrong dictionary type."); |
| dict->Store(tensor_out); |
| } |
| }; |
| |
| class IndexDeserializer : public BlobDeserializerBase { |
| public: |
| void Deserialize(const BlobProto& proto, Blob* blob) override { |
| TensorDeserializer deser; |
| Blob tensor_blob; |
| deser.Deserialize(proto, &tensor_blob); |
| |
| std::istringstream is(proto.content()); |
| int64_t maxElements{std::numeric_limits<int64_t>::max()}; |
| bool isFrozen{false}; |
| is >> maxElements >> isFrozen; |
| |
| auto& tensor_in = tensor_blob.template Get<Tensor>(); |
| auto* base = blob->template GetMutable<std::unique_ptr<IndexBase>>(); |
| |
| if (tensor_in.IsType<std::string>()) { |
| doLoad<std::string>(base, maxElements, tensor_in); |
| } else if (tensor_in.IsType<int32_t>()) { |
| doLoad<int32_t>(base, maxElements, tensor_in); |
| } else if (tensor_in.IsType<int64_t>()) { |
| doLoad<int64_t>(base, maxElements, tensor_in); |
| } else { |
| CAFFE_THROW("Index of this type cannot be deserialized."); |
| } |
| |
| if (isFrozen) { |
| (*base)->Freeze(); |
| } |
| } |
| |
| private: |
| template <typename T> |
| void doLoad( |
| std::unique_ptr<IndexBase>* base, |
| int64_t maxElements, |
| const Tensor& tensor_in) { |
| base->reset(new Index<T>(maxElements)); |
| auto* dict = dynamic_cast_if_rtti<Index<T>*>(base->get()); |
| dict->Load(tensor_in.data<T>(), tensor_in.numel()); |
| } |
| }; |
| |
| CAFFE_KNOWN_TYPE(std::unique_ptr<caffe2::IndexBase>); |
| |
| REGISTER_BLOB_SERIALIZER( |
| (TypeMeta::Id<std::unique_ptr<caffe2::IndexBase>>()), |
| IndexSerializer); |
| REGISTER_BLOB_DESERIALIZER( |
| std::unique_ptr<caffe2::IndexBase>, |
| IndexDeserializer); |
| |
| } // namespace caffe2 |