File size: 1,476 Bytes
113d0af
 
 
828d66c
113d0af
 
 
 
 
 
 
828d66c
 
 
 
 
 
 
 
 
113d0af
 
aa59070
 
113d0af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from fairseq.models.transformer import TransformerModel
import os
import torch
import zipfile

class Translator:
    def __init__(self, isFon:bool, device='cuda' if torch.cuda.is_available() else 'cpu'):

        # Charger le modèle pré-entraîné avec Fairseq
        inner = "fon_fr" if isFon else "fr_fon"

        if not os.path.exists('utils/data_prepared/'):
            print("Not existed")

            with zipfile.ZipFile('utils/data_prepared.zip', 'r') as zip_ref:
                zip_ref.extractall('utils/')
                
        else:
            print("Existed")

        self.model = TransformerModel.from_pretrained(
            './utils/checkpoints/fon_fr',
            checkpoint_file='checkpoint_best.pt',
            data_name_or_path='utils/data_prepared/',
            source_lang='fon',
            target_lang='fr'
        )

        print("#########################")
        print(type(self.model))
        print("#########################")
        # Définir le périphérique sur lequel exécuter le modèle (par défaut sur 'cuda' si disponible)
        self.model.to(device)
        
        # Mettre le modèle en mode évaluation (pas de mise à jour des poids)
        self.model.eval()

    def translate(self, text):
        # Encodage du texte en tokens
        tokens = self.model.encode(text)
        
        # Décodage des tokens en traduction
        translation = self.model.decode(tokens)
        
        return translation