Spaces:
Runtime error
Runtime error
File size: 3,071 Bytes
7900c16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import math
import torch
import torch.nn as nn
from tencentpretrain.utils.rope import apply_rotary_emb
class MultiHeadedAttention(nn.Module):
"""
Each head is a self-attention operation.
self-attention refers to https://arxiv.org/pdf/1706.03762.pdf
"""
def __init__(self, hidden_size, heads_num, attention_head_size, dropout, has_bias=True, with_scale=True):
super(MultiHeadedAttention, self).__init__()
self.heads_num = heads_num
self.per_head_size = attention_head_size
self.with_scale = with_scale
self.inner_hidden_size = heads_num * attention_head_size
self.linear_layers = nn.ModuleList(
[nn.Linear(hidden_size, self.inner_hidden_size, bias=has_bias) for _ in range(3)]
)
self.dropout = nn.Dropout(dropout)
self.final_linear = nn.Linear(self.inner_hidden_size, hidden_size, bias=has_bias)
def forward(self, key, value, query, mask, position_bias=None, has_residual_attention=False, prev_attn=None, freqs_cis=None):
"""
Args:
key: [batch_size x seq_length x hidden_size]
value: [batch_size x seq_length x hidden_size]
query: [batch_size x seq_length x hidden_size]
mask: [batch_size x 1 x seq_length x seq_length]
position_bias: [1 x heads_num x seq_length x seq_length]
Returns:
output: [batch_size x seq_length x hidden_size]
"""
batch_size, seq_length, _ = query.size()
heads_num = self.heads_num
per_head_size = self.per_head_size
def shape(x):
return x. \
contiguous(). \
view(batch_size, seq_length, heads_num, per_head_size). \
transpose(1, 2)
def unshape(x):
return x. \
transpose(1, 2). \
contiguous(). \
view(batch_size, seq_length, self.inner_hidden_size)
query, key, value = [l(x). \
view(batch_size, -1, heads_num, per_head_size). \
transpose(1, 2) \
for l, x in zip(self.linear_layers, (query, key, value))
]
if freqs_cis is not None:
query, key = apply_rotary_emb(query.transpose(1,2), key.transpose(1,2), freqs_cis=freqs_cis)
scores = torch.matmul(query, key.transpose(-2, -1))
if position_bias is not None:
scores = scores + position_bias
if self.with_scale:
scores = scores / math.sqrt(float(per_head_size))
scores = scores + mask.type_as(scores)
prev_attn_out = None
if has_residual_attention:
if prev_attn is not None:
scores += prev_attn
prev_attn_out = scores
probs = nn.Softmax(dim=-1)(scores)
probs = self.dropout(probs)
output = unshape(torch.matmul(probs, value))
output = self.final_linear(output)
return output, prev_attn_out
|