njwright92 commited on
Commit
e5bbead
·
verified ·
1 Parent(s): 1cc0294

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +41 -51
handler.py CHANGED
@@ -1,52 +1,42 @@
1
- from ctransformers import AutoModelForCausalLM, AutoTokenizer
2
- from transformers import pipeline
3
  import json
4
-
5
- class EndpointHandler:
6
- def __init__(self, model_dir):
7
- self.model_dir = model_dir
8
- self.model = None
9
- self.tokenizer = None
10
- self.pipe = None
11
-
12
- def load_model(self):
13
- self.model = AutoModelForCausalLM.from_pretrained(
14
- f"{self.model_dir}/comic_mistral-v5.2.q5_0.gguf",
15
- model_type="mistral",
16
- lib="avx2",
17
- gpu_layers=0,
18
- hf=True
19
- )
20
- self.tokenizer = AutoTokenizer.from_pretrained(self.model)
21
-
22
- def preprocess(self, data):
23
- return data
24
-
25
- def __call__(self, data):
26
- if self.model is None or self.tokenizer is None:
27
- self.load_model()
28
-
29
- inputs = self.preprocess(data)
30
- prompt = inputs["inputs"]
31
-
32
- # Generate text using the model
33
- generated_text = ""
34
- for text in self.model(prompt,
35
- max_new_tokens=256,
36
- temperature=0.8,
37
- repetition_penalty=1.1,
38
- do_sample=True,
39
- stream=True):
40
- generated_text += text
41
-
42
- # Return a JSON-serializable response
43
- response = {"generated_text": generated_text}
44
- return json.dumps(response)
45
-
46
- def postprocess(self, data):
47
- return data
48
-
49
- def get_handler(model_dir):
50
- handler = EndpointHandler(model_dir)
51
- handler.load_model()
52
- return handler
 
 
 
1
  import json
2
+ import os
3
+ from typing import Dict, List, Any
4
+ from llama_cpp import Llama
5
+ import gemma_tools as gem
6
+
7
+ MAX_TOKENS = 512
8
+
9
+ class EndpointHandler():
10
+ def __init__(self, data):
11
+ # Update the model path and filename with your ComicBot model
12
+ self.model = Llama.from_pretrained("njwright92/ComicBot_v.2-gguf", filename="ComicBot_v.2-q4_k_m.gguf", n_ctx=8192)
13
+
14
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
15
+ args = gem.get_args_or_none(data)
16
+ fmat = "<startofturn>system\n{system_prompt} <endofturn>\n<startofturn>user\n{prompt} <endofturn>\n<startofturn>model"
17
+ print(args, fmat)
18
+ if not args[0]:
19
+ return {
20
+ "status": args["status"],
21
+ "message": args["description"]
22
+ }
23
+ try:
24
+ fmat = fmat.format(system_prompt=args["system_prompt"], prompt=args["inputs"])
25
+ except Exception as e:
26
+ return json.dumps({
27
+ "status": "error",
28
+ "reason": "invalid format"
29
+ })
30
+
31
+ max_length = data.pop("max_length", 512)
32
+ try:
33
+ max_length = int(max_length)
34
+ except Exception as e:
35
+ return json.dumps({
36
+ "status": "error",
37
+ "reason": "max_length was passed as something that was absolutely not a plain old int"
38
+ })
39
+
40
+ res = self.model(fmat, temperature=args["temperature"], top_p=args["top_p"], top_k=args["top_k"], max_tokens=max_length)
41
+
42
+ return res