|
from typing import Dict, List, Any |
|
from llama_cpp import Llama |
|
import gemma_tools |
|
import os |
|
|
|
MAX_TOKENS = 1000 |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir: str = None): |
|
""" |
|
Initialize the EndpointHandler with the path to the model directory. |
|
|
|
:param model_dir: Path to the directory containing the model file. |
|
""" |
|
if model_dir: |
|
|
|
model_path = os.path.join( |
|
model_dir, "comic_mistral-v5.2.q5_0.gguf") |
|
if not os.path.exists(model_path): |
|
raise FileNotFoundError( |
|
f"The model file was not found at {model_path}") |
|
|
|
try: |
|
self.model = Llama( |
|
model_path=model_path, |
|
n_ctx=MAX_TOKENS, |
|
) |
|
except Exception as e: |
|
raise RuntimeError(f"Failed to load the model: {e}") |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
Handle incoming requests for model inference. |
|
|
|
:param data: Dictionary containing input data and parameters for the model. |
|
:return: A list with a dictionary containing the status and response or error details. |
|
""" |
|
|
|
args_check = gemma_tools.get_args_or_none(data) |
|
|
|
if not args_check[0]: |
|
return [{ |
|
"status": args_check.get("status", "error"), |
|
"reason": args_check.get("reason", "unknown"), |
|
"description": args_check.get("description", "Validation error in arguments") |
|
}] |
|
|
|
|
|
args = args_check[1] |
|
|
|
|
|
prompt_format = "<startofturn>system\n{system_prompt} <endofturn>\n<startofturn>user\n{inputs} <endofturn>\n<startofturn>model" |
|
|
|
try: |
|
formatted_prompt = prompt_format.format(**args) |
|
except Exception as e: |
|
return [{ |
|
"status": "error", |
|
"reason": "Invalid format", |
|
"detail": str(e) |
|
}] |
|
|
|
|
|
max_length = data.get("max_length", 212) |
|
try: |
|
max_length = int(max_length) |
|
except ValueError: |
|
return [{ |
|
"status": "error", |
|
"reason": "max_length must be an integer", |
|
"detail": "max_length was not a valid integer" |
|
}] |
|
|
|
|
|
try: |
|
res = self.model( |
|
formatted_prompt, |
|
temperature=args["temperature"], |
|
top_p=args["top_p"], |
|
top_k=args["top_k"], |
|
max_tokens=max_length |
|
) |
|
except Exception as e: |
|
return [{ |
|
"status": "error", |
|
"reason": "Inference failed", |
|
"detail": str(e) |
|
}] |
|
|
|
return [{ |
|
"status": "success", |
|
|
|
"response": res['choices'][0]['text'].strip() |
|
}] |
|
|
|
|
|
|
|
try: |
|
handler = EndpointHandler("/repository") |
|
except (FileNotFoundError, RuntimeError) as e: |
|
print(f"Initialization error: {e}") |
|
exit(1) |
|
|