Aniun commited on
Commit
d324047
·
verified ·
1 Parent(s): 8ecfdcf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +259 -12
app.py CHANGED
@@ -10,43 +10,268 @@ import os
10
  current_file_path = os.path.dirname(os.path.abspath(__file__))
11
  root_path = os.path.abspath(current_file_path)
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  class RepoSearch:
14
  def __init__(self):
15
 
16
  db_path = os.path.join(root_path, "database", "faiss_index")
17
- print(db_path)
18
- # db_path = root_path
19
- embeddings = OpenAIEmbeddings(api_key="sk-Mo5K9m2hKXjV1DeGeBAIzXLZFxxiOTvSwUoemKmfMXdmE9Bs",
20
- base_url="https://api.wlai.vip/v1",
21
  model="text-embedding-3-small")
22
- print("embeddings already")
23
 
24
  assert os.path.exists(db_path), f"Database not found: {db_path}"
25
  self.vector_db = FAISS.load_local(db_path, embeddings,
26
  allow_dangerous_deserialization=True)
27
- print("vector_db already")
28
- # pass
29
 
30
- def search(self, query, k=20):
31
  '''
32
  name + description + html_url + topics
33
  '''
34
- # return "sss"
35
  results = self.vector_db.similarity_search(query + " technology", k=k)
36
 
37
  simple_str = ""
 
38
  for i, doc in enumerate(results):
39
  content = json.loads(doc.page_content)
40
  if content["description"] is None:
41
  content["description"] = ""
42
  desc = content["description"] if len(content["description"]) < 300 else content["description"][:300] + "..."
43
  simple_str += f"\t**{i+1}. {content['name']}** || **Description:** {desc} || **Url:** {content['html_url']} \n"
 
 
 
 
 
44
 
45
- return simple_str
46
-
47
 
48
  def main():
49
  search = RepoSearch()
 
50
 
