File size: 5,990 Bytes
9e608cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import gradio as gr
import time
import requests
import json
import os
MODEL = "gpt-4-0125-preview"
API_URL = os.getenv("API_URL")
API_KEY = os.getenv("API_KEY")
url = f"{API_URL}/v1/chat/completions"
# The headers for the HTTP request
headers = {
"accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Bearer {API_KEY}",
}
def is_valid_json(data):
try:
# Attempt to parse the JSON data
parsed_data = json.loads(data)
return True, parsed_data
except ValueError as e:
# If an error occurs, the JSON is not valid
return False, str(e)
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.Button("Clear")
with gr.Row():
with gr.Column(scale=4):
# Define inputs for additional parameters
system_prompt_input = gr.Textbox(
label="System Prompt",
placeholder="Type system prompt here...",
value="You are a helpful assistant.",
)
temperature_input = gr.Slider(
label="Temperature", minimum=0.0, maximum=1.0, value=0.9, step=0.01
)
max_new_tokens_input = gr.Slider(
label="Max New Tokens", minimum=0, maximum=1024, value=256, step=1
)
top_p_input = gr.Slider(
label="Top P", minimum=0.0, maximum=1.0, value=0.9, step=0.01
)
repetition_penalty_input = gr.Slider(
label="Repetition Penalty",
minimum=1.0,
maximum=2.0,
value=1.2,
step=0.01,
)
with gr.Column(scale=1):
markup = gr.Markdown("## Mistral 7B Instruct v0.2 GGUF")
def update_globals(
system_prompt, temperature, max_new_tokens, top_p, repetition_penalty
):
global global_system_prompt, global_temperature, global_max_new_tokens, global_top_p, global_repetition_penalty
global_system_prompt = system_prompt
global_temperature = temperature
global_max_new_tokens = max_new_tokens
global_top_p = top_p
global_repetition_penalty = repetition_penalty
def user(user_message, history):
# print(f"User: {user_message}")
# print(f"History: {history}")
return "", history + [[user_message, None]]
def bot(
history, system_prompt, temperature, max_new_tokens, top_p, repetition_penalty
):
print(f"History in bot: {history}")
print(f"System Prompt: {system_prompt}")
print(f"Temperature: {temperature}")
print(f"Max New Tokens: {max_new_tokens}")
print(f"Top P: {top_p}")
print(f"Repetition Penalty: {repetition_penalty}")
# print(f"History in bot: {history}")
# [['Capital of France', 'The capital city of France is Paris.'], ['Thansk', 'You are welcome.'], ['What is the capital of France?', '']]
# convert this to [['Capital of France', 'The capital city of France is Paris.'], ['Thansk', 'You are welcome.'], ['What is the capital of France?', '']] to list of dict of role user and assiatant
history_messages = [{"content": h[0], "role": "user"} for h in history if h[0]]
# let's extract the user's question which should be the last touple first element
# user_question = history[-1][0]
history[-1][1] = ""
history_messages = system_prompt + history_messages
print(history_messages)
data = {
"messages": history_messages,
"stream": True,
"temprature": temperature,
"top_k": 50,
"top_p": 0.95,
"seed": 42,
"repeat_penalty": repetition_penalty,
"chat_format": "mistral-instruct",
"max_tokens": max_new_tokens,
"response_format": {
"type": "json_object",
},
}
# # Making the POST request and streaming the response
response = requests.post(
url, headers=headers, data=json.dumps(data), stream=True
)
for line in response.iter_lines():
# Filter out keep-alive new lines
if line:
data = line.decode("utf-8").lstrip("data: ")
# Check if the examples are valid
valid_check = is_valid_json(data)
if valid_check[0]:
try:
# Attempt to parse the JSON dataa
# json_data = json.loads(data)
json_data = valid_check[1]
delta_content = (
json_data.get("choices", [{}])[0]
.get("delta", {})
.get("content", "")
)
if delta_content: # Ensure there's content to print
# print(f"Bot: {delta_content}")
history[-1][1] += delta_content
# print(history)
time.sleep(0.05)
yield history
except json.JSONDecodeError as e:
print(
f"Error decoding JSON: {e} date: {data}"
) # print(delta_content, flush=True, end="")
# print(json_data['choices'][0])
msg.submit(
user, [msg, chatbot], [msg, chatbot], queue=False, concurrency_limit=5
).then(
bot,
inputs=[
chatbot,
system_prompt_input,
temperature_input,
max_new_tokens_input,
top_p_input,
repetition_penalty_input,
],
outputs=chatbot,
)
clear.click(lambda: None, None, chatbot, queue=False)
demo.queue()
if __name__ == "__main__":
demo.launch(show_api=False, share=False)
|