blob: fb60c984b2e43eceb0f954147c0aadc9be56fee8 [file] [log] [blame]
#include "channel_shuffle_op.h"
#include <array>
#include <string>
#include <vector>
#ifdef CAFFE2_USE_MKL
#include <mkl.h>
#endif // CAFFE2_USE_MKL
#include "caffe2/utils/math.h"
namespace caffe2 {
namespace {
template <typename T>
void RunChannelShuffleNCHW(
const int N,
const int G,
const int K,
const int HxW,
const T* X,
T* Y,
CPUContext* context) {
const int stride = G * K * HxW;
for (int i = 0; i < N; ++i) {
if (G < K) {
for (int j = 0; j < G; ++j) {
math::CopyMatrix<T, CPUContext>(
K, HxW, X + j * K * HxW, HxW, Y + j * HxW, G * HxW, context);
}
} else {
for (int j = 0; j < K; ++j) {
math::CopyMatrix<T, CPUContext>(
G, HxW, X + j * HxW, K * HxW, Y + j * G * HxW, HxW, context);
}
}
X += stride;
Y += stride;
}
}
template <typename T>
void RunChannelShuffleNHWC(
const int N,
const int G,
const int K,
const int HxW,
const T* X,
T* Y,
CPUContext* context) {
const std::array<std::int64_t, 2> dims = {G, K};
const std::array<std::int32_t, 2> axes = {1, 0};
const int M = N * HxW;
const int C = G * K;
for (int i = 0; i < M; ++i) {
math::Transpose<std::int64_t, T, CPUContext>(
2, dims.data(), axes.data(), X, Y, context);
X += C;
Y += C;
}
}
} // namespace
template <>
bool ChannelShuffleOp<float, CPUContext>::RunOnDeviceWithOrderNCHW() {
const auto& X = Input(0);
auto* Y = Output(0, X.sizes(), at::dtype<float>());
const int N = X.dim32(0);
const int C = X.dim32(1);
const int G = group_;
CAFFE_ENFORCE_EQ(C % G, 0);
const int K = C / G;
const int HxW = X.size_from_dim(2);
const float* X_data = X.data<float>();
float* Y_data = Y->mutable_data<float>();
RunChannelShuffleNCHW<float>(N, G, K, HxW, X_data, Y_data, &context_);
return true;
} // namespace caffe2
template <>
bool ChannelShuffleOp<float, CPUContext>::RunOnDeviceWithOrderNHWC() {
const auto& X = Input(0);
auto* Y = Output(0, X.sizes(), at::dtype<float>());
const int ndim = X.dim();
const int N = X.dim32(0);
const int C = X.dim32(ndim - 1);
const int G = group_;
CAFFE_ENFORCE_EQ(C % G, 0);
const int K = C / G;
const int HxW = X.size_between_dim(0, ndim - 1);
const float* X_data = X.data<float>();
float* Y_data = Y->mutable_data<float>();
RunChannelShuffleNHWC<float>(N, G, K, HxW, X_data, Y_data, &context_);
return true;
}
template <>
bool ChannelShuffleGradientOp<float, CPUContext>::RunOnDeviceWithOrderNCHW() {
const auto& dY = Input(0);
auto* dX = Output(0, dY.sizes(), at::dtype<float>());
const int N = dY.dim32(0);
const int C = dY.dim32(1);
const int G = group_;
CAFFE_ENFORCE_EQ(C % G, 0);
const int K = C / G;
const int HxW = dY.size_from_dim(2);
const float* dY_data = dY.data<float>();
float* dX_data = dX->mutable_data<float>();
RunChannelShuffleNCHW<float>(N, K, G, HxW, dY_data, dX_data, &context_);
return true;
}
template <>
bool ChannelShuffleGradientOp<float, CPUContext>::RunOnDeviceWithOrderNHWC() {
const auto& dY = Input(0);
auto* dX = Output(0, dY.sizes(), at::dtype<float>());
const int ndim = dY.dim();
const int N = dY.dim32(0);
const int C = dY.dim32(ndim - 1);
const int G = group_;
CAFFE_ENFORCE_EQ(C % G, 0);
const int K = C / G;
const int HxW = dY.size_between_dim(0, ndim - 1);
const float* dY_data = dY.data<float>();
float* dX_data = dX->mutable_data<float>();
RunChannelShuffleNHWC<float>(N, K, G, HxW, dY_data, dX_data, &context_);
return true;
}
REGISTER_CPU_OPERATOR(ChannelShuffle, ChannelShuffleOp<float, CPUContext>);
REGISTER_CPU_GRADIENT_OPERATOR(
ChannelShuffleGradient,
ChannelShuffleGradientOp<float, CPUContext>);
OPERATOR_SCHEMA(ChannelShuffle)
.IdenticalTypeAndShape()
.NumInputs(1)
.NumOutputs(1)
.InheritOnnxSchema();
GRADIENT_OPERATOR_SCHEMA(ChannelShuffleGradient)
.IdenticalTypeAndShape()
.NumInputs(1)
.NumOutputs(1);
namespace {
class GetChannelShuffleGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
std::vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
"ChannelShuffleGradient",
"",
std::vector<std::string>{GO(0)},
std::vector<std::string>{GI(0)});
}
};
} // namespace
REGISTER_GRADIENT(ChannelShuffle, GetChannelShuffleGradient);
} // namespace caffe2