Corvius commited on
Commit
7324de2
Β·
verified Β·
1 Parent(s): 11da62d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -0
app.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ import json
4
+ import threading
5
+ import os
6
+
7
+ stop_generation = threading.Event()
8
+ API_URL = os.environ.get('API_URL')
9
+ API_KEY = os.environ.get('API_KEY')
10
+
11
+ headers = {
12
+ "Authorization": f"Bearer {API_KEY}",
13
+ "Content-Type": "application/json"
14
+ }
15
+
16
+ def predict(message, history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
17
+ global stop_generation
18
+ stop_generation.clear()
19
+
20
+ history_format = [{"role": "system", "content": system_prompt}]
21
+ for human, assistant in history:
22
+ history_format.append({"role": "user", "content": human})
23
+ if assistant:
24
+ history_format.append({"role": "assistant", "content": assistant})
25
+ history_format.append({"role": "user", "content": message})
26
+
27
+ data = {
28
+ "model": "meta-llama/Meta-Llama-3.1-405B-Instruct",
29
+ "messages": history_format,
30
+ "stream": True,
31
+ "temperature": temperature,
32
+ "top_p": top_p,
33
+ "top_k": top_k,
34
+ "frequency_penalty": frequency_penalty,
35
+ "presence_penalty": presence_penalty,
36
+ "repetition_penalty": repetition_penalty,
37
+ "max_tokens": max_tokens
38
+ }
39
+
40
+ response = requests.post(API_URL, headers=headers, data=json.dumps(data), stream=True)
41
+
42
+ print("Raw API Response:")
43
+ print(response.text)
44
+
45
+ partial_message = ""
46
+ for line in response.iter_lines():
47
+ if stop_generation.is_set():
48
+ break
49
+ if line:
50
+ line = line.decode('utf-8')
51
+ if line.startswith("data: "):
52
+ if line.strip() == "data: [DONE]":
53
+ break
54
+ try:
55
+ json_data = json.loads(line[6:])
56
+ if 'choices' in json_data and json_data['choices']:
57
+ content = json_data['choices'][0]['delta'].get('content', '')
58
+ if content:
59
+ partial_message += content
60
+ yield partial_message
61
+ except json.JSONDecodeError:
62
+ continue
63
+
64
+ if partial_message:
65
+ yield partial_message
66
+
67
+ def import_chat(custom_format_string):
68
+ try:
69
+ sections = custom_format_string.split('<|')
70
+
71
+ imported_history = []
72
+ system_prompt = ""
73
+
74
+ for section in sections:
75
+ if section.startswith('system|>'):
76
+ system_prompt = section.replace('system|>', '').strip()
77
+ elif section.startswith('user|>'):
78
+ user_message = section.replace('user|>', '').strip()
79
+ imported_history.append([user_message, None])
80
+ elif section.startswith('assistant|>'):
81
+ assistant_message = section.replace('assistant|>', '').strip()
82
+ if imported_history:
83
+ imported_history[-1][1] = assistant_message
84
+ else:
85
+ imported_history.append(["", assistant_message])
86
+
87
+ return imported_history, system_prompt
88
+ except Exception as e:
89
+ print(f"Error importing chat: {e}")
90
+ return None, None
91
+
92
+ def export_chat(history, system_prompt):
93
+ export_data = f"<|system|>\n{system_prompt}\n\n"
94
+ for user_msg, assistant_msg in history:
95
+ export_data += f"<|user|>\n{user_msg}\n\n"
96
+ if assistant_msg:
97
+ export_data += f"<|assistant|>\n{assistant_msg}\n\n"
98
+ return export_data
99
+
100
+ def stop_generation_func():
101
+ global stop_generation
102
+ stop_generation.set()
103
+
104
+
105
+ with gr.Blocks(theme='gradio/monochrome') as demo:
106
+ with gr.Row():
107
+ with gr.Column(scale=2):
108
+ chatbot = gr.Chatbot()
109
+ msg = gr.Textbox(label="Message")
110
+ with gr.Row():
111
+ clear = gr.Button("Clear")
112
+ regenerate = gr.Button("Regenerate")
113
+ stop_btn = gr.Button("Stop")
114
+ with gr.Row():
115
+ with gr.Column(scale=4):
116
+ import_textbox = gr.Textbox(label="Import textbox", lines=5)
117
+ with gr.Column(scale=1):
118
+ export_button = gr.Button("Export Chat")
119
+ import_button = gr.Button("Import Chat")
120
+
121
+ with gr.Column(scale=1):
122
+ system_prompt = gr.Textbox("", label="System Prompt", lines=5)
123
+ temperature = gr.Slider(0, 2, value=0.8, step=0.01, label="Temperature")
124
+ top_p = gr.Slider(0, 1, value=0.95, step=0.01, label="Top P")
125
+ top_k = gr.Slider(1, 500, value=40, step=1, label="Top K")
126
+ frequency_penalty = gr.Slider(-2, 2, value=0, step=0.1, label="Frequency Penalty")
127
+ presence_penalty = gr.Slider(-2, 2, value=0, step=0.1, label="Presence Penalty")
128
+ repetition_penalty = gr.Slider(0.01, 5, value=1.1, step=0.01, label="Repetition Penalty")
129
+ max_tokens = gr.Slider(1, 1024, value=256, step=1, label="Max Output (max_tokens)")
130
+
131
+ def user(user_message, history):
132
+ return "", history + [[user_message, None]]
133
+
134
+ def bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
135
+ global stop_generation
136
+ user_message = history[-1][0]
137
+ bot_message = predict(user_message, history[:-1], system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens)
138
+ history[-1][1] = ""
139
+ for chunk in bot_message:
140
+ if stop_generation.is_set():
141
+ history[-1][1] += " [Generation stopped]"
142
+ break
143
+ history[-1][1] = chunk
144
+ yield history
145
+ stop_generation.clear()
146
+
147
+ def regenerate_response(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
148
+ if len(history) > 0:
149
+ last_user_message = history[-1][0]
150
+ history[-1][1] = None
151
+ for new_history in bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
152
+ yield new_history
153
+ else:
154
+ yield history
155
+
156
+ def import_chat_wrapper(custom_format_string):
157
+ imported_history, imported_system_prompt = import_chat(custom_format_string)
158
+ return imported_history, imported_system_prompt
159
+
160
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
161
+ bot, [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens], chatbot
162
+ )
163
+
164
+ clear.click(lambda: None, None, chatbot, queue=False)
165
+
166
+ regenerate.click(
167
+ regenerate_response,
168
+ [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens],
169
+ chatbot
170
+ )
171
+
172
+ import_button.click(import_chat_wrapper, inputs=[import_textbox], outputs=[chatbot, system_prompt])
173
+
174
+ export_button.click(
175
+ export_chat,
176
+ inputs=[chatbot, system_prompt],
177
+ outputs=[import_textbox]
178
+ )
179
+
180
+ stop_btn.click(stop_generation_func, inputs=[], outputs=[])
181
+
182
+ if __name__ == "__main__":
183
+ demo.launch()