蒲源 commited on
Commit
ca8843c
·
1 Parent(s): ae77d05

polish(pu): add conversation_history_cache and similar question reuse mechanism, add english comments

Browse files
app_mqa_database.py CHANGED
@@ -1,12 +1,14 @@
 
1
  import os
2
  import sqlite3
3
  import threading
4
 
5
  import gradio as gr
 
6
  from dotenv import load_dotenv
7
  from langchain.document_loaders import TextLoader
 
8
 
9
- from analyze_conversation_history import analyze_conversation_history
10
  from rag_demo import load_and_split_document, create_vector_store, setup_rag_chain, execute_query, get_retriever
11
 
12
  # 环境设置
@@ -21,13 +23,14 @@ if QUESTION_LANG == "cn":
21
  <img src="https://raw.githubusercontent.com/puyuan1996/ZeroPal/main/assets/banner.svg" width="80%" height="20%" alt="Banner Image">
22
  </div>
23
 
24
- 📢 **操作说明**:请在下方的“问题”框中输入关于 LightZero 的问题,并点击“提交”按钮。右侧的“回答”框将展示 RAG 模型提供的答案。
25
- 您可以在问答框下方查看当前“对话历史”,点击“清除对话历史”按钮可清空历史记录。在“对话历史”框下方,您将找到相关参考文档,其中相关文段将以黄色高亮显示。
26
- 如果您喜欢这个项目,请在 GitHub [LightZero RAG Demo](https://github.com/puyuan1996/ZeroPal) 上给我们点赞!✨ 您的支持是我们持续更新的动力。
27
 
28
- <div align="center">
29
- <strong>注意:算法模型输出可能包含一定的随机性。结果不代表开发者和相关 AI 服务的态度和意见。本项目开发者不对结果作出任何保证,仅供参考之用。使用该服务即代表同意后文所述的使用条款。</strong>
30
- </div>
 
31
  """
32
  tos_markdown = """
33
  ### 使用条款
@@ -63,16 +66,14 @@ def get_db_connection():
63
  """
64
  conn = getattr(threadLocal, 'conn', None)
65
  if conn is None:
66
- # 连接到SQLite数据库
67
  conn = sqlite3.connect('database/conversation_history.db')
68
  c = conn.cursor()
69
- # Drop the existing 'history' table if it exists
70
- # c.execute('DROP TABLE IF EXISTS history')
71
- # 创建存储对话历史的表
72
  c.execute('''CREATE TABLE IF NOT EXISTS history
73
  (id INTEGER PRIMARY KEY AUTOINCREMENT,
74
  user_id TEXT NOT NULL,
75
  user_input TEXT NOT NULL,
 
76
  assistant_output TEXT NOT NULL,
77
  timestamp DATETIME DEFAULT CURRENT_TIMESTAMP)''')
78
  threadLocal.conn = conn
@@ -107,6 +108,18 @@ def close_db_connection():
107
  chunks = load_and_split_document(file_path, chunk_size=5000, chunk_overlap=500)
108
  vectorstore = create_vector_store(chunks, model='OpenAI')
109
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  # def rag_answer(question, temperature=0.01, k=5, user_id='user'):
112
  def rag_answer(question, k=5, user_id='user'):
@@ -120,47 +133,77 @@ def rag_answer(question, k=5, user_id='user'):
120
  :return: 模型生成的答案和高亮显示上下文的Markdown文本
121
  """
122
  temperature = 0.01 # TODO: 使用固定的温度参数
 
123
  try:
124
- retriever = get_retriever(vectorstore, k)
125
- rag_chain = setup_rag_chain(model_name='kimi', temperature=temperature)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  if user_id not in conversation_history:
128
  conversation_history[user_id] = []
129
 
130
- history_str = "\n".join([f"{role}: {text}" for role, text in conversation_history[user_id]])
131
  conversation_history[user_id].append((f"User[{user_id}]", question))
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
- history_question = [history_str, question]
134
- retrieved_documents, answer = execute_query(retriever, rag_chain, history_question, model_name='kimi',
135
- temperature=temperature)
136
-
137
- ############################
138
- # 获取当前线程的数据库连接和游标
139
- ############################
140
- conn = get_db_connection()
141
- c = get_db_cursor()
142
-
143
- # 分析对话历史
144
- # analyze_conversation_history()
145
  # 获取总的对话记录数
146
  c.execute("SELECT COUNT(*) FROM history")
147
  total_records = c.fetchone()[0]
148
  print(f"总对话记录数: {total_records}")
149
 
150
  # 将问题和回答存储到数据库
151
- c.execute("INSERT INTO history (user_id, user_input, assistant_output) VALUES (?, ?, ?)",
152
- (user_id, question, answer))
 
153
  conn.commit()
154
 
155
- # 在文档中高亮显示上下文
156
- context = [retrieved_documents[i].page_content for i in range(len(retrieved_documents))]
157
- highlighted_document = orig_documents[0].page_content
158
- for i in range(len(context)):
159
- highlighted_document = highlighted_document.replace(context[i], f"<mark>{context[i]}</mark>")
 
 
 
 
 
 
 
 
 
160
 
161
  conversation_history[user_id].append(("Assistant", answer))
162
-
163
  full_history = "\n".join([f"{role}: {text}" for role, text in conversation_history[user_id]])
 
164
  except Exception as e:
165
  print(f"An error occurred: {e}")
166
  return f"处理您的问题时出现错误,请稍后再试。错误内容为:{e}", "", ""
@@ -187,22 +230,25 @@ if __name__ == "__main__":
187
  with gr.Row():
188
  with gr.Column():
189
  user_id = gr.Textbox(
190
- placeholder="请输入您的真实姓名或昵称作为用户ID",
191
- label="用户ID")
192
  inputs = gr.Textbox(
193
- placeholder="请您在这里输入任何关于 LightZero 的问题。",
194
- label="问题")
195
  # temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.01, step=0.01, label="温度参数")
196
- k = gr.Slider(minimum=1, maximum=7, value=3, step=1, label="检索相关文档块的数量") # readme总长度为35000左右,文段块长度为5000,因此最大值为35000/5000=7
 
197
  with gr.Row():
198
- gr_submit = gr.Button('提交')
199
- gr_clear = gr.Button('清除对话历史')
200
 
201
- outputs_answer = gr.Textbox(placeholder="当你点击提交按钮后,这里会显示 RAG 模型给出的回答。",
202
- label="回答")
203
- outputs_history = gr.Textbox(label="对话历史")
 
204
  with gr.Row():
205
- outputs_context = gr.Markdown(label="参考的文档(检索得到的相关文段用高亮显示)")
 
206
  gr_clear.click(clear_context, inputs=user_id, outputs=[outputs_context, outputs_history])
207
  gr_submit.click(
208
  rag_answer,
@@ -216,5 +262,5 @@ if __name__ == "__main__":
216
  favicon_path = os.path.join(os.path.dirname(__file__), 'assets', 'avatar.png')
217
  zero_pal.queue().launch(max_threads=concurrency, favicon_path=favicon_path, share=True)
218
 
219
- # 在合适的地方,例如程序退出时,调用close_db_connection函数
220
  close_db_connection()
 
1
+ import collections
2
  import os
3
  import sqlite3
4
  import threading
5
 
6
  import gradio as gr
7
+ import numpy as np
8
  from dotenv import load_dotenv
9
  from langchain.document_loaders import TextLoader
10
+ from sentence_transformers import SentenceTransformer, util
11
 
 
12
  from rag_demo import load_and_split_document, create_vector_store, setup_rag_chain, execute_query, get_retriever
13
 
14
  # 环境设置
 
23
  <img src="https://raw.githubusercontent.com/puyuan1996/ZeroPal/main/assets/banner.svg" width="80%" height="20%" alt="Banner Image">
24
  </div>
25
 
26
+ 📢 **操作说明**:请在下方的"问题"框中输入关于 LightZero 的问题,并点击"提交"按钮。右侧的"回答"框将展示 RAG 模型提供的答案。
27
+ 您可以在问答框下方查看当前"对话历史",点击"清除对话历史"按钮可清空历史记录。在"对话历史"框下方,您将找到相关参考文档,其中相关文段将以黄色高亮显示。
28
+ 如果您喜欢这个项目,请在 GitHub [LightZero RAG Demo](https://github.com/puyuan1996/ZeroPal) 上给我们点赞!✨ 您的支持是我们持续更新的动力。注意:算法模型输出可能包含一定的随机性。结果不代表开发者和相关 AI 服务的态度和意见。本项目开发者不对结果作出任何保证,仅供参考之用。使用该服务即代表同意后文所述的使用条款。
29
 
30
+ 📢 **Instructions**: Please enter your questions about LightZero in the "Question" box below and click the "Submit" button. The "Answer" box on the right will display the answers provided by the RAG model.
31
+ Below the Q&A box, you can view the current "Conversation History". Clicking the "Clear Conversation History" button will erase the history records. Below the "Conversation History" box, you'll find relevant reference documents, with the pertinent sections highlighted in yellow.
32
+ If you like this project, please give us a thumbs up on GitHub at [LightZero RAG Demo](https://github.com/puyuan1996/ZeroPal)! ✨ Your support motivates us to keep updating.
33
+ Note: The output from the algorithm model may contain a degree of randomness. The results do not represent the attitudes and opinions of the developers and related AI services. The developers of this project make no guarantees about the results, which are for reference only. Use of this service indicates agreement with the terms of use described later in the text.
34
  """
35
  tos_markdown = """
36
  ### 使用条款
 
66
  """
67
  conn = getattr(threadLocal, 'conn', None)
68
  if conn is None:
69
+ # 创建存储对话历史的表
70
  conn = sqlite3.connect('database/conversation_history.db')
71
  c = conn.cursor()
 
 
 
72
  c.execute('''CREATE TABLE IF NOT EXISTS history
73
  (id INTEGER PRIMARY KEY AUTOINCREMENT,
74
  user_id TEXT NOT NULL,
75
  user_input TEXT NOT NULL,
76
+ user_input_embedding BLOB NOT NULL,
77
  assistant_output TEXT NOT NULL,
78
  timestamp DATETIME DEFAULT CURRENT_TIMESTAMP)''')
79
  threadLocal.conn = conn
 
108
  chunks = load_and_split_document(file_path, chunk_size=5000, chunk_overlap=500)
109
  vectorstore = create_vector_store(chunks, model='OpenAI')
110
 
111
+ # 加载预训练的SBERT模型
112
+ sbert_model = SentenceTransformer('all-MiniLM-L6-v2')
113
+
114
+ # 定义余弦相似度阈值
115
+ cosine_threshold = 0.96 # 为了提高检索的准确性,将余弦相似度阈值调高
116
+
117
+ # 设置LRU缓存的大小
118
+ CACHE_SIZE = 1000
119
+
120
+ # 创建历史问题的缓存
121
+ conversation_history_cache = collections.OrderedDict()
122
+
123
 
124
  # def rag_answer(question, temperature=0.01, k=5, user_id='user'):
125
  def rag_answer(question, k=5, user_id='user'):
 
133
  :return: 模型生成的答案和高亮显示上下文的Markdown文本
134
  """
135
  temperature = 0.01 # TODO: 使用固定的温度参数
136
+
137
  try:
138
+ # 获取当前线程的数据库连接和游标
139
+ conn = get_db_connection()
140
+ c = get_db_cursor()
141
+
142
+ question_embedding = sbert_model.encode(question)
143
+ question_embedding_bytes = question_embedding.tobytes() # 将numpy数组转换为字节串
144
+
145
+ # 从数据库中获取所有用户的对话历史
146
+ c.execute("SELECT user_input, user_input_embedding, assistant_output FROM history")
147
+ all_history = c.fetchall()
148
+ # 初始化最高的余弦相似度和对应的答案
149
+ max_cosine_score = 0
150
+ best_answer = ""
151
+ # 在历史问题的缓存中查找相似问题
152
+ for history_question_bytes, (history_question, history_answer) in conversation_history_cache.items():
153
+ history_question_embedding_numpy = np.frombuffer(history_question_bytes, dtype=np.float32)
154
+ cosine_score = util.cos_sim(question_embedding, history_question_embedding_numpy).item()
155
+ # print(f"检索到历史问题: {history_question}")
156
+ # print(f"当前问题与历史问题的余弦相似度: {cosine_score}")
157
+ if cosine_score > cosine_threshold and cosine_score > max_cosine_score:
158
+ max_cosine_score = cosine_score
159
+ best_answer = history_answer
160
 
161
  if user_id not in conversation_history:
162
  conversation_history[user_id] = []
163
 
 
164
  conversation_history[user_id].append((f"User[{user_id}]", question))
165
+ # 如果余弦相似度高于阈值,则更新最佳答案
166
+ if max_cosine_score > cosine_threshold:
167
+ print('=' * 20)
168
+ print(f"找到了足够相似的历史问题,直接返回对应的答案。余弦相似度为: {max_cosine_score}")
169
+ answer = best_answer
170
+ else:
171
+ retriever = get_retriever(vectorstore, k)
172
+ rag_chain = setup_rag_chain(model_name='kimi', temperature=temperature)
173
+ history_str = "\n".join([f"{role}: {text}" for role, text in conversation_history[user_id]])
174
+ history_question = [history_str, question]
175
+ retrieved_documents, answer = execute_query(retriever, rag_chain, history_question, model_name='kimi',
176
+ temperature=temperature)
177
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  # 获取总的对话记录数
179
  c.execute("SELECT COUNT(*) FROM history")
180
  total_records = c.fetchone()[0]
181
  print(f"总对话记录数: {total_records}")
182
 
183
  # 将问题和回答存储到数据库
184
+ c.execute(
185
+ "INSERT INTO history (user_id, user_input, user_input_embedding, assistant_output) VALUES (?, ?, ?, ?)",
186
+ (user_id, question, question_embedding_bytes, answer))
187
  conn.commit()
188
 
189
+ # 将新问题和答案添加到历史问题的缓存中
190
+ conversation_history_cache[question_embedding_bytes] = (question, answer)
191
+ # 如果缓存大小超过限制,则淘汰最近最少使用的问题
192
+ if len(conversation_history_cache) > CACHE_SIZE:
193
+ conversation_history_cache.popitem(last=False)
194
+
195
+ if max_cosine_score > cosine_threshold:
196
+ highlighted_document = ""
197
+ else:
198
+ # 在文档中高亮显示上下文
199
+ context = [retrieved_documents[i].page_content for i in range(len(retrieved_documents))]
200
+ highlighted_document = orig_documents[0].page_content
201
+ for i in range(len(context)):
202
+ highlighted_document = highlighted_document.replace(context[i], f"<mark>{context[i]}</mark>")
203
 
204
  conversation_history[user_id].append(("Assistant", answer))
 
205
  full_history = "\n".join([f"{role}: {text}" for role, text in conversation_history[user_id]])
206
+
207
  except Exception as e:
208
  print(f"An error occurred: {e}")
209
  return f"处理您的问题时出现错误,请稍后再试。错误内容为:{e}", "", ""
 
230
  with gr.Row():
231
  with gr.Column():
232
  user_id = gr.Textbox(
233
+ placeholder="请输入您的真实姓名或昵称作为用户ID(Please enter your real name or nickname as the user ID.)",
234
+ label="用户ID(Username)")
235
  inputs = gr.Textbox(
236
+ placeholder="请您在这里输入任何关于 LightZero 的问题。(Please enter any questions about LightZero here.)",
237
+ label="问题(Question)")
238
  # temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.01, step=0.01, label="温度参数")
239
+ k = gr.Slider(minimum=1, maximum=7, value=3, step=1,
240
+ label="检索到的相关文档块的数量(The number of relevant document blocks retrieved.)") # readme总长度为35000左右,文段块长度为5000,因此最大值为35000/5000=7
241
  with gr.Row():
242
+ gr_submit = gr.Button('提交(Submit)')
243
+ gr_clear = gr.Button('清除对话历史(Clear Context)')
244
 
245
+ outputs_answer = gr.Textbox(
246
+ placeholder="当你点击提交按钮后,这里会显示 RAG 模型给出的回答。(After you click the submit button, the answer given by the RAG model will be displayed here.)",
247
+ label="回答(Answer)")
248
+ outputs_history = gr.Textbox(label="对话历史(Conversation History)")
249
  with gr.Row():
250
+ outputs_context = gr.Markdown(
251
+ label="参考的文档(检索得到的相关文段用高亮显示) Referenced documents (the relevant excerpts retrieved are highlighted).")
252
  gr_clear.click(clear_context, inputs=user_id, outputs=[outputs_context, outputs_history])
253
  gr_submit.click(
254
  rag_answer,
 
262
  favicon_path = os.path.join(os.path.dirname(__file__), 'assets', 'avatar.png')
263
  zero_pal.queue().launch(max_threads=concurrency, favicon_path=favicon_path, share=True)
264
 
265
+ # 在合适的地方,例如程序退出时,调用close_db_connection函数
266
  close_db_connection()
database/conversation_history.db CHANGED
Binary files a/database/conversation_history.db and b/database/conversation_history.db differ