damienliccia commited on
Commit
b6ff77b
·
verified ·
1 Parent(s): 79e850c

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +40 -31
handler.py CHANGED
@@ -7,55 +7,64 @@ class EndpointHandler:
7
  self.model = MBartForConditionalGeneration.from_pretrained(model_dir)
8
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
  self.model.to(self.device)
10
- self.max_length = 1024
11
 
12
- def _validate_input(self, inputs):
13
- if isinstance(inputs, str):
14
- return [inputs]
15
- elif isinstance(inputs, list) and all(isinstance(item, dict) and "input" in item for item in inputs):
16
- return [item["input"] for item in inputs]
17
- raise ValueError("Input must be a string or a list of dictionaries with 'input' key")
18
-
19
- def process(self, inputs):
20
  try:
21
- # Validation et préparation
22
- texts = self._validate_input(inputs)
23
-
24
- # Configuration explicite des langues source et cible
25
  self.tokenizer.src_lang = "ru_RU"
26
 
27
  # Tokenization
28
  inputs = self.tokenizer(
29
- texts,
30
  return_tensors="pt",
31
  padding=True,
32
  truncation=True,
33
- max_length=self.max_length
34
  ).to(self.device)
35
 
36
- # Inférence avec langue cible explicite
37
  with torch.no_grad():
38
- outputs = self.model.generate(
39
  **inputs,
40
  forced_bos_token_id=self.tokenizer.lang_code_to_id["en_XX"],
41
- max_length=self.max_length,
42
- num_beams=5,
 
43
  do_sample=False
44
  )
45
 
46
- # Post-traitement
47
- translations = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
48
-
49
- # Retourne à la fois le texte brut et le JSON
50
- text_output = translations[0] if translations else ""
51
- json_output = [{"output": translation} for translation in translations]
52
-
53
- return text_output, json_output
54
 
55
  except Exception as e:
56
- return "", [{"error": str(e)}]
57
 
58
  def __call__(self, data):
59
- if not isinstance(data, dict) or "inputs" not in data:
60
- return "", [{"error": "Request must contain 'inputs' field"}]
61
- return self.process(data["inputs"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  self.model = MBartForConditionalGeneration.from_pretrained(model_dir)
8
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
  self.model.to(self.device)
 
10
 
11
+ def process_single_text(self, text):
 
 
 
 
 
 
 
12
  try:
13
+ # Configuration de la langue source
 
 
 
14
  self.tokenizer.src_lang = "ru_RU"
15
 
16
  # Tokenization
17
  inputs = self.tokenizer(
18
+ text,
19
  return_tensors="pt",
20
  padding=True,
21
  truncation=True,
22
+ max_length=512 # Réduit pour plus d'efficacité
23
  ).to(self.device)
24
 
25
+ # Génération
26
  with torch.no_grad():
27
+ generated_tokens = self.model.generate(
28
  **inputs,
29
  forced_bos_token_id=self.tokenizer.lang_code_to_id["en_XX"],
30
+ max_length=512,
31
+ num_beams=4,
32
+ length_penalty=1.0,
33
  do_sample=False
34
  )
35
 
36
+ # Décodage
37
+ translation = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
38
+ return translation
 
 
 
 
 
39
 
40
  except Exception as e:
41
+ return f"Error: {str(e)}"
42
 
43
  def __call__(self, data):
44
+ try:
45
+ if "inputs" not in data:
46
+ raise ValueError("Request must contain 'inputs' field")
47
+
48
+ inputs = data["inputs"]
49
+
50
+ # Si l'entrée est une chaîne de caractères
51
+ if isinstance(inputs, str):
52
+ translation = self.process_single_text(inputs)
53
+ return translation # Retourne directement la traduction comme texte
54
+
55
+ # Si l'entrée est une liste
56
+ elif isinstance(inputs, list):
57
+ translations = []
58
+ for item in inputs:
59
+ if isinstance(item, dict) and "input" in item:
60
+ translation = self.process_single_text(item["input"])
61
+ translations.append({"output": translation})
62
+ else:
63
+ translations.append({"error": "Invalid input format"})
64
+ return translations
65
+
66
+ else:
67
+ raise ValueError("Invalid input format")
68
+
69
+ except Exception as e:
70
+ return str(e) # Retourne l'erreur comme texte