blob: 3f52a906a276589e9114eea16367bf5b306b7f73 [file] [log] [blame]
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef CAFFE2_OPERATORS_FUNHASH_OP_H_
#define CAFFE2_OPERATORS_FUNHASH_OP_H_
#include <xxhash.h>
#include <array>
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
#define SIGN_MAGIC 0x9e3779b97f4a7c15
#define INDEX_MAGIC 0xf39cc0605cedc834
#define USE_SIGN
namespace caffe2 {
template <typename T, class Context>
class FunHashOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
FunHashOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
num_outputs_(
OperatorBase::GetSingleArgument<int64_t>("num_outputs", -1)),
num_segments_(
OperatorBase::GetSingleArgument<int64_t>("num_segments", -1)),
seed_(OperatorBase::GetSingleArgument<uint64_t>("seed", 0)) {
CAFFE_ENFORCE(
OperatorBase::HasArgument("num_outputs"),
"Argument `num_outputs` is missing.");
// If alpha is provided, use adaptive hashing parameterized by alpha.
adaptive_ = (InputSize() == 5);
}
bool RunOnDevice() override {
const auto& val = Input(0);
const auto& key = Input(1);
const auto& seg = Input(2);
const auto& weight = Input(3);
int64_t num_alpha = 1;
if (adaptive_) {
const auto& alpha = Input(4);
num_alpha = alpha.size(0);
}
const auto* seg_data = seg.template data<int>();
int64_t num_weight = weight.size(0);
int64_t num_nz_ent = seg.size(0);
int64_t n_segments = num_segments_;
if (num_segments_ == -1) {
for (const auto i : c10::irange(num_nz_ent)) {
if (seg_data[i] > n_segments) {
n_segments = seg_data[i];
}
}
++n_segments;
}
auto* output = Output(0, {n_segments, num_outputs_}, at::dtype<T>());
T* output_data = output->template mutable_data<T>();
memset(output_data, 0, sizeof(T) * n_segments * num_outputs_);
const auto* weight_data = weight.template data<T>();
const auto* alpha_data = adaptive_ ? Input(4).template data<T>() : 0;
const auto* val_data = val.template data<T>();
const auto* key_data = key.template data<int64_t>();
for (const auto j : c10::irange(num_nz_ent)) {
int64_t cur_seg = seg_data[j];
int64_t cur_key = key_data[j];
T cur_val = val_data[j];
int64_t output_stride = cur_seg * num_outputs_;
for (const auto i : c10::irange(num_outputs_)) {
T sum = 0;
for (const auto k : c10::irange(num_alpha)) {
uint64_t hash;
// The hash function takes as input four integers:
// 1. feature index
// 2. output index
// 3. alpha index
// 4. magic number: SIGN_MAGIC for sign (-1/+1)
// INDEX_MAGIC for weight index
hash_data[0] = cur_key;
hash_data[1] = i;
hash_data[2] = k;
hash_data[3] = INDEX_MAGIC;
hash = XXH64(hash_data.data(), hash_data.size(), seed_);
int64_t index = hash % num_weight;
T cur_weight = weight_data[index];
#ifdef USE_SIGN
hash_data[3] = SIGN_MAGIC;
hash = XXH64(hash_data.data(), hash_data.size(), seed_);
if (hash % 2) {
cur_weight = -cur_weight;
}
#endif // USE_SIGN
if (adaptive_) {
sum += cur_weight * alpha_data[k];
} else {
sum += cur_weight;
}
}
output_data[output_stride + i] += sum * cur_val;
}
}
return true;
}
protected:
int64_t num_outputs_;
int64_t num_segments_;
uint64_t seed_;
std::array<uint64_t, 4> hash_data;
bool adaptive_;
};
template <typename T, class Context>
class FunHashGradientOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
FunHashGradientOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
num_outputs_(
OperatorBase::GetSingleArgument<int64_t>("num_outputs", -1)),
seed_(OperatorBase::GetSingleArgument<uint64_t>("seed", 0)) {
adaptive_ = (InputSize() == 6);
}
bool RunOnDevice() override {
const auto& grad_out = Input(0);
const auto& val = Input(1);
const auto& key = Input(2);
const auto& seg = Input(3);
const auto& weight = Input(4);
int64_t num_alpha = 1;
T* grad_alpha_data = 0;
if (adaptive_) {
const auto& alpha = Input(5);
num_alpha = alpha.size(0);
auto* grad_alpha = Output(1, alpha.sizes(), at::dtype<T>());
grad_alpha_data = grad_alpha->template mutable_data<T>();
memset(grad_alpha_data, 0, sizeof(T) * num_alpha);
}
const auto* seg_data = seg.template data<int>();
int64_t num_weight = weight.size(0);
int64_t num_nz_ent = seg.size(0);
auto* grad_weight = Output(0, weight.sizes(), at::dtype<T>());
T* grad_weight_data = grad_weight->template mutable_data<T>();
const auto* grad_out_data = grad_out.template data<T>();
const auto* weight_data = weight.template data<T>();
const auto* alpha_data = adaptive_ ? Input(5).template data<T>() : 0;
const auto* val_data = val.template data<T>();
const auto* key_data = key.template data<int64_t>();
memset(grad_weight_data, 0, sizeof(T) * num_weight);
for (const auto j : c10::irange(num_nz_ent)) {
int64_t cur_seg = seg_data[j];
int64_t cur_key = key_data[j];
T cur_val = val_data[j];
int64_t grad_out_stride = cur_seg * num_outputs_;
for (const auto i : c10::irange(num_outputs_)) {
T grad_out_scale = grad_out_data[grad_out_stride + i] * cur_val;
for (const auto k : c10::irange(num_alpha)) {
uint64_t hash;
hash_data[0] = cur_key;
hash_data[1] = i;
hash_data[2] = k;
hash_data[3] = INDEX_MAGIC;
hash = XXH64(hash_data.data(), hash_data.size(), seed_);
int64_t index = hash % num_weight;
T cur_grad_out_scale = grad_out_scale;
#ifdef USE_SIGN
hash_data[3] = SIGN_MAGIC;
hash = XXH64(hash_data.data(), hash_data.size(), seed_);
if (hash % 2) {
cur_grad_out_scale = -cur_grad_out_scale;
}
#endif // USE_SIGN
if (adaptive_) {
grad_alpha_data[k] += cur_grad_out_scale * weight_data[index];
grad_weight_data[index] += alpha_data[k] * cur_grad_out_scale;
} else {
grad_weight_data[index] += cur_grad_out_scale;
}
}
}
}
return true;
}
protected:
int64_t num_outputs_;
uint64_t seed_;
std::array<uint64_t, 4> hash_data;
bool adaptive_;
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_FUNHASH_OP_H_