Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import numpy as np | |
import pandas as pd | |
import random | |
import sys | |
import os | |
import json | |
import enum | |
import traceback | |
import re | |
#F_DIR = os.path.dirname(os.path.realpath(__file__)) | |
F_DIR = '/home/user/app/ttsv/checkpoints/' | |
class XlitError(enum.Enum): | |
lang_err = "Unsupported langauge ID requested ;( Please check available languages." | |
string_err = "String passed is incompatable ;(" | |
internal_err = "Internal crash ;(" | |
unknown_err = "Unknown Failure" | |
loading_err = "Loading failed ;( Check if metadata/paths are correctly configured." | |
##=================== Network ================================================== | |
class Encoder(nn.Module): | |
def __init__( | |
self, | |
input_dim, | |
embed_dim, | |
hidden_dim, | |
rnn_type="gru", | |
layers=1, | |
bidirectional=False, | |
dropout=0, | |
device="cpu", | |
): | |
super(Encoder, self).__init__() | |
self.input_dim = input_dim # src_vocab_sz | |
self.enc_embed_dim = embed_dim | |
self.enc_hidden_dim = hidden_dim | |
self.enc_rnn_type = rnn_type | |
self.enc_layers = layers | |
self.enc_directions = 2 if bidirectional else 1 | |
self.device = device | |
self.embedding = nn.Embedding(self.input_dim, self.enc_embed_dim) | |
if self.enc_rnn_type == "gru": | |
self.enc_rnn = nn.GRU( | |
input_size=self.enc_embed_dim, | |
hidden_size=self.enc_hidden_dim, | |
num_layers=self.enc_layers, | |
bidirectional=bidirectional, | |
) | |
elif self.enc_rnn_type == "lstm": | |
self.enc_rnn = nn.LSTM( | |
input_size=self.enc_embed_dim, | |
hidden_size=self.enc_hidden_dim, | |
num_layers=self.enc_layers, | |
bidirectional=bidirectional, | |
) | |
else: | |
raise Exception("XlitError: unknown RNN type mentioned") | |
def forward(self, x, x_sz, hidden=None): | |
""" | |
x_sz: (batch_size, 1) - Unpadded sequence lengths used for pack_pad | |
""" | |
batch_sz = x.shape[0] | |
# x: batch_size, max_length, enc_embed_dim | |
x = self.embedding(x) | |
## pack the padded data | |
# x: max_length, batch_size, enc_embed_dim -> for pack_pad | |
x = x.permute(1, 0, 2) | |
x = nn.utils.rnn.pack_padded_sequence(x, x_sz, enforce_sorted=False) # unpad | |
# output: packed_size, batch_size, enc_embed_dim | |
# hidden: n_layer**num_directions, batch_size, hidden_dim | if LSTM (h_n, c_n) | |
output, hidden = self.enc_rnn( | |
x | |
) # gru returns hidden state of all timesteps as well as hidden state at last timestep | |
## pad the sequence to the max length in the batch | |
# output: max_length, batch_size, enc_emb_dim*directions) | |
output, _ = nn.utils.rnn.pad_packed_sequence(output) | |
# output: batch_size, max_length, hidden_dim | |
output = output.permute(1, 0, 2) | |
return output, hidden | |
def get_word_embedding(self, x): | |
""" """ | |
x_sz = torch.tensor([len(x)]) | |
x_ = torch.tensor(x).unsqueeze(0).to(dtype=torch.long) | |
# x: 1, max_length, enc_embed_dim | |
x = self.embedding(x_) | |
## pack the padded data | |
# x: max_length, 1, enc_embed_dim -> for pack_pad | |
x = x.permute(1, 0, 2) | |
x = nn.utils.rnn.pack_padded_sequence(x, x_sz, enforce_sorted=False) # unpad | |
# output: packed_size, 1, enc_embed_dim | |
# hidden: n_layer**num_directions, 1, hidden_dim | if LSTM (h_n, c_n) | |
output, hidden = self.enc_rnn( | |
x | |
) # gru returns hidden state of all timesteps as well as hidden state at last timestep | |
out_embed = hidden[0].squeeze() | |
return out_embed | |
class Decoder(nn.Module): | |
def __init__( | |
self, | |
output_dim, | |
embed_dim, | |
hidden_dim, | |
rnn_type="gru", | |
layers=1, | |
use_attention=True, | |
enc_outstate_dim=None, # enc_directions * enc_hidden_dim | |
dropout=0, | |
device="cpu", | |
): | |
super(Decoder, self).__init__() | |
self.output_dim = output_dim # tgt_vocab_sz | |
self.dec_hidden_dim = hidden_dim | |
self.dec_embed_dim = embed_dim | |
self.dec_rnn_type = rnn_type | |
self.dec_layers = layers | |
self.use_attention = use_attention | |
self.device = device | |
if self.use_attention: | |
self.enc_outstate_dim = enc_outstate_dim if enc_outstate_dim else hidden_dim | |
else: | |
self.enc_outstate_dim = 0 | |
self.embedding = nn.Embedding(self.output_dim, self.dec_embed_dim) | |
if self.dec_rnn_type == "gru": | |
self.dec_rnn = nn.GRU( | |
input_size=self.dec_embed_dim | |
+ self.enc_outstate_dim, # to concat attention_output | |
hidden_size=self.dec_hidden_dim, # previous Hidden | |
num_layers=self.dec_layers, | |
batch_first=True, | |
) | |
elif self.dec_rnn_type == "lstm": | |
self.dec_rnn = nn.LSTM( | |
input_size=self.dec_embed_dim | |
+ self.enc_outstate_dim, # to concat attention_output | |
hidden_size=self.dec_hidden_dim, # previous Hidden | |
num_layers=self.dec_layers, | |
batch_first=True, | |
) | |
else: | |
raise Exception("XlitError: unknown RNN type mentioned") | |
self.fc = nn.Sequential( | |
nn.Linear(self.dec_hidden_dim, self.dec_embed_dim), | |
nn.LeakyReLU(), | |
# nn.Linear(self.dec_embed_dim, self.dec_embed_dim), nn.LeakyReLU(), # removing to reduce size | |
nn.Linear(self.dec_embed_dim, self.output_dim), | |
) | |
##----- Attention ---------- | |
if self.use_attention: | |
self.W1 = nn.Linear(self.enc_outstate_dim, self.dec_hidden_dim) | |
self.W2 = nn.Linear(self.dec_hidden_dim, self.dec_hidden_dim) | |
self.V = nn.Linear(self.dec_hidden_dim, 1) | |
def attention(self, x, hidden, enc_output): | |
""" | |
x: (batch_size, 1, dec_embed_dim) -> after Embedding | |
enc_output: batch_size, max_length, enc_hidden_dim *num_directions | |
hidden: n_layers, batch_size, hidden_size | if LSTM (h_n, c_n) | |
""" | |
## perform addition to calculate the score | |
# hidden_with_time_axis: batch_size, 1, hidden_dim | |
## hidden_with_time_axis = hidden.permute(1, 0, 2) ## replaced with below 2lines | |
hidden_with_time_axis = ( | |
torch.sum(hidden, axis=0) | |
if self.dec_rnn_type != "lstm" | |
else torch.sum(hidden[0], axis=0) | |
) # h_n | |
hidden_with_time_axis = hidden_with_time_axis.unsqueeze(1) | |
# score: batch_size, max_length, hidden_dim | |
score = torch.tanh(self.W1(enc_output) + self.W2(hidden_with_time_axis)) | |
# attention_weights: batch_size, max_length, 1 | |
# we get 1 at the last axis because we are applying score to self.V | |
attention_weights = torch.softmax(self.V(score), dim=1) | |
# context_vector shape after sum == (batch_size, hidden_dim) | |
context_vector = attention_weights * enc_output | |
context_vector = torch.sum(context_vector, dim=1) | |
# context_vector: batch_size, 1, hidden_dim | |
context_vector = context_vector.unsqueeze(1) | |
# attend_out (batch_size, 1, dec_embed_dim + hidden_size) | |
attend_out = torch.cat((context_vector, x), -1) | |
return attend_out, attention_weights | |
def forward(self, x, hidden, enc_output): | |
""" | |
x: (batch_size, 1) | |
enc_output: batch_size, max_length, dec_embed_dim | |
hidden: n_layer, batch_size, hidden_size | lstm: (h_n, c_n) | |
""" | |
if (hidden is None) and (self.use_attention is False): | |
raise Exception( | |
"XlitError: No use of a decoder with No attention and No Hidden" | |
) | |
batch_sz = x.shape[0] | |
if hidden is None: | |
# hidden: n_layers, batch_size, hidden_dim | |
hid_for_att = torch.zeros( | |
(self.dec_layers, batch_sz, self.dec_hidden_dim) | |
).to(self.device) | |
elif self.dec_rnn_type == "lstm": | |
hid_for_att = hidden[1] # c_n | |
# x (batch_size, 1, dec_embed_dim) -> after embedding | |
x = self.embedding(x) | |
if self.use_attention: | |
# x (batch_size, 1, dec_embed_dim + hidden_size) -> after attention | |
# aw: (batch_size, max_length, 1) | |
x, aw = self.attention(x, hidden, enc_output) | |
else: | |
x, aw = x, 0 | |
# passing the concatenated vector to the GRU | |
# output: (batch_size, n_layers, hidden_size) | |
# hidden: n_layers, batch_size, hidden_size | if LSTM (h_n, c_n) | |
output, hidden = ( | |
self.dec_rnn(x, hidden) if hidden is not None else self.dec_rnn(x) | |
) | |
# output :shp: (batch_size * 1, hidden_size) | |
output = output.view(-1, output.size(2)) | |
# output :shp: (batch_size * 1, output_dim) | |
output = self.fc(output) | |
return output, hidden, aw | |
class Seq2Seq(nn.Module): | |
""" | |
Class dependency: Encoder, Decoder | |
""" | |
def __init__( | |
self, encoder, decoder, pass_enc2dec_hid=False, dropout=0, device="cpu" | |
): | |
super(Seq2Seq, self).__init__() | |
self.encoder = encoder | |
self.decoder = decoder | |
self.device = device | |
self.pass_enc2dec_hid = pass_enc2dec_hid | |
_force_en2dec_hid_conv = False | |
if self.pass_enc2dec_hid: | |
assert ( | |
decoder.dec_hidden_dim == encoder.enc_hidden_dim | |
), "Hidden Dimension of encoder and decoder must be same, or unset `pass_enc2dec_hid`" | |
if decoder.use_attention: | |
assert ( | |
decoder.enc_outstate_dim | |
== encoder.enc_directions * encoder.enc_hidden_dim | |
), "Set `enc_out_dim` correctly in decoder" | |
assert ( | |
self.pass_enc2dec_hid or decoder.use_attention | |
), "No use of a decoder with No attention and No Hidden from Encoder" | |
self.use_conv_4_enc2dec_hid = False | |
if ( | |
self.pass_enc2dec_hid | |
and (encoder.enc_directions * encoder.enc_layers != decoder.dec_layers) | |
) or _force_en2dec_hid_conv: | |
if encoder.enc_rnn_type == "lstm" or encoder.enc_rnn_type == "lstm": | |
raise Exception( | |
"XlitError: conv for enc2dec_hid not implemented; Change the layer numbers appropriately" | |
) | |
self.use_conv_4_enc2dec_hid = True | |
self.enc_hid_1ax = encoder.enc_directions * encoder.enc_layers | |
self.dec_hid_1ax = decoder.dec_layers | |
self.e2d_hidden_conv = nn.Conv1d(self.enc_hid_1ax, self.dec_hid_1ax, 1) | |
def enc2dec_hidden(self, enc_hidden): | |
""" | |
enc_hidden: n_layer, batch_size, hidden_dim*num_directions | |
TODO: Implement the logic for LSTm bsed model | |
""" | |
# hidden: batch_size, enc_layer*num_directions, enc_hidden_dim | |
hidden = enc_hidden.permute(1, 0, 2).contiguous() | |
# hidden: batch_size, dec_layers, dec_hidden_dim -> [N,C,Tstep] | |
hidden = self.e2d_hidden_conv(hidden) | |
# hidden: dec_layers, batch_size , dec_hidden_dim | |
hidden_for_dec = hidden.permute(1, 0, 2).contiguous() | |
return hidden_for_dec | |
def active_beam_inference(self, src, beam_width=3, max_tgt_sz=50): | |
"""Search based decoding | |
src: (sequence_len) | |
""" | |
def _avg_score(p_tup): | |
"""Used for Sorting | |
TODO: Dividing by length of sequence power alpha as hyperparam | |
""" | |
return p_tup[0] | |
import sys | |
batch_size = 1 | |
start_tok = src[0] | |
end_tok = src[-1] | |
src_sz = torch.tensor([len(src)]) | |
src_ = src.unsqueeze(0) | |
# enc_output: (batch_size, padded_seq_length, enc_hidden_dim*num_direction) | |
# enc_hidden: (enc_layers*num_direction, batch_size, hidden_dim) | |
enc_output, enc_hidden = self.encoder(src_, src_sz) | |
if self.pass_enc2dec_hid: | |
# dec_hidden: dec_layers, batch_size , dec_hidden_dim | |
if self.use_conv_4_enc2dec_hid: | |
init_dec_hidden = self.enc2dec_hidden(enc_hidden) | |
else: | |
init_dec_hidden = enc_hidden | |
else: | |
# dec_hidden -> Will be initialized to zeros internally | |
init_dec_hidden = None | |
# top_pred[][0] = Σ-log_softmax | |
# top_pred[][1] = sequence torch.tensor shape: (1) | |
# top_pred[][2] = dec_hidden | |
top_pred_list = [(0, start_tok.unsqueeze(0), init_dec_hidden)] | |
for t in range(max_tgt_sz): | |
cur_pred_list = [] | |
for p_tup in top_pred_list: | |
if p_tup[1][-1] == end_tok: | |
cur_pred_list.append(p_tup) | |
continue | |
# dec_hidden: dec_layers, 1, hidden_dim | |
# dec_output: 1, output_dim | |
dec_output, dec_hidden, _ = self.decoder( | |
x=p_tup[1][-1].view(1, 1), # dec_input: (1,1) | |
hidden=p_tup[2], | |
enc_output=enc_output, | |
) | |
## π{prob} = Σ{log(prob)} -> to prevent diminishing | |
# dec_output: (1, output_dim) | |
dec_output = nn.functional.log_softmax(dec_output, dim=1) | |
# pred_topk.values & pred_topk.indices: (1, beam_width) | |
pred_topk = torch.topk(dec_output, k=beam_width, dim=1) | |
for i in range(beam_width): | |
sig_logsmx_ = p_tup[0] + pred_topk.values[0][i] | |
# seq_tensor_ : (seq_len) | |
seq_tensor_ = torch.cat((p_tup[1], pred_topk.indices[0][i].view(1))) | |
cur_pred_list.append((sig_logsmx_, seq_tensor_, dec_hidden)) | |
cur_pred_list.sort(key=_avg_score, reverse=True) # Maximized order | |
top_pred_list = cur_pred_list[:beam_width] | |
# check if end_tok of all topk | |
end_flags_ = [1 if t[1][-1] == end_tok else 0 for t in top_pred_list] | |
if beam_width == sum(end_flags_): | |
break | |
pred_tnsr_list = [t[1] for t in top_pred_list] | |
return pred_tnsr_list | |
##===================== Glyph handlers ======================================= | |
class GlyphStrawboss: | |
def __init__(self, glyphs="en"): | |
"""list of letters in a language in unicode | |
lang: ISO Language code | |
glyphs: json file with script information | |
""" | |
if glyphs == "en": | |
# Smallcase alone | |
self.glyphs = [chr(alpha) for alpha in range(97, 122 + 1)] | |
else: | |
self.dossier = json.load(open(glyphs, encoding="utf-8")) | |
self.glyphs = self.dossier["glyphs"] | |
self.numsym_map = self.dossier["numsym_map"] | |
self.char2idx = {} | |
self.idx2char = {} | |
self._create_index() | |
def _create_index(self): | |
self.char2idx["_"] = 0 # pad | |
self.char2idx["$"] = 1 # start | |
self.char2idx["#"] = 2 # end | |
self.char2idx["*"] = 3 # Mask | |
self.char2idx["'"] = 4 # apostrophe U+0027 | |
self.char2idx["%"] = 5 # unused | |
self.char2idx["!"] = 6 # unused | |
# letter to index mapping | |
for idx, char in enumerate(self.glyphs): | |
self.char2idx[char] = idx + 7 # +7 token initially | |
# index to letter mapping | |
for char, idx in self.char2idx.items(): | |
self.idx2char[idx] = char | |
def size(self): | |
return len(self.char2idx) | |
def word2xlitvec(self, word): | |
"""Converts given string of gyphs(word) to vector(numpy) | |
Also adds tokens for start and end | |
""" | |
try: | |
vec = [self.char2idx["$"]] # start token | |
for i in list(word): | |
vec.append(self.char2idx[i]) | |
vec.append(self.char2idx["#"]) # end token | |
vec = np.asarray(vec, dtype=np.int64) | |
return vec | |
except Exception as error: | |
print("XlitError: In word:", word, "Error Char not in Token:", error) | |
sys.exit() | |
def xlitvec2word(self, vector): | |
"""Converts vector(numpy) to string of glyphs(word)""" | |
char_list = [] | |
for i in vector: | |
char_list.append(self.idx2char[i]) | |
word = "".join(char_list).replace("$", "").replace("#", "") # remove tokens | |
word = word.replace("_", "").replace("*", "") # remove tokens | |
return word | |
class VocabSanitizer: | |
def __init__(self, data_file): | |
""" | |
data_file: path to file conatining vocabulary list | |
""" | |
extension = os.path.splitext(data_file)[-1] | |
if extension == ".json": | |
self.vocab_set = set(json.load(open(data_file, encoding="utf-8"))) | |
elif extension == ".csv": | |
self.vocab_df = pd.read_csv(data_file).set_index("WORD") | |
self.vocab_set = set(self.vocab_df.index) | |
else: | |
print("XlitError: Only Json/CSV file extension supported") | |
def reposition(self, word_list): | |
"""Reorder Words in list""" | |
new_list = [] | |
temp_ = word_list.copy() | |
for v in word_list: | |
if v in self.vocab_set: | |
new_list.append(v) | |
temp_.remove(v) | |
new_list.extend(temp_) | |
return new_list | |
##=============== INSTANTIATION ================================================ | |
class XlitPiston: | |
""" | |
For handling prediction & post-processing of transliteration for a single language | |
Class dependency: Seq2Seq, GlyphStrawboss, VocabSanitizer | |
Global Variables: F_DIR | |
""" | |
def __init__( | |
self, | |
weight_path, | |
vocab_file, | |
tglyph_cfg_file, | |
iglyph_cfg_file="en", | |
device="cpu", | |
): | |
self.device = device | |
self.in_glyph_obj = GlyphStrawboss(iglyph_cfg_file) | |
self.tgt_glyph_obj = GlyphStrawboss(glyphs=tglyph_cfg_file) | |
self.voc_sanity = VocabSanitizer(vocab_file) | |
self._numsym_set = set( | |
json.load(open(tglyph_cfg_file, encoding="utf-8"))["numsym_map"].keys() | |
) | |
self._inchar_set = set("abcdefghijklmnopqrstuvwxyz") | |
self._natscr_set = set().union( | |
self.tgt_glyph_obj.glyphs, sum(self.tgt_glyph_obj.numsym_map.values(), []) | |
) | |
## Model Config Static TODO: add defining in json support | |
input_dim = self.in_glyph_obj.size() | |
output_dim = self.tgt_glyph_obj.size() | |
enc_emb_dim = 300 | |
dec_emb_dim = 300 | |
enc_hidden_dim = 512 | |
dec_hidden_dim = 512 | |
rnn_type = "lstm" | |
enc2dec_hid = True | |
attention = True | |
enc_layers = 1 | |
dec_layers = 2 | |
m_dropout = 0 | |
enc_bidirect = True | |
enc_outstate_dim = enc_hidden_dim * (2 if enc_bidirect else 1) | |
enc = Encoder( | |
input_dim=input_dim, | |
embed_dim=enc_emb_dim, | |
hidden_dim=enc_hidden_dim, | |
rnn_type=rnn_type, | |
layers=enc_layers, | |
dropout=m_dropout, | |
device=self.device, | |
bidirectional=enc_bidirect, | |
) | |
dec = Decoder( | |
output_dim=output_dim, | |
embed_dim=dec_emb_dim, | |
hidden_dim=dec_hidden_dim, | |
rnn_type=rnn_type, | |
layers=dec_layers, | |
dropout=m_dropout, | |
use_attention=attention, | |
enc_outstate_dim=enc_outstate_dim, | |
device=self.device, | |
) | |
self.model = Seq2Seq(enc, dec, pass_enc2dec_hid=enc2dec_hid, device=self.device) | |
self.model = self.model.to(self.device) | |
weights = torch.load(weight_path, map_location=torch.device(self.device)) | |
self.model.load_state_dict(weights) | |
self.model.eval() | |
def character_model(self, word, beam_width=1): | |
in_vec = torch.from_numpy(self.in_glyph_obj.word2xlitvec(word)).to(self.device) | |
## change to active or passive beam | |
p_out_list = self.model.active_beam_inference(in_vec, beam_width=beam_width) | |
p_result = [ | |
self.tgt_glyph_obj.xlitvec2word(out.cpu().numpy()) for out in p_out_list | |
] | |
result = self.voc_sanity.reposition(p_result) | |
# List type | |
return result | |
def numsym_model(self, seg): | |
"""tgt_glyph_obj.numsym_map[x] returns a list object""" | |
if len(seg) == 1: | |
return [seg] + self.tgt_glyph_obj.numsym_map[seg] | |
a = [self.tgt_glyph_obj.numsym_map[n][0] for n in seg] | |
return [seg] + ["".join(a)] | |
def _word_segementer(self, sequence): | |
sequence = sequence.lower() | |
accepted = set().union(self._numsym_set, self._inchar_set, self._natscr_set) | |
# sequence = ''.join([i for i in sequence if i in accepted]) | |
segment = [] | |
idx = 0 | |
seq_ = list(sequence) | |
while len(seq_): | |
# for Number-Symbol | |
temp = "" | |
while len(seq_) and seq_[0] in self._numsym_set: | |
temp += seq_[0] | |
seq_.pop(0) | |
if temp != "": | |
segment.append(temp) | |
# for Target Chars | |
temp = "" | |
while len(seq_) and seq_[0] in self._natscr_set: | |
temp += seq_[0] | |
seq_.pop(0) | |
if temp != "": | |
segment.append(temp) | |
# for Input-Roman Chars | |
temp = "" | |
while len(seq_) and seq_[0] in self._inchar_set: | |
temp += seq_[0] | |
seq_.pop(0) | |
if temp != "": | |
segment.append(temp) | |
temp = "" | |
while len(seq_) and seq_[0] not in accepted: | |
temp += seq_[0] | |
seq_.pop(0) | |
if temp != "": | |
segment.append(temp) | |
return segment | |
def inferencer(self, sequence, beam_width=10): | |
seg = self._word_segementer(sequence[:120]) | |
lit_seg = [] | |
p = 0 | |
while p < len(seg): | |
if seg[p][0] in self._natscr_set: | |
lit_seg.append([seg[p]]) | |
p += 1 | |
elif seg[p][0] in self._inchar_set: | |
lit_seg.append(self.character_model(seg[p], beam_width=beam_width)) | |
p += 1 | |
elif seg[p][0] in self._numsym_set: # num & punc | |
lit_seg.append(self.numsym_model(seg[p])) | |
p += 1 | |
else: | |
lit_seg.append([seg[p]]) | |
p += 1 | |
## IF segment less/equal to 2 then return combinotorial, | |
## ELSE only return top1 of each result concatenated | |
if len(lit_seg) == 1: | |
final_result = lit_seg[0] | |
elif len(lit_seg) == 2: | |
final_result = [""] | |
for seg in lit_seg: | |
new_result = [] | |
for s in seg: | |
for f in final_result: | |
new_result.append(f + s) | |
final_result = new_result | |
else: | |
new_result = [] | |
for seg in lit_seg: | |
new_result.append(seg[0]) | |
final_result = ["".join(new_result)] | |
return final_result | |
from collections.abc import Iterable | |
from pydload import dload | |
import zipfile | |
MODEL_DOWNLOAD_URL_PREFIX = "https://github.com/AI4Bharat/IndianNLP-Transliteration/releases/download/xlit_v0.5.0/" | |
def is_folder_writable(folder): | |
try: | |
os.makedirs(folder, exist_ok=True) | |
tmp_file = os.path.join(folder, ".write_test") | |
with open(tmp_file, "w") as f: | |
f.write("Permission Check") | |
os.remove(tmp_file) | |
return True | |
except: | |
return False | |
def is_directory_writable(path): | |
if os.name == "nt": | |
return is_folder_writable(path) | |
return os.access(path, os.W_OK | os.X_OK) | |
class XlitEngine: | |
""" | |
For Managing the top level tasks and applications of transliteration | |
Global Variables: F_DIR | |
""" | |
def __init__( | |
self, lang2use="all", config_path="translit_models/default_lineup.json" | |
): | |
lineup = json.load(open(os.path.join(F_DIR, config_path), encoding="utf-8")) | |
self.lang_config = {} | |
if isinstance(lang2use, str): | |
if lang2use == "all": | |
self.lang_config = lineup | |
elif lang2use in lineup: | |
self.lang_config[lang2use] = lineup[lang2use] | |
else: | |
raise Exception( | |
"XlitError: The entered Langauge code not found. Available are {}".format( | |
lineup.keys() | |
) | |
) | |
elif isinstance(lang2use, Iterable): | |
for l in lang2use: | |
try: | |
self.lang_config[l] = lineup[l] | |
except: | |
print( | |
"XlitError: Language code {} not found, Skipping...".format(l) | |
) | |
else: | |
raise Exception( | |
"XlitError: lang2use must be a list of language codes (or) string of single language code" | |
) | |
if is_directory_writable(F_DIR): | |
models_path = os.path.join(F_DIR, "translit_models") | |
else: | |
user_home = os.path.expanduser("~") | |
models_path = os.path.join(user_home, ".AI4Bharat_Xlit_Models") | |
os.makedirs(models_path, exist_ok=True) | |
self.download_models(models_path) | |
self.langs = {} | |
self.lang_model = {} | |
for la in self.lang_config: | |
try: | |
print("Loading {}...".format(la)) | |
self.lang_model[la] = XlitPiston( | |
weight_path=os.path.join( | |
models_path, self.lang_config[la]["weight"] | |
), | |
vocab_file=os.path.join(models_path, self.lang_config[la]["vocab"]), | |
tglyph_cfg_file=os.path.join( | |
models_path, self.lang_config[la]["script"] | |
), | |
iglyph_cfg_file="en", | |
) | |
self.langs[la] = self.lang_config[la]["name"] | |
except Exception as error: | |
print("XlitError: Failure in loading {} \n".format(la), error) | |
print(XlitError.loading_err.value) | |
def download_models(self, models_path): | |
""" | |
Download models from GitHub Releases if not exists | |
""" | |
for l in self.lang_config: | |
lang_name = self.lang_config[l]["eng_name"] | |
lang_model_path = os.path.join(models_path, lang_name) | |
if not os.path.isdir(lang_model_path): | |
print("Downloading model for language: %s" % lang_name) | |
remote_url = MODEL_DOWNLOAD_URL_PREFIX + lang_name + ".zip" | |
downloaded_zip_path = os.path.join(models_path, lang_name + ".zip") | |
dload(url=remote_url, save_to_path=downloaded_zip_path, max_time=None) | |
if not os.path.isfile(downloaded_zip_path): | |
exit( | |
f"ERROR: Unable to download model from {remote_url} into {models_path}" | |
) | |
with zipfile.ZipFile(downloaded_zip_path, "r") as zip_ref: | |
zip_ref.extractall(models_path) | |
if os.path.isdir(lang_model_path): | |
os.remove(downloaded_zip_path) | |
else: | |
exit( | |
f"ERROR: Unable to find models in {lang_model_path} after download" | |
) | |
return | |
def translit_word(self, eng_word, lang_code="default", topk=7, beam_width=10): | |
if eng_word == "": | |
return [] | |
if lang_code in self.langs: | |
try: | |
res_list = self.lang_model[lang_code].inferencer( | |
eng_word, beam_width=beam_width | |
) | |
return res_list[:topk] | |
except Exception as error: | |
print("XlitError:", traceback.format_exc()) | |
print(XlitError.internal_err.value) | |
return XlitError.internal_err | |
elif lang_code == "default": | |
try: | |
res_dict = {} | |
for la in self.lang_model: | |
res = self.lang_model[la].inferencer( | |
eng_word, beam_width=beam_width | |
) | |
res_dict[la] = res[:topk] | |
return res_dict | |
except Exception as error: | |
print("XlitError:", traceback.format_exc()) | |
print(XlitError.internal_err.value) | |
return XlitError.internal_err | |
else: | |
print("XlitError: Unknown Langauge requested", lang_code) | |
print(XlitError.lang_err.value) | |
return XlitError.lang_err | |
def translit_sentence(self, eng_sentence, lang_code="default", beam_width=10): | |
if eng_sentence == "": | |
return [] | |
if lang_code in self.langs: | |
try: | |
out_str = "" | |
for word in eng_sentence.split(): | |
res_ = self.lang_model[lang_code].inferencer( | |
word, beam_width=beam_width | |
) | |
out_str = out_str + res_[0] + " " | |
return out_str[:-1] | |
except Exception as error: | |
print("XlitError:", traceback.format_exc()) | |
print(XlitError.internal_err.value) | |
return XlitError.internal_err | |
elif lang_code == "default": | |
try: | |
res_dict = {} | |
for la in self.lang_model: | |
out_str = "" | |
for word in eng_sentence.split(): | |
res_ = self.lang_model[la].inferencer( | |
word, beam_width=beam_width | |
) | |
out_str = out_str + res_[0] + " " | |
res_dict[la] = out_str[:-1] | |
return res_dict | |
except Exception as error: | |
print("XlitError:", traceback.format_exc()) | |
print(XlitError.internal_err.value) | |
return XlitError.internal_err | |
else: | |
print("XlitError: Unknown Langauge requested", lang_code) | |
print(XlitError.lang_err.value) | |
return XlitError.lang_err | |
if __name__ == "__main__": | |
available_lang = [ | |
"bn", | |
"gu", | |
"hi", | |
"kn", | |
"gom", | |
"mai", | |
"ml", | |
"mr", | |
"pa", | |
"sd", | |
"si", | |
"ta", | |
"te", | |
"ur", | |
] | |
reg = re.compile(r"[a-zA-Z]") | |
lang = "hi" | |
engine = XlitEngine( | |
lang | |
) # if you don't specify lang code here, this will give results in all langs available | |
sent = "Hello World! ABCD क्या हाल है आपका?" | |
words = [ | |
engine.translit_word(word, topk=1)[lang][0] if reg.match(word) else word | |
for word in sent.split() | |
] # only transliterated en words, leaves rest as it is | |
updated_sent = " ".join(words) | |
print(updated_sent) | |
# output : हेलो वर्ल्ड! क्या हाल है आपका? | |
# y = engine.translit_sentence("Hello World !")['hi'] | |
# print(y) | |