# Copyright 2018 Dong-Hyun Lee, Kakao Brain. # (Strongly inspired by original Google BERT code and Hugging Face's code) """ Transformer Model Classes & Config Class """ import math import json from typing import NamedTuple import numpy as np import torch import torch.nn as nn import torch.nn.functional as F def split_last(x, shape): "split the last dimension to given shape" shape = list(shape) assert shape.count(-1) <= 1 if -1 in shape: shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape)) return x.view(*x.size()[:-1], *shape) def merge_last(x, n_dims): "merge the last n_dims to a dimension" s = x.size() assert n_dims > 1 and n_dims < len(s) return x.view(*s[:-n_dims], -1) def gelu(x): "Implementation of the gelu activation function by Hugging Face" return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) class LayerNorm(nn.Module): "A layernorm module in the TF style (epsilon inside the square root)." def __init__(self, cfg, variance_epsilon=1e-12): super().__init__() self.gamma = nn.Parameter(torch.ones(cfg.dim)) self.beta = nn.Parameter(torch.zeros(cfg.dim)) self.variance_epsilon = variance_epsilon def forward(self, x): u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) return self.gamma * x + self.beta class Embeddings(nn.Module): "The embedding module from word, position and token_type embeddings." def __init__(self, cfg): super().__init__() self.pos_embed = nn.Embedding(cfg.p_dim, cfg.dim) # position embedding self.norm = LayerNorm(cfg) self.drop = nn.Dropout(cfg.p_drop_hidden) def forward(self, x): seq_len = x.size(1) pos = torch.arange(seq_len, dtype=torch.long, device=x.device) pos = pos.unsqueeze(0).expand(x.size(0), -1) # (S,) -> (B, S) e = x + self.pos_embed(pos) return self.drop(self.norm(e)) class MultiHeadedSelfAttention(nn.Module): """ Multi-Headed Dot Product Attention """ def __init__(self, cfg): super().__init__() self.proj_q = nn.Linear(cfg.dim, cfg.dim) self.proj_k = nn.Linear(cfg.dim, cfg.dim) self.proj_v = nn.Linear(cfg.dim, cfg.dim) self.drop = nn.Dropout(cfg.p_drop_attn) self.scores = None # for visualization self.n_heads = cfg.n_heads def forward(self, x, mask): """ x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim)) mask : (B(batch_size) x S(seq_len)) * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W """ # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W) q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x) q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v]) # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S) scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1)) if mask is not None: mask = mask[:, None, None, :].float() scores -= 10000.0 * (1.0 - mask) scores = self.drop(F.softmax(scores, dim=-1)) # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W) h = (scores @ v).transpose(1, 2).contiguous() # -merge-> (B, S, D) h = merge_last(h, 2) self.scores = scores return h class PositionWiseFeedForward(nn.Module): """ FeedForward Neural Networks for each position """ def __init__(self, cfg): super().__init__() self.fc1 = nn.Linear(cfg.dim, cfg.dim_ff) self.fc2 = nn.Linear(cfg.dim_ff, cfg.dim) #self.activ = lambda x: activ_fn(cfg.activ_fn, x) def forward(self, x): # (B, S, D) -> (B, S, D_ff) -> (B, S, D) return self.fc2(gelu(self.fc1(x))) class Block(nn.Module): """ Transformer Block """ def __init__(self, cfg): super().__init__() self.attn = MultiHeadedSelfAttention(cfg) self.proj = nn.Linear(cfg.dim, cfg.dim) self.norm1 = LayerNorm(cfg) self.pwff = PositionWiseFeedForward(cfg) self.norm2 = LayerNorm(cfg) self.drop = nn.Dropout(cfg.p_drop_hidden) def forward(self, x, mask): h = self.attn(x, mask) h = self.norm1(x + self.drop(self.proj(h))) h = self.norm2(h + self.drop(self.pwff(h))) return h class Transformer(nn.Module): """ Transformer with Self-Attentive Blocks""" def __init__(self, cfg, n_layers): super().__init__() self.embed = Embeddings(cfg) self.blocks = nn.ModuleList([Block(cfg) for _ in range(n_layers)]) def forward(self, x, mask): h = self.embed(x) for block in self.blocks: h = block(h, mask) return h class Parallel_Attention(nn.Module): ''' the Parallel Attention Module for 2D attention reference the origin paper: https://arxiv.org/abs/1906.05708 ''' def __init__(self, cfg): super().__init__() self.atten_w1 = nn.Linear(cfg.dim_c, cfg.dim_c) self.atten_w2 = nn.Linear(cfg.dim_c, cfg.max_vocab_size) self.activ_fn = nn.Tanh() self.soft = nn.Softmax(dim=1) self.drop = nn.Dropout(0.1) def forward(self, origin_I, bert_out, mask=None): bert_out = self.activ_fn(self.drop(self.atten_w1(bert_out))) atten_w = self.soft(self.atten_w2(bert_out)) # b*200*94 x = torch.bmm(origin_I.transpose(1,2), atten_w) # b*512*94 return x class MultiHeadAttention(nn.Module): ''' Multi-Head Attention module ''' def __init__(self, n_head=8, d_k=64, d_model=128, max_vocab_size=94, dropout=0.1): ''' d_k: the attention dim d_model: the encoder output feature max_vocab_size: the output maxium length of sequence ''' super(MultiHeadAttention, self).__init__() self.n_head, self.d_k = n_head, d_k self.temperature = np.power(d_k, 0.5) self.max_vocab_size = max_vocab_size self.w_encoder = nn.Linear(d_model, n_head * d_k) self.w_atten = nn.Linear(d_model, n_head * max_vocab_size) self.w_out = nn.Linear(n_head * d_k, d_model) self.activ_fn = nn.Tanh() self.softmax = nn.Softmax(dim=1) # at the d_in dimension self.dropout = nn.Dropout(dropout) nn.init.normal_(self.w_encoder.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) nn.init.normal_(self.w_atten.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) nn.init.xavier_normal_(self.w_out.weight) def forward(self, encoder_feature, bert_out, mask=None): d_k, n_head, max_vocab_size = self.d_k, self.n_head, self.max_vocab_size sz_b, d_in, _ = encoder_feature.size() # 原始特征 encoder_feature = encoder_feature.view(sz_b, d_in, n_head, d_k) encoder_feature = encoder_feature.permute(2, 0, 1, 3).contiguous().view(-1, d_in, d_k) # 32*200*64 # 求解权值 alpha = self.activ_fn(self.dropout(self.w_encoder(bert_out))) alpha = self.w_atten(alpha).view(sz_b, d_in, n_head, max_vocab_size) # 4*200*8*94 alpha = alpha.permute(2, 0, 1, 3).contiguous().view(-1, d_in, max_vocab_size) # 32*200*94 alpha = alpha / self.temperature alpha = self.dropout(self.softmax(alpha)) # 32*200*94 # 输出部分 output = torch.bmm(encoder_feature.transpose(1,2), alpha) # 32*64*94 output = output.view(n_head, sz_b, d_k, max_vocab_size) output = output.permute(1, 3, 0, 2).contiguous().view(sz_b, max_vocab_size, -1) # 4*94*512 output = self.dropout(self.w_out(output)) output = output.transpose(1,2) return output class Two_Stage_Decoder(nn.Module): def __init__(self, cfg): super().__init__() self.out_w = nn.Linear(cfg.dim_c, cfg.len_alphabet) self.relation_attention = Transformer(cfg, cfg.decoder_atten_layers) self.out_w1 = nn.Linear(cfg.dim_c, cfg.len_alphabet) def forward(self, x): x1 = self.out_w(x) x2 = self.relation_attention(x, mask=None) x2 = self.out_w1(x2) # 两个分支的输出部分采用不同的网络 return x1, x2 class Bert_Ocr(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg self.transformer = Transformer(cfg, cfg.attention_layers) self.attention = Parallel_Attention(cfg) # self.attention = MultiHeadAttention(d_model=cfg.dim, max_vocab_size=cfg.max_vocab_size) self.decoder = Two_Stage_Decoder(cfg) def forward(self, encoder_feature, mask=None): bert_out = self.transformer(encoder_feature, mask) # 做一个self_attention//4*200*512 glimpses = self.attention(encoder_feature, bert_out, mask) # 原始序列和目标序列的转化//4*512*94 res = self.decoder(glimpses.transpose(1,2)) return res class Config(object): '''参数设置''' """ Relation Attention Module """ p_drop_attn = 0.1 p_drop_hidden = 0.1 dim = 512 # the encode output feature attention_layers = 2 # the layers of transformer n_heads = 8 dim_ff = 1024 * 2 # 位置前向传播的隐含层维度 ''' Parallel Attention Module ''' dim_c = dim max_vocab_size = 26 # 一张图片含有字符的最大长度 """ Two-stage Decoder """ len_alphabet = 39 # 字符类别数量 decoder_atten_layers = 2 def numel(model): return sum(p.numel() for p in model.parameters()) if __name__ == '__main__': cfg = Config() mask = None x = torch.randn(4, 200, cfg.dim) net = Bert_Ocr(cfg) res1, res2 = net(x, mask) print(res1.shape, res2.shape) print('参数总量为:', numel(net))