Update handler.py
Browse files- handler.py +51 -6
handler.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import torch
|
|
|
2 |
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
|
3 |
|
4 |
class EndpointHandler:
|
@@ -8,8 +9,39 @@ class EndpointHandler:
|
|
8 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
9 |
self.model.to(self.device)
|
10 |
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
# Configuration de la langue source
|
14 |
self.tokenizer.src_lang = "ru_RU"
|
15 |
|
@@ -35,7 +67,8 @@ class EndpointHandler:
|
|
35 |
|
36 |
# Décodage
|
37 |
translation = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
|
38 |
-
|
|
|
39 |
|
40 |
except Exception as e:
|
41 |
return [{"error": str(e)}]
|
@@ -46,17 +79,29 @@ class EndpointHandler:
|
|
46 |
return [{"error": "Request must contain 'inputs' field"}]
|
47 |
|
48 |
inputs = data["inputs"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
-
#
|
51 |
if isinstance(inputs, str):
|
52 |
-
return self.process_single_text(inputs)
|
53 |
elif isinstance(inputs, list) and len(inputs) > 0:
|
54 |
if isinstance(inputs[0], dict) and "input" in inputs[0]:
|
55 |
-
return self.process_single_text(inputs[0]["input"])
|
56 |
else:
|
57 |
return [{"error": "Invalid input format"}]
|
58 |
else:
|
59 |
return [{"error": "Invalid input format"}]
|
60 |
|
61 |
except Exception as e:
|
62 |
-
return [{"error": str(e)}]
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
+
import pandas as pd
|
3 |
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
|
4 |
|
5 |
class EndpointHandler:
|
|
|
9 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
10 |
self.model.to(self.device)
|
11 |
|
12 |
+
# Définition des tags disponibles
|
13 |
+
self.tags = [
|
14 |
+
"[AFRICA]", "[EASTERN_EUROPE]", "[TRANSLATION]", "[FAR_RIGHT_EUROPE]",
|
15 |
+
"[CONSPIRACY_LIKELIHOOD]", "[UNITED_STATES]", "[UKRAINE]",
|
16 |
+
"[RUSSIAN_OPPOSITION]", "[OCCIDENTAL_VALUES]", "[ORGANIZATION]",
|
17 |
+
"[EUROPEAN_UNION]", "[PROPAGANDA]", "[CENTRAL_EUROPE]", "[COUNTRY]",
|
18 |
+
"[NATO]", "[HISTORICAL_REVISIONISM]", "[BRICS]", "[TOPIC_LIST]",
|
19 |
+
"[TOPIC_DETERMINISTIC]", "[BALTIC_STATES]", "[RUSSIAN_PARAMILITARY]",
|
20 |
+
"[ANTI_GLOBALISM]", "[MIDDLE_EAST]", "[NER]", "[SUMMARY]",
|
21 |
+
"[DEHUMANIZATION]"
|
22 |
+
]
|
23 |
+
|
24 |
+
def query_with_tags(self, text, tags):
|
25 |
+
# Créer un DataFrame temporaire avec le texte
|
26 |
+
temp_df = pd.DataFrame([{"text": text}])
|
27 |
+
|
28 |
+
try:
|
29 |
+
# Appeler la fonction init.query_with_df (à implémenter selon votre init)
|
30 |
+
result_df = self.init.query_with_df(df=temp_df, tags=tags)
|
31 |
+
return result_df
|
32 |
+
except Exception as e:
|
33 |
+
return pd.DataFrame([{"error": str(e)}])
|
34 |
+
|
35 |
+
def process_single_text(self, text, tags=None):
|
36 |
try:
|
37 |
+
# Si des tags sont fournis, traiter d'abord avec les tags
|
38 |
+
if tags:
|
39 |
+
tagged_result = self.query_with_tags(text, tags)
|
40 |
+
if "error" in tagged_result:
|
41 |
+
return [{"error": tagged_result["error"].iloc[0]}]
|
42 |
+
# Utiliser le résultat du traitement des tags comme entrée pour la traduction
|
43 |
+
text = tagged_result["text"].iloc[0]
|
44 |
+
|
45 |
# Configuration de la langue source
|
46 |
self.tokenizer.src_lang = "ru_RU"
|
47 |
|
|
|
67 |
|
68 |
# Décodage
|
69 |
translation = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
|
70 |
+
|
71 |
+
return [{"output": translation}]
|
72 |
|
73 |
except Exception as e:
|
74 |
return [{"error": str(e)}]
|
|
|
79 |
return [{"error": "Request must contain 'inputs' field"}]
|
80 |
|
81 |
inputs = data["inputs"]
|
82 |
+
tags = data.get("tags", None) # Récupérer les tags s'ils sont fournis
|
83 |
+
|
84 |
+
# Validation des tags
|
85 |
+
if tags:
|
86 |
+
invalid_tags = [tag for tag in tags if tag not in self.tags]
|
87 |
+
if invalid_tags:
|
88 |
+
return [{"error": f"Invalid tags: {invalid_tags}"}]
|
89 |
|
90 |
+
# Traitement de l'entrée
|
91 |
if isinstance(inputs, str):
|
92 |
+
return self.process_single_text(inputs, tags)
|
93 |
elif isinstance(inputs, list) and len(inputs) > 0:
|
94 |
if isinstance(inputs[0], dict) and "input" in inputs[0]:
|
95 |
+
return self.process_single_text(inputs[0]["input"], tags)
|
96 |
else:
|
97 |
return [{"error": "Invalid input format"}]
|
98 |
else:
|
99 |
return [{"error": "Invalid input format"}]
|
100 |
|
101 |
except Exception as e:
|
102 |
+
return [{"error": str(e)}]
|
103 |
+
|
104 |
+
# Exemple d'utilisation
|
105 |
+
def init_handler(model_dir):
|
106 |
+
handler = EndpointHandler(model_dir)
|
107 |
+
return handler
|