blob: 46b0e4d42cdfd2e4174d2e9d3581b01dcb65bbc3 [file] [log] [blame]
## @package batch_lr_loss
# Module caffe2.python.layers.batch_lr_loss
from caffe2.python import core, schema
from caffe2.python.layers.layers import (
ModelLayer,
)
from caffe2.python.layers.tags import (
Tags
)
import numpy as np
class BatchLRLoss(ModelLayer):
def __init__(
self,
model,
input_record,
name='batch_lr_loss',
average_loss=True,
jsd_weight=0.0,
pos_label_target=1.0,
neg_label_target=0.0,
homotopy_weighting=False,
log_D_trick=False,
unjoined_lr_loss=False,
uncertainty_penalty=1.0,
focal_gamma=0.0,
stop_grad_in_focal_factor=False,
task_gamma=1.0,
task_gamma_lb=0.1,
**kwargs
):
super(BatchLRLoss, self).__init__(model, name, input_record, **kwargs)
self.average_loss = average_loss
assert (schema.is_schema_subset(
schema.Struct(
('label', schema.Scalar()),
('logit', schema.Scalar())
),
input_record
))
self.jsd_fuse = False
assert jsd_weight >= 0 and jsd_weight <= 1
if jsd_weight > 0 or homotopy_weighting:
assert 'prediction' in input_record
self.init_weight(jsd_weight, homotopy_weighting)
self.jsd_fuse = True
self.homotopy_weighting = homotopy_weighting
assert pos_label_target <= 1 and pos_label_target >= 0
assert neg_label_target <= 1 and neg_label_target >= 0
assert pos_label_target >= neg_label_target
self.pos_label_target = pos_label_target
self.neg_label_target = neg_label_target
assert not (log_D_trick and unjoined_lr_loss)
self.log_D_trick = log_D_trick
self.unjoined_lr_loss = unjoined_lr_loss
assert uncertainty_penalty >= 0
self.uncertainty_penalty = uncertainty_penalty
self.tags.update([Tags.EXCLUDE_FROM_PREDICTION])
self.output_schema = schema.Scalar(
np.float32,
self.get_next_blob_reference('output')
)
self.focal_gamma = focal_gamma
self.stop_grad_in_focal_factor = stop_grad_in_focal_factor
self.apply_exp_decay = False
if task_gamma < 1.0:
self.apply_exp_decay = True
self.task_gamma_cur = self.create_param(
param_name=('%s_task_gamma_cur' % self.name),
shape=[1],
initializer=(
'ConstantFill', {
'value': 1.0,
'dtype': core.DataType.FLOAT
}
),
optimizer=self.model.NoOptim,
)
self.task_gamma = self.create_param(
param_name=('%s_task_gamma' % self.name),
shape=[1],
initializer=(
'ConstantFill', {
'value': task_gamma,
'dtype': core.DataType.FLOAT
}
),
optimizer=self.model.NoOptim,
)
self.task_gamma_lb = self.create_param(
param_name=('%s_task_gamma_lb' % self.name),
shape=[1],
initializer=(
'ConstantFill', {
'value': task_gamma_lb,
'dtype': core.DataType.FLOAT
}
),
optimizer=self.model.NoOptim,
)
def init_weight(self, jsd_weight, homotopy_weighting):
if homotopy_weighting:
self.mutex = self.create_param(
param_name=('%s_mutex' % self.name),
shape=None,
initializer=('CreateMutex', ),
optimizer=self.model.NoOptim,
)
self.counter = self.create_param(
param_name=('%s_counter' % self.name),
shape=[1],
initializer=(
'ConstantFill', {
'value': 0,
'dtype': core.DataType.INT64
}
),
optimizer=self.model.NoOptim,
)
self.xent_weight = self.create_param(
param_name=('%s_xent_weight' % self.name),
shape=[1],
initializer=(
'ConstantFill', {
'value': 1.,
'dtype': core.DataType.FLOAT
}
),
optimizer=self.model.NoOptim,
)
self.jsd_weight = self.create_param(
param_name=('%s_jsd_weight' % self.name),
shape=[1],
initializer=(
'ConstantFill', {
'value': 0.,
'dtype': core.DataType.FLOAT
}
),
optimizer=self.model.NoOptim,
)
else:
self.jsd_weight = self.model.add_global_constant(
'%s_jsd_weight' % self.name, jsd_weight
)
self.xent_weight = self.model.add_global_constant(
'%s_xent_weight' % self.name, 1. - jsd_weight
)
def update_weight(self, net):
net.AtomicIter([self.mutex, self.counter], [self.counter])
# iter = 0: lr = 1;
# iter = 1e6; lr = 0.5^0.1 = 0.93
# iter = 1e9; lr = 1e-3^0.1 = 0.50
net.LearningRate([self.counter], [self.xent_weight], base_lr=1.0,
policy='inv', gamma=1e-6, power=0.1,)
net.Sub(
[self.model.global_constants['ONE'], self.xent_weight],
[self.jsd_weight]
)
return self.xent_weight, self.jsd_weight
def add_ops(self, net):
# numerically stable log-softmax with crossentropy
label = self.input_record.label()
# mandatory cast to float32
# self.input_record.label.field_type().base is np.float32 but
# label type is actually int
label = net.Cast(
label,
net.NextScopedBlob('label_float32'),
to=core.DataType.FLOAT)
label = net.ExpandDims(label, net.NextScopedBlob('expanded_label'),
dims=[1])
if self.pos_label_target != 1.0 or self.neg_label_target != 0.0:
label = net.StumpFunc(
label,
net.NextScopedBlob('smoothed_label'),
threshold=0.5,
low_value=self.neg_label_target,
high_value=self.pos_label_target,
)
xent = net.SigmoidCrossEntropyWithLogits(
[self.input_record.logit(), label],
net.NextScopedBlob('cross_entropy'),
log_D_trick=self.log_D_trick,
unjoined_lr_loss=self.unjoined_lr_loss
)
if self.focal_gamma != 0:
label = net.StopGradient(
[label],
[net.NextScopedBlob('label_stop_gradient')],
)
prediction = self.input_record.prediction()
# focal loss = (y(1-p) + p(1-y))^gamma * original LR loss
# y(1-p) + p(1-y) = y + p - 2 * yp
y_plus_p = net.Add(
[prediction, label],
net.NextScopedBlob("y_plus_p"),
)
yp = net.Mul([prediction, label], net.NextScopedBlob("yp"))
two_yp = net.Scale(yp, net.NextScopedBlob("two_yp"), scale=2.0)
y_plus_p_sub_two_yp = net.Sub(
[y_plus_p, two_yp], net.NextScopedBlob("y_plus_p_sub_two_yp")
)
focal_factor = net.Pow(
y_plus_p_sub_two_yp,
net.NextScopedBlob("y_plus_p_sub_two_yp_power"),
exponent=float(self.focal_gamma),
)
if self.stop_grad_in_focal_factor is True:
focal_factor = net.StopGradient(
[focal_factor],
[net.NextScopedBlob("focal_factor_stop_gradient")],
)
xent = net.Mul(
[xent, focal_factor], net.NextScopedBlob("focallossxent")
)
if self.apply_exp_decay:
net.Mul(
[self.task_gamma_cur, self.task_gamma],
self.task_gamma_cur
)
task_gamma_multiplier = net.Max(
[self.task_gamma_cur, self.task_gamma_lb],
net.NextScopedBlob("task_gamma_cur_multiplier")
)
xent = net.Mul(
[xent, task_gamma_multiplier], net.NextScopedBlob("expdecayxent")
)
# fuse with JSD
if self.jsd_fuse:
jsd = net.BernoulliJSD(
[self.input_record.prediction(), label],
net.NextScopedBlob('jsd'),
)
if self.homotopy_weighting:
self.update_weight(net)
loss = net.WeightedSum(
[xent, self.xent_weight, jsd, self.jsd_weight],
net.NextScopedBlob('loss'),
)
else:
loss = xent
if 'log_variance' in self.input_record.fields:
# mean (0.5 * exp(-s) * loss + 0.5 * penalty * s)
log_variance_blob = self.input_record.log_variance()
log_variance_blob = net.ExpandDims(
log_variance_blob, net.NextScopedBlob('expanded_log_variance'),
dims=[1]
)
neg_log_variance_blob = net.Negative(
[log_variance_blob],
net.NextScopedBlob('neg_log_variance')
)
# enforce less than 88 to avoid OverflowError
neg_log_variance_blob = net.Clip(
[neg_log_variance_blob],
net.NextScopedBlob('clipped_neg_log_variance'),
max=88.0
)
exp_neg_log_variance_blob = net.Exp(
[neg_log_variance_blob],
net.NextScopedBlob('exp_neg_log_variance')
)
exp_neg_log_variance_loss_blob = net.Mul(
[exp_neg_log_variance_blob, loss],
net.NextScopedBlob('exp_neg_log_variance_loss')
)
penalized_uncertainty = net.Scale(
log_variance_blob, net.NextScopedBlob("penalized_unceratinty"),
scale=float(self.uncertainty_penalty)
)
loss_2x = net.Add(
[exp_neg_log_variance_loss_blob, penalized_uncertainty],
net.NextScopedBlob('loss')
)
loss = net.Scale(loss_2x, net.NextScopedBlob("loss"), scale=0.5)
if 'weight' in self.input_record.fields:
weight_blob = self.input_record.weight()
if self.input_record.weight.field_type().base != np.float32:
weight_blob = net.Cast(
weight_blob,
weight_blob + '_float32',
to=core.DataType.FLOAT
)
weight_blob = net.StopGradient(
[weight_blob],
[net.NextScopedBlob('weight_stop_gradient')],
)
loss = net.Mul(
[loss, weight_blob],
net.NextScopedBlob('weighted_cross_entropy'),
)
if self.average_loss:
net.AveragedLoss(loss, self.output_schema.field_blobs())
else:
net.ReduceFrontSum(loss, self.output_schema.field_blobs())