Aktraiser commited on
Commit
2a70625
·
verified ·
1 Parent(s): 421ac56

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +28 -27
handler.py CHANGED
@@ -16,35 +16,36 @@ class EndpointHandler:
16
  self.model, self.tokenizer = load_model(path)
17
  self.pipeline = TextGenerationPipeline(
18
  model=self.model,
19
- tokenizer=self.tokenizer,
20
- max_new_tokens=512,
21
- temperature=0.7,
22
- top_p=0.95,
23
- repetition_penalty=1.15,
24
- do_sample=True
25
  )
26
 
27
  def __call__(self, data):
28
- inputs = data.pop("inputs", data)
29
- parameters = data.pop("parameters", {})
30
-
31
- generation_kwargs = {
32
- "max_new_tokens": 512,
33
- "temperature": 0.7,
34
- "top_p": 0.95,
35
- "repetition_penalty": 1.15,
36
- "do_sample": True
 
 
 
 
37
  }
38
- generation_kwargs.update(parameters)
39
-
40
- if isinstance(inputs, str):
41
- inputs = [inputs]
 
 
 
42
 
43
- outputs = self.pipeline(
44
- inputs,
45
- **generation_kwargs
46
- )
47
-
48
- if len(outputs) == 1:
49
- return {"generated_text": outputs[0]["generated_text"]}
50
- return [{"generated_text": o["generated_text"]} for o in outputs]
 
16
  self.model, self.tokenizer = load_model(path)
17
  self.pipeline = TextGenerationPipeline(
18
  model=self.model,
19
+ tokenizer=self.tokenizer
 
 
 
 
 
20
  )
21
 
22
  def __call__(self, data):
23
+ # Extraire le texte d'entrée
24
+ if isinstance(data, dict):
25
+ text = data.pop("inputs", "")
26
+ else:
27
+ text = data
28
+
29
+ # Paramètres de génération
30
+ params = {
31
+ "max_new_tokens": data.get("max_new_tokens", 512),
32
+ "temperature": data.get("temperature", 0.7),
33
+ "top_p": data.get("top_p", 0.95),
34
+ "repetition_penalty": data.get("repetition_penalty", 1.15),
35
+ "do_sample": data.get("do_sample", True)
36
  }
37
+
38
+ try:
39
+ # Générer le texte
40
+ result = self.pipeline(
41
+ text,
42
+ **params
43
+ )
44
 
45
+ # Formater la sortie
46
+ if isinstance(result, list):
47
+ return {"generated_text": result[0]["generated_text"]}
48
+ return {"generated_text": result["generated_text"]}
49
+
50
+ except Exception as e:
51
+ return {"error": str(e)}