SourcezZ commited on
Commit
bfcd0ca
·
1 Parent(s): 9944fe6
Files changed (1) hide show
  1. app.py +351 -0
app.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import gradio as gr
3
+ import os
4
+ import sys
5
+ import requests
6
+ import csv
7
+
8
+ my_api_key = os.environ.get('MY_API_KEY') # 在这里输入你的 API 密钥
9
+ HIDE_MY_KEY = True # 如果你想在UI中隐藏你的 API 密钥,将此值设置为 True
10
+
11
+ initial_prompt = "You are a helpful assistant."
12
+ API_URL = "https://api.openai.com/v1/chat/completions"
13
+ HISTORY_DIR = "history"
14
+ TEMPLATES_DIR = "templates"
15
+ login_username = os.environ.get('LOGIN_USERNAME')
16
+ login_password = os.environ.get('LOGIN_PASSWORD')
17
+
18
+ # if we are running in Docker
19
+ if os.environ.get('dockerrun') == 'yes':
20
+ dockerflag = True
21
+ else:
22
+ dockerflag = False
23
+
24
+ if dockerflag:
25
+ my_api_key = os.environ.get('my_api_key')
26
+ if my_api_key == "empty":
27
+ print("Please give a api key!")
28
+ sys.exit(1)
29
+ # auth
30
+ username = os.environ.get('USERNAME')
31
+ password = os.environ.get('PASSWORD')
32
+ if isinstance(username, type(None)) or isinstance(password, type(None)):
33
+ authflag = False
34
+ else:
35
+ authflag = True
36
+
37
+
38
+ def parse_text(text):
39
+ lines = text.split("\n")
40
+ lines = [line for line in lines if line != ""]
41
+ count = 0
42
+ firstline = False
43
+ for i, line in enumerate(lines):
44
+ if "```" in line:
45
+ count += 1
46
+ items = line.split('`')
47
+ if count % 2 == 1:
48
+ lines[
49
+ i] = f'<pre><code class="{items[-1]}" style="display: block; white-space: pre; padding: 0 1em 1em 1em; color: #fff; background: #000;">'
50
+ firstline = True
51
+ else:
52
+ lines[i] = f'</code></pre>'
53
+ else:
54
+ if i > 0:
55
+ if count % 2 == 1:
56
+ line = line.replace("&", "&amp;")
57
+ line = line.replace("\"", "`\"`")
58
+ line = line.replace("\'", "`\'`")
59
+ line = line.replace("<", "&lt;")
60
+ line = line.replace(">", "&gt;")
61
+ line = line.replace(" ", "&nbsp;")
62
+ line = line.replace("*", "&ast;")
63
+ line = line.replace("_", "&lowbar;")
64
+ line = line.replace("#", "&#35;")
65
+ line = line.replace("-", "&#45;")
66
+ line = line.replace(".", "&#46;")
67
+ line = line.replace("!", "&#33;")
68
+ line = line.replace("(", "&#40;")
69
+ line = line.replace(")", "&#41;")
70
+ lines[i] = "<br>" + line
71
+ text = "".join(lines)
72
+ return text
73
+
74
+
75
+ def predict(inputs, top_p, temperature, openai_api_key, chatbot=[], history=[], system_prompt=initial_prompt,
76
+ retry=False, summary=False): # repetition_penalty, top_k
77
+
78
+ print(f"chatbot 1: {chatbot}")
79
+
80
+ headers = {
81
+ "Content-Type": "application/json",
82
+ "Authorization": f"Bearer {openai_api_key}"
83
+ }
84
+
85
+ chat_counter = len(history) // 2
86
+
87
+ print(f"chat_counter - {chat_counter}")
88
+
89
+ messages = [compose_system(system_prompt)]
90
+ if chat_counter:
91
+ for data in chatbot:
92
+ temp1 = {}
93
+ temp1["role"] = "user"
94
+ temp1["content"] = data[0]
95
+ temp2 = {}
96
+ temp2["role"] = "assistant"
97
+ temp2["content"] = data[1]
98
+ if temp1["content"] != "":
99
+ messages.append(temp1)
100
+ messages.append(temp2)
101
+ else:
102
+ messages[-1]['content'] = temp2['content']
103
+ if retry and chat_counter:
104
+ messages.pop()
105
+ elif summary:
106
+ messages.append(compose_user(
107
+ "请帮我总结一下上述对话的内容,实现减少字数的同时,保证对话的质量。在总结中不要加入这一句话。"))
108
+ history = ["我们刚刚聊了什么?"]
109
+ else:
110
+ temp3 = {}
111
+ temp3["role"] = "user"
112
+ temp3["content"] = inputs
113
+ messages.append(temp3)
114
+ chat_counter += 1
115
+ # messages
116
+ payload = {
117
+ "model": "gpt-3.5-turbo",
118
+ "messages": messages, # [{"role": "user", "content": f"{inputs}"}],
119
+ "temperature": temperature, # 1.0,
120
+ "top_p": top_p, # 1.0,
121
+ "n": 1,
122
+ "stream": True,
123
+ "presence_penalty": 0,
124
+ "frequency_penalty": 0,
125
+ }
126
+
127
+ if not summary:
128
+ history.append(inputs)
129
+ print(f"payload is - {payload}")
130
+ # make a POST request to the API endpoint using the requests.post method, passing in stream=True
131
+ response = requests.post(API_URL, headers=headers, json=payload, stream=True, proxies=proxies)
132
+ # response = requests.post(API_URL, headers=headers, json=payload, stream=True)
133
+
134
+ token_counter = 0
135
+ partial_words = ""
136
+
137
+ counter = 0
138
+ chatbot.append((history[-1], ""))
139
+ for chunk in response.iter_lines():
140
+ if counter == 0:
141
+ counter += 1
142
+ continue
143
+ counter += 1
144
+ # check whether each line is non-empty
145
+ if chunk:
146
+ # decode each line as response data is in bytes
147
+ try:
148
+ if len(json.loads(chunk.decode()[6:])['choices'][0]["delta"]) == 0:
149
+ break
150
+ except Exception as e:
151
+ chatbot.pop()
152
+ chatbot.append((history[-1], f"☹️发生了错误<br>返回值:{response.text}<br>异常:{e}"))
153
+ history.pop()
154
+ yield chatbot, history
155
+ break
156
+ # print(json.loads(chunk.decode()[6:])['choices'][0]["delta"] ["content"])
157
+ partial_words = partial_words + \
158
+ json.loads(chunk.decode()[6:])[
159
+ 'choices'][0]["delta"]["content"]
160
+ if token_counter == 0:
161
+ history.append(" " + partial_words)
162
+ else:
163
+ history[-1] = parse_text(partial_words)
164
+ chatbot[-1] = (history[-2], history[-1])
165
+ # chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2) ] # convert to tuples of list
166
+ token_counter += 1
167
+ # resembles {chatbot: chat, state: history}
168
+ yield chatbot, history
169
+
170
+
171
+ def delete_last_conversation(chatbot, history):
172
+ chatbot.pop()
173
+ history.pop()
174
+ history.pop()
175
+ return chatbot, history
176
+
177
+
178
+ def save_chat_history(filename, system, history, chatbot):
179
+ if filename == "":
180
+ return
181
+ if not filename.endswith(".json"):
182
+ filename += ".json"
183
+ os.makedirs(HISTORY_DIR, exist_ok=True)
184
+ json_s = {"system": system, "history": history, "chatbot": chatbot}
185
+ with open(os.path.join(HISTORY_DIR, filename), "w") as f:
186
+ json.dump(json_s, f)
187
+
188
+
189
+ def load_chat_history(filename):
190
+ with open(os.path.join(HISTORY_DIR, filename), "r") as f:
191
+ json_s = json.load(f)
192
+ return filename, json_s["system"], json_s["history"], json_s["chatbot"]
193
+
194
+
195
+ def get_file_names(dir, plain=False, filetype=".json"):
196
+ # find all json files in the current directory and return their names
197
+ try:
198
+ files = [f for f in os.listdir(dir) if f.endswith(filetype)]
199
+ except FileNotFoundError:
200
+ files = []
201
+ if plain:
202
+ return files
203
+ else:
204
+ return gr.Dropdown.update(choices=files)
205
+
206
+
207
+ def get_history_names(plain=False):
208
+ return get_file_names(HISTORY_DIR, plain)
209
+
210
+
211
+ def load_template(filename):
212
+ lines = []
213
+ with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8") as csvfile:
214
+ reader = csv.reader(csvfile)
215
+ lines = list(reader)
216
+ lines = lines[1:]
217
+ return {row[0]: row[1] for row in lines}, gr.Dropdown.update(choices=[row[0] for row in lines])
218
+
219
+
220
+ def get_template_names(plain=False):
221
+ return get_file_names(TEMPLATES_DIR, plain, filetype=".csv")
222
+
223
+
224
+ def reset_state():
225
+ return [], []
226
+
227
+
228
+ def compose_system(system_prompt):
229
+ return {"role": "system", "content": system_prompt}
230
+
231
+
232
+ def compose_user(user_input):
233
+ return {"role": "user", "content": user_input}
234
+
235
+
236
+ def reset_textbox():
237
+ return gr.update(value='')
238
+
239
+
240
+ title = """<h1 align="center">ChatGPT 🚀</h1>"""
241
+ description = """<div align=center>此App使用 `gpt-3.5-turbo` 大语言模型
242
+ </div>
243
+ """
244
+ with gr.Blocks() as demo:
245
+ gr.HTML(title)
246
+ keyTxt = gr.Textbox(show_label=True, placeholder=f"在这里输入你的OpenAI API-key...",
247
+ value=my_api_key, label="API Key", type="password", visible=not HIDE_MY_KEY).style(
248
+ container=True)
249
+ chatbot = gr.Chatbot() # .style(color_map=("#1D51EE", "#585A5B"))
250
+ history = gr.State([])
251
+ promptTemplates = gr.State({})
252
+ TRUECOMSTANT = gr.State(True)
253
+ FALSECONSTANT = gr.State(False)
254
+ topic = gr.State("未命名对话历史记录")
255
+
256
+ with gr.Row():
257
+ with gr.Column(scale=12):
258
+ txt = gr.Textbox(show_label=False, placeholder="在这里输入").style(
259
+ container=False)
260
+ with gr.Column(min_width=50, scale=1):
261
+ submitBtn = gr.Button("🚀", variant="primary")
262
+ with gr.Row():
263
+ emptyBtn = gr.Button("🧹 新的对话")
264
+ retryBtn = gr.Button("🔄 重新生成")
265
+ delLastBtn = gr.Button("🗑️ 删除上条对话")
266
+ reduceTokenBtn = gr.Button("♻️ 总结对话")
267
+ systemPromptTxt = gr.Textbox(show_label=True, placeholder=f"在这里输入System Prompt...",
268
+ label="System prompt", value=initial_prompt).style(container=True)
269
+ with gr.Accordion(label="加载Prompt模板", open=False):
270
+ with gr.Column():
271
+ with gr.Row():
272
+ with gr.Column(scale=6):
273
+ templateFileSelectDropdown = gr.Dropdown(label="选择Prompt模板集合文件(.csv)",
274
+ choices=get_template_names(plain=True), multiselect=False)
275
+ with gr.Column(scale=1):
276
+ templateRefreshBtn = gr.Button("🔄 刷新")
277
+ templaeFileReadBtn = gr.Button("📂 读入模板")
278
+ with gr.Row():
279
+ with gr.Column(scale=6):
280
+ templateSelectDropdown = gr.Dropdown(label="从Prompt模板中加载", choices=[], multiselect=False)
281
+ with gr.Column(scale=1):
282
+ templateApplyBtn = gr.Button("⬇️ 应用")
283
+ with gr.Accordion(
284
+ label="保存/加载对话历史记录(在文本框中输入文件名,点击“保存对话”按钮,历史记录文件会被存储到Python文件旁边)",
285
+ open=False):
286
+ with gr.Column():
287
+ with gr.Row():
288
+ with gr.Column(scale=6):
289
+ saveFileName = gr.Textbox(
290
+ show_label=True, placeholder=f"在这里输入保存的文件名...", label="设置保存文件名",
291
+ value="对话历史记录").style(container=True)
292
+ with gr.Column(scale=1):
293
+ saveBtn = gr.Button("💾 保存对话")
294
+ with gr.Row():
295
+ with gr.Column(scale=6):
296
+ historyFileSelectDropdown = gr.Dropdown(label="从列表中加载对话",
297
+ choices=get_history_names(plain=True), multiselect=False)
298
+ with gr.Column(scale=1):
299
+ historyRefreshBtn = gr.Button("🔄 刷新")
300
+ historyReadBtn = gr.Button("📂 读入对话")
301
+ # inputs, top_p, temperature, top_k, repetition_penalty
302
+ with gr.Accordion("参数", open=False):
303
+ top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05,
304
+ interactive=True, label="Top-p (nucleus sampling)", )
305
+ temperature = gr.Slider(minimum=-0, maximum=5.0, value=1.0,
306
+ step=0.1, interactive=True, label="Temperature", )
307
+ # top_k = gr.Slider( minimum=1, maximum=50, value=4, step=1, interactive=True, label="Top-k",)
308
+ # repetition_penalty = gr.Slider( minimum=0.1, maximum=3.0, value=1.03, step=0.01, interactive=True, label="Repetition Penalty", )
309
+ gr.Markdown(description)
310
+
311
+ txt.submit(predict, [txt, top_p, temperature, keyTxt,
312
+ chatbot, history, systemPromptTxt], [chatbot, history])
313
+ txt.submit(reset_textbox, [], [txt])
314
+ submitBtn.click(predict, [txt, top_p, temperature, keyTxt, chatbot,
315
+ history, systemPromptTxt], [chatbot, history], show_progress=True)
316
+ submitBtn.click(reset_textbox, [], [txt])
317
+ emptyBtn.click(reset_state, outputs=[chatbot, history])
318
+ retryBtn.click(predict, [txt, top_p, temperature, keyTxt, chatbot, history,
319
+ systemPromptTxt, TRUECOMSTANT], [chatbot, history], show_progress=True)
320
+ delLastBtn.click(delete_last_conversation, [chatbot, history], [
321
+ chatbot, history], show_progress=True)
322
+ reduceTokenBtn.click(predict, [txt, top_p, temperature, keyTxt, chatbot, history,
323
+ systemPromptTxt, FALSECONSTANT, TRUECOMSTANT], [chatbot, history],
324
+ show_progress=True)
325
+ saveBtn.click(save_chat_history, [
326
+ saveFileName, systemPromptTxt, history, chatbot], None, show_progress=True)
327
+ saveBtn.click(get_history_names, None, [historyFileSelectDropdown])
328
+ historyRefreshBtn.click(get_history_names, None, [historyFileSelectDropdown])
329
+ historyReadBtn.click(load_chat_history, [historyFileSelectDropdown],
330
+ [saveFileName, systemPromptTxt, history, chatbot], show_progress=True)
331
+ templateRefreshBtn.click(get_template_names, None, [templateFileSelectDropdown])
332
+ templaeFileReadBtn.click(load_template, [templateFileSelectDropdown], [promptTemplates, templateSelectDropdown],
333
+ show_progress=True)
334
+ templateApplyBtn.click(lambda x, y: x[y], [promptTemplates, templateSelectDropdown], [systemPromptTxt],
335
+ show_progress=True)
336
+
337
+ print("温馨提示:访问 http://localhost:7860 查看界面")
338
+ # 默认开启本地服务器,默认可以直接从IP访问,默认不创建公开分享链接
339
+ demo.title = "ChatGPT 🚀"
340
+
341
+ # if running in Docker
342
+ if dockerflag:
343
+ if authflag:
344
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860, auth=(username, password))
345
+ else:
346
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=False)
347
+ # if not running in Docker
348
+ else:
349
+ # demo.queue().launch(share=False) # 改为 share=True 可以创建公开分享链接
350
+ # demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=False) # 可自定义端口
351
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860, auth=(login_username, login_password)) # 可设置用户名与密码