nllb-es-agr / pipeline.py
angelLino's picture
Add custom pipeline (#1)
3481bcb verified
from typing import Dict, Any
from transformers import AutoModelForSeq2SeqLM, NllbTokenizer
from translation import fix_tokenizer, TextPreprocessor, sentenize_with_fillers
from sentence_splitter import SentenceSplitter
import torch
class PreTrainedPipeline():
def __init__(self, path=""):
self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
if torch.cuda.is_available():
self.model = self.model.cuda()
self.tokenizer = NllbTokenizer.from_pretrained(path)
fix_tokenizer(self.tokenizer)
self.splitter = SentenceSplitter(language='es')
self.preprocessor = TextPreprocessor()
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
inputs = data.get("text", "")
src_lang = data.get("src_lang", "spa_Latn")
tgt_lang = data.get("tgt_lang", "agr_Latn")
preprocess = data.get("preprocess", True)
sentences, fillers = sentenize_with_fillers(inputs, self.splitter)
if preprocess:
sentences = [self.preprocessor(sent) for sent in sentences]
translated_sentences = []
for sentence in sentences:
self.tokenizer.src_lang = src_lang
encoded = self.tokenizer(sentence, return_tensors="pt")
generated_tokens = self.model.generate(
**encoded.to(self.model.device),
forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang]
)
translated_sentences.append(
self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
)
output = "".join(
filler + sentence for filler, sentence in zip(fillers, translated_sentences)
) + fillers[-1]
return {"translation": output}