Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import requests
|
3 |
+
import json
|
4 |
+
import threading
|
5 |
+
import os
|
6 |
+
|
7 |
+
stop_generation = threading.Event()
|
8 |
+
API_URL = os.environ.get('API_URL')
|
9 |
+
API_KEY = os.environ.get('API_KEY')
|
10 |
+
|
11 |
+
headers = {
|
12 |
+
"Authorization": f"Bearer {API_KEY}",
|
13 |
+
"Content-Type": "application/json"
|
14 |
+
}
|
15 |
+
|
16 |
+
def predict(message, history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
|
17 |
+
global stop_generation
|
18 |
+
stop_generation.clear()
|
19 |
+
|
20 |
+
history_format = [{"role": "system", "content": system_prompt}]
|
21 |
+
for human, assistant in history:
|
22 |
+
history_format.append({"role": "user", "content": human})
|
23 |
+
if assistant:
|
24 |
+
history_format.append({"role": "assistant", "content": assistant})
|
25 |
+
history_format.append({"role": "user", "content": message})
|
26 |
+
|
27 |
+
data = {
|
28 |
+
"model": "meta-llama/Meta-Llama-3.1-405B-Instruct",
|
29 |
+
"messages": history_format,
|
30 |
+
"stream": True,
|
31 |
+
"temperature": temperature,
|
32 |
+
"top_p": top_p,
|
33 |
+
"top_k": top_k,
|
34 |
+
"frequency_penalty": frequency_penalty,
|
35 |
+
"presence_penalty": presence_penalty,
|
36 |
+
"repetition_penalty": repetition_penalty,
|
37 |
+
"max_tokens": max_tokens
|
38 |
+
}
|
39 |
+
|
40 |
+
response = requests.post(API_URL, headers=headers, data=json.dumps(data), stream=True)
|
41 |
+
|
42 |
+
print("Raw API Response:")
|
43 |
+
print(response.text)
|
44 |
+
|
45 |
+
partial_message = ""
|
46 |
+
for line in response.iter_lines():
|
47 |
+
if stop_generation.is_set():
|
48 |
+
break
|
49 |
+
if line:
|
50 |
+
line = line.decode('utf-8')
|
51 |
+
if line.startswith("data: "):
|
52 |
+
if line.strip() == "data: [DONE]":
|
53 |
+
break
|
54 |
+
try:
|
55 |
+
json_data = json.loads(line[6:])
|
56 |
+
if 'choices' in json_data and json_data['choices']:
|
57 |
+
content = json_data['choices'][0]['delta'].get('content', '')
|
58 |
+
if content:
|
59 |
+
partial_message += content
|
60 |
+
yield partial_message
|
61 |
+
except json.JSONDecodeError:
|
62 |
+
continue
|
63 |
+
|
64 |
+
if partial_message:
|
65 |
+
yield partial_message
|
66 |
+
|
67 |
+
def import_chat(custom_format_string):
|
68 |
+
try:
|
69 |
+
sections = custom_format_string.split('<|')
|
70 |
+
|
71 |
+
imported_history = []
|
72 |
+
system_prompt = ""
|
73 |
+
|
74 |
+
for section in sections:
|
75 |
+
if section.startswith('system|>'):
|
76 |
+
system_prompt = section.replace('system|>', '').strip()
|
77 |
+
elif section.startswith('user|>'):
|
78 |
+
user_message = section.replace('user|>', '').strip()
|
79 |
+
imported_history.append([user_message, None])
|
80 |
+
elif section.startswith('assistant|>'):
|
81 |
+
assistant_message = section.replace('assistant|>', '').strip()
|
82 |
+
if imported_history:
|
83 |
+
imported_history[-1][1] = assistant_message
|
84 |
+
else:
|
85 |
+
imported_history.append(["", assistant_message])
|
86 |
+
|
87 |
+
return imported_history, system_prompt
|
88 |
+
except Exception as e:
|
89 |
+
print(f"Error importing chat: {e}")
|
90 |
+
return None, None
|
91 |
+
|
92 |
+
def export_chat(history, system_prompt):
|
93 |
+
export_data = f"<|system|>\n{system_prompt}\n\n"
|
94 |
+
for user_msg, assistant_msg in history:
|
95 |
+
export_data += f"<|user|>\n{user_msg}\n\n"
|
96 |
+
if assistant_msg:
|
97 |
+
export_data += f"<|assistant|>\n{assistant_msg}\n\n"
|
98 |
+
return export_data
|
99 |
+
|
100 |
+
def stop_generation_func():
|
101 |
+
global stop_generation
|
102 |
+
stop_generation.set()
|
103 |
+
|
104 |
+
|
105 |
+
with gr.Blocks(theme='gradio/monochrome') as demo:
|
106 |
+
with gr.Row():
|
107 |
+
with gr.Column(scale=2):
|
108 |
+
chatbot = gr.Chatbot()
|
109 |
+
msg = gr.Textbox(label="Message")
|
110 |
+
with gr.Row():
|
111 |
+
clear = gr.Button("Clear")
|
112 |
+
regenerate = gr.Button("Regenerate")
|
113 |
+
stop_btn = gr.Button("Stop")
|
114 |
+
with gr.Row():
|
115 |
+
with gr.Column(scale=4):
|
116 |
+
import_textbox = gr.Textbox(label="Import textbox", lines=5)
|
117 |
+
with gr.Column(scale=1):
|
118 |
+
export_button = gr.Button("Export Chat")
|
119 |
+
import_button = gr.Button("Import Chat")
|
120 |
+
|
121 |
+
with gr.Column(scale=1):
|
122 |
+
system_prompt = gr.Textbox("", label="System Prompt", lines=5)
|
123 |
+
temperature = gr.Slider(0, 2, value=0.8, step=0.01, label="Temperature")
|
124 |
+
top_p = gr.Slider(0, 1, value=0.95, step=0.01, label="Top P")
|
125 |
+
top_k = gr.Slider(1, 500, value=40, step=1, label="Top K")
|
126 |
+
frequency_penalty = gr.Slider(-2, 2, value=0, step=0.1, label="Frequency Penalty")
|
127 |
+
presence_penalty = gr.Slider(-2, 2, value=0, step=0.1, label="Presence Penalty")
|
128 |
+
repetition_penalty = gr.Slider(0.01, 5, value=1.1, step=0.01, label="Repetition Penalty")
|
129 |
+
max_tokens = gr.Slider(1, 1024, value=256, step=1, label="Max Output (max_tokens)")
|
130 |
+
|
131 |
+
def user(user_message, history):
|
132 |
+
return "", history + [[user_message, None]]
|
133 |
+
|
134 |
+
def bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
|
135 |
+
global stop_generation
|
136 |
+
user_message = history[-1][0]
|
137 |
+
bot_message = predict(user_message, history[:-1], system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens)
|
138 |
+
history[-1][1] = ""
|
139 |
+
for chunk in bot_message:
|
140 |
+
if stop_generation.is_set():
|
141 |
+
history[-1][1] += " [Generation stopped]"
|
142 |
+
break
|
143 |
+
history[-1][1] = chunk
|
144 |
+
yield history
|
145 |
+
stop_generation.clear()
|
146 |
+
|
147 |
+
def regenerate_response(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
|
148 |
+
if len(history) > 0:
|
149 |
+
last_user_message = history[-1][0]
|
150 |
+
history[-1][1] = None
|
151 |
+
for new_history in bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
|
152 |
+
yield new_history
|
153 |
+
else:
|
154 |
+
yield history
|
155 |
+
|
156 |
+
def import_chat_wrapper(custom_format_string):
|
157 |
+
imported_history, imported_system_prompt = import_chat(custom_format_string)
|
158 |
+
return imported_history, imported_system_prompt
|
159 |
+
|
160 |
+
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
161 |
+
bot, [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens], chatbot
|
162 |
+
)
|
163 |
+
|
164 |
+
clear.click(lambda: None, None, chatbot, queue=False)
|
165 |
+
|
166 |
+
regenerate.click(
|
167 |
+
regenerate_response,
|
168 |
+
[chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens],
|
169 |
+
chatbot
|
170 |
+
)
|
171 |
+
|
172 |
+
import_button.click(import_chat_wrapper, inputs=[import_textbox], outputs=[chatbot, system_prompt])
|
173 |
+
|
174 |
+
export_button.click(
|
175 |
+
export_chat,
|
176 |
+
inputs=[chatbot, system_prompt],
|
177 |
+
outputs=[import_textbox]
|
178 |
+
)
|
179 |
+
|
180 |
+
stop_btn.click(stop_generation_func, inputs=[], outputs=[])
|
181 |
+
|
182 |
+
if __name__ == "__main__":
|
183 |
+
demo.launch()
|