import os from copy import deepcopy import numpy as np import opencc import pypinyin import torch from PIL import ImageFont from torch import nn from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import MaskedLMOutput from transformers import BertPreTrainedModel, BertModel def _is_chinese_char(cp): if ((cp >= 0x4E00 and cp <= 0x9FFF) or # (cp >= 0x3400 and cp <= 0x4DBF) or # (cp >= 0x20000 and cp <= 0x2A6DF) or # (cp >= 0x2A700 and cp <= 0x2B73F) or # (cp >= 0x2B740 and cp <= 0x2B81F) or # (cp >= 0x2B820 and cp <= 0x2CEAF) or (cp >= 0xF900 and cp <= 0xFAFF) or # (cp >= 0x2F800 and cp <= 0x2FA1F)): # return True return False class Pinyin2(object): def __init__(self): super(Pinyin2, self).__init__() pho_vocab = ['P'] pho_vocab += [chr(x) for x in range(ord('1'), ord('5') + 1)] pho_vocab += [chr(x) for x in range(ord('a'), ord('z') + 1)] pho_vocab += ['U'] assert len(pho_vocab) == 33 self.pho_vocab_size = len(pho_vocab) self.pho_vocab = {c: idx for idx, c in enumerate(pho_vocab)} def get_pho_size(self): return self.pho_vocab_size @staticmethod def get_pinyin(c): if len(c) > 1: return 'U' s = pypinyin.pinyin( c, style=pypinyin.Style.TONE3, neutral_tone_with_five=True, errors=lambda x: ['U' for _ in x], )[0][0] if s == 'U': return s assert isinstance(s, str) assert s[-1] in '12345' s = s[-1] + s[:-1] return s def convert(self, chars): pinyins = list(map(self.get_pinyin, chars)) pinyin_ids = [list(map(self.pho_vocab.get, pinyin)) for pinyin in pinyins] pinyin_lens = [len(pinyin) for pinyin in pinyins] pinyin_ids = torch.nn.utils.rnn.pad_sequence( [torch.tensor(x) for x in pinyin_ids], batch_first=True, padding_value=0, ) return pinyin_ids, pinyin_lens pho2_convertor = Pinyin2() class CharResNet(torch.nn.Module): def __init__(self, in_channels=1): super().__init__() # input_image: bxcx32x32, output_image: bx768x1x1 self.res_block1 = BasicBlock(in_channels, 64, stride=2) # channels: 64, size: 16x16 self.res_block2 = BasicBlock(64, 128, stride=2) # channels: 128, size: 8x8 self.res_block3 = BasicBlock(128, 256, stride=2) # channels: 256, size: 4x4 self.res_block4 = BasicBlock(256, 512, stride=2) # channels: 512, size: 2x2 self.res_block5 = BasicBlock(512, 768, stride=2) # channels: 768, size: 1x1 def forward(self, x): # input_shape: bxcx32x32, output_image: bx768 # x = x.unsqueeze(1) h = self.res_block1(x) h = self.res_block2(h) h = self.res_block3(h) h = self.res_block4(h) h = self.res_block5(h) h = h.squeeze(-1).squeeze(-1) return h class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_channels, out_channels, stride=1): super().__init__() self.residual_function = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels * BasicBlock.expansion) ) self.shortcut = nn.Sequential() if stride != 1 or in_channels != BasicBlock.expansion * out_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_channels * BasicBlock.expansion) ) def forward(self, x): return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) class CharResNet1(torch.nn.Module): def __init__(self, in_channels=1): super().__init__() self.res_block1 = BasicBlock(in_channels, 64, stride=2) # channels: 64, size: 16x16 self.res_block2 = BasicBlock(64, 128, stride=2) # channels: 128, size: 8x8 self.res_block3 = BasicBlock(128, 192, stride=2) # channels: 256, size: 4x4 self.res_block4 = BasicBlock(192, 192, stride=2) def forward(self, x): # input_shape: bxcx32x32, output_shape: bx128x8x8 h = x h = self.res_block1(h) h = self.res_block2(h) h = self.res_block3(h) h = self.res_block4(h) h = h.view(h.shape[0], -1) return h class ReaLiseForCSC(BertPreTrainedModel): def __init__(self, config): super(ReaLiseForCSC, self).__init__(config) self.config = config self.vocab_size = config.vocab_size self.bert = BertModel(config) self.pho_embeddings = nn.Embedding(pho2_convertor.get_pho_size(), config.hidden_size, padding_idx=0) self.pho_gru = nn.GRU( input_size=config.hidden_size, hidden_size=config.hidden_size, num_layers=1, batch_first=True, dropout=0, bidirectional=False, ) pho_config = deepcopy(config) pho_config.num_hidden_layers = 4 self.pho_model = BertModel(pho_config) self.char_images_multifonts = torch.nn.Parameter(torch.rand(21128, 3, 32, 32)) self.char_images_multifonts.requires_grad = False self.resnet = CharResNet(in_channels=3) self.resnet_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.gate_net = nn.Linear(4 * config.hidden_size, 3) out_config = deepcopy(config) out_config.num_hidden_layers = 3 self.output_block = BertModel(out_config) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, config.vocab_size) self.init_weights() self.loss_fnt = CrossEntropyLoss(ignore_index=0) self.tokenizer = None def tie_cls_weight(self): self.classifier.weight = self.bert.embeddings.word_embeddings.weight def build_glyce_embed(self, vocab_dir, font_path, font_size=32): vocab_path = os.path.join(vocab_dir, 'vocab.txt') with open(vocab_path, 'r', encoding='utf-8') as f: vocab = [s.strip() for s in f] font = ImageFont.truetype(font_path, size=font_size) char_images = [] for char in vocab: if len(char) != 1 or (not _is_chinese_char(ord(char))): char_images.append(np.zeros((font_size, font_size)).astype(np.float32)) continue image = font.getmask(char) image = np.asarray(image).astype(np.float32).reshape(image.size[::-1]) # Must be [::-1] # Crop image = image[:font_size, :font_size] # Pad if image.size != (font_size, font_size): back_image = np.zeros((font_size, font_size)).astype(np.float32) offset0 = (font_size - image.shape[0]) // 2 offset1 = (font_size - image.shape[1]) // 2 back_image[offset0:offset0 + image.shape[0], offset1:offset1 + image.shape[1]] = image image = back_image char_images.append(image) char_images = np.array(char_images) char_images = (char_images - np.mean(char_images)) / np.std(char_images) char_images = torch.from_numpy(char_images).reshape(char_images.shape[0], -1) assert char_images.shape == (21128, 1024) self.char_images.weight.data.copy_(char_images) # Add by hengdaxu def build_glyce_embed_multifonts(self, vocab_dir, num_fonts, use_traditional_font, font_size=32): font_paths = [ ('simhei.ttf', False), ('xiaozhuan.ttf', False), ('simhei.ttf', True), ] font_paths = font_paths[:num_fonts] if use_traditional_font: font_paths = font_paths[:-1] font_paths.append(('simhei.ttf', True)) self.converter = opencc.OpenCC('s2t.json') images_list = [] for font_path, use_traditional in font_paths: images = self.build_glyce_embed_onefont( vocab_dir=vocab_dir, font_path=font_path, font_size=font_size, use_traditional=use_traditional, ) images_list.append(images) char_images = torch.stack(images_list, dim=1).contiguous() self.char_images_multifonts.data.copy_(char_images) # Add by hengdaxu def build_glyce_embed_onefont(self, vocab_dir, font_path, font_size, use_traditional): vocab_path = os.path.join(vocab_dir, 'vocab.txt') with open(vocab_path, encoding='utf-8') as f: vocab = [s.strip() for s in f.readlines()] if use_traditional: vocab = [self.converter.convert(c) if len(c) == 1 else c for c in vocab] font = ImageFont.truetype(font_path, size=font_size) char_images = [] for char in vocab: if len(char) > 1: char_images.append(np.zeros((font_size, font_size)).astype(np.float32)) continue image = font.getmask(char) image = np.asarray(image).astype(np.float32).reshape(image.size[::-1]) # Must be [::-1] # Crop image = image[:font_size, :font_size] # Pad if image.size != (font_size, font_size): back_image = np.zeros((font_size, font_size)).astype(np.float32) offset0 = (font_size - image.shape[0]) // 2 offset1 = (font_size - image.shape[1]) // 2 back_image[offset0:offset0 + image.shape[0], offset1:offset1 + image.shape[1]] = image image = back_image char_images.append(image) char_images = np.array(char_images) char_images = (char_images - np.mean(char_images)) / np.std(char_images) char_images = torch.from_numpy(char_images).contiguous() return char_images @staticmethod def build_batch(batch, tokenizer): src_idx = batch['src_idx'].flatten().tolist() chars = tokenizer.convert_ids_to_tokens(src_idx) pho_idx, pho_lens = pho2_convertor.convert(chars) batch['pho_idx'] = pho_idx batch['pho_lens'] = pho_lens return batch def forward(self, input_ids=None, pho_idx=None, pho_lens=None, attention_mask=None, labels=None, **kwargs): input_shape = input_ids.size() bert_hiddens = self.bert(input_ids, attention_mask=attention_mask)[0] pho_embeddings = self.pho_embeddings(pho_idx) if torch.is_tensor(pho_lens): pho_lens = pho_lens.tolist() pho_embeddings = torch.nn.utils.rnn.pack_padded_sequence( input=pho_embeddings, lengths=pho_lens, batch_first=True, enforce_sorted=False, ) _, pho_hiddens = self.pho_gru(pho_embeddings) pho_hiddens = pho_hiddens.squeeze(0).reshape(input_shape[0], input_shape[1], -1).contiguous() pho_hiddens = self.pho_model(inputs_embeds=pho_hiddens, attention_mask=attention_mask)[0] src_idxs = input_ids.view(-1) if self.config.num_fonts == 1: images = self.char_images(src_idxs).reshape(src_idxs.shape[0], 1, 32, 32).contiguous() else: images = self.char_images_multifonts.index_select(dim=0, index=src_idxs) res_hiddens = self.resnet(images) res_hiddens = res_hiddens.reshape(input_shape[0], input_shape[1], -1).contiguous() res_hiddens = self.resnet_layernorm(res_hiddens) bert_hiddens_mean = (bert_hiddens * attention_mask.to(torch.float).unsqueeze(2)).sum(dim=1) / attention_mask.to( torch.float).sum(dim=1, keepdim=True) bert_hiddens_mean = bert_hiddens_mean.unsqueeze(1).expand(-1, bert_hiddens.size(1), -1) concated_outputs = torch.cat((bert_hiddens, pho_hiddens, res_hiddens, bert_hiddens_mean), dim=-1) gated_values = self.gate_net(concated_outputs) # B * S * 3 g0 = torch.sigmoid(gated_values[:, :, 0].unsqueeze(-1)) g1 = torch.sigmoid(gated_values[:, :, 1].unsqueeze(-1)) g2 = torch.sigmoid(gated_values[:, :, 2].unsqueeze(-1)) hiddens = g0 * bert_hiddens + g1 * pho_hiddens + g2 * res_hiddens outputs = self.output_block(inputs_embeds=hiddens, position_ids=torch.zeros(input_ids.size(), dtype=torch.long, device=input_ids.device), attention_mask=attention_mask) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) outputs = MaskedLMOutput( logits=logits, hidden_states=outputs.last_hidden_state, ) if labels is not None: # Only keep active parts of the loss labels[labels == 101] = 0 labels[labels == 102] = 0 loss = self.loss_fnt(logits.view(-1, logits.size(-1)), labels.view(-1)) outputs.loss = loss return outputs def set_tokenizer(self, tokenizer): self.tokenizer = tokenizer def predict(self, sentences): if self.tokenizer is None: raise RuntimeError("Please init tokenizer by `set_tokenizer(tokenizer)` before predict.") str_flag = False if type(sentences) == str: sentences = [sentences] str_flag = True inputs = self.tokenizer(sentences, padding=True, return_tensors="pt") outputs = self.forward(**inputs).logits ids_list = outputs.argmax(-1) preds = [] for i, ids in enumerate(ids_list): ids = ids[inputs['attention_mask'][i].bool()] pred_tokens = self.tokenizer.convert_ids_to_tokens(ids) pred_tokens = [t if not t.startswith('##') else t[2:] for t in pred_tokens] pred_tokens = [t if t != self.tokenizer.unk_token else '×' for t in pred_tokens] offsets = inputs[i].offsets src_tokens = list(sentences[i]) for (start, end), pred_token in zip(offsets, pred_tokens): if end - start <= 0: continue if (end - start) != len(pred_token): continue if pred_token == '×': continue if (end - start) == 1 and not _is_chinese_char(ord(src_tokens[start])): continue src_tokens[start:end] = pred_token pred = ''.join(src_tokens) preds.append(pred) if str_flag: return preds[0] return preds