Spaces:
Running
Running
蒲源
commited on
Commit
·
6bec747
1
Parent(s):
c7557a5
fix(pu): fix conversation_history query bug
Browse files- app_mqa.py +4 -5
- app_mqa_database.py +3 -3
- database/conversation_history.db +0 -0
- rag_demo.py +27 -33
app_mqa.py
CHANGED
@@ -65,13 +65,12 @@ def rag_answer(question, temperature, k):
|
|
65 |
retriever = get_retriever(vectorstore, k)
|
66 |
rag_chain = setup_rag_chain(model_name='kimi', temperature=temperature)
|
67 |
|
68 |
-
# 将问题添加到对话历史中
|
69 |
-
conversation_history.append(("User", question))
|
70 |
-
|
71 |
# 将对话历史转换为字符串
|
72 |
history_str = "\n".join([f"{role}: {text}" for role, text in conversation_history])
|
73 |
-
|
74 |
-
|
|
|
|
|
75 |
temperature=temperature)
|
76 |
|
77 |
# 在文档中高亮显示上下文
|
|
|
65 |
retriever = get_retriever(vectorstore, k)
|
66 |
rag_chain = setup_rag_chain(model_name='kimi', temperature=temperature)
|
67 |
|
|
|
|
|
|
|
68 |
# 将对话历史转换为字符串
|
69 |
history_str = "\n".join([f"{role}: {text}" for role, text in conversation_history])
|
70 |
+
history_question = [history_str, question]
|
71 |
+
# 将问题添加到对话历史中
|
72 |
+
conversation_history.append(("User", question))
|
73 |
+
retrieved_documents, answer = execute_query(retriever, rag_chain, history_question, model_name='kimi',
|
74 |
temperature=temperature)
|
75 |
|
76 |
# 在文档中高亮显示上下文
|
app_mqa_database.py
CHANGED
@@ -125,11 +125,11 @@ def rag_answer(question, temperature, k, user_id):
|
|
125 |
if user_id not in conversation_history:
|
126 |
conversation_history[user_id] = []
|
127 |
|
128 |
-
conversation_history[user_id].append((f"User[{user_id}]", question))
|
129 |
-
|
130 |
history_str = "\n".join([f"{role}: {text}" for role, text in conversation_history[user_id]])
|
|
|
131 |
|
132 |
-
|
|
|
133 |
temperature=temperature)
|
134 |
|
135 |
############################
|
|
|
125 |
if user_id not in conversation_history:
|
126 |
conversation_history[user_id] = []
|
127 |
|
|
|
|
|
128 |
history_str = "\n".join([f"{role}: {text}" for role, text in conversation_history[user_id]])
|
129 |
+
conversation_history[user_id].append((f"User[{user_id}]", question))
|
130 |
|
131 |
+
history_question = [history_str, question]
|
132 |
+
retrieved_documents, answer = execute_query(retriever, rag_chain, history_question, model_name='kimi',
|
133 |
temperature=temperature)
|
134 |
|
135 |
############################
|
database/conversation_history.db
CHANGED
Binary files a/database/conversation_history.db and b/database/conversation_history.db differ
|
|
rag_demo.py
CHANGED
@@ -77,13 +77,13 @@ def setup_rag_chain(model_name="gpt-4", temperature=0):
|
|
77 |
"""设置检索增强生成流程"""
|
78 |
if model_name.startswith("gpt"):
|
79 |
# 如果是以gpt开头的模型,使用原来的逻辑
|
80 |
-
prompt_template = """
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
"""
|
88 |
prompt = ChatPromptTemplate.from_template(prompt_template)
|
89 |
llm = ChatOpenAI(model_name=model_name, temperature=temperature)
|
@@ -115,37 +115,31 @@ def execute_query(retriever, rag_chain, query, model_name="gpt-4", temperature=0
|
|
115 |
retrieved_documents: 检索到的文档块列表
|
116 |
response_text: 生成的回答文本
|
117 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
# 使用检索器检索相关文档块
|
119 |
-
retrieved_documents = retriever.invoke(
|
120 |
|
121 |
if rag_chain is not None:
|
122 |
# 如果有RAG链,则使用RAG链生成回答
|
123 |
-
rag_chain_response = rag_chain.invoke({"context": retrieved_documents, "question":
|
124 |
response_text = rag_chain_response
|
125 |
else:
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
【回答】:
|
135 |
-
"""
|
136 |
-
else:
|
137 |
-
prompt_template = """您是一个用于问答任务的专业助手。
|
138 |
-
在处理问答任务时,请根据所提供的【上下文信息】给出回答。
|
139 |
-
如果【上下文信息】与【问题】不相关,那么请运用您的知识库为提问者提供准确的答复。
|
140 |
-
请确保回答内容的质量,包括相关性、准确性和可读性。
|
141 |
-
【问题】: {question}
|
142 |
-
【上下文信息】: {context}
|
143 |
-
【回答】:
|
144 |
-
"""
|
145 |
-
|
146 |
context = '\n'.join(
|
147 |
-
[
|
148 |
-
prompt = prompt_template.format(question=
|
149 |
response_text = execute_query_no_rag(model_name=model_name, temperature=temperature, query=prompt)
|
150 |
return retrieved_documents, response_text
|
151 |
|
@@ -225,7 +219,7 @@ def execute_query_no_rag(model_name="gpt-4", temperature=0, query=""):
|
|
225 |
completion = client.chat.completions.create(
|
226 |
model="moonshot-v1-128k",
|
227 |
messages=messages,
|
228 |
-
temperature=
|
229 |
top_p=1.0,
|
230 |
n=1, # 为每条输入消息生成多少个结果
|
231 |
stream=False # 流式输出
|
@@ -250,7 +244,7 @@ if __name__ == "__main__":
|
|
250 |
|
251 |
# 创建向量存储
|
252 |
vectorstore = create_vector_store(chunks, model=embedding_model)
|
253 |
-
retriever = get_retriever(vectorstore, k=
|
254 |
|
255 |
# 设置 RAG 流程
|
256 |
rag_chain = setup_rag_chain(model_name=model_name, temperature=temperature)
|
|
|
77 |
"""设置检索增强生成流程"""
|
78 |
if model_name.startswith("gpt"):
|
79 |
# 如果是以gpt开头的模型,使用原来的逻辑
|
80 |
+
prompt_template = """
|
81 |
+
您是一个擅长问答任务的专业助手。在执行问答任务时,应优先考虑所提供的**上下文信息**来形成回答,并适当参照**对话历史**。
|
82 |
+
如果**上下文信息**与**问题**无直接相关性,您应依据自己的知识库向提问者提供准确的信息。务必确保您的答案在相关性、准确性和可读性方面达到高标准。
|
83 |
+
**对话历史**: {conversation_history}
|
84 |
+
**问题**: {question}
|
85 |
+
**上下文信息**: {context}
|
86 |
+
**回答**:
|
87 |
"""
|
88 |
prompt = ChatPromptTemplate.from_template(prompt_template)
|
89 |
llm = ChatOpenAI(model_name=model_name, temperature=temperature)
|
|
|
115 |
retrieved_documents: 检索到的文档块列表
|
116 |
response_text: 生成的回答文本
|
117 |
"""
|
118 |
+
if isinstance(query, list):
|
119 |
+
[conversation_history, question] = query
|
120 |
+
else:
|
121 |
+
conversation_history = ''
|
122 |
+
question = query
|
123 |
+
|
124 |
# 使用检索器检索相关文档块
|
125 |
+
retrieved_documents = retriever.invoke(question)
|
126 |
|
127 |
if rag_chain is not None:
|
128 |
# 如果有RAG链,则使用RAG链生成回答
|
129 |
+
rag_chain_response = rag_chain.invoke({"context": retrieved_documents, "question": question})
|
130 |
response_text = rag_chain_response
|
131 |
else:
|
132 |
+
prompt_template = """
|
133 |
+
【对话历史】: {conversation_history}
|
134 |
+
【上下文信息】: {context}
|
135 |
+
您是一个擅长问答任务的专业助手。在执行问答任务时,应优先考虑所提供的【上下文信息】来形成回答,并适当参照【对话历史】。
|
136 |
+
如果【上下文信息】与【问题】无直接相关性,您应依据自己的知识库向提问者提供准确的信息。务必确保您的答案在相关性、准确性和可读性方面达到高标准。
|
137 |
+
【问题】: {question}
|
138 |
+
【回答】:
|
139 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
context = '\n'.join(
|
141 |
+
[retrieved_documents[i].page_content for i in range(len(retrieved_documents))])
|
142 |
+
prompt = prompt_template.format(conversation_history=conversation_history, question=question, context=context)
|
143 |
response_text = execute_query_no_rag(model_name=model_name, temperature=temperature, query=prompt)
|
144 |
return retrieved_documents, response_text
|
145 |
|
|
|
219 |
completion = client.chat.completions.create(
|
220 |
model="moonshot-v1-128k",
|
221 |
messages=messages,
|
222 |
+
temperature=temperature,
|
223 |
top_p=1.0,
|
224 |
n=1, # 为每条输入消息生成多少个结果
|
225 |
stream=False # 流式输出
|
|
|
244 |
|
245 |
# 创建向量存储
|
246 |
vectorstore = create_vector_store(chunks, model=embedding_model)
|
247 |
+
retriever = get_retriever(vectorstore, k=5)
|
248 |
|
249 |
# 设置 RAG 流程
|
250 |
rag_chain = setup_rag_chain(model_name=model_name, temperature=temperature)
|