File size: 3,620 Bytes
e5bbead 22f17bb 0b76d3c e5bbead 7b23915 e5bbead 84508be 0b76d3c 84508be 0b76d3c e5bbead 0b76d3c 84508be 0b76d3c 5da533d dd38ce1 22f17bb dd38ce1 22f17bb dd38ce1 0b76d3c 22f17bb 0b76d3c 22f17bb e5bbead 0b76d3c e5bbead dd38ce1 e5bbead dd38ce1 0b76d3c 7b23915 e5bbead 22f17bb dd38ce1 e5bbead 22f17bb dd38ce1 0b76d3c 22f17bb dd38ce1 0b76d3c dd38ce1 0b76d3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
from typing import Dict, List, Any
from llama_cpp import Llama
import gemma_tools
import os
MAX_TOKENS = 1000
class EndpointHandler:
def __init__(self, model_dir: str = None):
"""
Initialize the EndpointHandler with the path to the model directory.
:param model_dir: Path to the directory containing the model file.
"""
if model_dir:
# Update the model filename to match the one in your repository
model_path = os.path.join(
model_dir, "comic_mistral-v5.2.q5_0.gguf")
if not os.path.exists(model_path):
raise FileNotFoundError(
f"The model file was not found at {model_path}")
try:
self.model = Llama(
model_path=model_path,
n_ctx=MAX_TOKENS, # Use n_ctx for context size in llama_cpp
)
except Exception as e:
raise RuntimeError(f"Failed to load the model: {e}")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Handle incoming requests for model inference.
:param data: Dictionary containing input data and parameters for the model.
:return: A list with a dictionary containing the status and response or error details.
"""
# Extract and validate arguments from the data
args_check = gemma_tools.get_args_or_none(data)
if not args_check[0]: # If validation failed
return [{
"status": args_check.get("status", "error"),
"reason": args_check.get("reason", "unknown"),
"description": args_check.get("description", "Validation error in arguments")
}]
# If validation passed, args are in the second element of the tuple
args = args_check[1]
# Define the formatting template for the prompt
prompt_format = "<startofturn>system\n{system_prompt} <endofturn>\n<startofturn>user\n{inputs} <endofturn>\n<startofturn>model"
try:
formatted_prompt = prompt_format.format(**args)
except Exception as e:
return [{
"status": "error",
"reason": "Invalid format",
"detail": str(e)
}]
# Parse max_length, default to 212 if not provided or invalid
max_length = data.get("max_length", 212)
try:
max_length = int(max_length)
except ValueError:
return [{
"status": "error",
"reason": "max_length must be an integer",
"detail": "max_length was not a valid integer"
}]
# Perform inference
try:
res = self.model(
formatted_prompt,
temperature=args["temperature"],
top_p=args["top_p"],
top_k=args["top_k"],
max_tokens=max_length
)
except Exception as e:
return [{
"status": "error",
"reason": "Inference failed",
"detail": str(e)
}]
return [{
"status": "success",
# Extract the text from the response
"response": res['choices'][0]['text'].strip()
}]
# Usage in your script or where the handler is instantiated:
try:
handler = EndpointHandler("/repository")
except (FileNotFoundError, RuntimeError) as e:
print(f"Initialization error: {e}")
exit(1) # Exit with an error code if the handler cannot be initialized
|