Spaces:
Runtime error
Runtime error
JohnSmith9982
commited on
Commit
•
eac8ac9
1
Parent(s):
ded699d
Delete utils.py
Browse files
utils.py
DELETED
@@ -1,389 +0,0 @@
|
|
1 |
-
# -*- coding:utf-8 -*-
|
2 |
-
from __future__ import annotations
|
3 |
-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Type
|
4 |
-
import logging
|
5 |
-
import json
|
6 |
-
import os
|
7 |
-
import datetime
|
8 |
-
import hashlib
|
9 |
-
import csv
|
10 |
-
import requests
|
11 |
-
import re
|
12 |
-
|
13 |
-
import gradio as gr
|
14 |
-
from pypinyin import lazy_pinyin
|
15 |
-
import tiktoken
|
16 |
-
import mdtex2html
|
17 |
-
from markdown import markdown
|
18 |
-
from pygments import highlight
|
19 |
-
from pygments.lexers import get_lexer_by_name
|
20 |
-
from pygments.formatters import HtmlFormatter
|
21 |
-
|
22 |
-
from presets import *
|
23 |
-
|
24 |
-
# logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
|
25 |
-
|
26 |
-
if TYPE_CHECKING:
|
27 |
-
from typing import TypedDict
|
28 |
-
|
29 |
-
class DataframeData(TypedDict):
|
30 |
-
headers: List[str]
|
31 |
-
data: List[List[str | int | bool]]
|
32 |
-
|
33 |
-
|
34 |
-
def count_token(message):
|
35 |
-
encoding = tiktoken.get_encoding("cl100k_base")
|
36 |
-
input_str = f"role: {message['role']}, content: {message['content']}"
|
37 |
-
length = len(encoding.encode(input_str))
|
38 |
-
return length
|
39 |
-
|
40 |
-
|
41 |
-
def markdown_to_html_with_syntax_highlight(md_str):
|
42 |
-
def replacer(match):
|
43 |
-
lang = match.group(1) or "text"
|
44 |
-
code = match.group(2)
|
45 |
-
|
46 |
-
try:
|
47 |
-
lexer = get_lexer_by_name(lang, stripall=True)
|
48 |
-
except ValueError:
|
49 |
-
lexer = get_lexer_by_name("text", stripall=True)
|
50 |
-
|
51 |
-
formatter = HtmlFormatter()
|
52 |
-
highlighted_code = highlight(code, lexer, formatter)
|
53 |
-
|
54 |
-
return f'<pre><code class="{lang}">{highlighted_code}</code></pre>'
|
55 |
-
|
56 |
-
code_block_pattern = r"```(\w+)?\n([\s\S]+?)\n```"
|
57 |
-
md_str = re.sub(code_block_pattern, replacer, md_str, flags=re.MULTILINE)
|
58 |
-
|
59 |
-
html_str = markdown(md_str)
|
60 |
-
return html_str
|
61 |
-
|
62 |
-
|
63 |
-
def normalize_markdown(md_text: str) -> str:
|
64 |
-
lines = md_text.split("\n")
|
65 |
-
normalized_lines = []
|
66 |
-
inside_list = False
|
67 |
-
|
68 |
-
for i, line in enumerate(lines):
|
69 |
-
if re.match(r"^(\d+\.|-|\*|\+)\s", line.strip()):
|
70 |
-
if not inside_list and i > 0 and lines[i - 1].strip() != "":
|
71 |
-
normalized_lines.append("")
|
72 |
-
inside_list = True
|
73 |
-
normalized_lines.append(line)
|
74 |
-
elif inside_list and line.strip() == "":
|
75 |
-
if i < len(lines) - 1 and not re.match(
|
76 |
-
r"^(\d+\.|-|\*|\+)\s", lines[i + 1].strip()
|
77 |
-
):
|
78 |
-
normalized_lines.append(line)
|
79 |
-
continue
|
80 |
-
else:
|
81 |
-
inside_list = False
|
82 |
-
normalized_lines.append(line)
|
83 |
-
|
84 |
-
return "\n".join(normalized_lines)
|
85 |
-
|
86 |
-
|
87 |
-
def convert_mdtext(md_text):
|
88 |
-
code_block_pattern = re.compile(r"```(.*?)(?:```|$)", re.DOTALL)
|
89 |
-
inline_code_pattern = re.compile(r"`(.*?)`", re.DOTALL)
|
90 |
-
code_blocks = code_block_pattern.findall(md_text)
|
91 |
-
non_code_parts = code_block_pattern.split(md_text)[::2]
|
92 |
-
|
93 |
-
result = []
|
94 |
-
for non_code, code in zip(non_code_parts, code_blocks + [""]):
|
95 |
-
if non_code.strip():
|
96 |
-
non_code = normalize_markdown(non_code)
|
97 |
-
if inline_code_pattern.search(non_code):
|
98 |
-
result.append(markdown(non_code, extensions=["tables"]))
|
99 |
-
else:
|
100 |
-
result.append(mdtex2html.convert(non_code, extensions=["tables"]))
|
101 |
-
if code.strip():
|
102 |
-
# _, code = detect_language(code) # 暂时去除代码高亮功能,因为在大段代码的情况下会出现问题
|
103 |
-
# code = code.replace("\n\n", "\n") # 暂时去除代码中的空行,因为在大段代码的情况下会出现问题
|
104 |
-
code = f"```{code}\n\n```"
|
105 |
-
code = markdown_to_html_with_syntax_highlight(code)
|
106 |
-
result.append(code)
|
107 |
-
result = "".join(result)
|
108 |
-
return result
|
109 |
-
|
110 |
-
def convert_user(userinput):
|
111 |
-
userinput = userinput.replace("\n", "<br>")
|
112 |
-
return f"<pre>{userinput}</pre>"
|
113 |
-
|
114 |
-
def detect_language(code):
|
115 |
-
if code.startswith("\n"):
|
116 |
-
first_line = ""
|
117 |
-
else:
|
118 |
-
first_line = code.strip().split("\n", 1)[0]
|
119 |
-
language = first_line.lower() if first_line else ""
|
120 |
-
code_without_language = code[len(first_line) :].lstrip() if first_line else code
|
121 |
-
return language, code_without_language
|
122 |
-
|
123 |
-
|
124 |
-
def construct_text(role, text):
|
125 |
-
return {"role": role, "content": text}
|
126 |
-
|
127 |
-
|
128 |
-
def construct_user(text):
|
129 |
-
return construct_text("user", text)
|
130 |
-
|
131 |
-
|
132 |
-
def construct_system(text):
|
133 |
-
return construct_text("system", text)
|
134 |
-
|
135 |
-
|
136 |
-
def construct_assistant(text):
|
137 |
-
return construct_text("assistant", text)
|
138 |
-
|
139 |
-
|
140 |
-
def construct_token_message(token, stream=False):
|
141 |
-
return f"Token 计数: {token}"
|
142 |
-
|
143 |
-
|
144 |
-
def delete_last_conversation(chatbot, history, previous_token_count):
|
145 |
-
if len(chatbot) > 0 and standard_error_msg in chatbot[-1][1]:
|
146 |
-
logging.info("由于包含报错信息,只删除chatbot记录")
|
147 |
-
chatbot.pop()
|
148 |
-
return chatbot, history
|
149 |
-
if len(history) > 0:
|
150 |
-
logging.info("删除了一组对话历史")
|
151 |
-
history.pop()
|
152 |
-
history.pop()
|
153 |
-
if len(chatbot) > 0:
|
154 |
-
logging.info("删除了一组chatbot对话")
|
155 |
-
chatbot.pop()
|
156 |
-
if len(previous_token_count) > 0:
|
157 |
-
logging.info("删除了一组对话的token计数记录")
|
158 |
-
previous_token_count.pop()
|
159 |
-
return (
|
160 |
-
chatbot,
|
161 |
-
history,
|
162 |
-
previous_token_count,
|
163 |
-
construct_token_message(sum(previous_token_count)),
|
164 |
-
)
|
165 |
-
|
166 |
-
|
167 |
-
def save_file(filename, system, history, chatbot):
|
168 |
-
logging.info("保存对话历史中……")
|
169 |
-
os.makedirs(HISTORY_DIR, exist_ok=True)
|
170 |
-
if filename.endswith(".json"):
|
171 |
-
json_s = {"system": system, "history": history, "chatbot": chatbot}
|
172 |
-
print(json_s)
|
173 |
-
with open(os.path.join(HISTORY_DIR, filename), "w") as f:
|
174 |
-
json.dump(json_s, f)
|
175 |
-
elif filename.endswith(".md"):
|
176 |
-
md_s = f"system: \n- {system} \n"
|
177 |
-
for data in history:
|
178 |
-
md_s += f"\n{data['role']}: \n- {data['content']} \n"
|
179 |
-
with open(os.path.join(HISTORY_DIR, filename), "w", encoding="utf8") as f:
|
180 |
-
f.write(md_s)
|
181 |
-
logging.info("保存对话历史完毕")
|
182 |
-
return os.path.join(HISTORY_DIR, filename)
|
183 |
-
|
184 |
-
|
185 |
-
def save_chat_history(filename, system, history, chatbot):
|
186 |
-
if filename == "":
|
187 |
-
return
|
188 |
-
if not filename.endswith(".json"):
|
189 |
-
filename += ".json"
|
190 |
-
return save_file(filename, system, history, chatbot)
|
191 |
-
|
192 |
-
|
193 |
-
def export_markdown(filename, system, history, chatbot):
|
194 |
-
if filename == "":
|
195 |
-
return
|
196 |
-
if not filename.endswith(".md"):
|
197 |
-
filename += ".md"
|
198 |
-
return save_file(filename, system, history, chatbot)
|
199 |
-
|
200 |
-
|
201 |
-
def load_chat_history(filename, system, history, chatbot):
|
202 |
-
logging.info("加载对话历史中……")
|
203 |
-
if type(filename) != str:
|
204 |
-
filename = filename.name
|
205 |
-
try:
|
206 |
-
with open(os.path.join(HISTORY_DIR, filename), "r") as f:
|
207 |
-
json_s = json.load(f)
|
208 |
-
try:
|
209 |
-
if type(json_s["history"][0]) == str:
|
210 |
-
logging.info("历史记录格式为旧版,正在转换……")
|
211 |
-
new_history = []
|
212 |
-
for index, item in enumerate(json_s["history"]):
|
213 |
-
if index % 2 == 0:
|
214 |
-
new_history.append(construct_user(item))
|
215 |
-
else:
|
216 |
-
new_history.append(construct_assistant(item))
|
217 |
-
json_s["history"] = new_history
|
218 |
-
logging.info(new_history)
|
219 |
-
except:
|
220 |
-
# 没有对话历史
|
221 |
-
pass
|
222 |
-
logging.info("加载对话历史完毕")
|
223 |
-
return filename, json_s["system"], json_s["history"], json_s["chatbot"]
|
224 |
-
except FileNotFoundError:
|
225 |
-
logging.info("没有找到对话历史文件,不执行任何操作")
|
226 |
-
return filename, system, history, chatbot
|
227 |
-
|
228 |
-
|
229 |
-
def sorted_by_pinyin(list):
|
230 |
-
return sorted(list, key=lambda char: lazy_pinyin(char)[0][0])
|
231 |
-
|
232 |
-
|
233 |
-
def get_file_names(dir, plain=False, filetypes=[".json"]):
|
234 |
-
logging.info(f"获取文件名列表,目录为{dir},文件类型为{filetypes},是否为纯文本列表{plain}")
|
235 |
-
files = []
|
236 |
-
try:
|
237 |
-
for type in filetypes:
|
238 |
-
files += [f for f in os.listdir(dir) if f.endswith(type)]
|
239 |
-
except FileNotFoundError:
|
240 |
-
files = []
|
241 |
-
files = sorted_by_pinyin(files)
|
242 |
-
if files == []:
|
243 |
-
files = [""]
|
244 |
-
if plain:
|
245 |
-
return files
|
246 |
-
else:
|
247 |
-
return gr.Dropdown.update(choices=files)
|
248 |
-
|
249 |
-
|
250 |
-
def get_history_names(plain=False):
|
251 |
-
logging.info("获取历史记录文件名列表")
|
252 |
-
return get_file_names(HISTORY_DIR, plain)
|
253 |
-
|
254 |
-
|
255 |
-
def load_template(filename, mode=0):
|
256 |
-
logging.info(f"加载模板文件{filename},模式为{mode}(0为返回字典和下拉菜单,1为返回下拉菜单,2为返回字典)")
|
257 |
-
lines = []
|
258 |
-
logging.info("Loading template...")
|
259 |
-
if filename.endswith(".json"):
|
260 |
-
with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8") as f:
|
261 |
-
lines = json.load(f)
|
262 |
-
lines = [[i["act"], i["prompt"]] for i in lines]
|
263 |
-
else:
|
264 |
-
with open(
|
265 |
-
os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8"
|
266 |
-
) as csvfile:
|
267 |
-
reader = csv.reader(csvfile)
|
268 |
-
lines = list(reader)
|
269 |
-
lines = lines[1:]
|
270 |
-
if mode == 1:
|
271 |
-
return sorted_by_pinyin([row[0] for row in lines])
|
272 |
-
elif mode == 2:
|
273 |
-
return {row[0]: row[1] for row in lines}
|
274 |
-
else:
|
275 |
-
choices = sorted_by_pinyin([row[0] for row in lines])
|
276 |
-
return {row[0]: row[1] for row in lines}, gr.Dropdown.update(
|
277 |
-
choices=choices, value=choices[0]
|
278 |
-
)
|
279 |
-
|
280 |
-
|
281 |
-
def get_template_names(plain=False):
|
282 |
-
logging.info("获取模板文件名列表")
|
283 |
-
return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
|
284 |
-
|
285 |
-
|
286 |
-
def get_template_content(templates, selection, original_system_prompt):
|
287 |
-
logging.info(f"应用模板中,选择为{selection},原始系统提示为{original_system_prompt}")
|
288 |
-
try:
|
289 |
-
return templates[selection]
|
290 |
-
except:
|
291 |
-
return original_system_prompt
|
292 |
-
|
293 |
-
|
294 |
-
def reset_state():
|
295 |
-
logging.info("重置状态")
|
296 |
-
return [], [], [], construct_token_message(0)
|
297 |
-
|
298 |
-
|
299 |
-
def reset_textbox():
|
300 |
-
return gr.update(value="")
|
301 |
-
|
302 |
-
|
303 |
-
def reset_default():
|
304 |
-
global API_URL
|
305 |
-
API_URL = "https://api.openai.com/v1/chat/completions"
|
306 |
-
os.environ.pop("HTTPS_PROXY", None)
|
307 |
-
os.environ.pop("https_proxy", None)
|
308 |
-
return gr.update(value=API_URL), gr.update(value=""), "API URL 和代理已重置"
|
309 |
-
|
310 |
-
|
311 |
-
def change_api_url(url):
|
312 |
-
global API_URL
|
313 |
-
API_URL = url
|
314 |
-
msg = f"API地址更改为了{url}"
|
315 |
-
logging.info(msg)
|
316 |
-
return msg
|
317 |
-
|
318 |
-
|
319 |
-
def change_proxy(proxy):
|
320 |
-
os.environ["HTTPS_PROXY"] = proxy
|
321 |
-
msg = f"代理更改为了{proxy}"
|
322 |
-
logging.info(msg)
|
323 |
-
return msg
|
324 |
-
|
325 |
-
|
326 |
-
def hide_middle_chars(s):
|
327 |
-
if len(s) <= 8:
|
328 |
-
return s
|
329 |
-
else:
|
330 |
-
head = s[:4]
|
331 |
-
tail = s[-4:]
|
332 |
-
hidden = "*" * (len(s) - 8)
|
333 |
-
return head + hidden + tail
|
334 |
-
|
335 |
-
|
336 |
-
def submit_key(key):
|
337 |
-
key = key.strip()
|
338 |
-
msg = f"API密钥更改为了{hide_middle_chars(key)}"
|
339 |
-
logging.info(msg)
|
340 |
-
return key, msg
|
341 |
-
|
342 |
-
|
343 |
-
def sha1sum(filename):
|
344 |
-
sha1 = hashlib.sha1()
|
345 |
-
sha1.update(filename.encode("utf-8"))
|
346 |
-
return sha1.hexdigest()
|
347 |
-
|
348 |
-
|
349 |
-
def replace_today(prompt):
|
350 |
-
today = datetime.datetime.today().strftime("%Y-%m-%d")
|
351 |
-
return prompt.replace("{current_date}", today)
|
352 |
-
|
353 |
-
|
354 |
-
def get_geoip():
|
355 |
-
response = requests.get("https://ipapi.co/json/", timeout=5)
|
356 |
-
try:
|
357 |
-
data = response.json()
|
358 |
-
except:
|
359 |
-
data = {"error": True, "reason": "连接ipapi失败"}
|
360 |
-
if "error" in data.keys():
|
361 |
-
logging.warning(f"无法获取IP地址信息。\n{data}")
|
362 |
-
if data["reason"] == "RateLimited":
|
363 |
-
return (
|
364 |
-
f"获取IP地理位置失败,因为达到了检测IP的速率限制。聊天功能可能仍然可用,但请注意,如果您的IP地址在不受支持的地区,您可能会遇到问题。"
|
365 |
-
)
|
366 |
-
else:
|
367 |
-
return f"获取IP地理位置失败。原因:{data['reason']}。你仍然可以使用聊天功能。"
|
368 |
-
else:
|
369 |
-
country = data["country_name"]
|
370 |
-
if country == "China":
|
371 |
-
text = "**您的IP区域:中国。请立即检查代理设置,在不受支持的地区使用API可能导致账号被封禁。**"
|
372 |
-
else:
|
373 |
-
text = f"您的IP区域:{country}。"
|
374 |
-
logging.info(text)
|
375 |
-
return text
|
376 |
-
|
377 |
-
|
378 |
-
def find_n(lst, max_num):
|
379 |
-
n = len(lst)
|
380 |
-
total = sum(lst)
|
381 |
-
|
382 |
-
if total < max_num:
|
383 |
-
return n
|
384 |
-
|
385 |
-
for i in range(len(lst)):
|
386 |
-
if total - lst[i] < max_num:
|
387 |
-
return n - i -1
|
388 |
-
total = total - lst[i]
|
389 |
-
return 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|