kajdun commited on
Commit
b1cba66
·
1 Parent(s): a833c3c

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +37 -0
handler.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from llama_cpp import Llama
3
+ import torch
4
+ from loguru import logger
5
+
6
+ MAX_INPUT_TOKEN_LENGTH = 4000
7
+ MAX_MAX_NEW_TOKENS = 2048
8
+ DEFAULT_MAX_NEW_TOKENS = 1024
9
+
10
+ class EndpointHandler():
11
+ def __init__(self, path=""):
12
+ self.model = Llama(model_path="/repository/iubaris-13b-v3_ggml_Q4_K_S.bin", n_ctx=4000, n_gpu_layers=50, n_threads=cpu_count, verbose=True)
13
+
14
+ def get_input_token_length(self, message: str) -> int:
15
+ input_ids = self.model([message.encode('utf-8')]
16
+ return len(input_ids)
17
+
18
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
19
+ inputs = data.pop("inputs", data)
20
+ parameters = data.pop("parameters", {})
21
+
22
+ parameters["max_new_tokens"] = parameters.pop("max_new_tokens", DEFAULT_MAX_NEW_TOKENS)
23
+
24
+ if parameters["max_new_tokens"] > MAX_MAX_NEW_TOKENS:
25
+ logger.error(f"requested max_new_tokens too high (> {MAX_MAX_NEW_TOKENS})")
26
+ return [{"generated_text": None, "error": f"requested max_new_tokens too high (> {MAX_MAX_NEW_TOKENS})"}]
27
+
28
+ input_token_length = self.get_input_token_length(inputs)
29
+ if input_token_length > MAX_INPUT_TOKEN_LENGTH:
30
+ logger.error(f"input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH})")
31
+ return [{"generated_text": None, "error": f"input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH})"}]
32
+
33
+ logger.info(f"inputs: {inputs}")
34
+
35
+ outputs = self.model(inputs, **parameters)
36
+
37
+ return [{"generated_text": outputs["choices"][0]["text"]}]