File size: 1,639 Bytes
3481bcb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
from typing import Dict, Any
from transformers import AutoModelForSeq2SeqLM, NllbTokenizer
from translation import fix_tokenizer, TextPreprocessor, sentenize_with_fillers
from sentence_splitter import SentenceSplitter
import torch
class PreTrainedPipeline():
def __init__(self, path=""):
self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
if torch.cuda.is_available():
self.model = self.model.cuda()
self.tokenizer = NllbTokenizer.from_pretrained(path)
fix_tokenizer(self.tokenizer)
self.splitter = SentenceSplitter(language='es')
self.preprocessor = TextPreprocessor()
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
inputs = data.get("text", "")
src_lang = data.get("src_lang", "spa_Latn")
tgt_lang = data.get("tgt_lang", "agr_Latn")
preprocess = data.get("preprocess", True)
sentences, fillers = sentenize_with_fillers(inputs, self.splitter)
if preprocess:
sentences = [self.preprocessor(sent) for sent in sentences]
translated_sentences = []
for sentence in sentences:
self.tokenizer.src_lang = src_lang
encoded = self.tokenizer(sentence, return_tensors="pt")
generated_tokens = self.model.generate(
**encoded.to(self.model.device),
forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang]
)
translated_sentences.append(
self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
)
output = "".join(
filler + sentence for filler, sentence in zip(fillers, translated_sentences)
) + fillers[-1]
return {"translation": output}
|