from typing import Dict, List, Any from llama_cpp import Llama import torch from loguru import logger import time import psutil _ = psutil.cpu_count(logical=True) cpu_count: int = int(_) if _ else 1 MAX_INPUT_TOKEN_LENGTH = 4000 MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 1024 class EndpointHandler(): def __init__(self, path=""): self.model = Llama(model_path="/repository/iubaris-13b-v3_ggml_Q4_K_S.bin", n_ctx=4000, n_gpu_layers=50, n_threads=cpu_count, verbose=True) def get_input_token_length(self, message: str) -> int: input_ids = self.model([message.encode('utf-8')]) return len(input_ids) def __call__(self, data: Any) -> List[List[Dict[str, float]]]: inputs = data.pop("inputs", data) parameters = data.pop("parameters", {}) parameters["max_tokens"] = parameters.pop("max_tokens", DEFAULT_MAX_NEW_TOKENS) logger.info(inputs) logger.info(parameters) if parameters["max_tokens"] > MAX_MAX_NEW_TOKENS: logger.error(f"requested max_new_tokens too high (> {MAX_MAX_NEW_TOKENS})") return [{"generated_text": None, "error": f"requested max_new_tokens too high (> {MAX_MAX_NEW_TOKENS})"}] #input_token_length = self.get_input_token_length(inputs) #if input_token_length > MAX_INPUT_TOKEN_LENGTH: # logger.error(f"input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH})") # return [{"generated_text": None, "error": f"input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH})"}] #logger.info(f"inputs: {inputs}") outputs = self.model.create_completion(inputs, **parameters) logger.info(outputs) return [{"generated_text": outputs["choices"][0]["text"]}]