| #include "caffe2/operators/string_ops.h" |
| #include "caffe2/core/operator.h" |
| |
| namespace caffe2 { |
| |
| template <> |
| template <typename T> |
| bool StringJoinOp<CPUContext>::DoRunWithType() { |
| const auto& input = Input(0); |
| |
| CAFFE_ENFORCE_GT(input.numel(), 0); |
| CAFFE_ENFORCE_LE(input.dim(), 2, "Only 1-D and 2-D tensors are supported"); |
| |
| const auto* inputData = input.data<T>(); |
| // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) |
| int rowSize = (input.dim() == 2) ? input.size(1) : 1; |
| if (this->axis_ == 0) { |
| auto* output = Output(0, {input.size(0)}, at::dtype<std::string>()); |
| auto* outputData = output->template mutable_data<std::string>(); |
| |
| int offset = 0; |
| for (int i = 0; i < input.size(0); ++i) { |
| std::stringstream stream; |
| std::copy( |
| inputData + offset, |
| inputData + offset + rowSize, |
| std::ostream_iterator<T>(stream, delimiter_.c_str())); |
| outputData[i] = stream.str(); |
| offset += rowSize; |
| } |
| } else if (this->axis_ == 1) { |
| auto* output = Output(0, {input.size(1)}, at::dtype<std::string>()); |
| auto* outputData = output->template mutable_data<std::string>(); |
| |
| for (int j = 0; j < input.size(1); ++j) { |
| std::stringstream stream; |
| for (int i = 0; i < input.size(0); ++i) { |
| stream << inputData[i * rowSize + j] << delimiter_; |
| } |
| outputData[j] = stream.str(); |
| } |
| } else { |
| CAFFE_ENFORCE(false, "Not supported"); |
| } |
| |
| return true; |
| } |
| |
| namespace { |
| |
| struct StartsWith { |
| explicit StartsWith(OperatorBase& op) |
| : prefix_(op.GetSingleArgument<std::string>("prefix", "")) {} |
| bool operator()(const std::string& str) { |
| return std::mismatch(prefix_.begin(), prefix_.end(), str.begin()).first == |
| prefix_.end(); |
| } |
| |
| private: |
| std::string prefix_; |
| }; |
| |
| struct EndsWith { |
| explicit EndsWith(OperatorBase& op) |
| : suffix_(op.GetSingleArgument<std::string>("suffix", "")) {} |
| bool operator()(const std::string& str) { |
| return std::mismatch(suffix_.rbegin(), suffix_.rend(), str.rbegin()) |
| .first == suffix_.rend(); |
| } |
| |
| private: |
| std::string suffix_; |
| }; |
| |
| struct StrEquals { |
| explicit StrEquals(OperatorBase& op) |
| : text_(op.GetSingleArgument<std::string>("text", "")) {} |
| bool operator()(const std::string& str) { |
| return str == text_; |
| } |
| |
| private: |
| std::string text_; |
| }; |
| |
| struct Prefix { |
| explicit Prefix(OperatorBase& op) |
| : length_(op.GetSingleArgument<int>("length", 3)) {} |
| std::string operator()(const std::string& str) { |
| return std::string(str.begin(), std::min(str.end(), str.begin() + length_)); |
| } |
| |
| private: |
| int length_; |
| }; |
| |
| struct Suffix { |
| explicit Suffix(OperatorBase& op) |
| : length_(op.GetSingleArgument<int>("length", 3)) {} |
| std::string operator()(const std::string& str) { |
| return std::string(std::max(str.begin(), str.end() - length_), str.end()); |
| } |
| |
| private: |
| int length_; |
| }; |
| |
| template <typename ScalarFunctor, typename TypeMap = FixedType<std::string>> |
| using StringElementwiseOp = UnaryElementwiseWithArgsOp< |
| TensorTypes<std::string>, |
| CPUContext, |
| ForEach<ScalarFunctor>, |
| TypeMap>; |
| |
| REGISTER_CPU_OPERATOR(StringPrefix, StringElementwiseOp<Prefix>); |
| REGISTER_CPU_OPERATOR(StringSuffix, StringElementwiseOp<Suffix>); |
| REGISTER_CPU_OPERATOR( |
| StringStartsWith, |
| StringElementwiseOp<StartsWith, FixedType<bool>>); |
| REGISTER_CPU_OPERATOR( |
| StringEndsWith, |
| StringElementwiseOp<EndsWith, FixedType<bool>>); |
| REGISTER_CPU_OPERATOR( |
| StringEquals, |
| StringElementwiseOp<StrEquals, FixedType<bool>>); |
| REGISTER_CPU_OPERATOR(StringJoin, StringJoinOp<CPUContext>); |
| |
| OPERATOR_SCHEMA(StringPrefix) |
| .NumInputs(1) |
| .NumOutputs(1) |
| .SetDoc(R"DOC( |
| Computes the element-wise string prefix of the string tensor. |
| Input strings that are shorter than prefix length will be returned unchanged. |
| NOTE: Prefix is computed on number of bytes, which may lead to wrong behavior |
| and potentially invalid strings for variable-length encodings such as utf-8. |
| )DOC") |
| .Arg("length", "Maximum size of the prefix, in bytes.") |
| .Input(0, "strings", "Tensor of std::string.") |
| .Output( |
| 0, |
| "prefixes", |
| "Tensor of std::string containing prefixes for each input."); |
| |
| OPERATOR_SCHEMA(StringSuffix) |
| .NumInputs(1) |
| .NumOutputs(1) |
| .SetDoc(R"DOC( |
| Computes the element-wise string suffix of the string tensor. |
| Input strings that are shorter than suffix length will be returned unchanged. |
| NOTE: Prefix is computed on number of bytes, which may lead to wrong behavior |
| and potentially invalid strings for variable-length encodings such as utf-8. |
| )DOC") |
| .Input(0, "strings", "Tensor of std::string.") |
| .Output( |
| 0, |
| "suffixes", |
| "Tensor of std::string containing suffixes for each output.") |
| .Arg("length", "Maximum size of the suffix, in bytes."); |
| |
| OPERATOR_SCHEMA(StringStartsWith) |
| .NumInputs(1) |
| .NumOutputs(1) |
| .SetDoc(R"DOC( |
| Performs the starts-with check on each string in the input tensor. |
| Returns tensor of boolean of the same dimension of input. |
| )DOC") |
| .Arg("prefix", "The prefix to check input strings against.") |
| .Input(0, "strings", "Tensor of std::string.") |
| .Output(0, "bools", "Tensor of bools of same shape as input."); |
| |
| OPERATOR_SCHEMA(StringEndsWith) |
| .NumInputs(1) |
| .NumOutputs(1) |
| .SetDoc(R"DOC( |
| Performs the ends-with check on each string in the input tensor. |
| Returns tensor of boolean of the same dimension of input. |
| )DOC") |
| .Arg("suffix", "The suffix to check input strings against.") |
| .Input(0, "strings", "Tensor of std::string.") |
| .Output(0, "bools", "Tensor of bools of same shape as input."); |
| |
| OPERATOR_SCHEMA(StringEquals) |
| .NumInputs(1) |
| .NumOutputs(1) |
| .SetDoc(R"DOC( |
| Performs equality check on each string in the input tensor. |
| Returns tensor of booleans of the same dimension as input. |
| )DOC") |
| .Arg("text", "The text to check input strings equality against.") |
| .Input(0, "strings", "Tensor of std::string.") |
| .Output(0, "bools", "Tensor of bools of same shape as input."); |
| |
| OPERATOR_SCHEMA(StringJoin) |
| .NumInputs(1) |
| .NumOutputs(1) |
| .SetDoc(R"DOC( |
| Takes a 1-D or a 2-D tensor as input and joins elements in each row with the |
| provided delimiter. Output is a 1-D tensor of size equal to the first dimension |
| of the input. Each element in the output tensor is a string of concatenated |
| elements corresponding to each row in the input tensor. For 1-D input, each |
| element is treated as a row. |
| )DOC") |
| .Arg("delimiter", "Delimiter for join (Default: \",\").") |
| .Arg("axis", "Axis for the join (either 0 or 1)") |
| .Input(0, "input", "1-D or 2-D tensor") |
| .Output( |
| 0, |
| "strings", |
| "1-D tensor of strings created by joining row elements from the " |
| "input tensor."); |
| |
| SHOULD_NOT_DO_GRADIENT(StringPrefix); |
| SHOULD_NOT_DO_GRADIENT(StringSuffix); |
| SHOULD_NOT_DO_GRADIENT(StringStartsWith); |
| SHOULD_NOT_DO_GRADIENT(StringEndsWith); |
| SHOULD_NOT_DO_GRADIENT(StringEquals); |
| SHOULD_NOT_DO_GRADIENT(StringJoin); |
| } |
| } // namespace caffe2 |