ssa-perin / model /module /char_embedding.py
larkkin's picture
Add code
991f07c
raw
history blame
2.01 kB
#!/usr/bin/env python3
# coding=utf-8
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
class CharEmbedding(nn.Module):
def __init__(self, vocab_size: int, embedding_size: int, output_size: int):
super(CharEmbedding, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_size, sparse=False)
self.layer_norm = nn.LayerNorm(embedding_size)
self.gru = nn.GRU(embedding_size, embedding_size, num_layers=1, bidirectional=True)
self.out_linear = nn.Linear(2*embedding_size, output_size)
self.layer_norm_2 = nn.LayerNorm(output_size)
def forward(self, words, sentence_lens, word_lens):
# input shape: (B, W, C)
n_words = words.size(1)
sentence_lens = sentence_lens.cpu()
sentence_packed = pack_padded_sequence(words, sentence_lens, batch_first=True) # shape: (B*W, C)
lens_packed = pack_padded_sequence(word_lens, sentence_lens, batch_first=True) # shape: (B*W)
word_packed = pack_padded_sequence(sentence_packed.data, lens_packed.data.cpu(), batch_first=True, enforce_sorted=False) # shape: (B*W*C)
embedded = self.embedding(word_packed.data) # shape: (B*W*C, D)
embedded = self.layer_norm(embedded) # shape: (B*W*C, D)
embedded_packed = PackedSequence(embedded, word_packed[1], word_packed[2], word_packed[3])
_, embedded = self.gru(embedded_packed) # shape: (layers * 2, B*W, D)
embedded = embedded[-2:, :, :].transpose(0, 1).flatten(1, 2) # shape: (B*W, 2*D)
embedded = F.relu(embedded)
embedded = self.out_linear(embedded)
embedded = self.layer_norm_2(embedded)
embedded, _ = pad_packed_sequence(
PackedSequence(embedded, sentence_packed[1], sentence_packed[2], sentence_packed[3]), batch_first=True, total_length=n_words,
) # shape: (B, W, 2*D)
return embedded # shape: (B, W, 2*D)