Update handler.py
Browse files- handler.py +12 -14
handler.py
CHANGED
@@ -4,18 +4,15 @@ import gemma_tools
|
|
4 |
|
5 |
MAX_TOKENS = 1000
|
6 |
|
7 |
-
|
8 |
class EndpointHandler():
|
9 |
def __init__(self, model_dir=None):
|
10 |
if model_dir:
|
11 |
print(f"Initializing with model from directory: {model_dir}")
|
12 |
|
13 |
-
#
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
# Use model_id instead of filename for repo reference
|
18 |
-
model_id="njwright92/ComicBot_v.2-gguf",
|
19 |
n_ctx=MAX_TOKENS,
|
20 |
chat_format="llama-2"
|
21 |
)
|
@@ -24,8 +21,7 @@ class EndpointHandler():
|
|
24 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
25 |
# Extract and validate arguments from the data
|
26 |
print("Extracting and validating arguments from the data payload...")
|
27 |
-
args_check = gemma_tools.get_args_or_none(
|
28 |
-
data) # Using the new function
|
29 |
|
30 |
if not args_check[0]: # If validation failed
|
31 |
return [{
|
@@ -62,11 +58,13 @@ class EndpointHandler():
|
|
62 |
}]
|
63 |
|
64 |
print("Generating response from the model...")
|
65 |
-
res = self.model(
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
70 |
|
71 |
print(f"Model response: {res}")
|
72 |
|
|
|
4 |
|
5 |
MAX_TOKENS = 1000
|
6 |
|
|
|
7 |
class EndpointHandler():
|
8 |
def __init__(self, model_dir=None):
|
9 |
if model_dir:
|
10 |
print(f"Initializing with model from directory: {model_dir}")
|
11 |
|
12 |
+
# Initialize the Llama model directly
|
13 |
+
print("Initializing Llama model...")
|
14 |
+
self.model = Llama(
|
15 |
+
model_path=f"{model_dir}/ComicBot_v.2-gguf", # Adjust the path if necessary
|
|
|
|
|
16 |
n_ctx=MAX_TOKENS,
|
17 |
chat_format="llama-2"
|
18 |
)
|
|
|
21 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
22 |
# Extract and validate arguments from the data
|
23 |
print("Extracting and validating arguments from the data payload...")
|
24 |
+
args_check = gemma_tools.get_args_or_none(data)
|
|
|
25 |
|
26 |
if not args_check[0]: # If validation failed
|
27 |
return [{
|
|
|
58 |
}]
|
59 |
|
60 |
print("Generating response from the model...")
|
61 |
+
res = self.model(
|
62 |
+
formatted_prompt,
|
63 |
+
temperature=args["temperature"],
|
64 |
+
top_p=args["top_p"],
|
65 |
+
top_k=args["top_k"],
|
66 |
+
max_tokens=max_length
|
67 |
+
)
|
68 |
|
69 |
print(f"Model response: {res}")
|
70 |
|