Spaces:
Build error
Build error
# Copyright 2018 Dong-Hyun Lee, Kakao Brain. | |
# (Strongly inspired by original Google BERT code and Hugging Face's code) | |
""" Transformer Model Classes & Config Class """ | |
import math | |
import json | |
from typing import NamedTuple | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
def split_last(x, shape): | |
"split the last dimension to given shape" | |
shape = list(shape) | |
assert shape.count(-1) <= 1 | |
if -1 in shape: | |
shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape)) | |
return x.view(*x.size()[:-1], *shape) | |
def merge_last(x, n_dims): | |
"merge the last n_dims to a dimension" | |
s = x.size() | |
assert n_dims > 1 and n_dims < len(s) | |
return x.view(*s[:-n_dims], -1) | |
def gelu(x): | |
"Implementation of the gelu activation function by Hugging Face" | |
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) | |
class LayerNorm(nn.Module): | |
"A layernorm module in the TF style (epsilon inside the square root)." | |
def __init__(self, cfg, variance_epsilon=1e-12): | |
super().__init__() | |
self.gamma = nn.Parameter(torch.ones(cfg.dim)) | |
self.beta = nn.Parameter(torch.zeros(cfg.dim)) | |
self.variance_epsilon = variance_epsilon | |
def forward(self, x): | |
u = x.mean(-1, keepdim=True) | |
s = (x - u).pow(2).mean(-1, keepdim=True) | |
x = (x - u) / torch.sqrt(s + self.variance_epsilon) | |
return self.gamma * x + self.beta | |
class Embeddings(nn.Module): | |
"The embedding module from word, position and token_type embeddings." | |
def __init__(self, cfg): | |
super().__init__() | |
self.pos_embed = nn.Embedding(cfg.p_dim, cfg.dim) # position embedding | |
self.norm = LayerNorm(cfg) | |
self.drop = nn.Dropout(cfg.p_drop_hidden) | |
def forward(self, x): | |
seq_len = x.size(1) | |
pos = torch.arange(seq_len, dtype=torch.long, device=x.device) | |
pos = pos.unsqueeze(0).expand(x.size(0), -1) # (S,) -> (B, S) | |
e = x + self.pos_embed(pos) | |
return self.drop(self.norm(e)) | |
class MultiHeadedSelfAttention(nn.Module): | |
""" Multi-Headed Dot Product Attention """ | |
def __init__(self, cfg): | |
super().__init__() | |
self.proj_q = nn.Linear(cfg.dim, cfg.dim) | |
self.proj_k = nn.Linear(cfg.dim, cfg.dim) | |
self.proj_v = nn.Linear(cfg.dim, cfg.dim) | |
self.drop = nn.Dropout(cfg.p_drop_attn) | |
self.scores = None # for visualization | |
self.n_heads = cfg.n_heads | |
def forward(self, x, mask): | |
""" | |
x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim)) | |
mask : (B(batch_size) x S(seq_len)) | |
* split D(dim) into (H(n_heads), W(width of head)) ; D = H * W | |
""" | |
# (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W) | |
q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x) | |
q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) | |
for x in [q, k, v]) | |
# (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S) | |
scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1)) | |
if mask is not None: | |
mask = mask[:, None, None, :].float() | |
scores -= 10000.0 * (1.0 - mask) | |
scores = self.drop(F.softmax(scores, dim=-1)) | |
# (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W) | |
h = (scores @ v).transpose(1, 2).contiguous() | |
# -merge-> (B, S, D) | |
h = merge_last(h, 2) | |
self.scores = scores | |
return h | |
class PositionWiseFeedForward(nn.Module): | |
""" FeedForward Neural Networks for each position """ | |
def __init__(self, cfg): | |
super().__init__() | |
self.fc1 = nn.Linear(cfg.dim, cfg.dim_ff) | |
self.fc2 = nn.Linear(cfg.dim_ff, cfg.dim) | |
#self.activ = lambda x: activ_fn(cfg.activ_fn, x) | |
def forward(self, x): | |
# (B, S, D) -> (B, S, D_ff) -> (B, S, D) | |
return self.fc2(gelu(self.fc1(x))) | |
class Block(nn.Module): | |
""" Transformer Block """ | |
def __init__(self, cfg): | |
super().__init__() | |
self.attn = MultiHeadedSelfAttention(cfg) | |
self.proj = nn.Linear(cfg.dim, cfg.dim) | |
self.norm1 = LayerNorm(cfg) | |
self.pwff = PositionWiseFeedForward(cfg) | |
self.norm2 = LayerNorm(cfg) | |
self.drop = nn.Dropout(cfg.p_drop_hidden) | |
def forward(self, x, mask): | |
h = self.attn(x, mask) | |
h = self.norm1(x + self.drop(self.proj(h))) | |
h = self.norm2(h + self.drop(self.pwff(h))) | |
return h | |
class Transformer(nn.Module): | |
""" Transformer with Self-Attentive Blocks""" | |
def __init__(self, cfg, n_layers): | |
super().__init__() | |
self.embed = Embeddings(cfg) | |
self.blocks = nn.ModuleList([Block(cfg) for _ in range(n_layers)]) | |
def forward(self, x, mask): | |
h = self.embed(x) | |
for block in self.blocks: | |
h = block(h, mask) | |
return h | |
class Parallel_Attention(nn.Module): | |
''' the Parallel Attention Module for 2D attention | |
reference the origin paper: https://arxiv.org/abs/1906.05708 | |
''' | |
def __init__(self, cfg): | |
super().__init__() | |
self.atten_w1 = nn.Linear(cfg.dim_c, cfg.dim_c) | |
self.atten_w2 = nn.Linear(cfg.dim_c, cfg.max_vocab_size) | |
self.activ_fn = nn.Tanh() | |
self.soft = nn.Softmax(dim=1) | |
self.drop = nn.Dropout(0.1) | |
def forward(self, origin_I, bert_out, mask=None): | |
bert_out = self.activ_fn(self.drop(self.atten_w1(bert_out))) | |
atten_w = self.soft(self.atten_w2(bert_out)) # b*200*94 | |
x = torch.bmm(origin_I.transpose(1,2), atten_w) # b*512*94 | |
return x | |
class MultiHeadAttention(nn.Module): | |
''' Multi-Head Attention module ''' | |
def __init__(self, n_head=8, d_k=64, d_model=128, max_vocab_size=94, dropout=0.1): | |
''' d_k: the attention dim | |
d_model: the encoder output feature | |
max_vocab_size: the output maxium length of sequence | |
''' | |
super(MultiHeadAttention, self).__init__() | |
self.n_head, self.d_k = n_head, d_k | |
self.temperature = np.power(d_k, 0.5) | |
self.max_vocab_size = max_vocab_size | |
self.w_encoder = nn.Linear(d_model, n_head * d_k) | |
self.w_atten = nn.Linear(d_model, n_head * max_vocab_size) | |
self.w_out = nn.Linear(n_head * d_k, d_model) | |
self.activ_fn = nn.Tanh() | |
self.softmax = nn.Softmax(dim=1) # at the d_in dimension | |
self.dropout = nn.Dropout(dropout) | |
nn.init.normal_(self.w_encoder.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) | |
nn.init.normal_(self.w_atten.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) | |
nn.init.xavier_normal_(self.w_out.weight) | |
def forward(self, encoder_feature, bert_out, mask=None): | |
d_k, n_head, max_vocab_size = self.d_k, self.n_head, self.max_vocab_size | |
sz_b, d_in, _ = encoder_feature.size() | |
# 原始特征 | |
encoder_feature = encoder_feature.view(sz_b, d_in, n_head, d_k) | |
encoder_feature = encoder_feature.permute(2, 0, 1, 3).contiguous().view(-1, d_in, d_k) # 32*200*64 | |
# 求解权值 | |
alpha = self.activ_fn(self.dropout(self.w_encoder(bert_out))) | |
alpha = self.w_atten(alpha).view(sz_b, d_in, n_head, max_vocab_size) # 4*200*8*94 | |
alpha = alpha.permute(2, 0, 1, 3).contiguous().view(-1, d_in, max_vocab_size) # 32*200*94 | |
alpha = alpha / self.temperature | |
alpha = self.dropout(self.softmax(alpha)) # 32*200*94 | |
# 输出部分 | |
output = torch.bmm(encoder_feature.transpose(1,2), alpha) # 32*64*94 | |
output = output.view(n_head, sz_b, d_k, max_vocab_size) | |
output = output.permute(1, 3, 0, 2).contiguous().view(sz_b, max_vocab_size, -1) # 4*94*512 | |
output = self.dropout(self.w_out(output)) | |
output = output.transpose(1,2) | |
return output | |
class Two_Stage_Decoder(nn.Module): | |
def __init__(self, cfg): | |
super().__init__() | |
self.out_w = nn.Linear(cfg.dim_c, cfg.len_alphabet) | |
self.relation_attention = Transformer(cfg, cfg.decoder_atten_layers) | |
self.out_w1 = nn.Linear(cfg.dim_c, cfg.len_alphabet) | |
def forward(self, x): | |
x1 = self.out_w(x) | |
x2 = self.relation_attention(x, mask=None) | |
x2 = self.out_w1(x2) # 两个分支的输出部分采用不同的网络 | |
return x1, x2 | |
class Bert_Ocr(nn.Module): | |
def __init__(self, cfg): | |
super().__init__() | |
self.cfg = cfg | |
self.transformer = Transformer(cfg, cfg.attention_layers) | |
self.attention = Parallel_Attention(cfg) | |
# self.attention = MultiHeadAttention(d_model=cfg.dim, max_vocab_size=cfg.max_vocab_size) | |
self.decoder = Two_Stage_Decoder(cfg) | |
def forward(self, encoder_feature, mask=None): | |
bert_out = self.transformer(encoder_feature, mask) # 做一个self_attention//4*200*512 | |
glimpses = self.attention(encoder_feature, bert_out, mask) # 原始序列和目标序列的转化//4*512*94 | |
res = self.decoder(glimpses.transpose(1,2)) | |
return res | |
class Config(object): | |
'''参数设置''' | |
""" Relation Attention Module """ | |
p_drop_attn = 0.1 | |
p_drop_hidden = 0.1 | |
dim = 512 # the encode output feature | |
attention_layers = 2 # the layers of transformer | |
n_heads = 8 | |
dim_ff = 1024 * 2 # 位置前向传播的隐含层维度 | |
''' Parallel Attention Module ''' | |
dim_c = dim | |
max_vocab_size = 26 # 一张图片含有字符的最大长度 | |
""" Two-stage Decoder """ | |
len_alphabet = 39 # 字符类别数量 | |
decoder_atten_layers = 2 | |
def numel(model): | |
return sum(p.numel() for p in model.parameters()) | |
if __name__ == '__main__': | |
cfg = Config() | |
mask = None | |
x = torch.randn(4, 200, cfg.dim) | |
net = Bert_Ocr(cfg) | |
res1, res2 = net(x, mask) | |
print(res1.shape, res2.shape) | |
print('参数总量为:', numel(net)) | |