| # @package adaptive_weight |
| # Module caffe2.fb.python.layers.adaptive_weight |
| |
| |
| import numpy as np |
| from caffe2.python import core, schema |
| from caffe2.python.layers.layers import ModelLayer |
| from caffe2.python.regularizer import BoundedGradientProjection, LogBarrier |
| |
| |
| """ |
| Implementation of adaptive weighting: https://arxiv.org/pdf/1705.07115.pdf |
| """ |
| |
| |
| class AdaptiveWeight(ModelLayer): |
| def __init__( |
| self, |
| model, |
| input_record, |
| name="adaptive_weight", |
| optimizer=None, |
| weights=None, |
| enable_diagnose=False, |
| estimation_method="log_std", |
| pos_optim_method="log_barrier", |
| reg_lambda=0.1, |
| **kwargs |
| ): |
| super(AdaptiveWeight, self).__init__(model, name, input_record, **kwargs) |
| self.output_schema = schema.Scalar( |
| np.float32, self.get_next_blob_reference("adaptive_weight") |
| ) |
| self.data = self.input_record.field_blobs() |
| self.num = len(self.data) |
| self.optimizer = optimizer |
| if weights is not None: |
| assert len(weights) == self.num |
| else: |
| weights = [1. / self.num for _ in range(self.num)] |
| assert min(weights) > 0, "initial weights must be positive" |
| self.weights = np.array(weights).astype(np.float32) |
| self.estimation_method = str(estimation_method).lower() |
| # used in positivity-constrained parameterization as when the estimation method |
| # is inv_var, with optimization method being either log barrier, or grad proj |
| self.pos_optim_method = str(pos_optim_method).lower() |
| self.reg_lambda = float(reg_lambda) |
| self.enable_diagnose = enable_diagnose |
| self.init_func = getattr(self, self.estimation_method + "_init") |
| self.weight_func = getattr(self, self.estimation_method + "_weight") |
| self.reg_func = getattr(self, self.estimation_method + "_reg") |
| self.init_func() |
| if self.enable_diagnose: |
| self.weight_i = [ |
| self.get_next_blob_reference("adaptive_weight_%d" % i) |
| for i in range(self.num) |
| ] |
| for i in range(self.num): |
| self.model.add_ad_hoc_plot_blob(self.weight_i[i]) |
| |
| def concat_data(self, net): |
| reshaped = [net.NextScopedBlob("reshaped_data_%d" % i) for i in range(self.num)] |
| # coerce shape for single real values |
| for i in range(self.num): |
| net.Reshape( |
| [self.data[i]], |
| [reshaped[i], net.NextScopedBlob("new_shape_%d" % i)], |
| shape=[1], |
| ) |
| concated = net.NextScopedBlob("concated_data") |
| net.Concat( |
| reshaped, [concated, net.NextScopedBlob("concated_new_shape")], axis=0 |
| ) |
| return concated |
| |
| def log_std_init(self): |
| """ |
| mu = 2 log sigma, sigma = standard variance |
| per task objective: |
| min 1 / 2 / e^mu X + mu / 2 |
| """ |
| values = np.log(1. / 2. / self.weights) |
| initializer = ( |
| "GivenTensorFill", |
| {"values": values, "dtype": core.DataType.FLOAT}, |
| ) |
| self.mu = self.create_param( |
| param_name="mu", |
| shape=[self.num], |
| initializer=initializer, |
| optimizer=self.optimizer, |
| ) |
| |
| def log_std_weight(self, x, net, weight): |
| """ |
| min 1 / 2 / e^mu X + mu / 2 |
| """ |
| mu_neg = net.NextScopedBlob("mu_neg") |
| net.Negative(self.mu, mu_neg) |
| mu_neg_exp = net.NextScopedBlob("mu_neg_exp") |
| net.Exp(mu_neg, mu_neg_exp) |
| net.Scale(mu_neg_exp, weight, scale=0.5) |
| |
| def log_std_reg(self, net, reg): |
| net.Scale(self.mu, reg, scale=0.5) |
| |
| def inv_var_init(self): |
| """ |
| k = 1 / variance |
| per task objective: |
| min 1 / 2 * k X - 1 / 2 * log k |
| """ |
| values = 2. * self.weights |
| initializer = ( |
| "GivenTensorFill", |
| {"values": values, "dtype": core.DataType.FLOAT}, |
| ) |
| if self.pos_optim_method == "log_barrier": |
| regularizer = LogBarrier(reg_lambda=self.reg_lambda) |
| elif self.pos_optim_method == "pos_grad_proj": |
| regularizer = BoundedGradientProjection(lb=0, left_open=True) |
| else: |
| raise TypeError( |
| "unknown positivity optimization method: {}".format( |
| self.pos_optim_method |
| ) |
| ) |
| self.k = self.create_param( |
| param_name="k", |
| shape=[self.num], |
| initializer=initializer, |
| optimizer=self.optimizer, |
| regularizer=regularizer, |
| ) |
| |
| def inv_var_weight(self, x, net, weight): |
| net.Scale(self.k, weight, scale=0.5) |
| |
| def inv_var_reg(self, net, reg): |
| log_k = net.NextScopedBlob("log_k") |
| net.Log(self.k, log_k) |
| net.Scale(log_k, reg, scale=-0.5) |
| |
| def _add_ops_impl(self, net, enable_diagnose): |
| x = self.concat_data(net) |
| weight = net.NextScopedBlob("weight") |
| reg = net.NextScopedBlob("reg") |
| weighted_x = net.NextScopedBlob("weighted_x") |
| weighted_x_add_reg = net.NextScopedBlob("weighted_x_add_reg") |
| self.weight_func(x, net, weight) |
| self.reg_func(net, reg) |
| net.Mul([weight, x], weighted_x) |
| net.Add([weighted_x, reg], weighted_x_add_reg) |
| net.SumElements(weighted_x_add_reg, self.output_schema()) |
| if enable_diagnose: |
| for i in range(self.num): |
| net.Slice(weight, self.weight_i[i], starts=[i], ends=[i + 1]) |
| |
| def add_ops(self, net): |
| self._add_ops_impl(net, self.enable_diagnose) |