| #ifndef CAFFE2_OPERATORS_DISTANCE_OP_H_ |
| #define CAFFE2_OPERATORS_DISTANCE_OP_H_ |
| |
| #include "caffe2/core/context.h" |
| #include "caffe2/core/operator.h" |
| #include "caffe2/utils/math.h" |
| #include "c10/util/irange.h" |
| |
| namespace caffe2 { |
| |
| template <typename T, class Context> |
| class SquaredL2DistanceOp : public Operator<Context> { |
| public: |
| template <class... Args> |
| explicit SquaredL2DistanceOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...) {} |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| |
| bool RunOnDevice() override; |
| |
| protected: |
| // Input: X, Y; Output: Distance |
| }; |
| |
| template <typename T, class Context> |
| class SquaredL2DistanceGradientOp final : public Operator<Context> { |
| public: |
| template <class... Args> |
| explicit SquaredL2DistanceGradientOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...) {} |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| |
| bool RunOnDevice() override { |
| auto& X = Input(0); |
| auto& Y = Input(1); |
| auto& dDistance = Input(2); |
| |
| int N = X.dim() > 0 ? X.dim32(0) : 1; |
| int D = N > 0 ? X.numel() / N : 0; |
| CAFFE_ENFORCE(X.dim() == Y.dim()); |
| for (const auto i : c10::irange(X.dim())) { |
| CAFFE_ENFORCE(X.dim32(i) == Y.dim32(i)); |
| } |
| CAFFE_ENFORCE(dDistance.dim() == 1); |
| CAFFE_ENFORCE(dDistance.dim32(0) == N); |
| auto* dX = Output(0, X.sizes(), at::dtype<T>()); |
| auto* dY = Output(1, Y.sizes(), at::dtype<T>()); |
| math::Sub<T, Context>( |
| X.numel(), |
| X.template data<T>(), |
| Y.template data<T>(), |
| dX->template mutable_data<T>(), |
| &context_); |
| for (const auto i : c10::irange(N)) { |
| math::Scale<T, T, Context>( |
| D, |
| dDistance.template data<T>() + i, |
| dX->template data<T>() + i * D, |
| dX->template mutable_data<T>() + i * D, |
| &context_); |
| } |
| // The gradient of the other side is basically the negative. |
| math::Scale<T, T, Context>( |
| X.numel(), |
| -1, |
| dX->template data<T>(), |
| dY->template mutable_data<T>(), |
| &context_); |
| return true; |
| } |
| |
| protected: |
| // Input: X, Y, dDistance; Output: dX, dY |
| }; |
| |
| template <typename T, class Context> |
| class L1DistanceOp : public Operator<Context> { |
| public: |
| template <class... Args> |
| explicit L1DistanceOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...) {} |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| |
| bool RunOnDevice() override; |
| |
| protected: |
| // Input: X, Y; Output: Distance |
| }; |
| |
| template <typename T, class Context> |
| class L1DistanceGradientOp : public Operator<Context> { |
| public: |
| template <class... Args> |
| explicit L1DistanceGradientOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...) {} |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| |
| bool RunOnDevice() override; |
| |
| protected: |
| // Input: X, Y, dDistance; Output: dX, dY |
| }; |
| |
| template <typename T, class Context> |
| class DotProductOp : public Operator<Context> { |
| public: |
| template <class... Args> |
| explicit DotProductOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...) {} |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| |
| bool RunOnDevice() override; |
| |
| protected: |
| INPUT_TAGS(X_IN, Y_IN); |
| OUTPUT_TAGS(DOT_OUT); |
| }; |
| |
| template <typename T, class Context> |
| class DotProductGradientOp final : public Operator<Context> { |
| public: |
| template <class... Args> |
| explicit DotProductGradientOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...) {} |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| |
| bool RunOnDevice() override; |
| |
| protected: |
| INPUT_TAGS(X_IN, Y_IN, DER_DOT_IN); |
| OUTPUT_TAGS(DER_X_OUT, DER_Y_OUT); |
| }; |
| |
| template <typename T, class Context> |
| class DotProductWithPaddingOp : public Operator<Context> { |
| public: |
| template <class... Args> |
| explicit DotProductWithPaddingOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...), |
| pad_value_(this->template GetSingleArgument<float>("pad_value", 0.0)), |
| replicate_(this->template GetSingleArgument<bool>("replicate", false)) { |
| } |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| |
| bool RunOnDevice() override; |
| |
| protected: |
| float pad_value_; |
| bool replicate_; |
| INPUT_TAGS(X_IN, Y_IN); |
| OUTPUT_TAGS(DOT_OUT); |
| }; |
| |
| template <typename T, class Context> |
| class CosineSimilarityOp : public Operator<Context> { |
| public: |
| template <class... Args> |
| explicit CosineSimilarityOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...) {} |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| |
| bool RunOnDevice() override; |
| |
| protected: |
| INPUT_TAGS(X_IN, Y_IN); |
| OUTPUT_TAGS(COS_OUT); |
| |
| private: |
| Tensor aux_; |
| }; |
| |
| template <typename T, class Context> |
| class CosineSimilarityGradientOp final : public Operator<Context> { |
| public: |
| template <class... Args> |
| explicit CosineSimilarityGradientOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...) {} |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| |
| bool RunOnDevice() override; |
| |
| protected: |
| INPUT_TAGS(X_IN, Y_IN, DER_COS_IN); |
| OUTPUT_TAGS(DER_X_OUT, DER_Y_OUT); |
| |
| private: |
| Tensor aux_; |
| }; |
| |
| template <typename T, class Context> |
| class DotProductWithPaddingGradientOp final : public Operator<Context> { |
| public: |
| template <class... Args> |
| explicit DotProductWithPaddingGradientOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...), |
| pad_value_(this->template GetSingleArgument<float>("pad_value", 0.0)), |
| replicate_(this->template GetSingleArgument<bool>("replicate", false)) { |
| } |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| |
| bool RunOnDevice() override { |
| auto& X = Input(X_IN); |
| auto& Y = Input(Y_IN); |
| auto& dDot = Input(DER_DOT_IN); |
| |
| int N, D, DX, DY, restD; |
| if (X.numel() > 0) { |
| N = X.dim() > 0 ? X.dim32(0) : 1; |
| DX = X.numel() / N; |
| DY = Y.numel() / N; |
| } else { |
| N = 0; |
| DX = 0; |
| DY = 0; |
| } |
| CAFFE_ENFORCE(!replicate_ || DX % DY == 0 || DY % DX == 0); |
| D = std::min(DX, DY); |
| restD = std::max(DX, DY) - D; |
| CAFFE_ENFORCE_EQ(X.dim(), Y.dim()); |
| CAFFE_ENFORCE_EQ(X.dim32(0), Y.dim32(0)); |
| CAFFE_ENFORCE_EQ(dDot.dim(), 1); |
| CAFFE_ENFORCE_EQ(dDot.dim32(0), N); |
| auto* dX = Output(DER_X_OUT, X.sizes(), at::dtype<T>()); |
| auto* dY = Output(DER_Y_OUT, Y.sizes(), at::dtype<T>()); |
| |
| const auto* X_data = X.template data<T>(); |
| const auto* Y_data = Y.template data<T>(); |
| const auto* dDot_data = dDot.template data<T>(); |
| auto* dX_data = dX->template mutable_data<T>(); |
| auto* dY_data = dY->template mutable_data<T>(); |
| for (const auto i : c10::irange(N)) { // TODO: multithreading |
| auto offsetX = i * DX; |
| auto offsetY = i * DY; |
| if (replicate_) { |
| // L_ for longer vector and S_ for shorter vector |
| const T *L_data, *S_data; |
| T *dL_data, *dS_data; |
| int DL, DS; |
| if (DX > DY) { |
| L_data = X_data + offsetX; |
| S_data = Y_data + offsetY; |
| dL_data = dX_data + offsetX; |
| dS_data = dY_data + offsetY; |
| DL = DX; |
| DS = DY; |
| } else { |
| L_data = Y_data + offsetY; |
| S_data = X_data + offsetX; |
| dL_data = dY_data + offsetY; |
| dS_data = dX_data + offsetX; |
| DL = DY; |
| DS = DX; |
| } |
| |
| // TODO: get rid of temp memory use |
| std::vector<T> tmp_data(DS); |
| math::Set<T, Context>(DS, 0.0, dS_data, &context_); |
| for (int j = 0; j < DL / DS; j++) { |
| math::Scale<T, T, Context>( |
| DS, dDot_data[i], S_data, dL_data + j * DS, &context_); |
| math::Scale<T, T, Context>( |
| DS, dDot_data[i], L_data + j * DS, tmp_data.data(), &context_); |
| math::Axpy<float, T, Context>( |
| DS, 1.0, tmp_data.data(), dS_data, &context_); |
| } |
| } else { |
| math::Scale<T, T, Context>( |
| D, dDot_data[i], X_data + offsetX, dY_data + offsetY, &context_); |
| math::Scale<T, T, Context>( |
| D, dDot_data[i], Y_data + offsetY, dX_data + offsetX, &context_); |
| } |
| |
| if (!replicate_ && DX != DY) { |
| T* rest_data; |
| if (DX > DY) { |
| rest_data = dX_data + offsetX + D; |
| } else { |
| rest_data = dY_data + offsetY + D; |
| } |
| auto pad_gradient = dDot_data[i] * pad_value_; |
| math::Set<T, Context>(restD, pad_gradient, rest_data, &context_); |
| } |
| } |
| |
| return true; |
| } |
| |
| protected: |
| float pad_value_; |
| bool replicate_; |
| INPUT_TAGS(X_IN, Y_IN, DER_DOT_IN); |
| OUTPUT_TAGS(DER_X_OUT, DER_Y_OUT); |
| }; |
| |
| } // namespace caffe2 |
| |
| #endif // CAFFE2_OPERATORS_DISTANCE_OP_H_ |