njwright92 commited on
Commit
5da533d
·
verified ·
1 Parent(s): c9ca6d1

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +12 -14
handler.py CHANGED
@@ -4,18 +4,15 @@ import gemma_tools
4
 
5
  MAX_TOKENS = 1000
6
 
7
-
8
  class EndpointHandler():
9
  def __init__(self, model_dir=None):
10
  if model_dir:
11
  print(f"Initializing with model from directory: {model_dir}")
12
 
13
- # For Hugging Face endpoints, you might not need to explicitly load the model if it's already linked
14
- # But if you need to initialize it specifically:
15
- print("Initializing Llama model directly from Hugging Face repository...")
16
- self.model = Llama.from_pretrained(
17
- # Use model_id instead of filename for repo reference
18
- model_id="njwright92/ComicBot_v.2-gguf",
19
  n_ctx=MAX_TOKENS,
20
  chat_format="llama-2"
21
  )
@@ -24,8 +21,7 @@ class EndpointHandler():
24
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
25
  # Extract and validate arguments from the data
26
  print("Extracting and validating arguments from the data payload...")
27
- args_check = gemma_tools.get_args_or_none(
28
- data) # Using the new function
29
 
30
  if not args_check[0]: # If validation failed
31
  return [{
@@ -62,11 +58,13 @@ class EndpointHandler():
62
  }]
63
 
64
  print("Generating response from the model...")
65
- res = self.model(formatted_prompt,
66
- temperature=args["temperature"],
67
- top_p=args["top_p"],
68
- top_k=args["top_k"],
69
- max_tokens=max_length)
 
 
70
 
71
  print(f"Model response: {res}")
72
 
 
4
 
5
  MAX_TOKENS = 1000
6
 
 
7
  class EndpointHandler():
8
  def __init__(self, model_dir=None):
9
  if model_dir:
10
  print(f"Initializing with model from directory: {model_dir}")
11
 
12
+ # Initialize the Llama model directly
13
+ print("Initializing Llama model...")
14
+ self.model = Llama(
15
+ model_path=f"{model_dir}/ComicBot_v.2-gguf", # Adjust the path if necessary
 
 
16
  n_ctx=MAX_TOKENS,
17
  chat_format="llama-2"
18
  )
 
21
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
22
  # Extract and validate arguments from the data
23
  print("Extracting and validating arguments from the data payload...")
24
+ args_check = gemma_tools.get_args_or_none(data)
 
25
 
26
  if not args_check[0]: # If validation failed
27
  return [{
 
58
  }]
59
 
60
  print("Generating response from the model...")
61
+ res = self.model(
62
+ formatted_prompt,
63
+ temperature=args["temperature"],
64
+ top_p=args["top_p"],
65
+ top_k=args["top_k"],
66
+ max_tokens=max_length
67
+ )
68
 
69
  print(f"Model response: {res}")
70