Spaces:
Running
Running
import math | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from openrec.modeling.common import Mlp | |
class NRTRDecoder(nn.Module): | |
"""A transformer model. User is able to modify the attributes as needed. | |
The architechture is based on the paper "Attention Is All You Need". Ashish | |
Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N | |
Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you | |
need. In Advances in Neural Information Processing Systems, pages | |
6000-6010. | |
Args: | |
d_model: the number of expected features in the encoder/decoder inputs (default=512). | |
nhead: the number of heads in the multiheadattention models (default=8). | |
num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6). | |
num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6). | |
dim_feedforward: the dimension of the feedforward network model (default=2048). | |
dropout: the dropout value (default=0.1). | |
custom_encoder: custom encoder (default=None). | |
custom_decoder: custom decoder (default=None). | |
""" | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
nhead=None, | |
num_encoder_layers=6, | |
beam_size=0, | |
num_decoder_layers=6, | |
max_len=25, | |
attention_dropout_rate=0.0, | |
residual_dropout_rate=0.1, | |
scale_embedding=True, | |
): | |
super(NRTRDecoder, self).__init__() | |
self.out_channels = out_channels | |
self.ignore_index = out_channels - 1 | |
self.bos = out_channels - 2 | |
self.eos = 0 | |
self.max_len = max_len | |
d_model = in_channels | |
dim_feedforward = d_model * 4 | |
nhead = nhead if nhead is not None else d_model // 32 | |
self.embedding = Embeddings( | |
d_model=d_model, | |
vocab=self.out_channels, | |
padding_idx=0, | |
scale_embedding=scale_embedding, | |
) | |
self.positional_encoding = PositionalEncoding( | |
dropout=residual_dropout_rate, dim=d_model) | |
if num_encoder_layers > 0: | |
self.encoder = nn.ModuleList([ | |
TransformerBlock( | |
d_model, | |
nhead, | |
dim_feedforward, | |
attention_dropout_rate, | |
residual_dropout_rate, | |
with_self_attn=True, | |
with_cross_attn=False, | |
) for i in range(num_encoder_layers) | |
]) | |
else: | |
self.encoder = None | |
self.decoder = nn.ModuleList([ | |
TransformerBlock( | |
d_model, | |
nhead, | |
dim_feedforward, | |
attention_dropout_rate, | |
residual_dropout_rate, | |
with_self_attn=True, | |
with_cross_attn=True, | |
) for i in range(num_decoder_layers) | |
]) | |
self.beam_size = beam_size | |
self.d_model = d_model | |
self.nhead = nhead | |
self.tgt_word_prj = nn.Linear(d_model, | |
self.out_channels - 2, | |
bias=False) | |
w0 = np.random.normal(0.0, d_model**-0.5, | |
(d_model, self.out_channels - 2)).astype( | |
np.float32) | |
self.tgt_word_prj.weight.data = torch.from_numpy(w0.transpose()) | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
nn.init.xavier_normal_(m.weight) | |
if m.bias is not None: | |
nn.init.zeros_(m.bias) | |
def forward_train(self, src, tgt): | |
tgt = tgt[:, :-1] | |
tgt = self.embedding(tgt) | |
tgt = self.positional_encoding(tgt) | |
tgt_mask = self.generate_square_subsequent_mask( | |
tgt.shape[1], device=src.get_device()) | |
if self.encoder is not None: | |
src = self.positional_encoding(src) | |
for encoder_layer in self.encoder: | |
src = encoder_layer(src) | |
memory = src # B N C | |
else: | |
memory = src # B N C | |
for decoder_layer in self.decoder: | |
tgt = decoder_layer(tgt, memory, self_mask=tgt_mask) | |
output = tgt | |
logit = self.tgt_word_prj(output) | |
return logit | |
def forward(self, src, data=None): | |
"""Take in and process masked source/target sequences. | |
Args: | |
src: the sequence to the encoder (required). | |
tgt: the sequence to the decoder (required). | |
Shape: | |
- src: :math:`(B, sN, C)`. | |
- tgt: :math:`(B, tN, C)`. | |
Examples: | |
>>> output = transformer_model(src, tgt) | |
""" | |
if self.training: | |
max_len = data[1].max() | |
tgt = data[0][:, :2 + max_len] | |
res = self.forward_train(src, tgt) | |
else: | |
res = self.forward_test(src) | |
return res | |
def forward_test(self, src): | |
bs = src.shape[0] | |
if self.encoder is not None: | |
src = self.positional_encoding(src) | |
for encoder_layer in self.encoder: | |
src = encoder_layer(src) | |
memory = src # B N C | |
else: | |
memory = src | |
dec_seq = torch.full((bs, self.max_len + 1), | |
self.ignore_index, | |
dtype=torch.int64, | |
device=src.get_device()) | |
dec_seq[:, 0] = self.bos | |
logits = [] | |
self.attn_maps = [] | |
for len_dec_seq in range(0, self.max_len): | |
dec_seq_embed = self.embedding( | |
dec_seq[:, :len_dec_seq + 1]) # N dim 26+10 # </s> 012 a | |
dec_seq_embed = self.positional_encoding(dec_seq_embed) | |
tgt_mask = self.generate_square_subsequent_mask( | |
dec_seq_embed.shape[1], src.get_device()) | |
tgt = dec_seq_embed # bs, 3, dim #bos, a, b, c, ... eos | |
for decoder_layer in self.decoder: | |
tgt = decoder_layer(tgt, memory, self_mask=tgt_mask) | |
self.attn_maps.append( | |
self.decoder[-1].cross_attn.attn_map[0][:, -1:, :]) | |
dec_output = tgt | |
dec_output = dec_output[:, -1:, :] | |
word_prob = F.softmax(self.tgt_word_prj(dec_output), dim=-1) | |
logits.append(word_prob) | |
if len_dec_seq < self.max_len: | |
# greedy decode. add the next token index to the target input | |
dec_seq[:, len_dec_seq + 1] = word_prob.squeeze().argmax(-1) | |
# Efficient batch decoding: If all output words have at least one EOS token, end decoding. | |
if (dec_seq == self.eos).any(dim=-1).all(): | |
break | |
logits = torch.cat(logits, dim=1) | |
return logits | |
def generate_square_subsequent_mask(self, sz, device): | |
"""Generate a square mask for the sequence. | |
The masked positions are filled with float('-inf'). Unmasked positions | |
are filled with float(0.0). | |
""" | |
mask = torch.zeros([sz, sz], dtype=torch.float32) | |
mask_inf = torch.triu( | |
torch.full((sz, sz), dtype=torch.float32, fill_value=-torch.inf), | |
diagonal=1, | |
) | |
mask = mask + mask_inf | |
return mask.unsqueeze(0).unsqueeze(0).to(device) | |
class MultiheadAttention(nn.Module): | |
def __init__(self, embed_dim, num_heads, dropout=0.0, self_attn=False): | |
super(MultiheadAttention, self).__init__() | |
self.embed_dim = embed_dim | |
self.num_heads = num_heads | |
self.head_dim = embed_dim // num_heads | |
assert (self.head_dim * num_heads == self.embed_dim | |
), 'embed_dim must be divisible by num_heads' | |
self.scale = self.head_dim**-0.5 | |
self.self_attn = self_attn | |
if self_attn: | |
self.qkv = nn.Linear(embed_dim, embed_dim * 3) | |
else: | |
self.q = nn.Linear(embed_dim, embed_dim) | |
self.kv = nn.Linear(embed_dim, embed_dim * 2) | |
self.attn_drop = nn.Dropout(dropout) | |
self.out_proj = nn.Linear(embed_dim, embed_dim) | |
def forward(self, query, key=None, attn_mask=None): | |
B, qN = query.shape[:2] | |
if self.self_attn: | |
qkv = self.qkv(query) | |
qkv = qkv.reshape(B, qN, 3, self.num_heads, | |
self.head_dim).permute(2, 0, 3, 1, 4) | |
q, k, v = qkv.unbind(0) | |
else: | |
kN = key.shape[1] | |
q = self.q(query) | |
q = q.reshape(B, qN, self.num_heads, self.head_dim).transpose(1, 2) | |
kv = self.kv(key) | |
kv = kv.reshape(B, kN, 2, self.num_heads, | |
self.head_dim).permute(2, 0, 3, 1, 4) | |
k, v = kv.unbind(0) | |
attn = (q.matmul(k.transpose(2, 3))) * self.scale | |
if attn_mask is not None: | |
attn += attn_mask | |
attn = F.softmax(attn, dim=-1) | |
if not self.training: | |
self.attn_map = attn | |
attn = self.attn_drop(attn) | |
x = (attn.matmul(v)).transpose(1, 2) | |
x = x.reshape(B, qN, self.embed_dim) | |
x = self.out_proj(x) | |
return x | |
class TransformerBlock(nn.Module): | |
def __init__( | |
self, | |
d_model, | |
nhead, | |
dim_feedforward=2048, | |
attention_dropout_rate=0.0, | |
residual_dropout_rate=0.1, | |
with_self_attn=True, | |
with_cross_attn=False, | |
epsilon=1e-5, | |
): | |
super(TransformerBlock, self).__init__() | |
self.with_self_attn = with_self_attn | |
if with_self_attn: | |
self.self_attn = MultiheadAttention(d_model, | |
nhead, | |
dropout=attention_dropout_rate, | |
self_attn=with_self_attn) | |
self.norm1 = nn.LayerNorm(d_model, eps=epsilon) | |
self.dropout1 = nn.Dropout(residual_dropout_rate) | |
self.with_cross_attn = with_cross_attn | |
if with_cross_attn: | |
self.cross_attn = MultiheadAttention( | |
d_model, nhead, dropout=attention_dropout_rate | |
) # for self_attn of encoder or cross_attn of decoder | |
self.norm2 = nn.LayerNorm(d_model, eps=epsilon) | |
self.dropout2 = nn.Dropout(residual_dropout_rate) | |
self.mlp = Mlp( | |
in_features=d_model, | |
hidden_features=dim_feedforward, | |
act_layer=nn.ReLU, | |
drop=residual_dropout_rate, | |
) | |
self.norm3 = nn.LayerNorm(d_model, eps=epsilon) | |
self.dropout3 = nn.Dropout(residual_dropout_rate) | |
def forward(self, tgt, memory=None, self_mask=None, cross_mask=None): | |
if self.with_self_attn: | |
tgt1 = self.self_attn(tgt, attn_mask=self_mask) | |
tgt = self.norm1(tgt + self.dropout1(tgt1)) | |
if self.with_cross_attn: | |
tgt2 = self.cross_attn(tgt, key=memory, attn_mask=cross_mask) | |
tgt = self.norm2(tgt + self.dropout2(tgt2)) | |
tgt = self.norm3(tgt + self.dropout3(self.mlp(tgt))) | |
return tgt | |
class PositionalEncoding(nn.Module): | |
"""Inject some information about the relative or absolute position of the | |
tokens in the sequence. The positional encodings have the same dimension as | |
the embeddings, so that the two can be summed. Here, we use sine and cosine | |
functions of different frequencies. | |
.. math:: | |
\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) | |
\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) | |
\text{where pos is the word position and i is the embed idx) | |
Args: | |
d_model: the embed dim (required). | |
dropout: the dropout value (default=0.1). | |
max_len: the max. length of the incoming sequence (default=5000). | |
Examples: | |
>>> pos_encoder = PositionalEncoding(d_model) | |
""" | |
def __init__(self, dropout, dim, max_len=5000): | |
super(PositionalEncoding, self).__init__() | |
self.dropout = nn.Dropout(p=dropout) | |
pe = torch.zeros([max_len, dim]) | |
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) | |
div_term = torch.exp( | |
torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim)) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
pe = torch.unsqueeze(pe, 0) | |
# pe = torch.permute(pe, [1, 0, 2]) | |
self.register_buffer('pe', pe) | |
def forward(self, x): | |
"""Inputs of forward function | |
Args: | |
x: the sequence fed to the positional encoder model (required). | |
Shape: | |
x: [sequence length, batch size, embed dim] | |
output: [sequence length, batch size, embed dim] | |
Examples: | |
>>> output = pos_encoder(x) | |
""" | |
# x = x.permute([1, 0, 2]) | |
# x = x + self.pe[:x.shape[0], :] | |
x = x + self.pe[:, :x.shape[1], :] | |
return self.dropout(x) # .permute([1, 0, 2]) | |
class PositionalEncoding_2d(nn.Module): | |
"""Inject some information about the relative or absolute position of the | |
tokens in the sequence. The positional encodings have the same dimension as | |
the embeddings, so that the two can be summed. Here, we use sine and cosine | |
functions of different frequencies. | |
.. math:: | |
\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) | |
\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) | |
\text{where pos is the word position and i is the embed idx) | |
Args: | |
d_model: the embed dim (required). | |
dropout: the dropout value (default=0.1). | |
max_len: the max. length of the incoming sequence (default=5000). | |
Examples: | |
>>> pos_encoder = PositionalEncoding(d_model) | |
""" | |
def __init__(self, dropout, dim, max_len=5000): | |
super(PositionalEncoding_2d, self).__init__() | |
self.dropout = nn.Dropout(p=dropout) | |
pe = torch.zeros([max_len, dim]) | |
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) | |
div_term = torch.exp( | |
torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim)) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
pe = torch.permute(torch.unsqueeze(pe, 0), [1, 0, 2]) | |
self.register_buffer('pe', pe) | |
self.avg_pool_1 = nn.AdaptiveAvgPool2d((1, 1)) | |
self.linear1 = nn.Linear(dim, dim) | |
self.linear1.weight.data.fill_(1.0) | |
self.avg_pool_2 = nn.AdaptiveAvgPool2d((1, 1)) | |
self.linear2 = nn.Linear(dim, dim) | |
self.linear2.weight.data.fill_(1.0) | |
def forward(self, x): | |
"""Inputs of forward function | |
Args: | |
x: the sequence fed to the positional encoder model (required). | |
Shape: | |
x: [sequence length, batch size, embed dim] | |
output: [sequence length, batch size, embed dim] | |
Examples: | |
>>> output = pos_encoder(x) | |
""" | |
w_pe = self.pe[:x.shape[-1], :] | |
w1 = self.linear1(self.avg_pool_1(x).squeeze()).unsqueeze(0) | |
w_pe = w_pe * w1 | |
w_pe = torch.permute(w_pe, [1, 2, 0]) | |
w_pe = torch.unsqueeze(w_pe, 2) | |
h_pe = self.pe[:x.shape[-2], :] | |
w2 = self.linear2(self.avg_pool_2(x).squeeze()).unsqueeze(0) | |
h_pe = h_pe * w2 | |
h_pe = torch.permute(h_pe, [1, 2, 0]) | |
h_pe = torch.unsqueeze(h_pe, 3) | |
x = x + w_pe + h_pe | |
x = torch.permute( | |
torch.reshape(x, | |
[x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]), | |
[2, 0, 1], | |
) | |
return self.dropout(x) | |
class Embeddings(nn.Module): | |
def __init__(self, d_model, vocab, padding_idx=None, scale_embedding=True): | |
super(Embeddings, self).__init__() | |
self.embedding = nn.Embedding(vocab, d_model, padding_idx=padding_idx) | |
self.embedding.weight.data.normal_(mean=0.0, std=d_model**-0.5) | |
self.d_model = d_model | |
self.scale_embedding = scale_embedding | |
def forward(self, x): | |
if self.scale_embedding: | |
x = self.embedding(x) | |
return x * math.sqrt(self.d_model) | |
return self.embedding(x) | |