|
from typing import Dict, Any |
|
from transformers import AutoModelForSeq2SeqLM, NllbTokenizer |
|
from translation import fix_tokenizer, TextPreprocessor, sentenize_with_fillers |
|
from sentence_splitter import SentenceSplitter |
|
import torch |
|
|
|
class PreTrainedPipeline(): |
|
def __init__(self, path=""): |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(path) |
|
if torch.cuda.is_available(): |
|
self.model = self.model.cuda() |
|
self.tokenizer = NllbTokenizer.from_pretrained(path) |
|
fix_tokenizer(self.tokenizer) |
|
self.splitter = SentenceSplitter(language='es') |
|
self.preprocessor = TextPreprocessor() |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
inputs = data.get("text", "") |
|
src_lang = data.get("src_lang", "spa_Latn") |
|
tgt_lang = data.get("tgt_lang", "agr_Latn") |
|
preprocess = data.get("preprocess", True) |
|
|
|
sentences, fillers = sentenize_with_fillers(inputs, self.splitter) |
|
if preprocess: |
|
sentences = [self.preprocessor(sent) for sent in sentences] |
|
|
|
translated_sentences = [] |
|
for sentence in sentences: |
|
self.tokenizer.src_lang = src_lang |
|
encoded = self.tokenizer(sentence, return_tensors="pt") |
|
generated_tokens = self.model.generate( |
|
**encoded.to(self.model.device), |
|
forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang] |
|
) |
|
translated_sentences.append( |
|
self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] |
|
) |
|
|
|
output = "".join( |
|
filler + sentence for filler, sentence in zip(fillers, translated_sentences) |
|
) + fillers[-1] |
|
|
|
return {"translation": output} |
|
|