| #include "ftrl_op.h" |
| |
| namespace caffe2 { |
| |
| template <class T> |
| inline T sgn(const T x) { |
| return (x == 0 ? 0 : (x < 0 ? -1 : 1)); |
| } |
| |
| template <typename T> |
| inline void ftrl_compute( |
| const T w, |
| const T n, |
| const T z, |
| const T g, |
| T& nw, |
| T& nn, |
| T& nz, |
| const FtrlParams<T>& params) { |
| auto new_n = n + g * g; |
| auto sigma = (sqrt(new_n) - sqrt(n)) * params.alphaInv; |
| nn = new_n; |
| nz = z + g - sigma * w; |
| // update the weight |
| if (std::abs(nz) > params.lambda1) { |
| nw = (params.lambda1 * sgn(nz) - nz) / |
| ((params.beta + sqrt(new_n)) * params.alphaInv + params.lambda2); |
| } else { |
| nw = 0.0; |
| } |
| } |
| |
| // TODO(dzhulgakov): implement SIMD-based version |
| template <typename Context, typename T> |
| void ftrl_update( |
| int N, |
| const T* w, |
| const T* nz, |
| const T* g, |
| T* new_w, |
| T* new_nz, |
| const FtrlParams<T>& params, |
| Context* /*context*/) { |
| // TODO(cxj): use OMP when it is reliable |
| // #pragma omp parallel for |
| for (auto i = 0; i < N; ++i) { |
| ftrl_compute( |
| w[i], |
| nz[i * 2], |
| nz[i * 2 + 1], |
| g[i], |
| new_w[i], |
| new_nz[i * 2], |
| new_nz[i * 2 + 1], |
| params); |
| } |
| } |
| |
| template <typename T, typename Context> |
| bool FtrlOp<T, Context>::RunOnDevice() { |
| // run time learning rate override |
| if (ALPHA < InputSize()) { |
| CAFFE_ENFORCE_EQ(Input(ALPHA).numel(), 1, "alpha should be real-valued"); |
| params_.alphaInv = 1.0 / *(Input(ALPHA).template data<T>()); |
| } |
| CAFFE_ENFORCE_EQ(Input(GRAD).numel(), Input(VAR).numel()); |
| CAFFE_ENFORCE_EQ(Input(GRAD).numel() * 2, Input(N_Z).numel()); |
| Output(OUTPUT_VAR)->ResizeLike(Input(VAR)); |
| Output(OUTPUT_N_Z)->ResizeLike(Input(N_Z)); |
| ftrl_update<Context>( |
| Input(GRAD).numel(), |
| Input(VAR).template data<T>(), |
| Input(N_Z).template data<T>(), |
| Input(GRAD).template data<T>(), |
| Output(OUTPUT_VAR)->template mutable_data<T>(), |
| Output(OUTPUT_N_Z)->template mutable_data<T>(), |
| params_, |
| &context_); |
| return true; |
| } |
| |
| template <typename T> |
| template <typename SIndex> |
| void SparseFtrlOp<T>::DoRun() { |
| auto* var = Output(OUTPUT_VAR); |
| auto* n_z = Output(OUTPUT_N_Z); |
| auto& indices = Input(INDICES); |
| auto& grad = Input(GRAD); |
| CAFFE_ENFORCE_EQ(&Input(VAR), var, "In place operation is required"); |
| CAFFE_ENFORCE_EQ(&Input(N_Z), n_z, "In place operation is required"); |
| int64_t M = var->numel(); |
| int64_t N = var->size(0); |
| int64_t block_size = M / N; |
| int64_t K = indices.numel(); |
| TORCH_DCHECK_EQ(M * 2, n_z->numel()); |
| TORCH_DCHECK_EQ(grad.numel(), K * block_size); |
| T* w = var->template mutable_data<T>(); |
| T* nz = n_z->template mutable_data<T>(); |
| const SIndex* idxs = indices.template data<SIndex>(); |
| const T* g = grad.template data<T>(); |
| |
| // TODO(cxj): use OMP when it is reliable |
| // #pragma omp parallel for |
| for (int64_t i = 0; i < K; ++i) { |
| SIndex idx = idxs[i]; |
| DCHECK(0 <= idx && idx < N) << "Index out of bounds: " << idx |
| << ", range 0 to " << N; |
| if (block_size == 1) { |
| ftrl_compute( |
| w[idx], |
| nz[idx * 2], |
| nz[idx * 2 + 1], |
| g[i], |
| w[idx], |
| nz[idx * 2], |
| nz[idx * 2 + 1], |
| params_); |
| } else { |
| int64_t x = block_size * idx; |
| ftrl_update( |
| block_size, |
| w + x, |
| nz + x * 2, |
| g + i * block_size, |
| w + x, |
| nz + x * 2, |
| params_, |
| &context_); |
| } |
| } |
| } |
| |
| namespace { |
| REGISTER_CPU_OPERATOR(Ftrl, FtrlOp<float, CPUContext>); |
| OPERATOR_SCHEMA(Ftrl).NumInputs(3, 4).NumOutputs(2).AllowInplace({{0, 0}, |
| {1, 1}}); |
| SHOULD_NOT_DO_GRADIENT(Ftrl); |
| |
| REGISTER_CPU_OPERATOR(SparseFtrl, SparseFtrlOp<float>); |
| OPERATOR_SCHEMA(SparseFtrl) |
| .NumInputs(4, 5) |
| .NumOutputs(2) |
| .EnforceInplace({{0, 0}, {1, 1}}); |
| SHOULD_NOT_DO_GRADIENT(SparseFtrl); |
| } |
| |
| } |