Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import argparse | |
import sentencepiece as spm | |
from utils import utils_cls | |
from model import BanglaTransformer | |
from config import config as cfg | |
torch.manual_seed(0) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# device = torch.device('cpu') | |
uobj = utils_cls(device=device) | |
__MODULE__ = "Bangla Language Translation" | |
__MAIL__ = "[email protected]" | |
__MODIFICAIOTN__ = "28/03/2023" | |
__LICENSE__ = "MIT" | |
BASE_URL = "https://huggingface.co/saiful9379/Bangla-to-English-Translation/tree/main" | |
class Bn2EnTranslation: | |
def __init__(self): | |
self.bn_tokenizer= os.path.join(BASE_URL , "bn_model.model") | |
self.en_tokenizer=os.path.join(BASE_URL, 'en_model.model') | |
self.bn_vocab=os.path.join(BASE_URL,'bn_vocab.pkl') | |
self.en_vocab=os.path.join(BASE_URL, 'en_vocab.pkl') | |
self.model= os.path.join(BASE_URL,'pytorch_model.pt') | |
def read_data(self, data_path): | |
with open(data_path, "r") as f: | |
data = f.readlines() | |
data = list(map(lambda x: [x.split("\t")[0], x.split("\t")[1].replace("\n", "")], data)) | |
return data | |
def load_tokenizer(self, tokenizer_path:str = "")->object: | |
_tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path) | |
return _tokenizer | |
def get_vocab(self, BN_VOCAL_PATH:str="", EN_VOCAL_PATH:str=""): | |
bn_vocal, en_vocal = uobj.load_bn_vocal(BN_VOCAL_PATH), uobj.load_en_vocal(EN_VOCAL_PATH) | |
return bn_vocal, en_vocal | |
def load_model(self, model_path:str = "", SRC_VOCAB_SIZE:int=0, TGT_VOCAB_SIZE:int=0): | |
model = BanglaTransformer( | |
cfg.NUM_ENCODER_LAYERS, cfg.NUM_DECODER_LAYERS, cfg.EMB_SIZE, SRC_VOCAB_SIZE, | |
TGT_VOCAB_SIZE, cfg.FFN_HID_DIM, nhead= cfg.NHEAD) | |
model.to(device) | |
checkpoint = torch.load(model_path) | |
model.load_state_dict(checkpoint['model_state_dict']) | |
model.eval() | |
return model | |
def greedy_decode(self, model, src, src_mask, max_len, start_symbol, eos_index): | |
src = src.to(device) | |
src_mask = src_mask.to(device) | |
memory = model.encode(src, src_mask) | |
ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device) | |
for i in range(max_len-1): | |
memory = memory.to(device) | |
memory_mask = torch.zeros(ys.shape[0], memory.shape[0]).to(device).type(torch.bool) | |
tgt_mask = (uobj.generate_square_subsequent_mask(ys.size(0)) | |
.type(torch.bool)).to(device) | |
out = model.decode(ys, memory, tgt_mask) | |
out = out.transpose(0, 1) | |
prob = model.generator(out[:, -1]) | |
_, next_word = torch.max(prob, dim = 1) | |
next_word = next_word.item() | |
ys = torch.cat([ys,torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0) | |
if next_word == eos_index: | |
break | |
return ys | |
def get_bntoen_model(self): | |
print("Tokenizer Loading ...... : ", end="", flush=True) | |
bn_tokenizer = self.load_tokenizer(tokenizer_path=self.bn_tokenizer) | |
print("Done") | |
print("Vocab Loading ...... : ", end="", flush=True) | |
bn_vocab, en_vocab = self.get_vocab(BN_VOCAL_PATH=self.bn_vocab, EN_VOCAL_PATH=self.en_vocab) | |
print("Done") | |
print("Model Loading ...... : ", end="", flush=True) | |
model = self.load_model(model_path=self.model, SRC_VOCAB_SIZE=len(bn_vocab), TGT_VOCAB_SIZE=len(en_vocab)) | |
print("Done") | |
models = { | |
"bn_tokenizer" : bn_tokenizer, | |
"bn_vocab" : bn_vocab, | |
"en_vocab" : en_vocab, | |
"model" : model | |
} | |
return models | |
def translate(self, text, models): | |
model = models["model"] | |
src_vocab = models["bn_vocab"] | |
tgt_vocab = models["en_vocab"] | |
src_tokenizer = models["bn_tokenizer"] | |
src = text | |
PAD_IDX, BOS_IDX, EOS_IDX= src_vocab['<pad>'], src_vocab['<bos>'], src_vocab['<eos>'] | |
tokens = [BOS_IDX] + [src_vocab.get_stoi()[tok] for tok in src_tokenizer.encode(src, out_type=str)]+ [EOS_IDX] | |
num_tokens = len(tokens) | |
src = (torch.LongTensor(tokens).reshape(num_tokens, 1) ) | |
src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool) | |
tgt_tokens = self.greedy_decode(model, src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX, eos_index= EOS_IDX).flatten() | |
p_text = " ".join([tgt_vocab.get_itos()[tok] for tok in tgt_tokens]).replace("<bos>", "").replace("<eos>", "") | |
pts = " ".join(list(map(lambda x : x , p_text.replace(" ", "").split("▁")))) | |
return pts.strip() | |
if __name__ == "__main__": | |
print(torch.cuda.get_device_name(0)) | |
text = "এই উপজেলায় ১টি সরকারি কলেজ রয়েছে" | |
obj = Bn2EnTranslation() | |
models = obj.get_bntoen_model() | |
pre = obj.translate(text, models) | |
print("="*20) | |
print(f"input : {text}") | |
print(f"prediction: {pre}") | |