File size: 4,326 Bytes
6249ec3 5e7f6d8 6249ec3 5e7f6d8 6249ec3 5e7f6d8 b6ff77b 79e850c 6249ec3 b6ff77b 6249ec3 9d470b0 6249ec3 b6ff77b 6249ec3 b6ff77b 6249ec3 79e850c 9d470b0 b6ff77b 6249ec3 b6ff77b 5e7f6d8 6249ec3 2f8ca30 6249ec3 b6ff77b 2f8ca30 b6ff77b 5e7f6d8 b6ff77b 5e7f6d8 b6ff77b 5e7f6d8 2f8ca30 5e7f6d8 2f8ca30 b6ff77b 2f8ca30 b6ff77b 5e7f6d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import torch
import pandas as pd
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
class EndpointHandler:
def __init__(self, model_dir):
self.tokenizer = MBart50TokenizerFast.from_pretrained(model_dir)
self.model = MBartForConditionalGeneration.from_pretrained(model_dir)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
# Définition des tags disponibles
self.tags = [
"[AFRICA]", "[EASTERN_EUROPE]", "[TRANSLATION]", "[FAR_RIGHT_EUROPE]",
"[CONSPIRACY_LIKELIHOOD]", "[UNITED_STATES]", "[UKRAINE]",
"[RUSSIAN_OPPOSITION]", "[OCCIDENTAL_VALUES]", "[ORGANIZATION]",
"[EUROPEAN_UNION]", "[PROPAGANDA]", "[CENTRAL_EUROPE]", "[COUNTRY]",
"[NATO]", "[HISTORICAL_REVISIONISM]", "[BRICS]", "[TOPIC_LIST]",
"[TOPIC_DETERMINISTIC]", "[BALTIC_STATES]", "[RUSSIAN_PARAMILITARY]",
"[ANTI_GLOBALISM]", "[MIDDLE_EAST]", "[NER]", "[SUMMARY]",
"[DEHUMANIZATION]"
]
def query_with_tags(self, text, tags):
# Créer un DataFrame temporaire avec le texte
temp_df = pd.DataFrame([{"text": text}])
try:
# Appeler la fonction init.query_with_df (à implémenter selon votre init)
result_df = self.init.query_with_df(df=temp_df, tags=tags)
return result_df
except Exception as e:
return pd.DataFrame([{"error": str(e)}])
def process_single_text(self, text, tags=None):
try:
# Si des tags sont fournis, traiter d'abord avec les tags
if tags:
tagged_result = self.query_with_tags(text, tags)
if "error" in tagged_result:
return [{"error": tagged_result["error"].iloc[0]}]
# Utiliser le résultat du traitement des tags comme entrée pour la traduction
text = tagged_result["text"].iloc[0]
# Configuration de la langue source
self.tokenizer.src_lang = "ru_RU"
# Tokenization
inputs = self.tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024
).to(self.device)
# Génération
with torch.no_grad():
generated_tokens = self.model.generate(
**inputs,
forced_bos_token_id=self.tokenizer.lang_code_to_id["en_XX"],
max_length=1024,
num_beams=4,
length_penalty=1.0,
do_sample=False
)
# Décodage
translation = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
return [{"output": translation}]
except Exception as e:
return [{"error": str(e)}]
def __call__(self, data):
try:
if "inputs" not in data:
return [{"error": "Request must contain 'inputs' field"}]
inputs = data["inputs"]
tags = data.get("tags", None) # Récupérer les tags s'ils sont fournis
# Validation des tags
if tags:
invalid_tags = [tag for tag in tags if tag not in self.tags]
if invalid_tags:
return [{"error": f"Invalid tags: {invalid_tags}"}]
# Traitement de l'entrée
if isinstance(inputs, str):
return self.process_single_text(inputs, tags)
elif isinstance(inputs, list) and len(inputs) > 0:
if isinstance(inputs[0], dict) and "input" in inputs[0]:
return self.process_single_text(inputs[0]["input"], tags)
else:
return [{"error": "Invalid input format"}]
else:
return [{"error": "Invalid input format"}]
except Exception as e:
return [{"error": str(e)}]
# Exemple d'utilisation
def init_handler(model_dir):
handler = EndpointHandler(model_dir)
return handler |