damienliccia commited on
Commit
5e7f6d8
·
verified ·
1 Parent(s): 2f8ca30

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +51 -6
handler.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
3
 
4
  class EndpointHandler:
@@ -8,8 +9,39 @@ class EndpointHandler:
8
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
  self.model.to(self.device)
10
 
11
- def process_single_text(self, text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  try:
 
 
 
 
 
 
 
 
13
  # Configuration de la langue source
14
  self.tokenizer.src_lang = "ru_RU"
15
 
@@ -35,7 +67,8 @@ class EndpointHandler:
35
 
36
  # Décodage
37
  translation = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
38
- return [{"translation": translation}]
 
39
 
40
  except Exception as e:
41
  return [{"error": str(e)}]
@@ -46,17 +79,29 @@ class EndpointHandler:
46
  return [{"error": "Request must contain 'inputs' field"}]
47
 
48
  inputs = data["inputs"]
 
 
 
 
 
 
 
49
 
50
- # Que ce soit une chaîne ou une liste, on traite comme une seule entrée
51
  if isinstance(inputs, str):
52
- return self.process_single_text(inputs)
53
  elif isinstance(inputs, list) and len(inputs) > 0:
54
  if isinstance(inputs[0], dict) and "input" in inputs[0]:
55
- return self.process_single_text(inputs[0]["input"])
56
  else:
57
  return [{"error": "Invalid input format"}]
58
  else:
59
  return [{"error": "Invalid input format"}]
60
 
61
  except Exception as e:
62
- return [{"error": str(e)}]
 
 
 
 
 
 
1
  import torch
2
+ import pandas as pd
3
  from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
4
 
5
  class EndpointHandler:
 
9
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
  self.model.to(self.device)
11
 
12
+ # Définition des tags disponibles
13
+ self.tags = [
14
+ "[AFRICA]", "[EASTERN_EUROPE]", "[TRANSLATION]", "[FAR_RIGHT_EUROPE]",
15
+ "[CONSPIRACY_LIKELIHOOD]", "[UNITED_STATES]", "[UKRAINE]",
16
+ "[RUSSIAN_OPPOSITION]", "[OCCIDENTAL_VALUES]", "[ORGANIZATION]",
17
+ "[EUROPEAN_UNION]", "[PROPAGANDA]", "[CENTRAL_EUROPE]", "[COUNTRY]",
18
+ "[NATO]", "[HISTORICAL_REVISIONISM]", "[BRICS]", "[TOPIC_LIST]",
19
+ "[TOPIC_DETERMINISTIC]", "[BALTIC_STATES]", "[RUSSIAN_PARAMILITARY]",
20
+ "[ANTI_GLOBALISM]", "[MIDDLE_EAST]", "[NER]", "[SUMMARY]",
21
+ "[DEHUMANIZATION]"
22
+ ]
23
+
24
+ def query_with_tags(self, text, tags):
25
+ # Créer un DataFrame temporaire avec le texte
26
+ temp_df = pd.DataFrame([{"text": text}])
27
+
28
+ try:
29
+ # Appeler la fonction init.query_with_df (à implémenter selon votre init)
30
+ result_df = self.init.query_with_df(df=temp_df, tags=tags)
31
+ return result_df
32
+ except Exception as e:
33
+ return pd.DataFrame([{"error": str(e)}])
34
+
35
+ def process_single_text(self, text, tags=None):
36
  try:
37
+ # Si des tags sont fournis, traiter d'abord avec les tags
38
+ if tags:
39
+ tagged_result = self.query_with_tags(text, tags)
40
+ if "error" in tagged_result:
41
+ return [{"error": tagged_result["error"].iloc[0]}]
42
+ # Utiliser le résultat du traitement des tags comme entrée pour la traduction
43
+ text = tagged_result["text"].iloc[0]
44
+
45
  # Configuration de la langue source
46
  self.tokenizer.src_lang = "ru_RU"
47
 
 
67
 
68
  # Décodage
69
  translation = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
70
+
71
+ return [{"output": translation}]
72
 
73
  except Exception as e:
74
  return [{"error": str(e)}]
 
79
  return [{"error": "Request must contain 'inputs' field"}]
80
 
81
  inputs = data["inputs"]
82
+ tags = data.get("tags", None) # Récupérer les tags s'ils sont fournis
83
+
84
+ # Validation des tags
85
+ if tags:
86
+ invalid_tags = [tag for tag in tags if tag not in self.tags]
87
+ if invalid_tags:
88
+ return [{"error": f"Invalid tags: {invalid_tags}"}]
89
 
90
+ # Traitement de l'entrée
91
  if isinstance(inputs, str):
92
+ return self.process_single_text(inputs, tags)
93
  elif isinstance(inputs, list) and len(inputs) > 0:
94
  if isinstance(inputs[0], dict) and "input" in inputs[0]:
95
+ return self.process_single_text(inputs[0]["input"], tags)
96
  else:
97
  return [{"error": "Invalid input format"}]
98
  else:
99
  return [{"error": "Invalid input format"}]
100
 
101
  except Exception as e:
102
+ return [{"error": str(e)}]
103
+
104
+ # Exemple d'utilisation
105
+ def init_handler(model_dir):
106
+ handler = EndpointHandler(model_dir)
107
+ return handler