51
  def respond(
52
  prompt: str,
@@ -59,7 +284,29 @@ def main():
59
  yield history
60
 
61
  response = {"role": "assistant", "content": ""}
62
- response["content"] = search.search(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  yield history + [response]
64
 
65
  with gr.Blocks() as demo:
 
10
  current_file_path = os.path.dirname(os.path.abspath(__file__))
11
  root_path = os.path.abspath(current_file_path)
12
 
13
+ from datetime import datetime, time
14
+ from textwrap import dedent
15
+
16
+ from langchain_openai import ChatOpenAI
17
+ from langchain_core.prompts import ChatPromptTemplate
18
+
19
+ import re
20
+
21
+
22
+ # 获取当前文件位置
23
+ current_path = os.path.dirname(os.path.abspath(__file__))
24
+
25
+ class OurLLM():
26
+ def __init__(self, model="gpt-4o"):
27
+
28
+ self.base_url = os.environ["OPENAI_BASE_URL"]
29
+ self.api_key = os.environ["OPENAI_API_KEY"]
30
+
31
+ # model: str, 模型名称 ["gpt-4o-mini", "gpt-4o", "o1-mini", "gemini-1.5-flash-002", "gemini-1.5-pro-002"]
32
+ self.big_model = "gpt-4o"
33
+
34
+ chat_prompt = ChatPromptTemplate.from_messages(
35
+ [
36
+ ("system", "{system_prompt}"),
37
+ ("user", "{input}"),
38
+ ]
39
+ )
40
+
41
+ self.chat_prompt = chat_prompt
42
+ self.llm = self.get_llm(self.big_model)
43
+
44
+ # 2. 获取指定仓库的 README 内容
45
+ def clean_json(self, s):
46
+ return s.replace("```json", "").replace("```", "")
47
+
48
+ def get_system_prompt(self, mode="assistant"):
49
+ prompt_map = {
50
+ "assistant": dedent("""
51
+ 你是一个智能助手,擅长用简洁的中文回答用户的问题。
52
+ 请确保你的回答准确、清晰、有条理,并且符合中文的语言习惯。
53
+ 重要提示:
54
+ 1. 回答要简洁明了,避免冗长
55
+ 2. 使用适当的专业术语
56
+ 3. 保持客观中立的语气
57
+ 4. 如果不确定,要明确指出
58
+ """),
59
+ # paper
60
+ "keyword_expand": dedent("""
61
+ 你是一个搜索关键词扩展专家,擅长将用户的搜索意图转化为多个相关的搜索词或短语。
62
+ 用户会输入一段描述他们搜索需求的文本,请你生成与之相关的关键词列表。
63
+ 你需要返回一个可以直接被 json 库解析的响应,不要使用任何 markdown 格式,包含以下内容:
64
+ {
65
+ 'keywords': [关键词列表],
66
+ }
67
+ 重要提示:
68
+ 1. 关键词应该包含同义词、近义词、上位词、下位词
69
+ 2. 短语要体现不同的表达方式和组合
70
+ 3. 描述句子要涵盖不同的应用场景和用途
71
+ 4. 所有内容必须与原始搜索意图高度相关
72
+ 5. 扩展搜索意图到相关的应用场景和工具,例如:
73
+ - 如果搜索"PDF转MD",应包含PDF内容提取、PDF解析工具、PDF数据处理等
74
+ - 如果搜索"图片压缩",应包含批量压缩工具、图片格式转换等
75
+ - 如果搜索"代码格式化",应包含代码美化工具、语法检查器、代码风格统一等
76
+ - 如果搜索"文本翻译",应包含机器翻译API、多语言翻译工具、离线翻译软件等
77
+ - 如果搜索"数据可视化",应包含图表生成工具、数据分析库、交互式图表等
78
+ - 如果搜索"网络爬虫",应包含数据采集框架、反爬虫绕过、数据解析工具等
79
+ - 如果搜索"API测试",应包含接口测试工具、性能监控、自动化测试框架等
80
+ 6. 所有内容主要使用英文表达,并对部分关键词添加额外的中文表示
81
+ """),
82
+ "github_match": dedent("""
83
+ 你是一个仓库匹配专家,擅长根据用户需求从多个仓库中选择最合适的仓库。
84
+ 用户会输入两部分内容:
85
+ 1. 用户的具体需求描述
86
+ 2. 多个仓库的描述列表(以1,2,3等数字开头)
87
+
88
+ 请你仔细分析用户需求,并从仓库列表中选择最符合需求的仓库。
89
+ 确保返回一个可以直接被 json 库解析的响应,不要使用任何 markdown 格式,尤其是不要使用 ```json 格式,包含以下内容:
90
+ {
91
+ "matched_repos": [匹配到的仓库编号列表,按相关度从高到低排序],
92
+ "match_scores": [对应的匹配度评分列表,0-100的整数,表示匹配程度]
93
+ }
94
+
95
+ 重要提示:
96
+ 1. 如果所有仓库都不相关,只返回编号1
97
+ 2. 匹配度评分要客观反映仓库与需求的契合度
98
+ 3. 返回的仓库数量不能超过输入仓库总数的一半
99
+ 4. 所有内容必须使用中文表达
100
+ """),
101
+ "github_score": dedent("""
102
+ 你是一个仓库评分专家,擅长根据用户需求对仓库进行评分。
103
+ 用户会输入两部分内容:
104
+ 1. 用户的具体需求描述
105
+ 2. 多个仓库的描述列表(以1,2,3等数字开头)
106
+
107
+ 请你仔细分析用户需求,并对每个仓库进行评分。
108
+ 确保返回一个可以直接被 json 库解析的响应,不要使用任何 markdown 格式,包含以下内容:
109
+ {
110
+ 'indices': [仓库编号列表,按分数从高到低],
111
+ 'scores': [编号对应的匹配度评分列表,0-100的整数,表示匹配程度]
112
+ }
113
+
114
+ 重要提示:
115
+ 2. 评分范围为0-100的整数,高于60分表示具有明显相关性
116
+ 1. 评分要客观反映仓库与需求的契合度
117
+ 3. 只返回评分大于 60 的仓库
118
+ """),
119
+ "title_class": dedent("""
120
+ 你是一个内容分类专家,用户会把一些论文的题目通过多行的形式发给你,
121
+ 请你使用尽可能简短的中文将所有题目分为不超过十类,并描述每个类别的名称以及其对应的文章数量,
122
+ 注意你不用重复论文题目,只需要给出类别名称以及数量即可即可。
123
+ """),
124
+ "summary_struct": dedent("""
125
+ 你是一个论文总结专家,同时也是一个翻译专家,对中文有深入的了解,包括词汇、语法和修辞技巧,
126
+ 能够深入分析所给英文内容的含义,可以将准备回复给用户的中文内容表示的流畅且符合中文语法习惯。
127
+ 用户将会论文的标题以及摘要通过 json 格式发给你,请你使用尽可能简短的中文按照下列要求对所给内容进行总结:
128
+ 回复格式为直接的 json 格式 ,样例如下:
129
+ {
130
+ "field": "研究领域(一个词组,使用多个逗号隔开的,保证分解后的各个词语之间具有较小的重叠性)",
131
+ "summary": "对论文的 abstract 进行摘要总结(保证内容尽可能的少,仅包含最关键的信息)",
132
+ "translation": "将论文的 abstract 翻译为中文,保证翻译的准确性和流畅性,并忠于原文"
133
+ }
134
+ 重要提示:
135
+ 1. 直接返回JSON对象,不要添加任何其他文本、注释或标记。
136
+ 2. 严禁在 key 对应的 value 中任何位置使用双引号!!!这样会导致我解析失败,切记!!!。
137
+ 3. 确保你的回复可以直接通过 JSON.parse() 解析,即不要返回非 json 格式的内容和字符。
138
+ 4. 保持回复简洁,避免重复内容。
139
+ 5. 直接回答问题,不要重复问题内容。
140
+ """),
141
+ "field_summary": dedent("""
142
+ 你是一个中文标签清洗专家,对中文词汇有深入了解,擅长将一对含有重复含义的标签进行合并并重命名
143
+ 用户会将一些标签名称通过逗号隔开的形式发给你,请你将含义重复但是名称不同的标签进行合并
144
+ 而没有含义重复的标签则不用理会,你最后需要返回你所修改的标签内容,回复为 json 格式,样例如下:
145
+ {
146
+ "label1": "合并后的标签1",
147
+ "label2": "合并后的标签1",
148
+ "label3": "合并后的标签2",
149
+ "label4": "合并后的标签2",
150
+ "label5": "合并后的标签2"
151
+ }
152
+ 重要提示:
153
+ 1. 直接返回JSON对象, 不要添加任何其他文本、注释或标记。
154
+ 2. 不要使用```json或任何其他格式标记。
155
+ 3. 确保你的回复可以直接通过JSON.parse()解析。
156
+ """)
157
+ }
158
+ return prompt_map[mode]
159
+
160
+ def get_llm(self, model="gpt-4o-mini"):
161
+ '''
162
+ params:
163
+ model: str, 模型名称 ["gpt-4o-mini", "gpt-4o", "o1-mini", "gemini-1.5-flash-002"]
164
+ '''
165
+ llm = ChatOpenAI(model=model,
166
+ base_url=self.base_url,
167
+ api_key=self.api_key)
168
+ print(f"Init model {model} successfully!")
169
+ return llm
170
+
171
+ def ask_question(self, question, system_prompt=None):
172
+ # 1. 获取系统提示
173
+ if system_prompt is None:
174
+ system_prompt = self.get_system_prompt()
175
+
176
+ # 2. 生成聊天提示
177
+ prompt = self.chat_prompt.format(input=question, system_prompt=system_prompt)
178
+ config = {
179
+ "configurable": {"response_format": {"type": "json_object"}}
180
+ }
181
+
182
+ # 3. 调用 OpenAI 模型进行回答(重调用三次,三次不成功就结束)
183
+ for _ in range(10):
184
+ try:
185
+ response = self.llm.invoke(prompt, config=config)
186
+ response.content = self.clean_json(response.content)
187
+ return response
188
+ except Exception as e:
189
+ print(e)
190
+ time.sleep(10)
191
+ continue
192
+ print(f"Failed to call llm for prompt: {prompt[0:10]}")
193
+ return None
194
+
195
+ async def ask_questions_parallel(self, questions, system_prompt=None):
196
+ import asyncio
197
+ import re
198
+ # 1. 获取系统提示
199
+ if system_prompt is None:
200
+ system_prompt = self.get_system_prompt()
201
+
202
+ # 2. 定义异步函数
203
+ async def call_llm(prompt):
204
+ for _ in range(10):
205
+ try:
206
+ config = {
207
+ "configurable": {"response_format": {"type": "json_object"}}
208
+ }
209
+ response = await self.llm.ainvoke(prompt, config=config)
210
+ # 1. 移除 json 标记
211
+ response.content = re.sub(r'^```json\s*', '', response.content)
212
+ response.content = re.sub(r'\s*```$', '', response.content)
213
+ # 2. 移除公式包裹符号
214
+ response.content = re.sub(r'\$', '', response.content)
215
+ # 3. 移除转移符号
216
+ response.content = re.sub(r'\\', '', response.content)
217
+ response.content = re.sub(r'/', '', response.content)
218
+ # 4. 移除各种引号
219
+ response.content = re.sub(r'[“”]', '', response.content)
220
+ response.content = re.sub(r'(\w+)"(\w+)', r'\1\2', response.content, flags=re.UNICODE)
221
+ return response
222
+ except Exception as e:
223
+ print(e)
224
+ await asyncio.sleep(10)
225
+ continue
226
+ print(f"Failed to call llm for prompt: {prompt[0:10]}")
227
+ return None
228
+
229
+ # 3. 构建 prompt
230
+ prompts = [self.chat_prompt.format(input=question, system_prompt=system_prompt) for question in questions]
231
+
232
+ # 4. 异步调用
233
+ tasks = [call_llm(prompt) for prompt in prompts]
234
+ results = await asyncio.gather(*tasks)
235
+
236
+ return results
237
+
238
  class RepoSearch:
239
  def __init__(self):
240
 
241
  db_path = os.path.join(root_path, "database", "faiss_index")
242
+ embeddings = OpenAIEmbeddings(api_key=os.environ["OPENAI_API_KEY"],
243
+ base_url=os.environ["OPENAI_BASE_URL"],
 
 
244
  model="text-embedding-3-small")
 
245
 
246
  assert os.path.exists(db_path), f"Database not found: {db_path}"
247
  self.vector_db = FAISS.load_local(db_path, embeddings,
248
  allow_dangerous_deserialization=True)
 
 
249
 
250
+ def search(self, query, k=10):
251
  '''
252
  name + description + html_url + topics
253
  '''
 
254
  results = self.vector_db.similarity_search(query + " technology", k=k)
255
 
256
  simple_str = ""
257
+ simple_list = []
258
  for i, doc in enumerate(results):
259
  content = json.loads(doc.page_content)
260
  if content["description"] is None:
261
  content["description"] = ""
262
  desc = content["description"] if len(content["description"]) < 300 else content["description"][:300] + "..."
263
  simple_str += f"\t**{i+1}. {content['name']}** || **Description:** {desc} || **Url:** {content['html_url']} \n"
264
+ simple_list.append({
265
+ "name": content["name"],
266
+ "description": desc,
267
+ "url": content["html_url"]
268
+ })
269
 
270
+ return simple_str, simple_list
 
271
 
272
  def main():
273
  search = RepoSearch()
274
+ llm = OurLLM()
275
 
276
  def respond(
277
  prompt: str,
 
284
  yield history
285
 
286
  response = {"role": "assistant", "content": ""}
287
+ response["content"] = "开始扩展关键词..."
288
+ yield history + [response]
289
+
290
+ query = llm.ask_question(prompt, system_prompt=llm.get_system_prompt("keyword_expand")).content
291
+ json_obj = ", ".join(json.loads(query)["keywords"])
292
+ # response["content"] = "拓展后关键词:" + json_obj
293
+ # yield history + [response]
294
+
295
+ response["content"] = "开始通过 LLM 评分得到最匹配的仓库..."
296
+ yield history + [response]
297
+ simple_str, simple_list = search.search(query, 40)
298
+ query = json_obj + '\n' + simple_str
299
+ out = llm.ask_question(query, system_prompt=llm.get_system_prompt("github_score")).content
300
+
301
+ out = out.replace('```json','').replace('```','').strip()
302
+ matched_repos = json.loads(out)["indices"]
303
+
304
+ result = [simple_list[idx-1] for idx in matched_repos]
305
+ simple_str = ""
306
+ for repo in result:
307
+ simple_str += f"\t**{repo['name']}** || **Description:** {repo['description']} || **Url:** {repo['url']} \n"
308
+ response = {"role": "assistant", "content": ""}
309
+ response["content"] = simple_str
310
  yield history + [response]
311
 
312
  with gr.Blocks() as demo: