Spaces:
Build error
Build error
# coding:utf-8 | |
# chenjun | |
# date:2020-04-18 | |
import torch.nn as nn | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
# def get_non_pad_mask(seq, PAD): | |
# assert seq.dim() == 2 | |
# return seq.ne(PAD).type(torch.float).unsqueeze(-1) | |
def get_pad_mask(seq, pad_idx): | |
return (seq == pad_idx).unsqueeze(-2) | |
def get_subsequent_mask(seq): | |
''' For masking out the subsequent info. ''' | |
sz_b, len_s = seq.size() | |
subsequent_mask = torch.triu( | |
torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1) # 返回上三角矩阵 | |
subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls | |
return subsequent_mask | |
def get_attn_key_pad_mask(seq_k, seq_q, PAD): | |
''' For masking out the padding part of key sequence. | |
seq_k:src_seq | |
seq_q:tgt_seq | |
''' | |
# Expand to fit the shape of key query attention matrix. | |
len_q = seq_q.size(1) # 目标序列 | |
padding_mask = seq_k.eq(PAD) # 源序列 | |
padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk | |
return padding_mask | |
class PositionalEncoding(nn.Module): | |
def __init__(self, d_hid, n_position=200): | |
super(PositionalEncoding, self).__init__() | |
# Not a parameter | |
self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid)) | |
def _get_sinusoid_encoding_table(self, n_position, d_hid): | |
''' Sinusoid position encoding table ''' | |
# TODO: make it with torch instead of numpy | |
def get_position_angle_vec(position): | |
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] | |
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) | |
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i | |
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 | |
return torch.FloatTensor(sinusoid_table).unsqueeze(0) | |
def forward(self, x): | |
return x + self.pos_table[:, :x.size(1)].clone().detach() | |
class ScaledDotProductAttention(nn.Module): | |
''' Scaled Dot-Product Attention ''' | |
def __init__(self, temperature, attn_dropout=0.1): | |
super(ScaledDotProductAttention, self).__init__() | |
self.temperature = temperature | |
self.dropout = nn.Dropout(attn_dropout) | |
self.softmax = nn.Softmax(dim=2) | |
def forward(self, q, k, v, mask=None): | |
attn = torch.bmm(q, k.transpose(1, 2)) | |
attn = attn / self.temperature | |
if mask is not None: | |
# print(mask.shape, attn.shape, v.shape) | |
attn = attn.masked_fill(mask, -1e9) | |
attn = self.softmax(attn) # 第3个维度为权重 | |
attn = self.dropout(attn) | |
output = torch.bmm(attn, v) | |
return output, attn | |
class MultiHeadAttention(nn.Module): | |
''' Multi-Head Attention module ''' | |
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): | |
super(MultiHeadAttention, self).__init__() | |
self.n_head = n_head | |
self.d_k = d_k | |
self.d_v = d_v | |
self.w_qs = nn.Linear(d_model, n_head * d_k) | |
self.w_ks = nn.Linear(d_model, n_head * d_k) | |
self.w_vs = nn.Linear(d_model, n_head * d_v) | |
nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) | |
nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) | |
nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) | |
self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) | |
self.layer_norm = nn.LayerNorm(d_model) | |
self.fc = nn.Linear(n_head * d_v, d_model) | |
nn.init.xavier_normal_(self.fc.weight) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, q, k, v, mask=None): | |
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head | |
sz_b, len_q, _ = q.size() | |
sz_b, len_k, _ = k.size() | |
sz_b, len_v, _ = v.size() | |
residual = q | |
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) # 4*21*512 ---- 4*21*8*64 | |
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) | |
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) | |
q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk | |
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk | |
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv | |
mask = mask.repeat(n_head, 1, 1) if mask is not None else None # (n*b) x .. x .. | |
output, attn = self.attention(q, k, v, mask=mask) | |
output = output.view(n_head, sz_b, len_q, d_v) | |
output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) | |
output = self.dropout(self.fc(output)) | |
output = self.layer_norm(output + residual) | |
return output, attn | |
class PositionwiseFeedForward(nn.Module): | |
''' A two-feed-forward-layer module ''' | |
def __init__(self, d_in, d_hid, dropout=0.1): | |
super(PositionwiseFeedForward, self).__init__() | |
self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise | |
self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise | |
self.layer_norm = nn.LayerNorm(d_in) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x): | |
residual = x | |
output = x.transpose(1, 2) | |
output = self.w_2(F.relu(self.w_1(output))) | |
output = output.transpose(1, 2) | |
output = self.dropout(output) | |
output = self.layer_norm(output + residual) | |
return output | |
class EncoderLayer(nn.Module): | |
''' Compose with two layers ''' | |
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): | |
super(EncoderLayer, self).__init__() | |
self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) | |
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) | |
def forward(self, enc_input, slf_attn_mask=None): | |
enc_output, enc_slf_attn = self.slf_attn( | |
enc_input, enc_input, enc_input, mask=slf_attn_mask) | |
enc_output = self.pos_ffn(enc_output) | |
return enc_output, enc_slf_attn | |
class Torch_transformer_encoder(nn.Module): | |
''' | |
use pytorch transformer for sequence learning | |
''' | |
def __init__(self, d_word_vec=512, n_layers=2, n_head=8, d_model=512, dim_feedforward=1024, n_position=256): | |
super(Torch_transformer_encoder, self).__init__() | |
self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position) | |
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_head, dim_feedforward=dim_feedforward) | |
self.layer_norm = nn.LayerNorm(d_model) | |
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers, norm=self.layer_norm) | |
self.dropout = nn.Dropout(p=0.1) | |
def forward(self, cnn_feature, src_mask=None, return_attns=False): | |
enc_slf_attn_list = [] | |
# -- Forward | |
enc_output = self.dropout(self.position_enc(cnn_feature)) # position embeding | |
enc_output = self.encoder(enc_output) | |
enc_output = self.layer_norm(enc_output) | |
if return_attns: | |
return enc_output, enc_slf_attn_list | |
return enc_output, | |
class Transforme_Encoder(nn.Module): | |
''' to capture the global spatial dependencies''' | |
''' | |
d_word_vec: 位置编码,特征空间维度 | |
n_layers: transformer的层数 | |
n_head:多头数量 | |
d_k: 64 | |
d_v: 64 | |
d_model: 512, | |
d_inner: 1024 | |
n_position: 位置编码的最大值 | |
''' | |
def __init__( | |
self, d_word_vec=512, n_layers=2, n_head=8, d_k=64, d_v=64, | |
d_model=512, d_inner=1024, dropout=0.1, n_position=256): | |
super().__init__() | |
self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position) | |
self.dropout = nn.Dropout(p=dropout) | |
self.layer_stack = nn.ModuleList([ | |
EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) | |
for _ in range(n_layers)]) | |
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) | |
def forward(self, cnn_feature, src_mask, return_attns=False): | |
enc_slf_attn_list = [] | |
# -- Forward | |
enc_output = self.dropout(self.position_enc(cnn_feature)) # position embeding | |
for enc_layer in self.layer_stack: | |
enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=src_mask) | |
enc_slf_attn_list += [enc_slf_attn] if return_attns else [] | |
enc_output = self.layer_norm(enc_output) | |
if return_attns: | |
return enc_output, enc_slf_attn_list | |
return enc_output, | |
class PVAM(nn.Module): | |
''' Parallel Visual attention module 平行解码''' | |
''' | |
n_dim:512,阅读顺序序列编码的空间维度 | |
N_max_character: 25,单张图片最多有多少个字符 | |
n_position: cnn出来之后特征的序列长度 | |
''' | |
def __init__(self, n_dim=512, N_max_character=25, n_position=256): | |
super(PVAM, self).__init__() | |
self.character_len = N_max_character | |
self.f0_embedding = nn.Embedding(N_max_character, n_dim) | |
self.w0 = nn.Linear(N_max_character, n_position) | |
self.wv = nn.Linear(n_dim, n_dim) | |
# first linear(512,25) | |
self.we = nn.Linear(n_dim, N_max_character) | |
self.active = nn.Tanh() | |
self.softmax = nn.Softmax(dim=2) | |
def forward(self, enc_output): | |
reading_order = torch.arange(self.character_len, dtype=torch.long, device=enc_output.device) | |
reading_order = reading_order.unsqueeze(0).expand(enc_output.size(0), -1) # (S,) -> (B, S) | |
reading_order_embed = self.f0_embedding(reading_order) # b,25,512 | |
t = self.w0(reading_order_embed.permute(0,2,1)) # b,512,256 | |
t = self.active(t.permute(0,2,1) + self.wv(enc_output)) # b,256,512 | |
# first linear(512,25) | |
attn = self.we(t) # b,256,25 | |
attn = self.softmax(attn.permute(0,2,1)) # b,25,256 | |
g_output = torch.bmm(attn, enc_output) # b,25,512 | |
return g_output | |
class GSRM(nn.Module): | |
# global semantic reasoning module | |
''' | |
n_dim:embed编码的特征空间维度 | |
n_class:embedding需要用到 | |
PAD:计算mask用到 | |
''' | |
def __init__(self, n_dim=512, n_class=37, PAD=37-1, n_layers=4, n_position=25): | |
super(GSRM, self).__init__() | |
self.PAD = PAD | |
self.argmax_embed = nn.Embedding(n_class, n_dim) | |
self.transformer_units = Transforme_Encoder(n_layers=n_layers, n_position=n_position) # for global context information | |
# self.transformer_units = Torch_transformer_encoder(n_layers=n_layers, n_position=n_position) | |
def forward(self, e_out): | |
''' | |
e_out: b,25,37 | the output from PVAM3 | |
''' | |
e_argmax = e_out.argmax(dim=-1) # b, 25 | |
e = self.argmax_embed(e_argmax) # b,25,512 | |
e_mask = get_pad_mask(e_argmax, self.PAD) # b,25,1 | |
s = self.transformer_units(e, None) # b,25,512 | |
return s | |
class SRN_Decoder(nn.Module): | |
# the wrapper of decoder layers | |
''' | |
n_dim: 特征空间维度 | |
n_class:字符种类 | |
N_max_character: 单张图最多只25个字符 | |
n_position:cnn输出的特征序列长度 | |
整个有三个部分的输出 | |
''' | |
def __init__(self, n_dim=512, n_class=37, N_max_character=25, n_position=256, GSRM_layer=4 ): | |
super(SRN_Decoder, self).__init__() | |
self.pvam = PVAM(N_max_character=N_max_character, n_position=n_position) | |
self.w_e = nn.Linear(n_dim, n_class) # output layer | |
self.GSRM = GSRM(n_class=n_class, PAD=n_class-1, n_dim=n_dim, n_position=N_max_character, n_layers=GSRM_layer) | |
self.w_s = nn.Linear(n_dim, n_class) # output layer | |
self.w_f = nn.Linear(n_dim, n_class) # output layer | |
def forward(self, cnn_feature ): | |
'''cnn_feature: b,256,512 | the output from cnn''' | |
g_output = self.pvam(cnn_feature) # b,25,512 | |
e_out = self.w_e(g_output) # b,25,37 ----> cross entropy loss | 第一个输出 | |
s = self.GSRM(e_out)[0] # b,25,512 | |
s_out = self.w_s(s) # b,25,37f | |
# TODO:change the add to gated unit | |
f = g_output + s # b,25,512 | |
f_out = self.w_f(f) | |
return e_out, s_out, f_out | |
def cal_performance(preds, gold, mask=None, smoothing='1'): | |
''' Apply label smoothing if needed ''' | |
loss = 0. | |
n_correct = 0 | |
weights = [1.0, 0.15, 2.0] | |
for ori_pred, weight in zip(preds, weights): | |
pred = ori_pred.view(-1, ori_pred.shape[-1]) | |
# debug show | |
t_gold = gold.view(ori_pred.shape[0], -1) | |
t_pred_index = ori_pred.max(2)[1] | |
mask = mask.view(-1) | |
non_pad_mask = mask.ne(0) if mask is not None else None | |
tloss = cal_loss(pred, gold, non_pad_mask, smoothing) | |
if torch.isnan(tloss): | |
print('have nan loss') | |
continue | |
else: | |
loss += tloss * weight | |
pred = pred.max(1)[1] | |
gold = gold.contiguous().view(-1) | |
n_correct = pred.eq(gold) | |
n_correct = n_correct.masked_select(non_pad_mask).sum().item() if mask is not None else None | |
return loss, n_correct | |
def cal_loss(pred, gold, mask, smoothing): | |
''' Calculate cross entropy loss, apply label smoothing if needed. ''' | |
gold = gold.contiguous().view(-1) | |
if smoothing=='0': | |
eps = 0.1 | |
n_class = pred.size(1) | |
one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1) | |
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) | |
log_prb = F.log_softmax(pred, dim=1) | |
non_pad_mask = gold.ne(0) | |
loss = -(one_hot * log_prb).sum(dim=1) | |
loss = loss.masked_select(non_pad_mask).sum() # average later | |
elif smoothing == '1': | |
if mask is not None: | |
loss = F.cross_entropy(pred, gold, reduction='none') | |
loss = loss.masked_select(mask) | |
loss = loss.sum() / mask.sum() | |
else: | |
loss = F.cross_entropy(pred, gold) | |
else: | |
# loss = F.cross_entropy(pred, gold, ignore_index=PAD) | |
loss = F.cross_entropy(pred, gold) | |
return loss | |
def cal_performance2(preds, gold, PAD, smoothing='1'): | |
''' Apply label smoothing if needed ''' | |
loss = 0. | |
n_correct = 0 | |
weights = [1.0, 0.15, 2.0] | |
for ori_pred, weight in zip(preds, weights): | |
pred = ori_pred.view(-1, ori_pred.shape[-1]) | |
# debug show | |
t_gold = gold.view(ori_pred.shape[0], -1) | |
t_pred_index = ori_pred.max(2)[1] | |
tloss = cal_loss2(pred, gold, PAD, smoothing=smoothing) | |
if torch.isnan(tloss): | |
print('have nan loss') | |
continue | |
else: | |
loss += tloss * weight | |
pred = pred.max(1)[1] | |
gold = gold.contiguous().view(-1) | |
n_correct = pred.eq(gold) | |
non_pad_mask = gold.ne(PAD) | |
n_correct = n_correct.masked_select(non_pad_mask).sum().item() | |
return loss, n_correct | |
def cal_loss2(pred, gold, PAD, smoothing='1'): | |
''' Calculate cross entropy loss, apply label smoothing if needed. ''' | |
gold = gold.contiguous().view(-1) | |
if smoothing=='0': | |
eps = 0.1 | |
n_class = pred.size(1) | |
one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1) | |
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) | |
log_prb = F.log_softmax(pred, dim=1) | |
non_pad_mask = gold.ne(0) | |
loss = -(one_hot * log_prb).sum(dim=1) | |
loss = loss.masked_select(non_pad_mask).sum() # average later | |
elif smoothing == '1': | |
loss = F.cross_entropy(pred, gold, ignore_index=PAD) | |
else: | |
# loss = F.cross_entropy(pred, gold, ignore_index=PAD) | |
loss = F.cross_entropy(pred, gold) | |
return loss | |
if __name__=='__main__': | |
cnn_feature = torch.rand((2,256,512)) | |
model1 = Transforme_Encoder() | |
image = model1(cnn_feature,src_mask=None)[0] | |
model = SRN_Decoder(N_max_character=30) | |
outs = model(image) | |
for out in outs: | |
print(out.shape) | |
# image = torch.rand((4,3,32,60)) | |
# tgt_seq = torch.tensor([[ 2, 24, 2176, 882, 2480, 612, 1525, 480, 875, 147, 1700, 715, | |
# 1465, 3], | |
# [ 2, 369, 1781, 882, 703, 879, 2855, 2415, 502, 1154, 833, 1465, | |
# 3, 0], | |
# [ 2, 2943, 334, 328, 480, 330, 1644, 1449, 163, 147, 1823, 1184, | |
# 1465, 3], | |
# [ 2, 24, 396, 480, 703, 1646, 897, 1711, 1508, 703, 2321, 147, | |
# 642, 1465]], device='cuda:0') | |
# tgt_pos = torch.tensor([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], | |
# [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0], | |
# [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], | |
# [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]], | |
# device='cuda:0') | |
# src_seq = torch.tensor([[ 2, 598, 2088, 822, 2802, 1156, 157, 1099, 1000, 598, 1707, 1345, | |
# 3, 0, 0, 0], | |
# [ 2, 598, 2348, 822, 598, 1222, 471, 948, 986, 423, 1345, 3, | |
# 0, 0, 0, 0], | |
# [ 2, 2437, 2470, 901, 2473, 598, 1735, 84, 1, 2277, 1979, 499, | |
# 962, 1345, 3, 0], | |
# [ 2, 598, 186, 1904, 598, 868, 1339, 1604, 84, 598, 608, 1728, | |
# 1345, 3, 0, 0]], device='cuda:0') | |
# device = torch.device('cuda') | |
# image = image.cuda() | |
# transformer = Transformer() | |
# transformer = transformer.to(device) | |
# transformer.train() | |
# out = transformer(image, tgt_seq, tgt_pos, src_seq) | |
# gold = tgt_seq[:, 1:] # 从第二列开始 | |
# # backward | |
# loss, n_correct = cal_performance(out, gold, smoothing=True) | |
# print(loss, n_correct) | |
# a = 1 |