File size: 2,143 Bytes
113d0af
 
394de27
 
 
113d0af
 
 
 
 
 
 
 
c2b9868
 
 
 
 
113d0af
 
 
 
 
 
 
394de27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113d0af
 
394de27
 
 
 
 
113d0af
394de27
113d0af
d29c54a
 
 
 
 
 
113d0af
d29c54a
113d0af
394de27
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from fairseq.models.transformer import TransformerModel
import torch
import re
import string


class Translator:
    def __init__(self, isFon:bool, device='cuda' if torch.cuda.is_available() else 'cpu'):

        # Charger le modèle pré-entraîné avec Fairseq
        inner = "fon_fr" if isFon else "fr_fon"

        self.model = TransformerModel.from_pretrained(
            f'./utils/checkpoints/{inner}',
            checkpoint_file = 'checkpoint_best.pt',
            data_name_or_path = f'utils/datas/data_prepared_{inner}/',
            source_lang='fon' if isFon else 'fr',
            target_lang='fr' if isFon else 'fon'
        )

        # Définir le périphérique sur lequel exécuter le modèle (par défaut sur 'cuda' si disponible)
        self.model.to(device)
        
        # Mettre le modèle en mode évaluation (pas de mise à jour des poids)
        self.model.eval()
    
    def preprocess(self, data):
        print('Preprocessing...')
        # Convertir chaque lettre en minuscule
        text = data.lower().strip()
        
        # Supprimer les apostrophes des phrases
        text = re.sub("'", "", text)
        
        # Supprimer toute ponctuation
        exclude = set(string.punctuation)
        text = ''.join(ch for ch in text if ch not in exclude)
        
        # Supprimer les chiffres
        digit = str.maketrans('', '', string.digits)
        text = text.translate(digit)
        
        return text

    def translate(self, text):

        print(text)
        pre_traited = self.preprocess(text)
        print(pre_traited)

        # Encodage du texte en tokens
        tokens = self.model.encode(pre_traited)
        
        # Utilisation de la méthode generate avec le paramètre beam
        translations = self.model.generate(tokens, beam=5)
        print(type(translations))
        print(translations[0])
        best_translation_tokens = [translations[i]['tokens'].tolist() for i in range(5)]

        # Décodage des tokens en traduction
        translations = [self.model.decode(best_translation_tokens[i]) for i in range(5)]
        
        return "\n".join(translations)