ChineseBERT-base / models /pinyin_embedding.py
iioSnail's picture
Upload 6 files
a2e10c0
raw
history blame
1.98 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@file : pinyin.py
@author: zijun
@contact : [email protected]
@date : 2020/8/16 14:45
@version: 1.0
@desc : pinyin embedding
"""
import json
import os
from torch import nn
from torch.nn import functional as F
class PinyinEmbedding(nn.Module):
def __init__(self, embedding_size: int, pinyin_out_dim: int, config_path):
"""
Pinyin Embedding Module
Args:
embedding_size: the size of each embedding vector
pinyin_out_dim: kernel number of conv
"""
super(PinyinEmbedding, self).__init__()
with open(os.path.join(config_path, 'pinyin_map.json')) as fin:
pinyin_dict = json.load(fin)
self.pinyin_out_dim = pinyin_out_dim
self.embedding = nn.Embedding(len(pinyin_dict['idx2char']), embedding_size)
self.conv = nn.Conv1d(in_channels=embedding_size, out_channels=self.pinyin_out_dim, kernel_size=2,
stride=1, padding=0)
def forward(self, pinyin_ids):
"""
Args:
pinyin_ids: (bs*sentence_length*pinyin_locs)
Returns:
pinyin_embed: (bs,sentence_length,pinyin_out_dim)
"""
# input pinyin ids for 1-D conv
embed = self.embedding(pinyin_ids) # [bs,sentence_length,pinyin_locs,embed_size]
bs, sentence_length, pinyin_locs, embed_size = embed.shape
view_embed = embed.view(-1, pinyin_locs, embed_size) # [(bs*sentence_length),pinyin_locs,embed_size]
input_embed = view_embed.permute(0, 2, 1) # [(bs*sentence_length), embed_size, pinyin_locs]
# conv + max_pooling
pinyin_conv = self.conv(input_embed) # [(bs*sentence_length),pinyin_out_dim,H]
pinyin_embed = F.max_pool1d(pinyin_conv, pinyin_conv.shape[-1]) # [(bs*sentence_length),pinyin_out_dim,1]
return pinyin_embed.view(bs, sentence_length, self.pinyin_out_dim) # [bs,sentence_length,pinyin_out_dim]