File size: 1,639 Bytes
3481bcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from typing import Dict, Any
from transformers import AutoModelForSeq2SeqLM, NllbTokenizer
from translation import fix_tokenizer, TextPreprocessor, sentenize_with_fillers
from sentence_splitter import SentenceSplitter
import torch

class PreTrainedPipeline():
  def __init__(self, path=""):
    self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
    if torch.cuda.is_available():
      self.model = self.model.cuda()
    self.tokenizer = NllbTokenizer.from_pretrained(path)
    fix_tokenizer(self.tokenizer)
    self.splitter = SentenceSplitter(language='es')
    self.preprocessor = TextPreprocessor()

  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
    inputs = data.get("text", "")
    src_lang = data.get("src_lang", "spa_Latn")
    tgt_lang = data.get("tgt_lang", "agr_Latn")
    preprocess = data.get("preprocess", True)
    
    sentences, fillers = sentenize_with_fillers(inputs, self.splitter)
    if preprocess:
      sentences = [self.preprocessor(sent) for sent in sentences]

    translated_sentences  = []
    for sentence in sentences:
      self.tokenizer.src_lang = src_lang
      encoded = self.tokenizer(sentence, return_tensors="pt")
      generated_tokens = self.model.generate(
          **encoded.to(self.model.device),
          forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang]
      )
      translated_sentences.append(
          self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
      )

    output = "".join(
        filler + sentence for filler, sentence in zip(fillers, translated_sentences)
    ) + fillers[-1]
    
    return {"translation": output}