JohnSmith9982 commited on
Commit
eac8ac9
1 Parent(s): ded699d

Delete utils.py

Browse files
Files changed (1) hide show
  1. utils.py +0 -389
utils.py DELETED
@@ -1,389 +0,0 @@
1
- # -*- coding:utf-8 -*-
2
- from __future__ import annotations
3
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Type
4
- import logging
5
- import json
6
- import os
7
- import datetime
8
- import hashlib
9
- import csv
10
- import requests
11
- import re
12
-
13
- import gradio as gr
14
- from pypinyin import lazy_pinyin
15
- import tiktoken
16
- import mdtex2html
17
- from markdown import markdown
18
- from pygments import highlight
19
- from pygments.lexers import get_lexer_by_name
20
- from pygments.formatters import HtmlFormatter
21
-
22
- from presets import *
23
-
24
- # logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s")
25
-
26
- if TYPE_CHECKING:
27
- from typing import TypedDict
28
-
29
- class DataframeData(TypedDict):
30
- headers: List[str]
31
- data: List[List[str | int | bool]]
32
-
33
-
34
- def count_token(message):
35
- encoding = tiktoken.get_encoding("cl100k_base")
36
- input_str = f"role: {message['role']}, content: {message['content']}"
37
- length = len(encoding.encode(input_str))
38
- return length
39
-
40
-
41
- def markdown_to_html_with_syntax_highlight(md_str):
42
- def replacer(match):
43
- lang = match.group(1) or "text"
44
- code = match.group(2)
45
-
46
- try:
47
- lexer = get_lexer_by_name(lang, stripall=True)
48
- except ValueError:
49
- lexer = get_lexer_by_name("text", stripall=True)
50
-
51
- formatter = HtmlFormatter()
52
- highlighted_code = highlight(code, lexer, formatter)
53
-
54
- return f'<pre><code class="{lang}">{highlighted_code}</code></pre>'
55
-
56
- code_block_pattern = r"```(\w+)?\n([\s\S]+?)\n```"
57
- md_str = re.sub(code_block_pattern, replacer, md_str, flags=re.MULTILINE)
58
-
59
- html_str = markdown(md_str)
60
- return html_str
61
-
62
-
63
- def normalize_markdown(md_text: str) -> str:
64
- lines = md_text.split("\n")
65
- normalized_lines = []
66
- inside_list = False
67
-
68
- for i, line in enumerate(lines):
69
- if re.match(r"^(\d+\.|-|\*|\+)\s", line.strip()):
70
- if not inside_list and i > 0 and lines[i - 1].strip() != "":
71
- normalized_lines.append("")
72
- inside_list = True
73
- normalized_lines.append(line)
74
- elif inside_list and line.strip() == "":
75
- if i < len(lines) - 1 and not re.match(
76
- r"^(\d+\.|-|\*|\+)\s", lines[i + 1].strip()
77
- ):
78
- normalized_lines.append(line)
79
- continue
80
- else:
81
- inside_list = False
82
- normalized_lines.append(line)
83
-
84
- return "\n".join(normalized_lines)
85
-
86
-
87
- def convert_mdtext(md_text):
88
- code_block_pattern = re.compile(r"```(.*?)(?:```|$)", re.DOTALL)
89
- inline_code_pattern = re.compile(r"`(.*?)`", re.DOTALL)
90
- code_blocks = code_block_pattern.findall(md_text)
91
- non_code_parts = code_block_pattern.split(md_text)[::2]
92
-
93
- result = []
94
- for non_code, code in zip(non_code_parts, code_blocks + [""]):
95
- if non_code.strip():
96
- non_code = normalize_markdown(non_code)
97
- if inline_code_pattern.search(non_code):
98
- result.append(markdown(non_code, extensions=["tables"]))
99
- else:
100
- result.append(mdtex2html.convert(non_code, extensions=["tables"]))
101
- if code.strip():
102
- # _, code = detect_language(code) # 暂时去除代码高亮功能,因为在大段代码的情况下会出现问题
103
- # code = code.replace("\n\n", "\n") # 暂时去除代码中的空行,因为在大段代码的情况下会出现问题
104
- code = f"```{code}\n\n```"
105
- code = markdown_to_html_with_syntax_highlight(code)
106
- result.append(code)
107
- result = "".join(result)
108
- return result
109
-
110
- def convert_user(userinput):
111
- userinput = userinput.replace("\n", "<br>")
112
- return f"<pre>{userinput}</pre>"
113
-
114
- def detect_language(code):
115
- if code.startswith("\n"):
116
- first_line = ""
117
- else:
118
- first_line = code.strip().split("\n", 1)[0]
119
- language = first_line.lower() if first_line else ""
120
- code_without_language = code[len(first_line) :].lstrip() if first_line else code
121
- return language, code_without_language
122
-
123
-
124
- def construct_text(role, text):
125
- return {"role": role, "content": text}
126
-
127
-
128
- def construct_user(text):
129
- return construct_text("user", text)
130
-
131
-
132
- def construct_system(text):
133
- return construct_text("system", text)
134
-
135
-
136
- def construct_assistant(text):
137
- return construct_text("assistant", text)
138
-
139
-
140
- def construct_token_message(token, stream=False):
141
- return f"Token 计数: {token}"
142
-
143
-
144
- def delete_last_conversation(chatbot, history, previous_token_count):
145
- if len(chatbot) > 0 and standard_error_msg in chatbot[-1][1]:
146
- logging.info("由于包含报错信息,只删除chatbot记录")
147
- chatbot.pop()
148
- return chatbot, history
149
- if len(history) > 0:
150
- logging.info("删除了一组对话历史")
151
- history.pop()
152
- history.pop()
153
- if len(chatbot) > 0:
154
- logging.info("删除了一组chatbot对话")
155
- chatbot.pop()
156
- if len(previous_token_count) > 0:
157
- logging.info("删除了一组对话的token计数记录")
158
- previous_token_count.pop()
159
- return (
160
- chatbot,
161
- history,
162
- previous_token_count,
163
- construct_token_message(sum(previous_token_count)),
164
- )
165
-
166
-
167
- def save_file(filename, system, history, chatbot):
168
- logging.info("保存对话历史中……")
169
- os.makedirs(HISTORY_DIR, exist_ok=True)
170
- if filename.endswith(".json"):
171
- json_s = {"system": system, "history": history, "chatbot": chatbot}
172
- print(json_s)
173
- with open(os.path.join(HISTORY_DIR, filename), "w") as f:
174
- json.dump(json_s, f)
175
- elif filename.endswith(".md"):
176
- md_s = f"system: \n- {system} \n"
177
- for data in history:
178
- md_s += f"\n{data['role']}: \n- {data['content']} \n"
179
- with open(os.path.join(HISTORY_DIR, filename), "w", encoding="utf8") as f:
180
- f.write(md_s)
181
- logging.info("保存对话历史完毕")
182
- return os.path.join(HISTORY_DIR, filename)
183
-
184
-
185
- def save_chat_history(filename, system, history, chatbot):
186
- if filename == "":
187
- return
188
- if not filename.endswith(".json"):
189
- filename += ".json"
190
- return save_file(filename, system, history, chatbot)
191
-
192
-
193
- def export_markdown(filename, system, history, chatbot):
194
- if filename == "":
195
- return
196
- if not filename.endswith(".md"):
197
- filename += ".md"
198
- return save_file(filename, system, history, chatbot)
199
-
200
-
201
- def load_chat_history(filename, system, history, chatbot):
202
- logging.info("加载对话历史中……")
203
- if type(filename) != str:
204
- filename = filename.name
205
- try:
206
- with open(os.path.join(HISTORY_DIR, filename), "r") as f:
207
- json_s = json.load(f)
208
- try:
209
- if type(json_s["history"][0]) == str:
210
- logging.info("历史记录格式为旧版,正在转换……")
211
- new_history = []
212
- for index, item in enumerate(json_s["history"]):
213
- if index % 2 == 0:
214
- new_history.append(construct_user(item))
215
- else:
216
- new_history.append(construct_assistant(item))
217
- json_s["history"] = new_history
218
- logging.info(new_history)
219
- except:
220
- # 没有对话历史
221
- pass
222
- logging.info("加载对话历史完毕")
223
- return filename, json_s["system"], json_s["history"], json_s["chatbot"]
224
- except FileNotFoundError:
225
- logging.info("没有找到对话历史文件,不执行任何操作")
226
- return filename, system, history, chatbot
227
-
228
-
229
- def sorted_by_pinyin(list):
230
- return sorted(list, key=lambda char: lazy_pinyin(char)[0][0])
231
-
232
-
233
- def get_file_names(dir, plain=False, filetypes=[".json"]):
234
- logging.info(f"获取文件名列表,目录为{dir},文件类型为{filetypes},是否为纯文本列表{plain}")
235
- files = []
236
- try:
237
- for type in filetypes:
238
- files += [f for f in os.listdir(dir) if f.endswith(type)]
239
- except FileNotFoundError:
240
- files = []
241
- files = sorted_by_pinyin(files)
242
- if files == []:
243
- files = [""]
244
- if plain:
245
- return files
246
- else:
247
- return gr.Dropdown.update(choices=files)
248
-
249
-
250
- def get_history_names(plain=False):
251
- logging.info("获取历史记录文件名列表")
252
- return get_file_names(HISTORY_DIR, plain)
253
-
254
-
255
- def load_template(filename, mode=0):
256
- logging.info(f"加载模板文件{filename},模式为{mode}(0为返回字典和下拉菜单,1为返回下拉菜单,2为返回字典)")
257
- lines = []
258
- logging.info("Loading template...")
259
- if filename.endswith(".json"):
260
- with open(os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8") as f:
261
- lines = json.load(f)
262
- lines = [[i["act"], i["prompt"]] for i in lines]
263
- else:
264
- with open(
265
- os.path.join(TEMPLATES_DIR, filename), "r", encoding="utf8"
266
- ) as csvfile:
267
- reader = csv.reader(csvfile)
268
- lines = list(reader)
269
- lines = lines[1:]
270
- if mode == 1:
271
- return sorted_by_pinyin([row[0] for row in lines])
272
- elif mode == 2:
273
- return {row[0]: row[1] for row in lines}
274
- else:
275
- choices = sorted_by_pinyin([row[0] for row in lines])
276
- return {row[0]: row[1] for row in lines}, gr.Dropdown.update(
277
- choices=choices, value=choices[0]
278
- )
279
-
280
-
281
- def get_template_names(plain=False):
282
- logging.info("获取模板文件名列表")
283
- return get_file_names(TEMPLATES_DIR, plain, filetypes=[".csv", "json"])
284
-
285
-
286
- def get_template_content(templates, selection, original_system_prompt):
287
- logging.info(f"应用模板中,选择为{selection},原始系统提示为{original_system_prompt}")
288
- try:
289
- return templates[selection]
290
- except:
291
- return original_system_prompt
292
-
293
-
294
- def reset_state():
295
- logging.info("重置状态")
296
- return [], [], [], construct_token_message(0)
297
-
298
-
299
- def reset_textbox():
300
- return gr.update(value="")
301
-
302
-
303
- def reset_default():
304
- global API_URL
305
- API_URL = "https://api.openai.com/v1/chat/completions"
306
- os.environ.pop("HTTPS_PROXY", None)
307
- os.environ.pop("https_proxy", None)
308
- return gr.update(value=API_URL), gr.update(value=""), "API URL 和代理已重置"
309
-
310
-
311
- def change_api_url(url):
312
- global API_URL
313
- API_URL = url
314
- msg = f"API地址更改为了{url}"
315
- logging.info(msg)
316
- return msg
317
-
318
-
319
- def change_proxy(proxy):
320
- os.environ["HTTPS_PROXY"] = proxy
321
- msg = f"代理更改为了{proxy}"
322
- logging.info(msg)
323
- return msg
324
-
325
-
326
- def hide_middle_chars(s):
327
- if len(s) <= 8:
328
- return s
329
- else:
330
- head = s[:4]
331
- tail = s[-4:]
332
- hidden = "*" * (len(s) - 8)
333
- return head + hidden + tail
334
-
335
-
336
- def submit_key(key):
337
- key = key.strip()
338
- msg = f"API密钥更改为了{hide_middle_chars(key)}"
339
- logging.info(msg)
340
- return key, msg
341
-
342
-
343
- def sha1sum(filename):
344
- sha1 = hashlib.sha1()
345
- sha1.update(filename.encode("utf-8"))
346
- return sha1.hexdigest()
347
-
348
-
349
- def replace_today(prompt):
350
- today = datetime.datetime.today().strftime("%Y-%m-%d")
351
- return prompt.replace("{current_date}", today)
352
-
353
-
354
- def get_geoip():
355
- response = requests.get("https://ipapi.co/json/", timeout=5)
356
- try:
357
- data = response.json()
358
- except:
359
- data = {"error": True, "reason": "连接ipapi失败"}
360
- if "error" in data.keys():
361
- logging.warning(f"无法获取IP地址信息。\n{data}")
362
- if data["reason"] == "RateLimited":
363
- return (
364
- f"获取IP地理位置失败,因为达到了检测IP的速率限制。聊天功能可能仍然可用,但请注意,如果您的IP地址在不受支持的地区,您可能会遇到问题。"
365
- )
366
- else:
367
- return f"获取IP地理位置失败。原因:{data['reason']}。你仍然可以使用聊天功能。"
368
- else:
369
- country = data["country_name"]
370
- if country == "China":
371
- text = "**您的IP区域:中国。请立即检查代理设置,在不受支持的地区使用API可能导致账号被封禁。**"
372
- else:
373
- text = f"您的IP区域:{country}。"
374
- logging.info(text)
375
- return text
376
-
377
-
378
- def find_n(lst, max_num):
379
- n = len(lst)
380
- total = sum(lst)
381
-
382
- if total < max_num:
383
- return n
384
-
385
- for i in range(len(lst)):
386
- if total - lst[i] < max_num:
387
- return n - i -1
388
- total = total - lst[i]
389
- return 1