ComicBot_v.2-gguf / handler.py
njwright92's picture
Update handler.py
0b76d3c verified
from typing import Dict, List, Any
from llama_cpp import Llama
import gemma_tools
import os
MAX_TOKENS = 1000
class EndpointHandler:
def __init__(self, model_dir: str = None):
"""
Initialize the EndpointHandler with the path to the model directory.
:param model_dir: Path to the directory containing the model file.
"""
if model_dir:
# Update the model filename to match the one in your repository
model_path = os.path.join(
model_dir, "comic_mistral-v5.2.q5_0.gguf")
if not os.path.exists(model_path):
raise FileNotFoundError(
f"The model file was not found at {model_path}")
try:
self.model = Llama(
model_path=model_path,
n_ctx=MAX_TOKENS, # Use n_ctx for context size in llama_cpp
)
except Exception as e:
raise RuntimeError(f"Failed to load the model: {e}")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Handle incoming requests for model inference.
:param data: Dictionary containing input data and parameters for the model.
:return: A list with a dictionary containing the status and response or error details.
"""
# Extract and validate arguments from the data
args_check = gemma_tools.get_args_or_none(data)
if not args_check[0]: # If validation failed
return [{
"status": args_check.get("status", "error"),
"reason": args_check.get("reason", "unknown"),
"description": args_check.get("description", "Validation error in arguments")
}]
# If validation passed, args are in the second element of the tuple
args = args_check[1]
# Define the formatting template for the prompt
prompt_format = "<startofturn>system\n{system_prompt} <endofturn>\n<startofturn>user\n{inputs} <endofturn>\n<startofturn>model"
try:
formatted_prompt = prompt_format.format(**args)
except Exception as e:
return [{
"status": "error",
"reason": "Invalid format",
"detail": str(e)
}]
# Parse max_length, default to 212 if not provided or invalid
max_length = data.get("max_length", 212)
try:
max_length = int(max_length)
except ValueError:
return [{
"status": "error",
"reason": "max_length must be an integer",
"detail": "max_length was not a valid integer"
}]
# Perform inference
try:
res = self.model(
formatted_prompt,
temperature=args["temperature"],
top_p=args["top_p"],
top_k=args["top_k"],
max_tokens=max_length
)
except Exception as e:
return [{
"status": "error",
"reason": "Inference failed",
"detail": str(e)
}]
return [{
"status": "success",
# Extract the text from the response
"response": res['choices'][0]['text'].strip()
}]
# Usage in your script or where the handler is instantiated:
try:
handler = EndpointHandler("/repository")
except (FileNotFoundError, RuntimeError) as e:
print(f"Initialization error: {e}")
exit(1) # Exit with an error code if the handler cannot be initialized