damienliccia's picture
Update handler.py
9d470b0 verified
import torch
import pandas as pd
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
class EndpointHandler:
def __init__(self, model_dir):
self.tokenizer = MBart50TokenizerFast.from_pretrained(model_dir)
self.model = MBartForConditionalGeneration.from_pretrained(model_dir)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
# Définition des tags disponibles
self.tags = [
"[AFRICA]", "[EASTERN_EUROPE]", "[TRANSLATION]", "[FAR_RIGHT_EUROPE]",
"[CONSPIRACY_LIKELIHOOD]", "[UNITED_STATES]", "[UKRAINE]",
"[RUSSIAN_OPPOSITION]", "[OCCIDENTAL_VALUES]", "[ORGANIZATION]",
"[EUROPEAN_UNION]", "[PROPAGANDA]", "[CENTRAL_EUROPE]", "[COUNTRY]",
"[NATO]", "[HISTORICAL_REVISIONISM]", "[BRICS]", "[TOPIC_LIST]",
"[TOPIC_DETERMINISTIC]", "[BALTIC_STATES]", "[RUSSIAN_PARAMILITARY]",
"[ANTI_GLOBALISM]", "[MIDDLE_EAST]", "[NER]", "[SUMMARY]",
"[DEHUMANIZATION]"
]
def query_with_tags(self, text, tags):
# Créer un DataFrame temporaire avec le texte
temp_df = pd.DataFrame([{"text": text}])
try:
# Appeler la fonction init.query_with_df (à implémenter selon votre init)
result_df = self.init.query_with_df(df=temp_df, tags=tags)
return result_df
except Exception as e:
return pd.DataFrame([{"error": str(e)}])
def process_single_text(self, text, tags=None):
try:
# Si des tags sont fournis, traiter d'abord avec les tags
if tags:
tagged_result = self.query_with_tags(text, tags)
if "error" in tagged_result:
return [{"error": tagged_result["error"].iloc[0]}]
# Utiliser le résultat du traitement des tags comme entrée pour la traduction
text = tagged_result["text"].iloc[0]
# Configuration de la langue source
self.tokenizer.src_lang = "ru_RU"
# Tokenization
inputs = self.tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024
).to(self.device)
# Génération
with torch.no_grad():
generated_tokens = self.model.generate(
**inputs,
forced_bos_token_id=self.tokenizer.lang_code_to_id["en_XX"],
max_length=1024,
num_beams=4,
length_penalty=1.0,
do_sample=False
)
# Décodage
translation = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
return [{"output": translation}]
except Exception as e:
return [{"error": str(e)}]
def __call__(self, data):
try:
if "inputs" not in data:
return [{"error": "Request must contain 'inputs' field"}]
inputs = data["inputs"]
tags = data.get("tags", None) # Récupérer les tags s'ils sont fournis
# Validation des tags
if tags:
invalid_tags = [tag for tag in tags if tag not in self.tags]
if invalid_tags:
return [{"error": f"Invalid tags: {invalid_tags}"}]
# Traitement de l'entrée
if isinstance(inputs, str):
return self.process_single_text(inputs, tags)
elif isinstance(inputs, list) and len(inputs) > 0:
if isinstance(inputs[0], dict) and "input" in inputs[0]:
return self.process_single_text(inputs[0]["input"], tags)
else:
return [{"error": "Invalid input format"}]
else:
return [{"error": "Invalid input format"}]
except Exception as e:
return [{"error": str(e)}]
# Exemple d'utilisation
def init_handler(model_dir):
handler = EndpointHandler(model_dir)
return handler