blob: e92938cd7b1d834250c4ae3f5eaaa55122d12642 [file] [log] [blame]
#include "wngrad_op.h"
namespace caffe2 {
REGISTER_CPU_OPERATOR(Wngrad, WngradOp<float, CPUContext>);
OPERATOR_SCHEMA(Wngrad)
.NumInputs(4)
.NumOutputs(2, 4)
.AllowInplace({{0, 0}, {1, 1}})
.SetDoc(R"DOC(
Computes the WnGrad update for an input gradient and accumulated
history. This operator implement the optimization algorithm
in https://arxiv.org/abs/1803.02865 by Wu, Ward and Bottou.
Concretely, given inputs (param, grad, seq_b, learning_rate),
computes
new_seq_b = seq_b + 1 / seq_b * norm(grad)^2
effective_lr = learning_rate / (new_seq_b + epsilon)
update = learning_rate * grad / (new_seq_b + epsilon)
new_param = param + update
and returns (new_param, new_seq_b).
Optionally returns effective_lr and update as well.
)DOC")
.Input(0, "param", "Parameters to be updated")
.Input(1, "seq_b", "Seq_b history")
.Input(2, "grad", "Gradient computed")
.Input(3, "lr", "learning rate")
.Output(0, "output_param", "Updated parameters")
.Output(1, "output_seq_b", "Updated seq_b")
.Output(2, "output_effective_lr", "(optional) Effective learning rate")
.Output(3, "output_update", "(optional) Actual update that is applied.")
.Arg("epsilon", "Default 1e-5");
REGISTER_CPU_OPERATOR(SparseWngrad, SparseWngradOp<float, CPUContext>);
OPERATOR_SCHEMA(SparseWngrad)
.NumInputs(5)
.NumOutputs(2)
.EnforceOneToOneInplace()
.SetDoc(R"DOC(
This operator implement the optimization algorithm
in https://arxiv.org/abs/1803.02865 by Wu, Ward and Bottou.
Given inputs (param, seq_b, indices, grad, lr), runs the dense WnGrad
update on (param, grad, seq_b, lr), and returns (new_param,
new_seq_b) as in the dense case.
)DOC")
.Input(0, "param", "Parameters to be updated")
.Input(1, "seq_b", "seq_b history")
.Input(2, "indices", "Sparse indices")
.Input(3, "grad", "Gradient computed")
.Input(4, "lr", "learning rate")
.Output(0, "output_param", "Updated parameters")
.Output(1, "output_seq_b", "Updated seq_b")
.Arg("epsilon", "Default 1e-5");
SHOULD_NOT_DO_GRADIENT(Wngrad);
SHOULD_NOT_DO_GRADIENT(SparseWngrad);
} // namespace caffe2