| ## @package batch_lr_loss |
| # Module caffe2.python.layers.batch_lr_loss |
| |
| |
| |
| |
| |
| from caffe2.python import core, schema |
| from caffe2.python.layers.layers import ( |
| ModelLayer, |
| ) |
| from caffe2.python.layers.tags import ( |
| Tags |
| ) |
| import numpy as np |
| |
| |
| class BatchLRLoss(ModelLayer): |
| def __init__( |
| self, |
| model, |
| input_record, |
| name='batch_lr_loss', |
| average_loss=True, |
| jsd_weight=0.0, |
| pos_label_target=1.0, |
| neg_label_target=0.0, |
| homotopy_weighting=False, |
| log_D_trick=False, |
| unjoined_lr_loss=False, |
| uncertainty_penalty=1.0, |
| focal_gamma=0.0, |
| stop_grad_in_focal_factor=False, |
| task_gamma=1.0, |
| task_gamma_lb=0.1, |
| **kwargs |
| ): |
| super(BatchLRLoss, self).__init__(model, name, input_record, **kwargs) |
| |
| self.average_loss = average_loss |
| |
| assert (schema.is_schema_subset( |
| schema.Struct( |
| ('label', schema.Scalar()), |
| ('logit', schema.Scalar()) |
| ), |
| input_record |
| )) |
| |
| self.jsd_fuse = False |
| assert jsd_weight >= 0 and jsd_weight <= 1 |
| if jsd_weight > 0 or homotopy_weighting: |
| assert 'prediction' in input_record |
| self.init_weight(jsd_weight, homotopy_weighting) |
| self.jsd_fuse = True |
| self.homotopy_weighting = homotopy_weighting |
| |
| assert pos_label_target <= 1 and pos_label_target >= 0 |
| assert neg_label_target <= 1 and neg_label_target >= 0 |
| assert pos_label_target >= neg_label_target |
| self.pos_label_target = pos_label_target |
| self.neg_label_target = neg_label_target |
| |
| assert not (log_D_trick and unjoined_lr_loss) |
| self.log_D_trick = log_D_trick |
| self.unjoined_lr_loss = unjoined_lr_loss |
| assert uncertainty_penalty >= 0 |
| self.uncertainty_penalty = uncertainty_penalty |
| |
| self.tags.update([Tags.EXCLUDE_FROM_PREDICTION]) |
| |
| self.output_schema = schema.Scalar( |
| np.float32, |
| self.get_next_blob_reference('output') |
| ) |
| |
| self.focal_gamma = focal_gamma |
| self.stop_grad_in_focal_factor = stop_grad_in_focal_factor |
| |
| self.apply_exp_decay = False |
| if task_gamma < 1.0: |
| self.apply_exp_decay = True |
| self.task_gamma_cur = self.create_param( |
| param_name=('%s_task_gamma_cur' % self.name), |
| shape=[1], |
| initializer=( |
| 'ConstantFill', { |
| 'value': 1.0, |
| 'dtype': core.DataType.FLOAT |
| } |
| ), |
| optimizer=self.model.NoOptim, |
| ) |
| |
| self.task_gamma = self.create_param( |
| param_name=('%s_task_gamma' % self.name), |
| shape=[1], |
| initializer=( |
| 'ConstantFill', { |
| 'value': task_gamma, |
| 'dtype': core.DataType.FLOAT |
| } |
| ), |
| optimizer=self.model.NoOptim, |
| ) |
| |
| self.task_gamma_lb = self.create_param( |
| param_name=('%s_task_gamma_lb' % self.name), |
| shape=[1], |
| initializer=( |
| 'ConstantFill', { |
| 'value': task_gamma_lb, |
| 'dtype': core.DataType.FLOAT |
| } |
| ), |
| optimizer=self.model.NoOptim, |
| ) |
| |
| def init_weight(self, jsd_weight, homotopy_weighting): |
| if homotopy_weighting: |
| self.mutex = self.create_param( |
| param_name=('%s_mutex' % self.name), |
| shape=None, |
| initializer=('CreateMutex', ), |
| optimizer=self.model.NoOptim, |
| ) |
| self.counter = self.create_param( |
| param_name=('%s_counter' % self.name), |
| shape=[1], |
| initializer=( |
| 'ConstantFill', { |
| 'value': 0, |
| 'dtype': core.DataType.INT64 |
| } |
| ), |
| optimizer=self.model.NoOptim, |
| ) |
| self.xent_weight = self.create_param( |
| param_name=('%s_xent_weight' % self.name), |
| shape=[1], |
| initializer=( |
| 'ConstantFill', { |
| 'value': 1., |
| 'dtype': core.DataType.FLOAT |
| } |
| ), |
| optimizer=self.model.NoOptim, |
| ) |
| self.jsd_weight = self.create_param( |
| param_name=('%s_jsd_weight' % self.name), |
| shape=[1], |
| initializer=( |
| 'ConstantFill', { |
| 'value': 0., |
| 'dtype': core.DataType.FLOAT |
| } |
| ), |
| optimizer=self.model.NoOptim, |
| ) |
| else: |
| self.jsd_weight = self.model.add_global_constant( |
| '%s_jsd_weight' % self.name, jsd_weight |
| ) |
| self.xent_weight = self.model.add_global_constant( |
| '%s_xent_weight' % self.name, 1. - jsd_weight |
| ) |
| |
| def update_weight(self, net): |
| net.AtomicIter([self.mutex, self.counter], [self.counter]) |
| # iter = 0: lr = 1; |
| # iter = 1e6; lr = 0.5^0.1 = 0.93 |
| # iter = 1e9; lr = 1e-3^0.1 = 0.50 |
| net.LearningRate([self.counter], [self.xent_weight], base_lr=1.0, |
| policy='inv', gamma=1e-6, power=0.1,) |
| net.Sub( |
| [self.model.global_constants['ONE'], self.xent_weight], |
| [self.jsd_weight] |
| ) |
| return self.xent_weight, self.jsd_weight |
| |
| def add_ops(self, net): |
| # numerically stable log-softmax with crossentropy |
| label = self.input_record.label() |
| # mandatory cast to float32 |
| # self.input_record.label.field_type().base is np.float32 but |
| # label type is actually int |
| label = net.Cast( |
| label, |
| net.NextScopedBlob('label_float32'), |
| to=core.DataType.FLOAT) |
| label = net.ExpandDims(label, net.NextScopedBlob('expanded_label'), |
| dims=[1]) |
| if self.pos_label_target != 1.0 or self.neg_label_target != 0.0: |
| label = net.StumpFunc( |
| label, |
| net.NextScopedBlob('smoothed_label'), |
| threshold=0.5, |
| low_value=self.neg_label_target, |
| high_value=self.pos_label_target, |
| ) |
| xent = net.SigmoidCrossEntropyWithLogits( |
| [self.input_record.logit(), label], |
| net.NextScopedBlob('cross_entropy'), |
| log_D_trick=self.log_D_trick, |
| unjoined_lr_loss=self.unjoined_lr_loss |
| ) |
| |
| if self.focal_gamma != 0: |
| label = net.StopGradient( |
| [label], |
| [net.NextScopedBlob('label_stop_gradient')], |
| ) |
| |
| prediction = self.input_record.prediction() |
| # focal loss = (y(1-p) + p(1-y))^gamma * original LR loss |
| # y(1-p) + p(1-y) = y + p - 2 * yp |
| y_plus_p = net.Add( |
| [prediction, label], |
| net.NextScopedBlob("y_plus_p"), |
| ) |
| yp = net.Mul([prediction, label], net.NextScopedBlob("yp")) |
| two_yp = net.Scale(yp, net.NextScopedBlob("two_yp"), scale=2.0) |
| y_plus_p_sub_two_yp = net.Sub( |
| [y_plus_p, two_yp], net.NextScopedBlob("y_plus_p_sub_two_yp") |
| ) |
| focal_factor = net.Pow( |
| y_plus_p_sub_two_yp, |
| net.NextScopedBlob("y_plus_p_sub_two_yp_power"), |
| exponent=float(self.focal_gamma), |
| ) |
| if self.stop_grad_in_focal_factor is True: |
| focal_factor = net.StopGradient( |
| [focal_factor], |
| [net.NextScopedBlob("focal_factor_stop_gradient")], |
| ) |
| xent = net.Mul( |
| [xent, focal_factor], net.NextScopedBlob("focallossxent") |
| ) |
| |
| if self.apply_exp_decay: |
| net.Mul( |
| [self.task_gamma_cur, self.task_gamma], |
| self.task_gamma_cur |
| ) |
| |
| task_gamma_multiplier = net.Max( |
| [self.task_gamma_cur, self.task_gamma_lb], |
| net.NextScopedBlob("task_gamma_cur_multiplier") |
| ) |
| |
| xent = net.Mul( |
| [xent, task_gamma_multiplier], net.NextScopedBlob("expdecayxent") |
| ) |
| |
| # fuse with JSD |
| if self.jsd_fuse: |
| jsd = net.BernoulliJSD( |
| [self.input_record.prediction(), label], |
| net.NextScopedBlob('jsd'), |
| ) |
| if self.homotopy_weighting: |
| self.update_weight(net) |
| loss = net.WeightedSum( |
| [xent, self.xent_weight, jsd, self.jsd_weight], |
| net.NextScopedBlob('loss'), |
| ) |
| else: |
| loss = xent |
| |
| if 'log_variance' in self.input_record.fields: |
| # mean (0.5 * exp(-s) * loss + 0.5 * penalty * s) |
| log_variance_blob = self.input_record.log_variance() |
| |
| log_variance_blob = net.ExpandDims( |
| log_variance_blob, net.NextScopedBlob('expanded_log_variance'), |
| dims=[1] |
| ) |
| |
| neg_log_variance_blob = net.Negative( |
| [log_variance_blob], |
| net.NextScopedBlob('neg_log_variance') |
| ) |
| |
| # enforce less than 88 to avoid OverflowError |
| neg_log_variance_blob = net.Clip( |
| [neg_log_variance_blob], |
| net.NextScopedBlob('clipped_neg_log_variance'), |
| max=88.0 |
| ) |
| |
| exp_neg_log_variance_blob = net.Exp( |
| [neg_log_variance_blob], |
| net.NextScopedBlob('exp_neg_log_variance') |
| ) |
| |
| exp_neg_log_variance_loss_blob = net.Mul( |
| [exp_neg_log_variance_blob, loss], |
| net.NextScopedBlob('exp_neg_log_variance_loss') |
| ) |
| |
| penalized_uncertainty = net.Scale( |
| log_variance_blob, net.NextScopedBlob("penalized_unceratinty"), |
| scale=float(self.uncertainty_penalty) |
| ) |
| |
| loss_2x = net.Add( |
| [exp_neg_log_variance_loss_blob, penalized_uncertainty], |
| net.NextScopedBlob('loss') |
| ) |
| loss = net.Scale(loss_2x, net.NextScopedBlob("loss"), scale=0.5) |
| |
| if 'weight' in self.input_record.fields: |
| weight_blob = self.input_record.weight() |
| if self.input_record.weight.field_type().base != np.float32: |
| weight_blob = net.Cast( |
| weight_blob, |
| weight_blob + '_float32', |
| to=core.DataType.FLOAT |
| ) |
| weight_blob = net.StopGradient( |
| [weight_blob], |
| [net.NextScopedBlob('weight_stop_gradient')], |
| ) |
| loss = net.Mul( |
| [loss, weight_blob], |
| net.NextScopedBlob('weighted_cross_entropy'), |
| ) |
| |
| if self.average_loss: |
| net.AveragedLoss(loss, self.output_schema.field_blobs()) |
| else: |
| net.ReduceFrontSum(loss, self.output_schema.field_blobs()) |