import gradio as gr import random import re import torch import torch.nn as nn import torch.nn.functional as F import unicodedata import nltk from nltk.tokenize.treebank import TreebankWordDetokenizer nltk.download('punkt') class Encoder(nn.Module): def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout): super().__init__() self.embedding = nn.Embedding(input_dim, emb_dim) self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional = True) self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim) self.dropout = nn.Dropout(dropout) def forward(self, src): embedded = self.dropout(self.embedding(src)) outputs, hidden = self.rnn(embedded) hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1))) return outputs, hidden class Attention(nn.Module): def __init__(self, enc_hid_dim, dec_hid_dim): super().__init__() self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim) self.v = nn.Linear(dec_hid_dim, 1, bias = False) def forward(self, hidden, encoder_outputs): batch_size = encoder_outputs.shape[1] src_len = encoder_outputs.shape[0] hidden = hidden.unsqueeze(1).repeat(1, src_len, 1) encoder_outputs = encoder_outputs.permute(1, 0, 2) energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2))) attention = self.v(energy).squeeze(2) return F.softmax(attention, dim=1) class Decoder(nn.Module): def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention): super().__init__() self.output_dim = output_dim self.attention = attention self.embedding = nn.Embedding(output_dim, emb_dim) self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim) self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim) self.dropout = nn.Dropout(dropout) def forward(self, input, hidden, encoder_outputs): input = input.unsqueeze(0) embedded = self.dropout(self.embedding(input)) a = self.attention(hidden, encoder_outputs) a = a.unsqueeze(1) encoder_outputs = encoder_outputs.permute(1, 0, 2) weighted = torch.bmm(a, encoder_outputs) weighted = weighted.permute(1, 0, 2) rnn_input = torch.cat((embedded, weighted), dim = 2) output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0)) assert (output == hidden).all() embedded = embedded.squeeze(0) output = output.squeeze(0) weighted = weighted.squeeze(0) prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1)) return prediction, hidden.squeeze(0) class Seq2Seq(nn.Module): def __init__(self, encoder, decoder, device): super().__init__() self.encoder = encoder self.decoder = decoder self.device = device def forward(self, src, trg, teacher_forcing_ratio = 0.5): batch_size = trg.shape[1] trg_len = trg.shape[0] trg_vocab_size = self.decoder.output_dim outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device) encoder_outputs, hidden = self.encoder(src) input = trg[0,:] for t in range(1, trg_len): output, hidden = self.decoder(input, hidden, encoder_outputs) outputs[t] = output teacher_force = random.random() < teacher_forcing_ratio top1 = output.argmax(1) input = trg[t] if teacher_force else top1 return outputs def unicodeToAscii(s): return ''.join( c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn' ) def tokenize_ar(text): """ Tokenizes Arabic text from a string into a list of strings (tokens) and reverses it """ return [tok for tok in nltk.tokenize.wordpunct_tokenize(unicodeToAscii(text))] src_vocab = torch.load("arabic_vocab.pth") trg_vocab = torch.load("english_vocab.pth") INPUT_DIM = 9790 OUTPUT_DIM = 5682 ENC_EMB_DIM = 256 DEC_EMB_DIM = 256 ENC_HID_DIM = 512 DEC_HID_DIM = 512 ENC_DROPOUT = 0.5 DEC_DROPOUT = 0.5 attn = Attention(ENC_HID_DIM, DEC_HID_DIM) enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT) dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn) model = Seq2Seq(enc, dec, "cpu") model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu'))) def infer(text, max_length=50): text = tokenize_ar(text) sequence = [] sequence.append(src_vocab['']) sequence.extend([src_vocab[token] for token in text]) sequence.append(src_vocab['']) sequence = torch.Tensor(sequence) sequence = sequence[:, None].to(torch.int64) target = torch.zeros(max_length, 1).to(torch.int64) with torch.no_grad(): model.eval() output = model(sequence, target, 0) output_dim = output.shape[-1] output = output[1:].view(-1, output_dim) prediction = [] for i in output: prediction.append(torch.argmax(i).item()) tokens = trg_vocab.lookup_tokens(prediction) en = TreebankWordDetokenizer().detokenize(tokens).replace('', "") return re.sub(r'[^\w\s]','',en).strip() iface = gr.Interface(fn=infer, inputs="text", outputs="text") iface.launch()