Aktraiser commited on
Commit
421ac56
·
verified ·
1 Parent(s): 5d0a536

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +50 -0
handler.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline
2
+ import torch
3
+
4
+ def load_model(model_id):
5
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
6
+ model = AutoModelForCausalLM.from_pretrained(
7
+ model_id,
8
+ device_map="auto",
9
+ torch_dtype=torch.float16,
10
+ load_in_4bit=True
11
+ )
12
+ return model, tokenizer
13
+
14
+ class EndpointHandler:
15
+ def __init__(self, path=""):
16
+ self.model, self.tokenizer = load_model(path)
17
+ self.pipeline = TextGenerationPipeline(
18
+ model=self.model,
19
+ tokenizer=self.tokenizer,
20
+ max_new_tokens=512,
21
+ temperature=0.7,
22
+ top_p=0.95,
23
+ repetition_penalty=1.15,
24
+ do_sample=True
25
+ )
26
+
27
+ def __call__(self, data):
28
+ inputs = data.pop("inputs", data)
29
+ parameters = data.pop("parameters", {})
30
+
31
+ generation_kwargs = {
32
+ "max_new_tokens": 512,
33
+ "temperature": 0.7,
34
+ "top_p": 0.95,
35
+ "repetition_penalty": 1.15,
36
+ "do_sample": True
37
+ }
38
+ generation_kwargs.update(parameters)
39
+
40
+ if isinstance(inputs, str):
41
+ inputs = [inputs]
42
+
43
+ outputs = self.pipeline(
44
+ inputs,
45
+ **generation_kwargs
46
+ )
47
+
48
+ if len(outputs) == 1:
49
+ return {"generated_text": outputs[0]["generated_text"]}
50
+ return [{"generated_text": o["generated_text"]} for o in outputs]