| ## @package attention |
| # Module caffe2.python.attention |
| |
| |
| |
| |
| |
| from caffe2.python import brew |
| |
| |
| class AttentionType: |
| Regular, Recurrent, Dot, SoftCoverage = tuple(range(4)) |
| |
| |
| def s(scope, name): |
| # We have to manually scope due to our internal/external blob |
| # relationships. |
| return "{}/{}".format(str(scope), str(name)) |
| |
| |
| # c_i = \sum_j w_{ij}\textbf{s}_j |
| def _calc_weighted_context( |
| model, |
| encoder_outputs_transposed, |
| encoder_output_dim, |
| attention_weights_3d, |
| scope, |
| ): |
| # [batch_size, encoder_output_dim, 1] |
| attention_weighted_encoder_context = brew.batch_mat_mul( |
| model, |
| [encoder_outputs_transposed, attention_weights_3d], |
| s(scope, 'attention_weighted_encoder_context'), |
| ) |
| # [batch_size, encoder_output_dim] |
| attention_weighted_encoder_context, _ = model.net.Reshape( |
| attention_weighted_encoder_context, |
| [ |
| attention_weighted_encoder_context, |
| s(scope, 'attention_weighted_encoder_context_old_shape'), |
| ], |
| shape=[1, -1, encoder_output_dim], |
| ) |
| return attention_weighted_encoder_context |
| |
| |
| # Calculate a softmax over the passed in attention energy logits |
| def _calc_attention_weights( |
| model, |
| attention_logits_transposed, |
| scope, |
| encoder_lengths=None, |
| ): |
| if encoder_lengths is not None: |
| attention_logits_transposed = model.net.SequenceMask( |
| [attention_logits_transposed, encoder_lengths], |
| ['masked_attention_logits'], |
| mode='sequence', |
| ) |
| |
| # [batch_size, encoder_length, 1] |
| attention_weights_3d = brew.softmax( |
| model, |
| attention_logits_transposed, |
| s(scope, 'attention_weights_3d'), |
| engine='CUDNN', |
| axis=1, |
| ) |
| return attention_weights_3d |
| |
| |
| # e_{ij} = \textbf{v}^T tanh \alpha(\textbf{h}_{i-1}, \textbf{s}_j) |
| def _calc_attention_logits_from_sum_match( |
| model, |
| decoder_hidden_encoder_outputs_sum, |
| encoder_output_dim, |
| scope, |
| ): |
| # [encoder_length, batch_size, encoder_output_dim] |
| decoder_hidden_encoder_outputs_sum = model.net.Tanh( |
| decoder_hidden_encoder_outputs_sum, |
| decoder_hidden_encoder_outputs_sum, |
| ) |
| |
| # [encoder_length, batch_size, 1] |
| attention_logits = brew.fc( |
| model, |
| decoder_hidden_encoder_outputs_sum, |
| s(scope, 'attention_logits'), |
| dim_in=encoder_output_dim, |
| dim_out=1, |
| axis=2, |
| freeze_bias=True, |
| ) |
| |
| # [batch_size, encoder_length, 1] |
| attention_logits_transposed = brew.transpose( |
| model, |
| attention_logits, |
| s(scope, 'attention_logits_transposed'), |
| axes=[1, 0, 2], |
| ) |
| return attention_logits_transposed |
| |
| |
| # \textbf{W}^\alpha used in the context of \alpha_{sum}(a,b) |
| def _apply_fc_weight_for_sum_match( |
| model, |
| input, |
| dim_in, |
| dim_out, |
| scope, |
| name, |
| ): |
| output = brew.fc( |
| model, |
| input, |
| s(scope, name), |
| dim_in=dim_in, |
| dim_out=dim_out, |
| axis=2, |
| ) |
| output = model.net.Squeeze( |
| output, |
| output, |
| dims=[0], |
| ) |
| return output |
| |
| |
| # Implement RecAtt due to section 4.1 in http://arxiv.org/abs/1601.03317 |
| def apply_recurrent_attention( |
| model, |
| encoder_output_dim, |
| encoder_outputs_transposed, |
| weighted_encoder_outputs, |
| decoder_hidden_state_t, |
| decoder_hidden_state_dim, |
| attention_weighted_encoder_context_t_prev, |
| scope, |
| encoder_lengths=None, |
| ): |
| weighted_prev_attention_context = _apply_fc_weight_for_sum_match( |
| model=model, |
| input=attention_weighted_encoder_context_t_prev, |
| dim_in=encoder_output_dim, |
| dim_out=encoder_output_dim, |
| scope=scope, |
| name='weighted_prev_attention_context', |
| ) |
| |
| weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match( |
| model=model, |
| input=decoder_hidden_state_t, |
| dim_in=decoder_hidden_state_dim, |
| dim_out=encoder_output_dim, |
| scope=scope, |
| name='weighted_decoder_hidden_state', |
| ) |
| # [1, batch_size, encoder_output_dim] |
| decoder_hidden_encoder_outputs_sum_tmp = model.net.Add( |
| [ |
| weighted_prev_attention_context, |
| weighted_decoder_hidden_state, |
| ], |
| s(scope, 'decoder_hidden_encoder_outputs_sum_tmp'), |
| ) |
| # [encoder_length, batch_size, encoder_output_dim] |
| decoder_hidden_encoder_outputs_sum = model.net.Add( |
| [ |
| weighted_encoder_outputs, |
| decoder_hidden_encoder_outputs_sum_tmp, |
| ], |
| s(scope, 'decoder_hidden_encoder_outputs_sum'), |
| broadcast=1, |
| ) |
| attention_logits_transposed = _calc_attention_logits_from_sum_match( |
| model=model, |
| decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum, |
| encoder_output_dim=encoder_output_dim, |
| scope=scope, |
| ) |
| |
| # [batch_size, encoder_length, 1] |
| attention_weights_3d = _calc_attention_weights( |
| model=model, |
| attention_logits_transposed=attention_logits_transposed, |
| scope=scope, |
| encoder_lengths=encoder_lengths, |
| ) |
| |
| # [batch_size, encoder_output_dim, 1] |
| attention_weighted_encoder_context = _calc_weighted_context( |
| model=model, |
| encoder_outputs_transposed=encoder_outputs_transposed, |
| encoder_output_dim=encoder_output_dim, |
| attention_weights_3d=attention_weights_3d, |
| scope=scope, |
| ) |
| return attention_weighted_encoder_context, attention_weights_3d, [ |
| decoder_hidden_encoder_outputs_sum, |
| ] |
| |
| |
| def apply_regular_attention( |
| model, |
| encoder_output_dim, |
| encoder_outputs_transposed, |
| weighted_encoder_outputs, |
| decoder_hidden_state_t, |
| decoder_hidden_state_dim, |
| scope, |
| encoder_lengths=None, |
| ): |
| weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match( |
| model=model, |
| input=decoder_hidden_state_t, |
| dim_in=decoder_hidden_state_dim, |
| dim_out=encoder_output_dim, |
| scope=scope, |
| name='weighted_decoder_hidden_state', |
| ) |
| |
| # [encoder_length, batch_size, encoder_output_dim] |
| decoder_hidden_encoder_outputs_sum = model.net.Add( |
| [weighted_encoder_outputs, weighted_decoder_hidden_state], |
| s(scope, 'decoder_hidden_encoder_outputs_sum'), |
| broadcast=1, |
| use_grad_hack=1, |
| ) |
| |
| attention_logits_transposed = _calc_attention_logits_from_sum_match( |
| model=model, |
| decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum, |
| encoder_output_dim=encoder_output_dim, |
| scope=scope, |
| ) |
| |
| # [batch_size, encoder_length, 1] |
| attention_weights_3d = _calc_attention_weights( |
| model=model, |
| attention_logits_transposed=attention_logits_transposed, |
| scope=scope, |
| encoder_lengths=encoder_lengths, |
| ) |
| |
| # [batch_size, encoder_output_dim, 1] |
| attention_weighted_encoder_context = _calc_weighted_context( |
| model=model, |
| encoder_outputs_transposed=encoder_outputs_transposed, |
| encoder_output_dim=encoder_output_dim, |
| attention_weights_3d=attention_weights_3d, |
| scope=scope, |
| ) |
| return attention_weighted_encoder_context, attention_weights_3d, [ |
| decoder_hidden_encoder_outputs_sum, |
| ] |
| |
| |
| def apply_dot_attention( |
| model, |
| encoder_output_dim, |
| # [batch_size, encoder_output_dim, encoder_length] |
| encoder_outputs_transposed, |
| # [1, batch_size, decoder_state_dim] |
| decoder_hidden_state_t, |
| decoder_hidden_state_dim, |
| scope, |
| encoder_lengths=None, |
| ): |
| if decoder_hidden_state_dim != encoder_output_dim: |
| weighted_decoder_hidden_state = brew.fc( |
| model, |
| decoder_hidden_state_t, |
| s(scope, 'weighted_decoder_hidden_state'), |
| dim_in=decoder_hidden_state_dim, |
| dim_out=encoder_output_dim, |
| axis=2, |
| ) |
| else: |
| weighted_decoder_hidden_state = decoder_hidden_state_t |
| |
| # [batch_size, decoder_state_dim] |
| squeezed_weighted_decoder_hidden_state = model.net.Squeeze( |
| weighted_decoder_hidden_state, |
| s(scope, 'squeezed_weighted_decoder_hidden_state'), |
| dims=[0], |
| ) |
| |
| # [batch_size, decoder_state_dim, 1] |
| expanddims_squeezed_weighted_decoder_hidden_state = model.net.ExpandDims( |
| squeezed_weighted_decoder_hidden_state, |
| squeezed_weighted_decoder_hidden_state, |
| dims=[2], |
| ) |
| |
| # [batch_size, encoder_output_dim, 1] |
| attention_logits_transposed = model.net.BatchMatMul( |
| [ |
| encoder_outputs_transposed, |
| expanddims_squeezed_weighted_decoder_hidden_state, |
| ], |
| s(scope, 'attention_logits'), |
| trans_a=1, |
| ) |
| |
| # [batch_size, encoder_length, 1] |
| attention_weights_3d = _calc_attention_weights( |
| model=model, |
| attention_logits_transposed=attention_logits_transposed, |
| scope=scope, |
| encoder_lengths=encoder_lengths, |
| ) |
| |
| # [batch_size, encoder_output_dim, 1] |
| attention_weighted_encoder_context = _calc_weighted_context( |
| model=model, |
| encoder_outputs_transposed=encoder_outputs_transposed, |
| encoder_output_dim=encoder_output_dim, |
| attention_weights_3d=attention_weights_3d, |
| scope=scope, |
| ) |
| return attention_weighted_encoder_context, attention_weights_3d, [] |
| |
| |
| def apply_soft_coverage_attention( |
| model, |
| encoder_output_dim, |
| encoder_outputs_transposed, |
| weighted_encoder_outputs, |
| decoder_hidden_state_t, |
| decoder_hidden_state_dim, |
| scope, |
| encoder_lengths, |
| coverage_t_prev, |
| coverage_weights, |
| ): |
| |
| weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match( |
| model=model, |
| input=decoder_hidden_state_t, |
| dim_in=decoder_hidden_state_dim, |
| dim_out=encoder_output_dim, |
| scope=scope, |
| name='weighted_decoder_hidden_state', |
| ) |
| |
| # [encoder_length, batch_size, encoder_output_dim] |
| decoder_hidden_encoder_outputs_sum_tmp = model.net.Add( |
| [weighted_encoder_outputs, weighted_decoder_hidden_state], |
| s(scope, 'decoder_hidden_encoder_outputs_sum_tmp'), |
| broadcast=1, |
| ) |
| # [batch_size, encoder_length] |
| coverage_t_prev_2d = model.net.Squeeze( |
| coverage_t_prev, |
| s(scope, 'coverage_t_prev_2d'), |
| dims=[0], |
| ) |
| # [encoder_length, batch_size] |
| coverage_t_prev_transposed = brew.transpose( |
| model, |
| coverage_t_prev_2d, |
| s(scope, 'coverage_t_prev_transposed'), |
| ) |
| |
| # [encoder_length, batch_size, encoder_output_dim] |
| scaled_coverage_weights = model.net.Mul( |
| [coverage_weights, coverage_t_prev_transposed], |
| s(scope, 'scaled_coverage_weights'), |
| broadcast=1, |
| axis=0, |
| ) |
| |
| # [encoder_length, batch_size, encoder_output_dim] |
| decoder_hidden_encoder_outputs_sum = model.net.Add( |
| [decoder_hidden_encoder_outputs_sum_tmp, scaled_coverage_weights], |
| s(scope, 'decoder_hidden_encoder_outputs_sum'), |
| ) |
| |
| # [batch_size, encoder_length, 1] |
| attention_logits_transposed = _calc_attention_logits_from_sum_match( |
| model=model, |
| decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum, |
| encoder_output_dim=encoder_output_dim, |
| scope=scope, |
| ) |
| |
| # [batch_size, encoder_length, 1] |
| attention_weights_3d = _calc_attention_weights( |
| model=model, |
| attention_logits_transposed=attention_logits_transposed, |
| scope=scope, |
| encoder_lengths=encoder_lengths, |
| ) |
| |
| # [batch_size, encoder_output_dim, 1] |
| attention_weighted_encoder_context = _calc_weighted_context( |
| model=model, |
| encoder_outputs_transposed=encoder_outputs_transposed, |
| encoder_output_dim=encoder_output_dim, |
| attention_weights_3d=attention_weights_3d, |
| scope=scope, |
| ) |
| |
| # [batch_size, encoder_length] |
| attention_weights_2d = model.net.Squeeze( |
| attention_weights_3d, |
| s(scope, 'attention_weights_2d'), |
| dims=[2], |
| ) |
| |
| coverage_t = model.net.Add( |
| [coverage_t_prev, attention_weights_2d], |
| s(scope, 'coverage_t'), |
| broadcast=1, |
| ) |
| |
| return ( |
| attention_weighted_encoder_context, |
| attention_weights_3d, |
| [decoder_hidden_encoder_outputs_sum], |
| coverage_t, |
| ) |