File size: 1,799 Bytes
b1cba66
 
 
 
b659443
 
 
 
 
 
b1cba66
 
 
 
 
 
 
 
 
 
b873d66
b1cba66
 
 
 
 
 
671ce7b
b565944
 
671ce7b
b1cba66
 
 
6211160
 
 
 
b1cba66
b565944
b1cba66
8aab45d
b565944
b1cba66
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from typing import  Dict, List, Any
from llama_cpp import Llama
import torch
from loguru import logger
import time

import psutil

_ = psutil.cpu_count(logical=True)
cpu_count: int = int(_) if _ else 1

MAX_INPUT_TOKEN_LENGTH  = 4000
MAX_MAX_NEW_TOKENS      = 2048
DEFAULT_MAX_NEW_TOKENS  = 1024

class EndpointHandler():
    def __init__(self, path=""):     
        self.model = Llama(model_path="/repository/iubaris-13b-v3_ggml_Q4_K_S.bin", n_ctx=4000, n_gpu_layers=50, n_threads=cpu_count, verbose=True)

    def get_input_token_length(self, message: str) -> int:
        input_ids = self.model([message.encode('utf-8')])
        return len(input_ids)    

    def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
        inputs = data.pop("inputs", data)
        parameters = data.pop("parameters", {})

        parameters["max_tokens"] = parameters.pop("max_tokens", DEFAULT_MAX_NEW_TOKENS)
        logger.info(inputs)
        logger.info(parameters)
        if parameters["max_tokens"] > MAX_MAX_NEW_TOKENS:
            logger.error(f"requested max_new_tokens too high (> {MAX_MAX_NEW_TOKENS})")
            return [{"generated_text": None, "error": f"requested max_new_tokens too high (> {MAX_MAX_NEW_TOKENS})"}] 

        #input_token_length = self.get_input_token_length(inputs)
        #if input_token_length > MAX_INPUT_TOKEN_LENGTH:
        #    logger.error(f"input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH})")
        #    return [{"generated_text": None, "error": f"input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH})"}] 

        #logger.info(f"inputs: {inputs}")
        
        outputs = self.model.create_completion(inputs, **parameters)
        logger.info(outputs)
        return [{"generated_text": outputs["choices"][0]["text"]}]