iubaris-13b-v3_GGML / handler.py
kajdun's picture
Update handler.py
0d80971
import os
from typing import Dict, List, Any
os.environ['CMAKE_ARGS'] = "-DLLAMA_CUBLAS=on"
os.environ['FORCE_CMAKE'] = "1"
import sys
import subprocess
subprocess.check_call([sys.executable, '-m', 'pip', 'install',
'llama-cpp-python'])
from llama_cpp import Llama
import torch
from loguru import logger
import time
import psutil
_ = psutil.cpu_count(logical=True)
cpu_count: int = int(_) if _ else 1
MAX_INPUT_TOKEN_LENGTH = 3072
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
class EndpointHandler():
def __init__(self, path=""):
self.llm = Llama(model_path="/repository/iubaris-13b-v3_ggml_Q4_K_M.gguf", n_ctx=MAX_INPUT_TOKEN_LENGTH, n_gpu_layers=50, n_threads=cpu_count, verbose=True)
def get_input_token_length(self, message: str) -> int:
input_ids = self.model([message.encode('utf-8')])
return len(input_ids)
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {})
parameters["max_tokens"] = parameters.pop("max_tokens", DEFAULT_MAX_NEW_TOKENS)
#logger.info(inputs)
#logger.info(parameters)
if parameters["max_tokens"] > MAX_MAX_NEW_TOKENS:
logger.error(f"requested max_new_tokens too high (> {MAX_MAX_NEW_TOKENS})")
return [{"generated_text": None, "error": f"requested max_new_tokens too high (> {MAX_MAX_NEW_TOKENS})"}]
#input_token_length = self.get_input_token_length(inputs)
#if input_token_length > MAX_INPUT_TOKEN_LENGTH:
# logger.error(f"input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH})")
# return [{"generated_text": None, "error": f"input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH})"}]
#logger.info(f"inputs: {inputs}")
outputs = self.llm.create_completion(inputs, **parameters)
#logger.info(outputs)
return [{"generated_text": outputs["choices"][0]["text"]}]