Spaces:
Runtime error
Runtime error
import gradio as gr | |
import random | |
import re | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import unicodedata | |
import nltk | |
from nltk.tokenize.treebank import TreebankWordDetokenizer | |
nltk.download('punkt') | |
class Encoder(nn.Module): | |
def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout): | |
super().__init__() | |
self.embedding = nn.Embedding(input_dim, emb_dim) | |
self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional = True) | |
self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, src): | |
embedded = self.dropout(self.embedding(src)) | |
outputs, hidden = self.rnn(embedded) | |
hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1))) | |
return outputs, hidden | |
class Attention(nn.Module): | |
def __init__(self, enc_hid_dim, dec_hid_dim): | |
super().__init__() | |
self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim) | |
self.v = nn.Linear(dec_hid_dim, 1, bias = False) | |
def forward(self, hidden, encoder_outputs): | |
batch_size = encoder_outputs.shape[1] | |
src_len = encoder_outputs.shape[0] | |
hidden = hidden.unsqueeze(1).repeat(1, src_len, 1) | |
encoder_outputs = encoder_outputs.permute(1, 0, 2) | |
energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2))) | |
attention = self.v(energy).squeeze(2) | |
return F.softmax(attention, dim=1) | |
class Decoder(nn.Module): | |
def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention): | |
super().__init__() | |
self.output_dim = output_dim | |
self.attention = attention | |
self.embedding = nn.Embedding(output_dim, emb_dim) | |
self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim) | |
self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, input, hidden, encoder_outputs): | |
input = input.unsqueeze(0) | |
embedded = self.dropout(self.embedding(input)) | |
a = self.attention(hidden, encoder_outputs) | |
a = a.unsqueeze(1) | |
encoder_outputs = encoder_outputs.permute(1, 0, 2) | |
weighted = torch.bmm(a, encoder_outputs) | |
weighted = weighted.permute(1, 0, 2) | |
rnn_input = torch.cat((embedded, weighted), dim = 2) | |
output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0)) | |
assert (output == hidden).all() | |
embedded = embedded.squeeze(0) | |
output = output.squeeze(0) | |
weighted = weighted.squeeze(0) | |
prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1)) | |
return prediction, hidden.squeeze(0) | |
class Seq2Seq(nn.Module): | |
def __init__(self, encoder, decoder, device): | |
super().__init__() | |
self.encoder = encoder | |
self.decoder = decoder | |
self.device = device | |
def forward(self, src, trg, teacher_forcing_ratio = 0.5): | |
batch_size = trg.shape[1] | |
trg_len = trg.shape[0] | |
trg_vocab_size = self.decoder.output_dim | |
outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device) | |
encoder_outputs, hidden = self.encoder(src) | |
input = trg[0,:] | |
for t in range(1, trg_len): | |
output, hidden = self.decoder(input, hidden, encoder_outputs) | |
outputs[t] = output | |
teacher_force = random.random() < teacher_forcing_ratio | |
top1 = output.argmax(1) | |
input = trg[t] if teacher_force else top1 | |
return outputs | |
def unicodeToAscii(s): | |
return ''.join( | |
c for c in unicodedata.normalize('NFD', s) | |
if unicodedata.category(c) != 'Mn' | |
) | |
def tokenize_ar(text): | |
""" | |
Tokenizes Arabic text from a string into a list of strings (tokens) and reverses it | |
""" | |
return [tok for tok in nltk.tokenize.wordpunct_tokenize(unicodeToAscii(text))] | |
src_vocab = torch.load("arabic_vocab.pth") | |
trg_vocab = torch.load("english_vocab.pth") | |
INPUT_DIM = 9790 | |
OUTPUT_DIM = 5682 | |
ENC_EMB_DIM = 256 | |
DEC_EMB_DIM = 256 | |
ENC_HID_DIM = 512 | |
DEC_HID_DIM = 512 | |
ENC_DROPOUT = 0.5 | |
DEC_DROPOUT = 0.5 | |
attn = Attention(ENC_HID_DIM, DEC_HID_DIM) | |
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT) | |
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn) | |
model = Seq2Seq(enc, dec, "cpu") | |
model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu'))) | |
def infer(text, max_length=50): | |
text = tokenize_ar(text) | |
sequence = [] | |
sequence.append(src_vocab['<sos>']) | |
sequence.extend([src_vocab[token] for token in text]) | |
sequence.append(src_vocab['<eos>']) | |
sequence = torch.Tensor(sequence) | |
sequence = sequence[:, None].to(torch.int64) | |
target = torch.zeros(max_length, 1).to(torch.int64) | |
with torch.no_grad(): | |
model.eval() | |
output = model(sequence, target, 0) | |
output_dim = output.shape[-1] | |
output = output[1:].view(-1, output_dim) | |
prediction = [] | |
for i in output: | |
prediction.append(torch.argmax(i).item()) | |
tokens = trg_vocab.lookup_tokens(prediction) | |
en = TreebankWordDetokenizer().detokenize(tokens).replace('<eos>', "") | |
return re.sub(r'[^\w\s]','',en).strip() | |
iface = gr.Interface(fn=infer, inputs="text", outputs="text") | |
iface.launch() |