|
|
|
|
|
""" |
|
@file : pinyin.py |
|
@author: zijun |
|
@contact : [email protected] |
|
@date : 2020/8/16 14:45 |
|
@version: 1.0 |
|
@desc : pinyin embedding |
|
""" |
|
import json |
|
import os |
|
|
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
|
|
class PinyinEmbedding(nn.Module): |
|
def __init__(self, embedding_size: int, pinyin_out_dim: int, config_path): |
|
""" |
|
Pinyin Embedding Module |
|
Args: |
|
embedding_size: the size of each embedding vector |
|
pinyin_out_dim: kernel number of conv |
|
""" |
|
super(PinyinEmbedding, self).__init__() |
|
with open(os.path.join(config_path, 'pinyin_map.json')) as fin: |
|
pinyin_dict = json.load(fin) |
|
self.pinyin_out_dim = pinyin_out_dim |
|
self.embedding = nn.Embedding(len(pinyin_dict['idx2char']), embedding_size) |
|
self.conv = nn.Conv1d(in_channels=embedding_size, out_channels=self.pinyin_out_dim, kernel_size=2, |
|
stride=1, padding=0) |
|
|
|
def forward(self, pinyin_ids): |
|
""" |
|
Args: |
|
pinyin_ids: (bs*sentence_length*pinyin_locs) |
|
|
|
Returns: |
|
pinyin_embed: (bs,sentence_length,pinyin_out_dim) |
|
""" |
|
|
|
embed = self.embedding(pinyin_ids) |
|
bs, sentence_length, pinyin_locs, embed_size = embed.shape |
|
view_embed = embed.view(-1, pinyin_locs, embed_size) |
|
input_embed = view_embed.permute(0, 2, 1) |
|
|
|
pinyin_conv = self.conv(input_embed) |
|
pinyin_embed = F.max_pool1d(pinyin_conv, pinyin_conv.shape[-1]) |
|
return pinyin_embed.view(bs, sentence_length, self.pinyin_out_dim) |
|
|