damienliccia commited on
Commit
2f8ca30
·
verified ·
1 Parent(s): b6ff77b

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +13 -21
handler.py CHANGED
@@ -19,7 +19,7 @@ class EndpointHandler:
19
  return_tensors="pt",
20
  padding=True,
21
  truncation=True,
22
- max_length=512 # Réduit pour plus d'efficacité
23
  ).to(self.device)
24
 
25
  # Génération
@@ -35,36 +35,28 @@ class EndpointHandler:
35
 
36
  # Décodage
37
  translation = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
38
- return translation
39
 
40
  except Exception as e:
41
- return f"Error: {str(e)}"
42
 
43
  def __call__(self, data):
44
  try:
45
  if "inputs" not in data:
46
- raise ValueError("Request must contain 'inputs' field")
47
 
48
  inputs = data["inputs"]
49
 
50
- # Si l'entrée est une chaîne de caractères
51
  if isinstance(inputs, str):
52
- translation = self.process_single_text(inputs)
53
- return translation # Retourne directement la traduction comme texte
54
-
55
- # Si l'entrée est une liste
56
- elif isinstance(inputs, list):
57
- translations = []
58
- for item in inputs:
59
- if isinstance(item, dict) and "input" in item:
60
- translation = self.process_single_text(item["input"])
61
- translations.append({"output": translation})
62
- else:
63
- translations.append({"error": "Invalid input format"})
64
- return translations
65
-
66
  else:
67
- raise ValueError("Invalid input format")
68
 
69
  except Exception as e:
70
- return str(e) # Retourne l'erreur comme texte
 
19
  return_tensors="pt",
20
  padding=True,
21
  truncation=True,
22
+ max_length=512
23
  ).to(self.device)
24
 
25
  # Génération
 
35
 
36
  # Décodage
37
  translation = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
38
+ return [{"translation": translation}]
39
 
40
  except Exception as e:
41
+ return [{"error": str(e)}]
42
 
43
  def __call__(self, data):
44
  try:
45
  if "inputs" not in data:
46
+ return [{"error": "Request must contain 'inputs' field"}]
47
 
48
  inputs = data["inputs"]
49
 
50
+ # Que ce soit une chaîne ou une liste, on traite comme une seule entrée
51
  if isinstance(inputs, str):
52
+ return self.process_single_text(inputs)
53
+ elif isinstance(inputs, list) and len(inputs) > 0:
54
+ if isinstance(inputs[0], dict) and "input" in inputs[0]:
55
+ return self.process_single_text(inputs[0]["input"])
56
+ else:
57
+ return [{"error": "Invalid input format"}]
 
 
 
 
 
 
 
 
58
  else:
59
+ return [{"error": "Invalid input format"}]
60
 
61
  except Exception as e:
62
+ return [{"error": str(e)}]