lambda_hf_v2 / utils /utils_function.py
FerdinandPyCode's picture
new setting and test of our model
394de27
raw
history blame
2.14 kB
from fairseq.models.transformer import TransformerModel
import torch
import re
import string
class Translator:
def __init__(self, isFon:bool, device='cuda' if torch.cuda.is_available() else 'cpu'):
# Charger le modèle pré-entraîné avec Fairseq
inner = "fon_fr" if isFon else "fr_fon"
self.model = TransformerModel.from_pretrained(
f'./utils/checkpoints/{inner}',
checkpoint_file = 'checkpoint_best.pt',
data_name_or_path = f'utils/datas/data_prepared_{inner}/',
source_lang='fon' if isFon else 'fr',
target_lang='fr' if isFon else 'fon'
)
# Définir le périphérique sur lequel exécuter le modèle (par défaut sur 'cuda' si disponible)
self.model.to(device)
# Mettre le modèle en mode évaluation (pas de mise à jour des poids)
self.model.eval()
def preprocess(self, data):
print('Preprocessing...')
# Convertir chaque lettre en minuscule
text = data.lower().strip()
# Supprimer les apostrophes des phrases
text = re.sub("'", "", text)
# Supprimer toute ponctuation
exclude = set(string.punctuation)
text = ''.join(ch for ch in text if ch not in exclude)
# Supprimer les chiffres
digit = str.maketrans('', '', string.digits)
text = text.translate(digit)
return text
def translate(self, text):
print(text)
pre_traited = self.preprocess(text)
print(pre_traited)
# Encodage du texte en tokens
tokens = self.model.encode(pre_traited)
# Utilisation de la méthode generate avec le paramètre beam
translations = self.model.generate(tokens, beam=5)
print(type(translations))
print(translations[0])
best_translation_tokens = [translations[i]['tokens'].tolist() for i in range(5)]
# Décodage des tokens en traduction
translations = [self.model.decode(best_translation_tokens[i]) for i in range(5)]
return "\n".join(translations)