Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# coding=utf-8 | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence | |
class CharEmbedding(nn.Module): | |
def __init__(self, vocab_size: int, embedding_size: int, output_size: int): | |
super(CharEmbedding, self).__init__() | |
self.embedding = nn.Embedding(vocab_size, embedding_size, sparse=False) | |
self.layer_norm = nn.LayerNorm(embedding_size) | |
self.gru = nn.GRU(embedding_size, embedding_size, num_layers=1, bidirectional=True) | |
self.out_linear = nn.Linear(2*embedding_size, output_size) | |
self.layer_norm_2 = nn.LayerNorm(output_size) | |
def forward(self, words, sentence_lens, word_lens): | |
# input shape: (B, W, C) | |
n_words = words.size(1) | |
sentence_lens = sentence_lens.cpu() | |
sentence_packed = pack_padded_sequence(words, sentence_lens, batch_first=True) # shape: (B*W, C) | |
lens_packed = pack_padded_sequence(word_lens, sentence_lens, batch_first=True) # shape: (B*W) | |
word_packed = pack_padded_sequence(sentence_packed.data, lens_packed.data.cpu(), batch_first=True, enforce_sorted=False) # shape: (B*W*C) | |
embedded = self.embedding(word_packed.data) # shape: (B*W*C, D) | |
embedded = self.layer_norm(embedded) # shape: (B*W*C, D) | |
embedded_packed = PackedSequence(embedded, word_packed[1], word_packed[2], word_packed[3]) | |
_, embedded = self.gru(embedded_packed) # shape: (layers * 2, B*W, D) | |
embedded = embedded[-2:, :, :].transpose(0, 1).flatten(1, 2) # shape: (B*W, 2*D) | |
embedded = F.relu(embedded) | |
embedded = self.out_linear(embedded) | |
embedded = self.layer_norm_2(embedded) | |
embedded, _ = pad_packed_sequence( | |
PackedSequence(embedded, sentence_packed[1], sentence_packed[2], sentence_packed[3]), batch_first=True, total_length=n_words, | |
) # shape: (B, W, 2*D) | |
return embedded # shape: (B, W, 2*D) | |