Spaces:
Running
Running
File size: 6,574 Bytes
b376f12 2e11c33 b376f12 2e11c33 53cb438 b376f12 53cb438 b376f12 53cb438 b8d64ca 53cb438 b8d64ca 6997fc5 b8d64ca 2e11c33 b376f12 a4b0abe 2e11c33 b376f12 53cb438 b376f12 03c2ae6 2e11c33 03c2ae6 db24877 03c2ae6 2e11c33 b376f12 2e11c33 b376f12 03c2ae6 edd0bac 03c2ae6 edd0bac db24877 edd0bac 03c2ae6 da09cca 418f2fe cbef7a0 da09cca 3ba38dc da09cca 19342c6 3ba38dc da09cca 03c2ae6 d7174fa 03c2ae6 da09cca b376f12 da09cca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
from collections.abc import Iterator
from datetime import datetime
from pathlib import Path
from threading import Thread
from huggingface_hub import hf_hub_download
from themes.research_monochrome import theme
from typing import Iterator, List, Dict
import requests
import json
import subprocess
import gradio as gr
today_date = datetime.today().strftime("%B %-d, %Y") # noqa: DTZ002
SYS_PROMPT = f"""Today's Date: {today_date}.
You are Granite, developed by IBM. You are a helpful AI assistant"""
TITLE = "IBM Granite 3.1 3b a800 MoE Instruct from local GGUF server"
DESCRIPTION = """
<p>Granite 3.1 3b instruct is an open-source LLM supporting a 128k context window. This demo uses only 2K context.
<span class="gr_docs_link">
<a href="https://www.ibm.com/granite/docs/">View Documentation <i class="fa fa-external-link"></i></a>
</span>
</p>
"""
LLAMA_CPP_SERVER = "http://127.0.0.1:8081"
MAX_NEW_TOKENS = 1024
TEMPERATURE = 0.7
TOP_P = 0.85
TOP_K = 50
REPETITION_PENALTY = 1.05
# download GGUF into local directory
gguf_path = hf_hub_download(
repo_id="bartowski/granite-3.1-3b-a800m-instruct-GGUF",
filename="granite-3.1-3b-a800m-instruct-Q8_0.gguf",
local_dir="."
)
# start llama-server
subprocess.run(["chmod", "+x", "llama-server"])
command = ["./llama-server", "-m", "granite-3.1-3b-a800m-instruct-Q8_0.gguf", "-ngl", "0", "--temp", "0.0", "-c", "2048", "-t", "8", "--port", "8081"]
process = subprocess.Popen(command)
print(f"Llama-server process started with PID {process.pid}")
def generate(
message: str,
chat_history: List[Dict],
temperature: float = TEMPERATURE,
repetition_penalty: float = REPETITION_PENALTY,
top_p: float = TOP_P,
top_k: float = TOP_K,
max_new_tokens: int = MAX_NEW_TOKENS,
) -> Iterator[str]:
"""Generate function for chat demo using Llama.cpp server."""
# Build messages
conversation = []
conversation.append({"role": "system", "content": SYS_PROMPT})
conversation += chat_history
conversation.append({"role": "user", "content": message})
# Prepare the prompt for the Llama.cpp server
prompt = ""
for item in conversation:
if item["role"] == "system":
prompt += f"<|system|>\n{item['content']}\n<|file_separator|>\n"
elif item["role"] == "user":
prompt += f"<|user|>\n{item['content']}\n<|file_separator|>\n"
elif item["role"] == "assistant":
prompt += f"<|model|>\n{item['content']}\n<|file_separator|>\n"
prompt += "<|model|>\n" # Add the beginning token for the assistant
# Construct the request payload
payload = {
"prompt": prompt,
"stream": True, # Enable streaming
"max_tokens": max_new_tokens,
"temperature": temperature,
"repeat_penalty": repetition_penalty,
"top_p": top_p,
"top_k": top_k,
"stop": ["<|file_separator|>"], #stops after it sees this
}
try:
# Make the request to the Llama.cpp server
with requests.post(f"{LLAMA_CPP_SERVER}/completion", json=payload, stream=True, timeout=60) as response:
response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
# Stream the response from the server
outputs = []
for line in response.iter_lines():
if line:
# Decode the line
decoded_line = line.decode('utf-8')
# Remove 'data: ' prefix if present
if decoded_line.startswith("data: "):
decoded_line = decoded_line[6:]
# Handle potential JSON decoding errors
try:
json_data = json.loads(decoded_line)
text = json_data.get("content", "") # Extract content field. crucial.
if text:
outputs.append(text)
yield "".join(outputs)
except json.JSONDecodeError:
print(f"JSONDecodeError: {decoded_line}")
# Handle the error, potentially skipping the line or logging it.
except requests.exceptions.RequestException as e:
print(f"Request failed: {e}")
yield f"Error: {e}" # Yield an error message to the user
except Exception as e:
print(f"An unexpected error occurred: {e}")
yield f"Error: {e}" # Yield error message
css_file_path = Path(Path(__file__).parent / "app.css")
# advanced settings (displayed in Accordion)
temperature_slider = gr.Slider(
minimum=0, maximum=1.0, value=TEMPERATURE, step=0.1, label="Temperature", elem_classes=["gr_accordion_element"]
)
top_p_slider = gr.Slider(
minimum=0, maximum=1.0, value=TOP_P, step=0.05, label="Top P", elem_classes=["gr_accordion_element"]
)
top_k_slider = gr.Slider(
minimum=0, maximum=100, value=TOP_K, step=1, label="Top K", elem_classes=["gr_accordion_element"]
)
repetition_penalty_slider = gr.Slider(
minimum=0,
maximum=2.0,
value=REPETITION_PENALTY,
step=0.05,
label="Repetition Penalty",
elem_classes=["gr_accordion_element"],
)
max_new_tokens_slider = gr.Slider(
minimum=1,
maximum=2000,
value=MAX_NEW_TOKENS,
step=1,
label="Max New Tokens",
elem_classes=["gr_accordion_element"],
)
chat_interface_accordion = gr.Accordion(label="Advanced Settings", open=False)
with gr.Blocks(fill_height=True, css_paths=css_file_path, theme=theme, title=TITLE) as demo:
gr.HTML(f"<h2>{TITLE}</h2>", elem_classes=["gr_title"])
gr.HTML(DESCRIPTION)
chat_interface = gr.ChatInterface(
fn=generate,
examples=[
["Explain the concept of quantum computing to someone with no background in physics or computer science."],
["What is OpenShift?"],
["What's the importance of low latency inference?"],
["Help me boost productivity habits."],
],
example_labels=[
"Explain quantum computing",
"What is OpenShift?",
"Importance of low latency inference",
"Boosting productivity habits",
],
cache_examples=False,
type="messages",
additional_inputs=[
temperature_slider,
repetition_penalty_slider,
top_p_slider,
top_k_slider,
max_new_tokens_slider,
],
additional_inputs_accordion=chat_interface_accordion,
)
if __name__ == "__main__":
demo.queue().launch()
|