tykiww commited on
Commit
0fb279d
·
verified ·
1 Parent(s): 71c7288

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +72 -0
handler.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utilities.setup import *
2
+
3
+ import json
4
+ import os
5
+
6
+ from typing import Dict, List, Any
7
+ from peft import AutoPeftModelForCausalLM
8
+ from transformers import AutoTokenizer
9
+
10
+
11
+ class EndpointHandler():
12
+ def __init__(self, path=""):
13
+ """Initialize class. Load model of interest upon init."""
14
+
15
+
16
+ print("Reading config")
17
+ self.path = path
18
+ self.HF_TOKEN = os.getenv("HF_TOKEN")
19
+ self.wd = os.getcwd()
20
+ self.model_name = os.path.basename(self.wd)
21
+
22
+ print("loading model")
23
+ self.model, self.tokenizer = self.load_model()
24
+
25
+
26
+ def load_model(self):
27
+ """Load unsloth model and tokenizer"""
28
+
29
+ model = AutoPeftModelForCausalLM.from_pretrained(
30
+ self.path,
31
+ load_in_4bit = True,
32
+ )
33
+ tokenizer = AutoTokenizer.from_pretrained(self.path)
34
+
35
+ return model, tokenizer
36
+
37
+
38
+ def prompt_formatter(self, prompt):
39
+ """Prompts must be formatted in alpaca style prior to API."""
40
+ inputs = self.tokenizer([prompt], return_tensors = "pt").to("cuda")
41
+
42
+ return inputs, prompt
43
+
44
+
45
+ def infer(self, prompt, max_new_tokens=1000): # add streaming capability
46
+ """Bringing it all together"""
47
+ # load model
48
+ inputs, prompt_text = self.prompt_formatter(prompt)
49
+ outputs = self.model.generate(**inputs,
50
+ max_new_tokens = max_new_tokens,
51
+ use_cache=True)
52
+ completion = self.tokenizer.batch_decode(outputs)
53
+
54
+ return completion
55
+
56
+
57
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
58
+ """
59
+ data args:
60
+ inputs (:obj: `str`)
61
+ kwargs
62
+ Return:
63
+ A :obj:`list` | `dict`: will be serialized and returned
64
+ """
65
+
66
+ if data["inputs"] is not None:
67
+ request = data['inputs']
68
+
69
+ prediction = self.infer(request)
70
+ return {"prediction": prediction}
71
+ else:
72
+ return [{"Error" : "no input received."}]