Aktraiser commited on
Commit
79bdc18
1 Parent(s): b1cbdc4

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +57 -0
handler.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline
2
+ import torch
3
+
4
+ def load_model(model_id):
5
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
6
+ model = AutoModelForCausalLM.from_pretrained(
7
+ model_id,
8
+ device_map="auto",
9
+ torch_dtype=torch.float16,
10
+ load_in_4bit=True
11
+ )
12
+ return model, tokenizer
13
+
14
+ class EndpointHandler:
15
+ def __init__(self, path=""):
16
+ self.model, self.tokenizer = load_model(path)
17
+ self.pipeline = TextGenerationPipeline(
18
+ model=self.model,
19
+ tokenizer=self.tokenizer
20
+ )
21
+
22
+ def __call__(self, data):
23
+ # Extraire le texte d'entrée
24
+ if isinstance(data, dict):
25
+ text = data.get("inputs", "")
26
+ else:
27
+ text = data
28
+
29
+ # Paramètres de génération par défaut
30
+ generation_kwargs = {
31
+ "max_new_tokens": 512,
32
+ "temperature": 0.7,
33
+ "top_p": 0.95,
34
+ "repetition_penalty": 1.15,
35
+ "do_sample": True,
36
+ "pad_token_id": self.tokenizer.pad_token_id,
37
+ "eos_token_id": self.tokenizer.eos_token_id,
38
+ }
39
+
40
+ # Mettre à jour avec les paramètres de la requête si fournis
41
+ if isinstance(data, dict) and "parameters" in data:
42
+ generation_kwargs.update(data["parameters"])
43
+
44
+ try:
45
+ # Générer la réponse
46
+ outputs = self.pipeline(
47
+ text,
48
+ **generation_kwargs
49
+ )
50
+
51
+ # Formater la sortie en tableau comme requis par l'API
52
+ if isinstance(outputs, list):
53
+ return [{"generated_text": output["generated_text"]} for output in outputs]
54
+ return [{"generated_text": outputs["generated_text"]}]
55
+
56
+ except Exception as e:
57
+ return [{"error": str(e)}]