File size: 3,620 Bytes
e5bbead
 
22f17bb
0b76d3c
e5bbead
7b23915
e5bbead
84508be
0b76d3c
 
 
 
84508be
0b76d3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5bbead
 
0b76d3c
 
84508be
0b76d3c
 
 
 
5da533d
dd38ce1
22f17bb
dd38ce1
22f17bb
 
 
dd38ce1
 
0b76d3c
 
22f17bb
0b76d3c
 
22f17bb
e5bbead
0b76d3c
e5bbead
dd38ce1
e5bbead
dd38ce1
 
 
 
0b76d3c
7b23915
e5bbead
 
22f17bb
dd38ce1
e5bbead
22f17bb
 
dd38ce1
 
0b76d3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22f17bb
dd38ce1
 
0b76d3c
 
dd38ce1
0b76d3c
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from typing import Dict, List, Any
from llama_cpp import Llama
import gemma_tools
import os

MAX_TOKENS = 1000


class EndpointHandler:
    def __init__(self, model_dir: str = None):
        """
        Initialize the EndpointHandler with the path to the model directory.

        :param model_dir: Path to the directory containing the model file.
        """
        if model_dir:
            # Update the model filename to match the one in your repository
            model_path = os.path.join(
                model_dir, "comic_mistral-v5.2.q5_0.gguf")
            if not os.path.exists(model_path):
                raise FileNotFoundError(
                    f"The model file was not found at {model_path}")

            try:
                self.model = Llama(
                    model_path=model_path,
                    n_ctx=MAX_TOKENS,  # Use n_ctx for context size in llama_cpp
                )
            except Exception as e:
                raise RuntimeError(f"Failed to load the model: {e}")

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Handle incoming requests for model inference.

        :param data: Dictionary containing input data and parameters for the model.
        :return: A list with a dictionary containing the status and response or error details.
        """
        # Extract and validate arguments from the data
        args_check = gemma_tools.get_args_or_none(data)

        if not args_check[0]:  # If validation failed
            return [{
                "status": args_check.get("status", "error"),
                "reason": args_check.get("reason", "unknown"),
                "description": args_check.get("description", "Validation error in arguments")
            }]

        # If validation passed, args are in the second element of the tuple
        args = args_check[1]

        # Define the formatting template for the prompt
        prompt_format = "<startofturn>system\n{system_prompt} <endofturn>\n<startofturn>user\n{inputs} <endofturn>\n<startofturn>model"

        try:
            formatted_prompt = prompt_format.format(**args)
        except Exception as e:
            return [{
                "status": "error",
                "reason": "Invalid format",
                "detail": str(e)
            }]

        # Parse max_length, default to 212 if not provided or invalid
        max_length = data.get("max_length", 212)
        try:
            max_length = int(max_length)
        except ValueError:
            return [{
                "status": "error",
                "reason": "max_length must be an integer",
                "detail": "max_length was not a valid integer"
            }]

        # Perform inference
        try:
            res = self.model(
                formatted_prompt,
                temperature=args["temperature"],
                top_p=args["top_p"],
                top_k=args["top_k"],
                max_tokens=max_length
            )
        except Exception as e:
            return [{
                "status": "error",
                "reason": "Inference failed",
                "detail": str(e)
            }]

        return [{
            "status": "success",
            # Extract the text from the response
            "response": res['choices'][0]['text'].strip()
        }]


# Usage in your script or where the handler is instantiated:
try:
    handler = EndpointHandler("/repository")
except (FileNotFoundError, RuntimeError) as e:
    print(f"Initialization error: {e}")
    exit(1)  # Exit with an error code if the handler cannot be initialized