蒲源 commited on
Commit
6bec747
·
1 Parent(s): c7557a5

fix(pu): fix conversation_history query bug

Browse files
app_mqa.py CHANGED
@@ -65,13 +65,12 @@ def rag_answer(question, temperature, k):
65
  retriever = get_retriever(vectorstore, k)
66
  rag_chain = setup_rag_chain(model_name='kimi', temperature=temperature)
67
 
68
- # 将问题添加到对话历史中
69
- conversation_history.append(("User", question))
70
-
71
  # 将对话历史转换为字符串
72
  history_str = "\n".join([f"{role}: {text}" for role, text in conversation_history])
73
-
74
- retrieved_documents, answer = execute_query(retriever, rag_chain, history_str, model_name='kimi',
 
 
75
  temperature=temperature)
76
 
77
  # 在文档中高亮显示上下文
 
65
  retriever = get_retriever(vectorstore, k)
66
  rag_chain = setup_rag_chain(model_name='kimi', temperature=temperature)
67
 
 
 
 
68
  # 将对话历史转换为字符串
69
  history_str = "\n".join([f"{role}: {text}" for role, text in conversation_history])
70
+ history_question = [history_str, question]
71
+ # 将问题添加到对话历史中
72
+ conversation_history.append(("User", question))
73
+ retrieved_documents, answer = execute_query(retriever, rag_chain, history_question, model_name='kimi',
74
  temperature=temperature)
75
 
76
  # 在文档中高亮显示上下文
app_mqa_database.py CHANGED
@@ -125,11 +125,11 @@ def rag_answer(question, temperature, k, user_id):
125
  if user_id not in conversation_history:
126
  conversation_history[user_id] = []
127
 
128
- conversation_history[user_id].append((f"User[{user_id}]", question))
129
-
130
  history_str = "\n".join([f"{role}: {text}" for role, text in conversation_history[user_id]])
 
131
 
132
- retrieved_documents, answer = execute_query(retriever, rag_chain, history_str, model_name='kimi',
 
133
  temperature=temperature)
134
 
135
  ############################
 
125
  if user_id not in conversation_history:
126
  conversation_history[user_id] = []
127
 
 
 
128
  history_str = "\n".join([f"{role}: {text}" for role, text in conversation_history[user_id]])
129
+ conversation_history[user_id].append((f"User[{user_id}]", question))
130
 
131
+ history_question = [history_str, question]
132
+ retrieved_documents, answer = execute_query(retriever, rag_chain, history_question, model_name='kimi',
133
  temperature=temperature)
134
 
135
  ############################
database/conversation_history.db CHANGED
Binary files a/database/conversation_history.db and b/database/conversation_history.db differ
 
rag_demo.py CHANGED
@@ -77,13 +77,13 @@ def setup_rag_chain(model_name="gpt-4", temperature=0):
77
  """设置检索增强生成流程"""
78
  if model_name.startswith("gpt"):
79
  # 如果是以gpt开头的模型,使用原来的逻辑
80
- prompt_template = """您是一个用于问答任务的专业助手。
81
- 在处理问答任务时,请根据所提供的[上下文信息]给出回答。
82
- 如果[上下文信息]与[问题]不相关,那么请运用您的知识库为提问者提供准确的答复。
83
- 请确保回答内容的质量, 包括相关性、准确性和可读性。
84
- [问题]: {question}
85
- [上下文信息]: {context}
86
- [回答]:
87
  """
88
  prompt = ChatPromptTemplate.from_template(prompt_template)
89
  llm = ChatOpenAI(model_name=model_name, temperature=temperature)
