lambda_hf_v2 / utils /utils_function.py
FerdinandPyCode's picture
well done now
d29c54a
raw
history blame
1.88 kB
from fairseq.models.transformer import TransformerModel
import os
import torch
import zipfile
import shutil
class Translator:
def __init__(self, isFon:bool, device='cuda' if torch.cuda.is_available() else 'cpu'):
# Charger le modèle pré-entraîné avec Fairseq
inner = "fon_fr" if isFon else "fr_fon"
# if not os.path.exists('utils/data_prepared/'):
# print("Not existed")
# shutil.chmod('utils/', 0o777)
# with zipfile.ZipFile('utils/data_prepared.zip', 'r') as zip_ref:
# zip_ref.extractall('utils/')
# else:
# print("Existed")
self.model = TransformerModel.from_pretrained(
'./utils/checkpoints/fon_fr',
checkpoint_file='checkpoint_best.pt',
data_name_or_path='utils/data_prepared/',
source_lang='fon',
target_lang='fr'
)
print("#########################")
print(type(self.model))
print("#########################")
# Définir le périphérique sur lequel exécuter le modèle (par défaut sur 'cuda' si disponible)
self.model.to(device)
# Mettre le modèle en mode évaluation (pas de mise à jour des poids)
self.model.eval()
def translate(self, text):
# Encodage du texte en tokens
tokens = self.model.encode(text)
# Utilisation de la méthode generate avec le paramètre beam
translations = self.model.generate(tokens, beam=5)
print(type(translations))
print(translations[0])
best_translation_tokens = [translations[i]['tokens'].tolist() for i in range(5)]
# Décodage des tokens en traduction
translations = [self.model.decode(best_translation_tokens[i]) for i in range(5)]
return "\n".join(translations)