strexp / modules_srn /SRN_modules.py
markytools's picture
added strexp
d61b9c7
raw
history blame
18 kB
# coding:utf-8
# chenjun
# date:2020-04-18
import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
# def get_non_pad_mask(seq, PAD):
# assert seq.dim() == 2
# return seq.ne(PAD).type(torch.float).unsqueeze(-1)
def get_pad_mask(seq, pad_idx):
return (seq == pad_idx).unsqueeze(-2)
def get_subsequent_mask(seq):
''' For masking out the subsequent info. '''
sz_b, len_s = seq.size()
subsequent_mask = torch.triu(
torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1) # 返回上三角矩阵
subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls
return subsequent_mask
def get_attn_key_pad_mask(seq_k, seq_q, PAD):
''' For masking out the padding part of key sequence.
seq_k:src_seq
seq_q:tgt_seq
'''
# Expand to fit the shape of key query attention matrix.
len_q = seq_q.size(1) # 目标序列
padding_mask = seq_k.eq(PAD) # 源序列
padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk
return padding_mask
class PositionalEncoding(nn.Module):
def __init__(self, d_hid, n_position=200):
super(PositionalEncoding, self).__init__()
# Not a parameter
self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
def _get_sinusoid_encoding_table(self, n_position, d_hid):
''' Sinusoid position encoding table '''
# TODO: make it with torch instead of numpy
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
def forward(self, x):
return x + self.pos_table[:, :x.size(1)].clone().detach()
class ScaledDotProductAttention(nn.Module):
''' Scaled Dot-Product Attention '''
def __init__(self, temperature, attn_dropout=0.1):
super(ScaledDotProductAttention, self).__init__()
self.temperature = temperature
self.dropout = nn.Dropout(attn_dropout)
self.softmax = nn.Softmax(dim=2)
def forward(self, q, k, v, mask=None):
attn = torch.bmm(q, k.transpose(1, 2))
attn = attn / self.temperature
if mask is not None:
# print(mask.shape, attn.shape, v.shape)
attn = attn.masked_fill(mask, -1e9)
attn = self.softmax(attn) # 第3个维度为权重
attn = self.dropout(attn)
output = torch.bmm(attn, v)
return output, attn
class MultiHeadAttention(nn.Module):
''' Multi-Head Attention module '''
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
super(MultiHeadAttention, self).__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
self.w_qs = nn.Linear(d_model, n_head * d_k)
self.w_ks = nn.Linear(d_model, n_head * d_k)
self.w_vs = nn.Linear(d_model, n_head * d_v)
nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
self.layer_norm = nn.LayerNorm(d_model)
self.fc = nn.Linear(n_head * d_v, d_model)
nn.init.xavier_normal_(self.fc.weight)
self.dropout = nn.Dropout(dropout)
def forward(self, q, k, v, mask=None):
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_q, _ = q.size()
sz_b, len_k, _ = k.size()
sz_b, len_v, _ = v.size()
residual = q
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) # 4*21*512 ---- 4*21*8*64
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
mask = mask.repeat(n_head, 1, 1) if mask is not None else None # (n*b) x .. x ..
output, attn = self.attention(q, k, v, mask=mask)
output = output.view(n_head, sz_b, len_q, d_v)
output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)
output = self.dropout(self.fc(output))
output = self.layer_norm(output + residual)
return output, attn
class PositionwiseFeedForward(nn.Module):
''' A two-feed-forward-layer module '''
def __init__(self, d_in, d_hid, dropout=0.1):
super(PositionwiseFeedForward, self).__init__()
self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise
self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise
self.layer_norm = nn.LayerNorm(d_in)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
residual = x
output = x.transpose(1, 2)
output = self.w_2(F.relu(self.w_1(output)))
output = output.transpose(1, 2)
output = self.dropout(output)
output = self.layer_norm(output + residual)
return output
class EncoderLayer(nn.Module):
''' Compose with two layers '''
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
super(EncoderLayer, self).__init__()
self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
def forward(self, enc_input, slf_attn_mask=None):
enc_output, enc_slf_attn = self.slf_attn(
enc_input, enc_input, enc_input, mask=slf_attn_mask)
enc_output = self.pos_ffn(enc_output)
return enc_output, enc_slf_attn
class Torch_transformer_encoder(nn.Module):
'''
use pytorch transformer for sequence learning
'''
def __init__(self, d_word_vec=512, n_layers=2, n_head=8, d_model=512, dim_feedforward=1024, n_position=256):
super(Torch_transformer_encoder, self).__init__()
self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position)
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_head, dim_feedforward=dim_feedforward)
self.layer_norm = nn.LayerNorm(d_model)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers, norm=self.layer_norm)
self.dropout = nn.Dropout(p=0.1)
def forward(self, cnn_feature, src_mask=None, return_attns=False):
enc_slf_attn_list = []
# -- Forward
enc_output = self.dropout(self.position_enc(cnn_feature)) # position embeding
enc_output = self.encoder(enc_output)
enc_output = self.layer_norm(enc_output)
if return_attns:
return enc_output, enc_slf_attn_list
return enc_output,
class Transforme_Encoder(nn.Module):
''' to capture the global spatial dependencies'''
'''
d_word_vec: 位置编码,特征空间维度
n_layers: transformer的层数
n_head:多头数量
d_k: 64
d_v: 64
d_model: 512,
d_inner: 1024
n_position: 位置编码的最大值
'''
def __init__(
self, d_word_vec=512, n_layers=2, n_head=8, d_k=64, d_v=64,
d_model=512, d_inner=1024, dropout=0.1, n_position=256):
super().__init__()
self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position)
self.dropout = nn.Dropout(p=dropout)
self.layer_stack = nn.ModuleList([
EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
for _ in range(n_layers)])
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
def forward(self, cnn_feature, src_mask, return_attns=False):
enc_slf_attn_list = []
# -- Forward
enc_output = self.dropout(self.position_enc(cnn_feature)) # position embeding
for enc_layer in self.layer_stack:
enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=src_mask)
enc_slf_attn_list += [enc_slf_attn] if return_attns else []
enc_output = self.layer_norm(enc_output)
if return_attns:
return enc_output, enc_slf_attn_list
return enc_output,
class PVAM(nn.Module):
''' Parallel Visual attention module 平行解码'''
'''
n_dim:512,阅读顺序序列编码的空间维度
N_max_character: 25,单张图片最多有多少个字符
n_position: cnn出来之后特征的序列长度
'''
def __init__(self, n_dim=512, N_max_character=25, n_position=256):
super(PVAM, self).__init__()
self.character_len = N_max_character
self.f0_embedding = nn.Embedding(N_max_character, n_dim)
self.w0 = nn.Linear(N_max_character, n_position)
self.wv = nn.Linear(n_dim, n_dim)
# first linear(512,25)
self.we = nn.Linear(n_dim, N_max_character)
self.active = nn.Tanh()
self.softmax = nn.Softmax(dim=2)
def forward(self, enc_output):
reading_order = torch.arange(self.character_len, dtype=torch.long, device=enc_output.device)
reading_order = reading_order.unsqueeze(0).expand(enc_output.size(0), -1) # (S,) -> (B, S)
reading_order_embed = self.f0_embedding(reading_order) # b,25,512
t = self.w0(reading_order_embed.permute(0,2,1)) # b,512,256
t = self.active(t.permute(0,2,1) + self.wv(enc_output)) # b,256,512
# first linear(512,25)
attn = self.we(t) # b,256,25
attn = self.softmax(attn.permute(0,2,1)) # b,25,256
g_output = torch.bmm(attn, enc_output) # b,25,512
return g_output
class GSRM(nn.Module):
# global semantic reasoning module
'''
n_dim:embed编码的特征空间维度
n_class:embedding需要用到
PAD:计算mask用到
'''
def __init__(self, n_dim=512, n_class=37, PAD=37-1, n_layers=4, n_position=25):
super(GSRM, self).__init__()
self.PAD = PAD
self.argmax_embed = nn.Embedding(n_class, n_dim)
self.transformer_units = Transforme_Encoder(n_layers=n_layers, n_position=n_position) # for global context information
# self.transformer_units = Torch_transformer_encoder(n_layers=n_layers, n_position=n_position)
def forward(self, e_out):
'''
e_out: b,25,37 | the output from PVAM3
'''
e_argmax = e_out.argmax(dim=-1) # b, 25
e = self.argmax_embed(e_argmax) # b,25,512
e_mask = get_pad_mask(e_argmax, self.PAD) # b,25,1
s = self.transformer_units(e, None) # b,25,512
return s
class SRN_Decoder(nn.Module):
# the wrapper of decoder layers
'''
n_dim: 特征空间维度
n_class:字符种类
N_max_character: 单张图最多只25个字符
n_position:cnn输出的特征序列长度
整个有三个部分的输出
'''
def __init__(self, n_dim=512, n_class=37, N_max_character=25, n_position=256, GSRM_layer=4 ):
super(SRN_Decoder, self).__init__()
self.pvam = PVAM(N_max_character=N_max_character, n_position=n_position)
self.w_e = nn.Linear(n_dim, n_class) # output layer
self.GSRM = GSRM(n_class=n_class, PAD=n_class-1, n_dim=n_dim, n_position=N_max_character, n_layers=GSRM_layer)
self.w_s = nn.Linear(n_dim, n_class) # output layer
self.w_f = nn.Linear(n_dim, n_class) # output layer
def forward(self, cnn_feature ):
'''cnn_feature: b,256,512 | the output from cnn'''
g_output = self.pvam(cnn_feature) # b,25,512
e_out = self.w_e(g_output) # b,25,37 ----> cross entropy loss | 第一个输出
s = self.GSRM(e_out)[0] # b,25,512
s_out = self.w_s(s) # b,25,37f
# TODO:change the add to gated unit
f = g_output + s # b,25,512
f_out = self.w_f(f)
return e_out, s_out, f_out
def cal_performance(preds, gold, mask=None, smoothing='1'):
''' Apply label smoothing if needed '''
loss = 0.
n_correct = 0
weights = [1.0, 0.15, 2.0]
for ori_pred, weight in zip(preds, weights):
pred = ori_pred.view(-1, ori_pred.shape[-1])
# debug show
t_gold = gold.view(ori_pred.shape[0], -1)
t_pred_index = ori_pred.max(2)[1]
mask = mask.view(-1)
non_pad_mask = mask.ne(0) if mask is not None else None
tloss = cal_loss(pred, gold, non_pad_mask, smoothing)
if torch.isnan(tloss):
print('have nan loss')
continue
else:
loss += tloss * weight
pred = pred.max(1)[1]
gold = gold.contiguous().view(-1)
n_correct = pred.eq(gold)
n_correct = n_correct.masked_select(non_pad_mask).sum().item() if mask is not None else None
return loss, n_correct
def cal_loss(pred, gold, mask, smoothing):
''' Calculate cross entropy loss, apply label smoothing if needed. '''
gold = gold.contiguous().view(-1)
if smoothing=='0':
eps = 0.1
n_class = pred.size(1)
one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1)
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = F.log_softmax(pred, dim=1)
non_pad_mask = gold.ne(0)
loss = -(one_hot * log_prb).sum(dim=1)
loss = loss.masked_select(non_pad_mask).sum() # average later
elif smoothing == '1':
if mask is not None:
loss = F.cross_entropy(pred, gold, reduction='none')
loss = loss.masked_select(mask)
loss = loss.sum() / mask.sum()
else:
loss = F.cross_entropy(pred, gold)
else:
# loss = F.cross_entropy(pred, gold, ignore_index=PAD)
loss = F.cross_entropy(pred, gold)
return loss
def cal_performance2(preds, gold, PAD, smoothing='1'):
''' Apply label smoothing if needed '''
loss = 0.
n_correct = 0
weights = [1.0, 0.15, 2.0]
for ori_pred, weight in zip(preds, weights):
pred = ori_pred.view(-1, ori_pred.shape[-1])
# debug show
t_gold = gold.view(ori_pred.shape[0], -1)
t_pred_index = ori_pred.max(2)[1]
tloss = cal_loss2(pred, gold, PAD, smoothing=smoothing)
if torch.isnan(tloss):
print('have nan loss')
continue
else:
loss += tloss * weight
pred = pred.max(1)[1]
gold = gold.contiguous().view(-1)
n_correct = pred.eq(gold)
non_pad_mask = gold.ne(PAD)
n_correct = n_correct.masked_select(non_pad_mask).sum().item()
return loss, n_correct
def cal_loss2(pred, gold, PAD, smoothing='1'):
''' Calculate cross entropy loss, apply label smoothing if needed. '''
gold = gold.contiguous().view(-1)
if smoothing=='0':
eps = 0.1
n_class = pred.size(1)
one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1)
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = F.log_softmax(pred, dim=1)
non_pad_mask = gold.ne(0)
loss = -(one_hot * log_prb).sum(dim=1)
loss = loss.masked_select(non_pad_mask).sum() # average later
elif smoothing == '1':
loss = F.cross_entropy(pred, gold, ignore_index=PAD)
else:
# loss = F.cross_entropy(pred, gold, ignore_index=PAD)
loss = F.cross_entropy(pred, gold)
return loss
if __name__=='__main__':
cnn_feature = torch.rand((2,256,512))
model1 = Transforme_Encoder()
image = model1(cnn_feature,src_mask=None)[0]
model = SRN_Decoder(N_max_character=30)
outs = model(image)
for out in outs:
print(out.shape)
# image = torch.rand((4,3,32,60))
# tgt_seq = torch.tensor([[ 2, 24, 2176, 882, 2480, 612, 1525, 480, 875, 147, 1700, 715,
# 1465, 3],
# [ 2, 369, 1781, 882, 703, 879, 2855, 2415, 502, 1154, 833, 1465,
# 3, 0],
# [ 2, 2943, 334, 328, 480, 330, 1644, 1449, 163, 147, 1823, 1184,
# 1465, 3],
# [ 2, 24, 396, 480, 703, 1646, 897, 1711, 1508, 703, 2321, 147,
# 642, 1465]], device='cuda:0')
# tgt_pos = torch.tensor([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
# [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0],
# [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
# [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]],
# device='cuda:0')
# src_seq = torch.tensor([[ 2, 598, 2088, 822, 2802, 1156, 157, 1099, 1000, 598, 1707, 1345,
# 3, 0, 0, 0],
# [ 2, 598, 2348, 822, 598, 1222, 471, 948, 986, 423, 1345, 3,
# 0, 0, 0, 0],
# [ 2, 2437, 2470, 901, 2473, 598, 1735, 84, 1, 2277, 1979, 499,
# 962, 1345, 3, 0],
# [ 2, 598, 186, 1904, 598, 868, 1339, 1604, 84, 598, 608, 1728,
# 1345, 3, 0, 0]], device='cuda:0')
# device = torch.device('cuda')
# image = image.cuda()
# transformer = Transformer()
# transformer = transformer.to(device)
# transformer.train()
# out = transformer(image, tgt_seq, tgt_pos, src_seq)
# gold = tgt_seq[:, 1:] # 从第二列开始
# # backward
# loss, n_correct = cal_performance(out, gold, smoothing=True)
# print(loss, n_correct)
# a = 1