sijieaaa commited on
Commit
a3dfe34
·
verified ·
1 Parent(s): 083213c

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +93 -0
handler.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from typing import Dict, List, Any
4
+ # from transformers import pipeline
5
+ # import holidays
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+
8
+
9
+
10
+ class EndpointHandler():
11
+ def __init__(self, path=None):
12
+ # self.pipeline = pipeline("text-classification",model=path)
13
+ # self.holidays = holidays.US()
14
+ model_id = 'sijieaaa/CodeModel-V1-3B-2024-02-07'
15
+ self.model = AutoModelForCausalLM.from_pretrained(
16
+ model_id,
17
+ # load_in_8bit=True,
18
+ torch_dtype="auto",
19
+ device_map="auto"
20
+ )
21
+ self.tokenizer = AutoTokenizer.from_pretrained(
22
+ model_id
23
+ )
24
+ self.model.eval()
25
+ # self.tokenizer.eval()
26
+ # llm = vllm.LLM(model=model_id,
27
+ # dtype=torch.bfloat16,
28
+ # trust_remote_code=True,
29
+ # quantization="bitsandbytes",
30
+ # load_format="bitsandbytes")
31
+ a=1
32
+
33
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
34
+ """
35
+ data args:
36
+ inputs (:obj: `str`)
37
+ date (:obj: `str`)
38
+ Return:
39
+ A :obj:`list` | `dict`: will be serialized and returned
40
+ """
41
+ prompt = data["inputs"]
42
+
43
+ messages = [
44
+ {"role": "system", "content": "You are a helpful assistant."},
45
+ {"role": "user", "content": prompt}
46
+ ]
47
+ text = self.tokenizer.apply_chat_template(
48
+ messages,
49
+ tokenize=False,
50
+ add_generation_prompt=True
51
+ )
52
+ model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
53
+
54
+ generated_ids = self.model.generate(
55
+ **model_inputs,
56
+ max_new_tokens=512
57
+ )
58
+ generated_ids = [
59
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
60
+ ]
61
+
62
+ response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
63
+
64
+ response = [
65
+ {"role": "assistant", "content": response}
66
+ ]
67
+
68
+ # yield response
69
+ return response
70
+
71
+
72
+ # def test():
73
+ # # init handler
74
+ # my_handler = EndpointHandler(path=".")
75
+
76
+ # # prepare sample payload
77
+ # non_holiday_payload = {"inputs": "I am quite excited how this will turn out", "date": "2022-08-08"}
78
+ # holiday_payload = {"inputs": "Today is a though day", "date": "2022-07-04"}
79
+
80
+ # # test the handler
81
+ # a = my_handler.__call__(non_holiday_payload)
82
+ # non_holiday_pred=my_handler(non_holiday_payload)
83
+ # holiday_payload=my_handler(holiday_payload)
84
+
85
+ # # show results
86
+ # print("non_holiday_pred", non_holiday_pred)
87
+ # print("holiday_payload", holiday_payload)
88
+
89
+ # a=1
90
+
91
+
92
+ # if __name__ == "__main__":
93
+ # test()