VISOR-GPT / train /tencentpretrain /layers /multi_headed_attn.py
szukevin's picture
upload
7900c16
import math
import torch
import torch.nn as nn
from tencentpretrain.utils.rope import apply_rotary_emb
class MultiHeadedAttention(nn.Module):
"""
Each head is a self-attention operation.
self-attention refers to https://arxiv.org/pdf/1706.03762.pdf
"""
def __init__(self, hidden_size, heads_num, attention_head_size, dropout, has_bias=True, with_scale=True):
super(MultiHeadedAttention, self).__init__()
self.heads_num = heads_num
self.per_head_size = attention_head_size
self.with_scale = with_scale
self.inner_hidden_size = heads_num * attention_head_size
self.linear_layers = nn.ModuleList(
[nn.Linear(hidden_size, self.inner_hidden_size, bias=has_bias) for _ in range(3)]
)
self.dropout = nn.Dropout(dropout)
self.final_linear = nn.Linear(self.inner_hidden_size, hidden_size, bias=has_bias)
def forward(self, key, value, query, mask, position_bias=None, has_residual_attention=False, prev_attn=None, freqs_cis=None):
"""
Args:
key: [batch_size x seq_length x hidden_size]
value: [batch_size x seq_length x hidden_size]
query: [batch_size x seq_length x hidden_size]
mask: [batch_size x 1 x seq_length x seq_length]
position_bias: [1 x heads_num x seq_length x seq_length]
Returns:
output: [batch_size x seq_length x hidden_size]
"""
batch_size, seq_length, _ = query.size()
heads_num = self.heads_num
per_head_size = self.per_head_size
def shape(x):
return x. \
contiguous(). \
view(batch_size, seq_length, heads_num, per_head_size). \
transpose(1, 2)
def unshape(x):
return x. \
transpose(1, 2). \
contiguous(). \
view(batch_size, seq_length, self.inner_hidden_size)
query, key, value = [l(x). \
view(batch_size, -1, heads_num, per_head_size). \
transpose(1, 2) \
for l, x in zip(self.linear_layers, (query, key, value))
]
if freqs_cis is not None:
query, key = apply_rotary_emb(query.transpose(1,2), key.transpose(1,2), freqs_cis=freqs_cis)
scores = torch.matmul(query, key.transpose(-2, -1))
if position_bias is not None:
scores = scores + position_bias
if self.with_scale:
scores = scores / math.sqrt(float(per_head_size))
scores = scores + mask.type_as(scores)
prev_attn_out = None
if has_residual_attention:
if prev_attn is not None:
scores += prev_attn
prev_attn_out = scores
probs = nn.Softmax(dim=-1)(scores)
probs = self.dropout(probs)
output = unshape(torch.matmul(probs, value))
output = self.final_linear(output)
return output, prev_attn_out