| # Copyright (c) Facebook, Inc. and its affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| import torch |
| from torch import nn |
| import math |
| |
| class BertSelfAttention(nn.Module): |
| def __init__(self, hidden_size, num_attention_heads, |
| attention_probs_dropout_prob, |
| position_embedding_type=None, max_position_embeddings=None): |
| super().__init__() |
| if hidden_size % num_attention_heads != 0: |
| raise ValueError( |
| f"The hidden size ({hidden_size}) is not a multiple of the number of attention " |
| f"heads ({num_attention_heads})" |
| ) |
| |
| self.num_attention_heads = num_attention_heads |
| self.attention_head_size = int(hidden_size / num_attention_heads) |
| self.all_head_size = self.num_attention_heads * self.attention_head_size |
| |
| self.query = nn.Linear(hidden_size, self.all_head_size) |
| self.key = nn.Linear(hidden_size, self.all_head_size) |
| self.value = nn.Linear(hidden_size, self.all_head_size) |
| |
| self.dropout = nn.Dropout(attention_probs_dropout_prob) |
| self.position_embedding_type = position_embedding_type |
| |
| if self.position_embedding_type is not None: |
| assert max_position_embeddings is not None |
| self.max_position_embeddings = max_position_embeddings |
| self.distance_embedding = nn.Embedding(2 * max_position_embeddings - 1, self.attention_head_size) |
| |
| def transpose_for_scores(self, x): |
| new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) |
| x = x.view(*new_x_shape) |
| return x.permute(0, 2, 1, 3) |
| |
| def forward( |
| self, |
| hidden_states, |
| past_key_value=None, |
| ): |
| q = self.query(hidden_states) |
| k = self.key(hidden_states) |
| v = self.value(hidden_states) |
| |
| q = self.transpose_for_scores(q) |
| k = self.transpose_for_scores(k) |
| v = self.transpose_for_scores(v) |
| |
| if past_key_value is not None: |
| k = torch.cat([past_key_value[0], k], dim=2) |
| v = torch.cat([past_key_value[1], v], dim=2) |
| |
| |
| # Take the dot product between "query" and "key" to get the raw attention scores. |
| attention_scores = torch.matmul(q, k.transpose(-1, -2)) |
| attention_scores = attention_scores / math.sqrt(self.attention_head_size) |
| |
| if self.position_embedding_type is not None: |
| seq_length = hidden_states.size()[1] |
| position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) |
| position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) |
| distance = position_ids_l - position_ids_r |
| positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) |
| positional_embedding = positional_embedding.to(dtype=q.dtype) # fp16 compatibility |
| |
| if self.position_embedding_type == "relative_key": |
| relative_position_scores = torch.einsum("bhld,lrd->bhlr", q, positional_embedding) |
| attention_scores = attention_scores + relative_position_scores |
| elif self.position_embedding_type == "relative_key_query": |
| relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", q, positional_embedding) |
| relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", k, positional_embedding) |
| attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key |
| |
| attention_probs = attention_scores |
| # Normalize the attention scores to probabilities. |
| attention_probs = nn.functional.softmax(attention_scores, dim=-1) |
| # # This is actually dropping out entire tokens to attend to, which might |
| # # seem a bit unusual, but is taken from the original Transformer paper. |
| attention_probs = self.dropout(attention_probs) |
| |
| |
| context_layer = torch.matmul(attention_probs, v) |
| |
| context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
| new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) |
| context_layer = context_layer.view(*new_context_layer_shape) |
| return context_layer |