@@ -115,37 +115,31 @@ def execute_query(retriever, rag_chain, query, model_name="gpt-4", temperature=0
115
  retrieved_documents: 检索到的文档块列表
116
  response_text: 生成的回答文本
117
  """
 
 
 
 
 
 
118
  # 使用检索器检索相关文档块
119
- retrieved_documents = retriever.invoke(query)
120
 
121
  if rag_chain is not None:
122
  # 如果有RAG链,则使用RAG链生成回答
123
- rag_chain_response = rag_chain.invoke({"context": retrieved_documents, "question": query})
124
  response_text = rag_chain_response
125
  else:
126
- # 如果没有RAG链,则将检索到的文档块和查询问题按照指定格式输入给语言模型
127
- if model_name == "kimi":
128
- # 对于有检索能力的模型,使用不同的模板
129
- prompt_template = """您是一个用于问答任务的专业助手。
130
- 在处理问答任务时,请根据所提供的【上下文信息】和【你的知识库和检索到的相关文档】给出回答。
131
- 请确保回答内容的质量,包括相关性、准确性和可读性。
132
- 【问题】: {question}
133
- 【上下文信息】: {context}
134
- 【回答】:
135
- """
136
- else:
137
- prompt_template = """您是一个用于问答任务的专业助手。
138
- 在处理问答任务时,请根据所提供的【上下文信息】给出回答。
139
- 如果【上下文信息】与【问题】不相关,那么请运用您的知识库为提问者提供准确的答复。
140
- 请确保回答内容的质量,包括相关性、准确性和可读性。
141
- 【问题】: {question}
142
- 【上下文信息】: {context}
143
- 【回答】:
144
- """
145
-
146
  context = '\n'.join(
147
- [f'**Document {i}**: ' + retrieved_documents[i].page_content for i in range(len(retrieved_documents))])
148
- prompt = prompt_template.format(question=query, context=context)
149
  response_text = execute_query_no_rag(model_name=model_name, temperature=temperature, query=prompt)
150
  return retrieved_documents, response_text
151
 
@@ -225,7 +219,7 @@ def execute_query_no_rag(model_name="gpt-4", temperature=0, query=""):
225
  completion = client.chat.completions.create(
226
  model="moonshot-v1-128k",
227
  messages=messages,
228
- temperature=0.01,
229
  top_p=1.0,
230
  n=1, # 为每条输入消息生成多少个结果
231
  stream=False # 流式输出
@@ -250,7 +244,7 @@ if __name__ == "__main__":
250
 
251
  # 创建向量存储
252
  vectorstore = create_vector_store(chunks, model=embedding_model)
253
- retriever = get_retriever(vectorstore, k=4)
254
 
255
  # 设置 RAG 流程
256
  rag_chain = setup_rag_chain(model_name=model_name, temperature=temperature)
 
77
  """设置检索增强生成流程"""
78
  if model_name.startswith("gpt"):
79
  # 如果是以gpt开头的模型,使用原来的逻辑
80
+ prompt_template = """
81
+ 您是一个擅长问答任务的专业助手。在执行问答任务时,应优先考虑所提供的**上下文信息**来形成回答,并适当参照**对话历史**。
82
+ 如果**上下文信息**与**问题**无直接相关性,您应依据自己的知识库向提问者提供准确的信息。务必确保您的答案在相关性、准确性和可读性方面达到高标准。
83
+ **对话历史**: {conversation_history}
84
+ **问题**: {question}
85
+ **上下文信息**: {context}
86
+ **回答**:
87
  """
88
  prompt = ChatPromptTemplate.from_template(prompt_template)
89
  llm = ChatOpenAI(model_name=model_name, temperature=temperature)
 
115
  retrieved_documents: 检索到的文档块列表
116
  response_text: 生成的回答文本
117
  """
118
+ if isinstance(query, list):
119
+ [conversation_history, question] = query
120
+ else:
121
+ conversation_history = ''
122
+ question = query
123
+
124
  # 使用检索器检索相关文档块
125
+ retrieved_documents = retriever.invoke(question)
126
 
127
  if rag_chain is not None:
128
  # 如果有RAG链,则使用RAG链生成回答
129
+ rag_chain_response = rag_chain.invoke({"context": retrieved_documents, "question": question})
130
  response_text = rag_chain_response
131
  else:
132
+ prompt_template = """
133
+ 【对话历史】: {conversation_history}
134
+ 【上下文信息】: {context}
135
+ 您是一个擅长问答任务的专业助手。在执行问答任务时,应优先考虑所提供的【上下文信息】来形成回答,并适当参照【对话历史】。
136
+ 如果【上下文信息】与【问题】无直接相关性,您应依据自己的知识库向提问者提供准确的信息。务必确保您的答案在相关性、准确性和可读性方面达到高标准。
137
+ 【问题】: {question}
138
+ 【回答】:
139
+ """
 
 
 
 
 
 
 
 
 
 
 
 
140
  context = '\n'.join(
141
+ [retrieved_documents[i].page_content for i in range(len(retrieved_documents))])
142
+ prompt = prompt_template.format(conversation_history=conversation_history, question=question, context=context)
143
  response_text = execute_query_no_rag(model_name=model_name, temperature=temperature, query=prompt)
144
  return retrieved_documents, response_text
145
 
 
219
  completion = client.chat.completions.create(
220
  model="moonshot-v1-128k",
221
  messages=messages,
222
+ temperature=temperature,
223
  top_p=1.0,
224
  n=1, # 为每条输入消息生成多少个结果
225
  stream=False # 流式输出
 
244
 
245
  # 创建向量存储
246
  vectorstore = create_vector_store(chunks, model=embedding_model)
247
+ retriever = get_retriever(vectorstore, k=5)
248
 
249
  # 设置 RAG 流程
250
  rag_chain = setup_rag_chain(model_name=model_name, temperature=temperature)