TobDeBer commited on
Commit
2e11c33
·
1 Parent(s): 0d42c5a

use local server

Browse files
Files changed (1) hide show
  1. app.py +74 -36
app.py CHANGED
@@ -2,11 +2,14 @@ from collections.abc import Iterator
2
  from datetime import datetime
3
  from pathlib import Path
4
  from threading import Thread
 
 
 
5
 
 
 
6
  import gradio as gr
7
 
8
- from themes.research_monochrome import theme
9
-
10
  today_date = datetime.today().strftime("%B %-d, %Y") # noqa: DTZ002
11
 
12
  SYS_PROMPT = f"""Knowledge Cutoff Date: April 2024.
@@ -21,6 +24,7 @@ or enter your own. Keep in mind that AI can occasionally make mistakes.
21
  </span>
22
  </p>
23
  """
 
24
  MAX_INPUT_TOKEN_LENGTH = 128_000
25
  MAX_NEW_TOKENS = 1024
26
  TEMPERATURE = 0.7
@@ -29,56 +33,90 @@ TOP_K = 50
29
  REPETITION_PENALTY = 1.05
30
 
31
  # download GGUF into local directory
 
 
 
 
 
32
 
33
- # chmod llama-server
34
- # start llama-server
 
35
 
36
  def generate(
37
  message: str,
38
- chat_history: list[dict],
39
  temperature: float = TEMPERATURE,
40
  repetition_penalty: float = REPETITION_PENALTY,
41
  top_p: float = TOP_P,
42
  top_k: float = TOP_K,
43
  max_new_tokens: int = MAX_NEW_TOKENS,
44
  ) -> Iterator[str]:
45
- """Generate function for chat demo."""
 
46
  # Build messages
47
  conversation = []
48
  conversation.append({"role": "system", "content": SYS_PROMPT})
49
  conversation += chat_history
50
  conversation.append({"role": "user", "content": message})
51
 
52
- # Convert messages to prompt format
53
- input_ids = tokenizer.apply_chat_template(
54
- conversation,
55
- return_tensors="pt",
56
- add_generation_prompt=True,
57
- truncation=True,
58
- max_length=MAX_INPUT_TOKEN_LENGTH - max_new_tokens,
59
- )
60
-
61
- input_ids = input_ids.to(model.device)
62
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
63
- generate_kwargs = dict(
64
- {"input_ids": input_ids},
65
- streamer=streamer,
66
- max_new_tokens=max_new_tokens,
67
- do_sample=True,
68
- top_p=top_p,
69
- top_k=top_k,
70
- temperature=temperature,
71
- num_beams=1,
72
- repetition_penalty=repetition_penalty,
73
- )
74
-
75
- t = Thread(target=model.generate, kwargs=generate_kwargs)
76
- t.start()
77
-
78
- outputs = []
79
- for text in streamer:
80
- outputs.append(text)
81
- yield "".join(outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
 
84
  css_file_path = Path(Path(__file__).parent / "app.css")
 
2
  from datetime import datetime
3
  from pathlib import Path
4
  from threading import Thread
5
+ from huggingface_hub import hf_hub_download
6
+ from themes.research_monochrome import theme
7
+ from typing import Iterator, List, Dict
8
 
9
+ import requests
10
+ import json
11
  import gradio as gr
12
 
 
 
13
  today_date = datetime.today().strftime("%B %-d, %Y") # noqa: DTZ002
14
 
15
  SYS_PROMPT = f"""Knowledge Cutoff Date: April 2024.
 
24
  </span>
25
  </p>
26
  """
27
+ LLAMA_CPP_SERVER = "http://127.0.0.1:8081"
28
  MAX_INPUT_TOKEN_LENGTH = 128_000
29
  MAX_NEW_TOKENS = 1024
30
  TEMPERATURE = 0.7
 
33
  REPETITION_PENALTY = 1.05
34
 
35
  # download GGUF into local directory
36
+ gguf_path = hf_hub_download(
37
+ repo_id="bartowski/granite-3.1-3b-a800m-instruct-GGUF",
38
+ filename="granite-3.1-3b-a800m-instruct-Q8_0.gguf",
39
+ local_dir="."
40
+ )
41
 
42
+ # TODO: chmod llama-server
43
+ # TODO: start llama-server
44
+ # ./llama-server -m granite-3.1-3b-a800m-instruct-Q8_0.gguf -ngl 0 --temp 0.0 -c 2048 -t 8 --port 8081
45
 
46
  def generate(
47
  message: str,
48
+ chat_history: List[Dict],
49
  temperature: float = TEMPERATURE,
50
  repetition_penalty: float = REPETITION_PENALTY,
51
  top_p: float = TOP_P,
52
  top_k: float = TOP_K,
53
  max_new_tokens: int = MAX_NEW_TOKENS,
54
  ) -> Iterator[str]:
55
+ """Generate function for chat demo using Llama.cpp server."""
56
+
57
  # Build messages
58
  conversation = []
59
  conversation.append({"role": "system", "content": SYS_PROMPT})
60
  conversation += chat_history
61
  conversation.append({"role": "user", "content": message})
62
 
63
+ # Prepare the prompt for the Llama.cpp server
64
+ prompt = ""
65
+ for item in conversation:
66
+ if item["role"] == "system":
67
+ prompt += f"<|system|>\n{item['content']}\n<|file_separator|>\n"
68
+ elif item["role"] == "user":
69
+ prompt += f"<|user|>\n{item['content']}\n<|file_separator|>\n"
70
+ elif item["role"] == "assistant":
71
+ prompt += f"<|model|>\n{item['content']}\n<|file_separator|>\n"
72
+ prompt += "<|model|>\n" # Add the beginning token for the assistant
73
+
74
+
75
+ # Construct the request payload
76
+ payload = {
77
+ "prompt": prompt,
78
+ "stream": True, # Enable streaming
79
+ "max_tokens": max_new_tokens,
80
+ "temperature": temperature,
81
+ "repeat_penalty": repetition_penalty,
82
+ "top_p": top_p,
83
+ "top_k": top_k,
84
+ "stop": ["<|file_separator|>"], #stops after it sees this
85
+ }
86
+
87
+ try:
88
+ # Make the request to the Llama.cpp server
89
+ with requests.post(f"{LLAMA_CPP_SERVER}/completion", json=payload, stream=True, timeout=60) as response:
90
+ response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
91
+
92
+ # Stream the response from the server
93
+ outputs = []
94
+ for line in response.iter_lines():
95
+ if line:
96
+ # Decode the line
97
+ decoded_line = line.decode('utf-8')
98
+ # Remove 'data: ' prefix if present
99
+ if decoded_line.startswith("data: "):
100
+ decoded_line = decoded_line[6:]
101
+
102
+ # Handle potential JSON decoding errors
103
+ try:
104
+ json_data = json.loads(decoded_line)
105
+ text = json_data.get("content", "") # Extract content field. crucial.
106
+ if text:
107
+ outputs.append(text)
108
+ yield "".join(outputs)
109
+
110
+ except json.JSONDecodeError:
111
+ print(f"JSONDecodeError: {decoded_line}")
112
+ # Handle the error, potentially skipping the line or logging it.
113
+
114
+ except requests.exceptions.RequestException as e:
115
+ print(f"Request failed: {e}")
116
+ yield f"Error: {e}" # Yield an error message to the user
117
+ except Exception as e:
118
+ print(f"An unexpected error occurred: {e}")
119
+ yield f"Error: {e}" # Yield error message
120
 
121
 
122
  css_file_path = Path(Path(__file__).parent / "app.css")