Spaces:
Running
Running
update LightZero RAG
Browse files- .gitignore +0 -2
- README.md +3 -4
- README_zh.md +2 -3
- app.py +90 -58
- app_mqa.py +132 -0
- app_qa.py +106 -0
- assets/avatar.png +0 -0
- rag_demo.py +199 -37
- rag_demo_v0.py +0 -136
- requirements.txt +1 -0
.gitignore
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
.env
|
2 |
-
*bkp.py
|
|
|
|
|
|
README.md
CHANGED
@@ -56,7 +56,6 @@ QUESTION_LANG='cn' # The language of the question, currently available option is
|
|
56 |
|
57 |
```python
|
58 |
|
59 |
-
# The difference between rag_demo.py and rag_demo_v0.py is that it can output the retrieved document chunks.
|
60 |
if __name__ == "__main__":
|
61 |
# Assuming documents are already present locally
|
62 |
file_path = './documents/LightZero_README.zh.md'
|
@@ -91,9 +90,9 @@ if __name__ == "__main__":
|
|
91 |
```
|
92 |
RAG/
|
93 |
│
|
94 |
-
├── rag_demo_v0.py # RAG demonstration script without support for outputting retrieved document chunks.
|
95 |
├── rag_demo.py # RAG demonstration script with support for outputting retrieved document chunks.
|
96 |
-
├──
|
|
|
97 |
├── .env # Environment variable configuration file
|
98 |
└── documents/ # Documents folder
|
99 |
└── your_document.txt # Context document
|
@@ -114,4 +113,4 @@ If you encounter any issues or require assistance, please submit a problem throu
|
|
114 |
|
115 |
## License
|
116 |
|
117 |
-
All code in this repository is compliant with [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0).
|
|
|
56 |
|
57 |
```python
|
58 |
|
|
|
59 |
if __name__ == "__main__":
|
60 |
# Assuming documents are already present locally
|
61 |
file_path = './documents/LightZero_README.zh.md'
|
|
|
90 |
```
|
91 |
RAG/
|
92 |
│
|
|
|
93 |
├── rag_demo.py # RAG demonstration script with support for outputting retrieved document chunks.
|
94 |
+
├── app_qa.py # Web-based interactive application built with Gradio and rag_demo.py.
|
95 |
+
├── app_mqa.py # Web-based interactive application built with Gradio and rag_demo.py. Supports maintaining conversation history.
|
96 |
├── .env # Environment variable configuration file
|
97 |
└── documents/ # Documents folder
|
98 |
└── your_document.txt # Context document
|
|
|
113 |
|
114 |
## License
|
115 |
|
116 |
+
All code in this repository is compliant with [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0).
|
README_zh.md
CHANGED
@@ -43,7 +43,6 @@ QUESTION_LANG='cn' # 问题语言,目前可选值为 'cn'
|
|
43 |
|
44 |
```python
|
45 |
|
46 |
-
# rag_demo.py 相对 rag_demo_v0.py 的不同之处在于可以输出检索到的文档块。
|
47 |
if __name__ == "__main__":
|
48 |
# 假设文档已存在于本地
|
49 |
file_path = './documents/LightZero_README.zh.md'
|
@@ -78,9 +77,9 @@ if __name__ == "__main__":
|
|
78 |
```
|
79 |
RAG/
|
80 |
│
|
81 |
-
├── rag_demo_v0.py # RAG 演示脚本,不支持输出检索到的文档块。
|
82 |
├── rag_demo.py # RAG 演示脚本,支持输出检索到的文档块。
|
83 |
-
├──
|
|
|
84 |
├── .env # 环境变量配置文件
|
85 |
└── documents/ # 文档文件夹
|
86 |
└── your_document.txt # 上下文文档
|
|
|
43 |
|
44 |
```python
|
45 |
|
|
|
46 |
if __name__ == "__main__":
|
47 |
# 假设文档已存在于本地
|
48 |
file_path = './documents/LightZero_README.zh.md'
|
|
|
77 |
```
|
78 |
RAG/
|
79 |
│
|
|
|
80 |
├── rag_demo.py # RAG 演示脚本,支持输出检索到的文档块。
|
81 |
+
├── app_qa.py # 基于 Gradio 和 rag_demo.py 构建的网页交互式应用。
|
82 |
+
├── app_mqa.py # 基于 Gradio 和 rag_demo.py 构建的网页交互式应用。支持保持对话历史。
|
83 |
├── .env # 环境变量配置文件
|
84 |
└── documents/ # 文档文件夹
|
85 |
└── your_document.txt # 上下文文档
|
app.py
CHANGED
@@ -1,16 +1,3 @@
|
|
1 |
-
"""
|
2 |
-
这段代码的整体功能是创建一个Gradio应用,用户可以在其中输入问题,应用会使用Retrieval-Augmented Generation (RAG)模型来寻找答案并将结果显示在界面上。
|
3 |
-
其中,检索到的上下文会在Markdown文档中高亮显示,帮助用户理解答案的来源。应用界面分为两部分:顶部是问答区,底部展示了RAG模型参考的上下文。
|
4 |
-
|
5 |
-
结构概述:
|
6 |
-
- 导入必要的库和函数。
|
7 |
-
- 设置环境变量和全局变量。
|
8 |
-
- 加载和处理Markdown文档。
|
9 |
-
- 定义处理用户问题并返回答案和高亮显示上下文的函数。
|
10 |
-
- 使用Gradio构建用户界面,包括Markdown、输入框、按钮和输出框。
|
11 |
-
- 启动Gradio应用并设置为可以分享。
|
12 |
-
"""
|
13 |
-
|
14 |
import os
|
15 |
|
16 |
import gradio as gr
|
@@ -22,7 +9,6 @@ from rag_demo import load_and_split_document, create_vector_store, setup_rag_cha
|
|
22 |
# 环境设置
|
23 |
load_dotenv() # 加载环境变量
|
24 |
QUESTION_LANG = os.getenv("QUESTION_LANG") # 从环境变量获取 QUESTION_LANG
|
25 |
-
|
26 |
assert QUESTION_LANG in ['cn', 'en'], QUESTION_LANG
|
27 |
|
28 |
if QUESTION_LANG == "cn":
|
@@ -31,8 +17,8 @@ if QUESTION_LANG == "cn":
|
|
31 |
<div align="center">
|
32 |
<img src="https://raw.githubusercontent.com/puyuan1996/RAG/main/assets/banner.svg" width="80%" height="20%" alt="Banner Image">
|
33 |
</div>
|
34 |
-
<h2 style="text-align: center; color: black;"><a href="https://github.com/puyuan1996/RAG">
|
35 |
-
<h4 align="center"> 📢说明:请您在下面的"
|
36 |
<h4 align="center"> 如果你喜欢这个项目,请给我们在 GitHub 点个 star ✨ 。我们将会持续保持更新。 </h4>
|
37 |
<strong><h5 align="center">注意:算法模型的输出可能包含一定的随机性。相关结果不代表任何开发者和相关 AI 服务的态度和意见。本项目开发者不对生成结果作任何保证,仅供参考。<h5></strong>
|
38 |
"""
|
@@ -47,55 +33,101 @@ if QUESTION_LANG == "cn":
|
|
47 |
|
48 |
# 路径变量,方便之后的文件使用
|
49 |
file_path = './documents/LightZero_README.zh.md'
|
50 |
-
chunks = load_and_split_document(file_path)
|
51 |
-
retriever = create_vector_store(chunks)
|
52 |
-
# rag_chain = setup_rag_chain(model_name="gpt-4")
|
53 |
-
rag_chain = setup_rag_chain(model_name="gpt-3.5-turbo")
|
54 |
|
55 |
# 加载原始Markdown文档
|
56 |
loader = TextLoader(file_path)
|
57 |
orig_documents = loader.load()
|
58 |
|
|
|
|
|
|
|
59 |
|
60 |
-
def rag_answer(question):
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
return answer, highlighted_document
|
68 |
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
with gr.Row():
|
79 |
-
with gr.Column():
|
80 |
-
inputs = gr.Textbox(
|
81 |
-
placeholder="请您输入任何关于 LightZero 的问题。",
|
82 |
-
label="问题 (Q)") # 设置输出框,包括答案和高亮显示参考文档
|
83 |
-
gr_submit = gr.Button('提交')
|
84 |
-
|
85 |
-
outputs_answer = gr.Textbox(placeholder="当你点击提交按钮后,这里会显示 RAG 模型给出的回答。",
|
86 |
-
label="回答 (A)")
|
87 |
-
with gr.Row():
|
88 |
-
# placeholder="当你点击提交按钮后,这里会显示参考的文档,其中检索得到的与问题最相关的 context 用高亮显示。"
|
89 |
-
outputs_context = gr.Markdown(label="参考的文档,检索得到的 context 用高亮显示 (C)")
|
90 |
-
|
91 |
-
gr.Markdown(tos_markdown)
|
92 |
-
|
93 |
-
gr_submit.click(
|
94 |
-
rag_answer,
|
95 |
-
inputs=inputs,
|
96 |
-
outputs=[outputs_answer, outputs_context],
|
97 |
-
)
|
98 |
|
99 |
if __name__ == "__main__":
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
|
3 |
import gradio as gr
|
|
|
9 |
# 环境设置
|
10 |
load_dotenv() # 加载环境变量
|
11 |
QUESTION_LANG = os.getenv("QUESTION_LANG") # 从环境变量获取 QUESTION_LANG
|
|
|
12 |
assert QUESTION_LANG in ['cn', 'en'], QUESTION_LANG
|
13 |
|
14 |
if QUESTION_LANG == "cn":
|
|
|
17 |
<div align="center">
|
18 |
<img src="https://raw.githubusercontent.com/puyuan1996/RAG/main/assets/banner.svg" width="80%" height="20%" alt="Banner Image">
|
19 |
</div>
|
20 |
+
<h2 style="text-align: center; color: black;"><a href="https://github.com/puyuan1996/RAG"> LightZero RAG Demo</a></h2>
|
21 |
+
<h4 align="center"> 📢说明:请您在下面的"问题(Q)"框中输入任何关于 LightZero 的问题,然后点击"提交"按钮。右侧"回答(A)"框中会显示 RAG 模型给出的回答。在 QA 栏的下方会给出参考文档(其中检索得到的相关文段会用黄色高亮显示)。</h4>
|
22 |
<h4 align="center"> 如果你喜欢这个项目,请给我们在 GitHub 点个 star ✨ 。我们将会持续保持更新。 </h4>
|
23 |
<strong><h5 align="center">注意:算法模型的输出可能包含一定的随机性。相关结果不代表任何开发者和相关 AI 服务的态度和意见。本项目开发者不对生成结果作任何保证,仅供参考。<h5></strong>
|
24 |
"""
|
|
|
33 |
|
34 |
# 路径变量,方便之后的文件使用
|
35 |
file_path = './documents/LightZero_README.zh.md'
|
|
|
|
|
|
|
|
|
36 |
|
37 |
# 加载原始Markdown文档
|
38 |
loader = TextLoader(file_path)
|
39 |
orig_documents = loader.load()
|
40 |
|
41 |
+
# 存储对话历史
|
42 |
+
conversation_history = []
|
43 |
+
|
44 |
|
45 |
+
def rag_answer(question, model_name, temperature, embedding_model, k):
|
46 |
+
"""
|
47 |
+
处理用户问题并返回答案和高亮显示的上下文
|
48 |
+
|
49 |
+
:param question: 用户输入的问题
|
50 |
+
:param model_name: 使用的语言模型名称
|
51 |
+
:param temperature: 生成答案时使用的温度参数
|
52 |
+
:param embedding_model: 使用的嵌入模型
|
53 |
+
:param k: 检索到的文档块数量
|
54 |
+
:return: 模型生成的答案和高亮显示上下文的Markdown文本
|
55 |
+
"""
|
56 |
+
try:
|
57 |
+
chunks = load_and_split_document(file_path, chunk_size=5000, chunk_overlap=500)
|
58 |
+
retriever = create_vector_store(chunks, model=embedding_model, k=k)
|
59 |
+
rag_chain = setup_rag_chain(model_name=model_name, temperature=temperature)
|
60 |
+
|
61 |
+
# 将问题添加到对话历史中
|
62 |
+
conversation_history.append(("User", question))
|
63 |
+
|
64 |
+
# 将对话历史转换为字符串
|
65 |
+
history_str = "\n".join([f"{role}: {text}" for role, text in conversation_history])
|
66 |
+
|
67 |
+
retrieved_documents, answer = execute_query(retriever, rag_chain, history_str, model_name=model_name,
|
68 |
+
temperature=temperature)
|
69 |
+
# 在文档中高亮显示上下文
|
70 |
+
context = [retrieved_documents[i].page_content for i in range(len(retrieved_documents))]
|
71 |
+
highlighted_document = orig_documents[0].page_content
|
72 |
+
for i in range(len(context)):
|
73 |
+
highlighted_document = highlighted_document.replace(context[i], f"<mark>{context[i]}</mark>")
|
74 |
+
|
75 |
+
# 将回答添加到���话历史中
|
76 |
+
conversation_history.append(("Assistant", answer))
|
77 |
+
except Exception as e:
|
78 |
+
print(f"An error occurred: {e}")
|
79 |
+
return "处理您的问题时出现错误,请稍后再试。", ""
|
80 |
return answer, highlighted_document
|
81 |
|
82 |
+
|
83 |
+
def clear_context():
|
84 |
+
"""
|
85 |
+
清除对话历史
|
86 |
+
"""
|
87 |
+
global conversation_history
|
88 |
+
conversation_history = []
|
89 |
+
return "", ""
|
90 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
if __name__ == "__main__":
|
93 |
+
with gr.Blocks(title=title, theme='ParityError/Interstellar') as rag_demo:
|
94 |
+
gr.Markdown(title_markdown)
|
95 |
+
|
96 |
+
with gr.Row():
|
97 |
+
with gr.Column():
|
98 |
+
inputs = gr.Textbox(
|
99 |
+
placeholder="请您输入任何关于 LightZero 的问题。",
|
100 |
+
label="问题 (Q)")
|
101 |
+
model_name = gr.Dropdown(
|
102 |
+
choices=['kimi', 'abab6-chat', 'glm-4', 'gpt-3.5-turbo', 'gpt-4', 'gpt-4-turbo', 'azure_gpt-4', 'azure_gpt-35-turbo-16k', 'azure_gpt-35-turbo'],
|
103 |
+
# value='azure_gpt-4',
|
104 |
+
value='kimi',
|
105 |
+
label="选择语言模型")
|
106 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.01, step=0.01, label="温度参数")
|
107 |
+
embedding_model = gr.Dropdown(
|
108 |
+
choices=['HuggingFace', 'TensorflowHub', 'OpenAI'],
|
109 |
+
value='OpenAI',
|
110 |
+
label="选择嵌入模型")
|
111 |
+
k = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="检索到的文档块数量")
|
112 |
+
with gr.Row():
|
113 |
+
gr_submit = gr.Button('提交')
|
114 |
+
gr_clear = gr.Button('清除上下文')
|
115 |
+
|
116 |
+
outputs_answer = gr.Textbox(placeholder="当你点击提交按钮后,这里会显示 RAG 模型给出的回答。",
|
117 |
+
label="回答 (A)")
|
118 |
+
with gr.Row():
|
119 |
+
outputs_context = gr.Markdown(label="参考的文档,检索得到的 context 用高亮显示 (C)")
|
120 |
+
|
121 |
+
gr.Markdown(tos_markdown)
|
122 |
+
|
123 |
+
gr_submit.click(
|
124 |
+
rag_answer,
|
125 |
+
inputs=[inputs, model_name, temperature, embedding_model, k],
|
126 |
+
outputs=[outputs_answer, outputs_context],
|
127 |
+
)
|
128 |
+
gr_clear.click(clear_context, outputs=[outputs_answer, outputs_context])
|
129 |
+
|
130 |
+
concurrency = int(os.environ.get('CONCURRENCY', os.cpu_count()))
|
131 |
+
favicon_path = os.path.join(os.path.dirname(__file__), 'assets', 'avatar.png')
|
132 |
+
rag_demo.queue().launch(max_threads=concurrency, favicon_path=favicon_path, share=True)
|
133 |
+
|
app_mqa.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
from langchain.document_loaders import TextLoader
|
6 |
+
|
7 |
+
from rag_demo import load_and_split_document, create_vector_store, setup_rag_chain, execute_query
|
8 |
+
|
9 |
+
# 环境设置
|
10 |
+
load_dotenv() # 加载环境变量
|
11 |
+
QUESTION_LANG = os.getenv("QUESTION_LANG") # 从环境变量获取 QUESTION_LANG
|
12 |
+
assert QUESTION_LANG in ['cn', 'en'], QUESTION_LANG
|
13 |
+
|
14 |
+
if QUESTION_LANG == "cn":
|
15 |
+
title = "LightZero RAG Demo"
|
16 |
+
title_markdown = """
|
17 |
+
<div align="center">
|
18 |
+
<img src="https://raw.githubusercontent.com/puyuan1996/RAG/main/assets/banner.svg" width="80%" height="20%" alt="Banner Image">
|
19 |
+
</div>
|
20 |
+
<h2 style="text-align: center; color: black;"><a href="https://github.com/puyuan1996/RAG"> LightZero RAG Demo</a></h2>
|
21 |
+
<h4 align="center"> 📢说明:请您在下面的"问题(Q)"框中输入任何关于 LightZero 的问题,然后点击"提交"按钮。右侧"回答(A)"框中会显示 RAG 模型给出的回答。在 QA 栏的下方会给出参考文档(其中检索得到的相关文段会用黄色高亮显示)。</h4>
|
22 |
+
<h4 align="center"> 如果你喜欢这个项目,请给我们在 GitHub 点个 star ✨ 。我们将会持续保持更新。 </h4>
|
23 |
+
<strong><h5 align="center">注意:算法模型的输出可能包含一定的随机性。相关结果不代表任何开发者和相关 AI 服务的态度和意见。本项目开发者不对生成结果作任何保证,仅供参考。<h5></strong>
|
24 |
+
"""
|
25 |
+
tos_markdown = """
|
26 |
+
### 使用条款
|
27 |
+
玩家使用本服务须同意以下条款:
|
28 |
+
该服务是一项探索性研究预览版,仅供非商业用途。它仅提供有限的安全措施,并可能生成令人反感的内容。不得将其用于任何非法、有害、暴力、种族主义等目的。
|
29 |
+
如果您的游玩体验有不佳之处,请发送邮件至 [email protected] ! 我们将删除相关信息,并不断改进这个项目。
|
30 |
+
为了获得最佳体验,请使用台式电脑,因为移动设备可能会影响可视化效果。
|
31 |
+
**版权所有 2024 OpenDILab。**
|
32 |
+
"""
|
33 |
+
|
34 |
+
# 路径变量,方便之后的文件使用
|
35 |
+
file_path = './documents/LightZero_README.zh.md'
|
36 |
+
|
37 |
+
# 加载原始Markdown文档
|
38 |
+
loader = TextLoader(file_path)
|
39 |
+
orig_documents = loader.load()
|
40 |
+
|
41 |
+
# 存储对话历史
|
42 |
+
conversation_history = []
|
43 |
+
|
44 |
+
|
45 |
+
def rag_answer(question, model_name, temperature, embedding_model, k):
|
46 |
+
"""
|
47 |
+
处理用户问题并返回答案和高亮显示的上下文
|
48 |
+
|
49 |
+
:param question: 用户输入的问题
|
50 |
+
:param model_name: 使用的语言模型名称
|
51 |
+
:param temperature: 生成答案时使用的温度参数
|
52 |
+
:param embedding_model: 使用的嵌入模型
|
53 |
+
:param k: 检索到的文档块数量
|
54 |
+
:return: 模型生成的答案和高亮显示上下文的Markdown文本
|
55 |
+
"""
|
56 |
+
try:
|
57 |
+
chunks = load_and_split_document(file_path, chunk_size=5000, chunk_overlap=500)
|
58 |
+
retriever = create_vector_store(chunks, model=embedding_model, k=k)
|
59 |
+
rag_chain = setup_rag_chain(model_name=model_name, temperature=temperature)
|
60 |
+
|
61 |
+
# 将问题添加到对话历史中
|
62 |
+
conversation_history.append(("User", question))
|
63 |
+
|
64 |
+
# 将对话历史转换为字符串
|
65 |
+
history_str = "\n".join([f"{role}: {text}" for role, text in conversation_history])
|
66 |
+
|
67 |
+
retrieved_documents, answer = execute_query(retriever, rag_chain, history_str, model_name=model_name,
|
68 |
+
temperature=temperature)
|
69 |
+
# 在文档中高亮显示上下文
|
70 |
+
context = [retrieved_documents[i].page_content for i in range(len(retrieved_documents))]
|
71 |
+
highlighted_document = orig_documents[0].page_content
|
72 |
+
for i in range(len(context)):
|
73 |
+
highlighted_document = highlighted_document.replace(context[i], f"<mark>{context[i]}</mark>")
|
74 |
+
|
75 |
+
# 将回答添加到对话历史中
|
76 |
+
conversation_history.append(("Assistant", answer))
|
77 |
+
except Exception as e:
|
78 |
+
print(f"An error occurred: {e}")
|
79 |
+
return "处理您的问题时出现错误,请稍后再试。", ""
|
80 |
+
return answer, highlighted_document
|
81 |
+
|
82 |
+
|
83 |
+
def clear_context():
|
84 |
+
"""
|
85 |
+
清除对话历史
|
86 |
+
"""
|
87 |
+
global conversation_history
|
88 |
+
conversation_history = []
|
89 |
+
return "", ""
|
90 |
+
|
91 |
+
|
92 |
+
if __name__ == "__main__":
|
93 |
+
with gr.Blocks(title=title, theme='ParityError/Interstellar') as rag_demo:
|
94 |
+
gr.Markdown(title_markdown)
|
95 |
+
|
96 |
+
with gr.Row():
|
97 |
+
with gr.Column():
|
98 |
+
inputs = gr.Textbox(
|
99 |
+
placeholder="请您输入任何关于 LightZero 的问题。",
|
100 |
+
label="问题 (Q)")
|
101 |
+
model_name = gr.Dropdown(
|
102 |
+
choices=['kimi', 'abab6-chat', 'glm-4', 'gpt-3.5-turbo', 'gpt-4', 'gpt-4-turbo', 'azure_gpt-4', 'azure_gpt-35-turbo-16k', 'azure_gpt-35-turbo'],
|
103 |
+
# value='azure_gpt-4',
|
104 |
+
value='kimi',
|
105 |
+
label="选择语言模型")
|
106 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.01, step=0.01, label="温度参数")
|
107 |
+
embedding_model = gr.Dropdown(
|
108 |
+
choices=['HuggingFace', 'TensorflowHub', 'OpenAI'],
|
109 |
+
value='OpenAI',
|
110 |
+
label="选择嵌入模型")
|
111 |
+
k = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="检索到的文档块数量")
|
112 |
+
with gr.Row():
|
113 |
+
gr_submit = gr.Button('提交')
|
114 |
+
gr_clear = gr.Button('清除上下文')
|
115 |
+
|
116 |
+
outputs_answer = gr.Textbox(placeholder="当你点击提交按钮后,这里会显示 RAG 模型给出的回答。",
|
117 |
+
label="回答 (A)")
|
118 |
+
with gr.Row():
|
119 |
+
outputs_context = gr.Markdown(label="参考的文档,检索得到的 context 用高亮显示 (C)")
|
120 |
+
|
121 |
+
gr.Markdown(tos_markdown)
|
122 |
+
|
123 |
+
gr_submit.click(
|
124 |
+
rag_answer,
|
125 |
+
inputs=[inputs, model_name, temperature, embedding_model, k],
|
126 |
+
outputs=[outputs_answer, outputs_context],
|
127 |
+
)
|
128 |
+
gr_clear.click(clear_context, outputs=[outputs_answer, outputs_context])
|
129 |
+
|
130 |
+
concurrency = int(os.environ.get('CONCURRENCY', os.cpu_count()))
|
131 |
+
favicon_path = os.path.join(os.path.dirname(__file__), 'assets', 'avatar.png')
|
132 |
+
rag_demo.queue().launch(max_threads=concurrency, favicon_path=favicon_path, share=True)
|
app_qa.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
from langchain.document_loaders import TextLoader
|
6 |
+
|
7 |
+
from rag_demo import load_and_split_document, create_vector_store, setup_rag_chain, execute_query
|
8 |
+
|
9 |
+
# 环境设置
|
10 |
+
load_dotenv() # 加载环境变量
|
11 |
+
QUESTION_LANG = os.getenv("QUESTION_LANG") # 从环境变量获取 QUESTION_LANG
|
12 |
+
assert QUESTION_LANG in ['cn', 'en'], QUESTION_LANG
|
13 |
+
|
14 |
+
if QUESTION_LANG == "cn":
|
15 |
+
title = "LightZero RAG Demo"
|
16 |
+
title_markdown = """
|
17 |
+
<div align="center">
|
18 |
+
<img src="https://raw.githubusercontent.com/puyuan1996/RAG/main/assets/banner.svg" width="80%" height="20%" alt="Banner Image">
|
19 |
+
</div>
|
20 |
+
<h2 style="text-align: center; color: black;"><a href="https://github.com/puyuan1996/RAG"> LightZero RAG Demo</a></h2>
|
21 |
+
<h4 align="center"> 📢说明:请您在下面的"问题(Q)"框中输入任何关于 LightZero 的问题,然后点击"提交"按钮。右侧"回答(A)"框中会显示 RAG 模型给出的回答。在 QA 栏的下方会给出参考文档(其中检索得到的相关文段会用黄色高亮显示)。</h4>
|
22 |
+
<h4 align="center"> 如果你喜欢这个项目,请给我们在 GitHub 点个 star ✨ 。我们将会持续保持更新。 </h4>
|
23 |
+
<strong><h5 align="center">注意:算法模型的输出可能包含一定的随机性。相关结果不代表任何开发者和相关 AI 服务的态度和意见。本项目开发者不对生成结果作任何保证,仅供参考。<h5></strong>
|
24 |
+
"""
|
25 |
+
tos_markdown = """
|
26 |
+
### 使用条款
|
27 |
+
玩家使用本服务须同意以下条款:
|
28 |
+
该服务是一项探索性研究预览版,仅供非商业用途。它仅提供有限的安全措施,并可能生成令人反感的内容。不得将其用于任何非法、有害、暴力、种族主义等目的。
|
29 |
+
如果您的游玩体验有不佳之处,请发送邮件至 [email protected] ! 我们将删除相关信息,并不断改进这个项目。
|
30 |
+
为了获得最佳体验,请使用台式电脑,因为移动设备可能会影响可视化效果。
|
31 |
+
**版权所有 2024 OpenDILab。**
|
32 |
+
"""
|
33 |
+
|
34 |
+
# 路径变量,方便之后的文件使用
|
35 |
+
file_path = './documents/LightZero_README.zh.md'
|
36 |
+
|
37 |
+
# 加载原始Markdown文档
|
38 |
+
loader = TextLoader(file_path)
|
39 |
+
orig_documents = loader.load()
|
40 |
+
|
41 |
+
def rag_answer(question, model_name, temperature, embedding_model, k):
|
42 |
+
"""
|
43 |
+
处理用户问题并返回答案和高亮显示的上下文
|
44 |
+
|
45 |
+
:param question: 用户输入的问题
|
46 |
+
:param model_name: 使用的语言模型名称
|
47 |
+
:param temperature: 生成答案时使用的温度参数
|
48 |
+
:param embedding_model: 使用的嵌入模型
|
49 |
+
:param k: 检索到的文档块数量
|
50 |
+
:return: 模型生成的答案和高亮显示上下文的Markdown文本
|
51 |
+
"""
|
52 |
+
try:
|
53 |
+
chunks = load_and_split_document(file_path, chunk_size=5000, chunk_overlap=500)
|
54 |
+
retriever = create_vector_store(chunks, model=embedding_model, k=k)
|
55 |
+
rag_chain = setup_rag_chain(model_name=model_name, temperature=temperature)
|
56 |
+
|
57 |
+
retrieved_documents, answer = execute_query(retriever, rag_chain, question, model_name=model_name, temperature=temperature)
|
58 |
+
# 在文档中高亮显示上下文
|
59 |
+
context = [retrieved_documents[i].page_content for i in range(len(retrieved_documents))]
|
60 |
+
highlighted_document = orig_documents[0].page_content
|
61 |
+
for i in range(len(context)):
|
62 |
+
highlighted_document = highlighted_document.replace(context[i], f"<mark>{context[i]}</mark>")
|
63 |
+
except Exception as e:
|
64 |
+
print(f"An error occurred: {e}")
|
65 |
+
return "处理您的问题时出现错误,请稍后再试。", ""
|
66 |
+
return answer, highlighted_document
|
67 |
+
|
68 |
+
|
69 |
+
if __name__ == "__main__":
|
70 |
+
with gr.Blocks(title=title, theme='ParityError/Interstellar') as rag_demo:
|
71 |
+
gr.Markdown(title_markdown)
|
72 |
+
|
73 |
+
with gr.Row():
|
74 |
+
with gr.Column():
|
75 |
+
inputs = gr.Textbox(
|
76 |
+
placeholder="请您输入任何关于 LightZero 的问题。",
|
77 |
+
label="问题 (Q)")
|
78 |
+
model_name = gr.Dropdown(
|
79 |
+
choices=['kimi', 'abab6-chat', 'glm-4', 'gpt-3.5-turbo', 'gpt-4', 'gpt-4-turbo', 'azure_gpt-4', 'azure_gpt-35-turbo-16k', 'azure_gpt-35-turbo'],
|
80 |
+
# value='azure_gpt-4',
|
81 |
+
value='kimi',
|
82 |
+
label="选择语言模型")
|
83 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.01, step=0.01, label="温度参数")
|
84 |
+
embedding_model = gr.Dropdown(
|
85 |
+
choices=['HuggingFace', 'TensorflowHub', 'OpenAI'],
|
86 |
+
value='OpenAI',
|
87 |
+
label="选择嵌入模型")
|
88 |
+
k = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="检索到的文档块数量")
|
89 |
+
gr_submit = gr.Button('提交')
|
90 |
+
|
91 |
+
outputs_answer = gr.Textbox(placeholder="当你点击提交按钮后,这里会显示 RAG 模型给出的回答。",
|
92 |
+
label="回答 (A)")
|
93 |
+
with gr.Row():
|
94 |
+
outputs_context = gr.Markdown(label="参考的文档,检索得到的 context 用高亮显示 (C)")
|
95 |
+
|
96 |
+
gr.Markdown(tos_markdown)
|
97 |
+
|
98 |
+
gr_submit.click(
|
99 |
+
rag_answer,
|
100 |
+
inputs=[inputs, model_name, temperature, embedding_model, k],
|
101 |
+
outputs=[outputs_answer, outputs_context],
|
102 |
+
)
|
103 |
+
|
104 |
+
concurrency = int(os.environ.get('CONCURRENCY', os.cpu_count()))
|
105 |
+
favicon_path = os.path.join(os.path.dirname(__file__), 'assets', 'avatar.png')
|
106 |
+
rag_demo.queue().launch(max_threads=concurrency, favicon_path=favicon_path, share=True)
|
assets/avatar.png
ADDED
![]() |
rag_demo.py
CHANGED
@@ -2,24 +2,34 @@
|
|
2 |
参考博客:https://mp.weixin.qq.com/s/RUdZjQMSlVOfHfhErSNXnA
|
3 |
"""
|
4 |
# 导入必要的库与模块
|
|
|
5 |
import os
|
6 |
import textwrap
|
7 |
|
|
|
8 |
from dotenv import load_dotenv
|
9 |
from langchain.chat_models import ChatOpenAI
|
10 |
from langchain.document_loaders import TextLoader
|
11 |
-
from langchain.embeddings import OpenAIEmbeddings
|
12 |
from langchain.prompts import ChatPromptTemplate
|
13 |
from langchain.schema.output_parser import StrOutputParser
|
14 |
-
from langchain.schema.runnable import RunnablePassthrough
|
15 |
from langchain.text_splitter import CharacterTextSplitter
|
16 |
from langchain.vectorstores import Weaviate
|
17 |
from weaviate import Client
|
18 |
from weaviate.embedded import EmbeddedOptions
|
|
|
|
|
19 |
|
20 |
# 环境设置与文档下载
|
21 |
load_dotenv() # 加载环境变量
|
22 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # 从环境变量获取 OpenAI API 密钥
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
# 确保 OPENAI_API_KEY 被正确设置
|
25 |
if not OPENAI_API_KEY:
|
@@ -37,79 +47,231 @@ def load_and_split_document(file_path, chunk_size=500, chunk_overlap=50):
|
|
37 |
|
38 |
|
39 |
# 向量存储建立
|
40 |
-
def create_vector_store(chunks, model="OpenAI"):
|
41 |
"""将文档块转换为向量并存储到 Weaviate 中"""
|
42 |
client = Client(embedded_options=EmbeddedOptions())
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
vectorstore = Weaviate.from_documents(
|
45 |
client=client,
|
46 |
documents=chunks,
|
47 |
embedding=embedding_model,
|
48 |
by_text=False
|
49 |
)
|
50 |
-
return vectorstore.as_retriever()
|
51 |
|
52 |
|
53 |
-
# 定义检索增强生成流程
|
54 |
def setup_rag_chain(model_name="gpt-4", temperature=0):
|
55 |
"""设置检索增强生成流程"""
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
return rag_chain
|
72 |
|
73 |
|
74 |
# 执行查询并打印结果
|
75 |
-
def execute_query(retriever, rag_chain, query):
|
76 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
retrieved_documents = retriever.invoke(query)
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
|
82 |
-
# 执行无 RAG 链的查询
|
83 |
def execute_query_no_rag(model_name="gpt-4", temperature=0, query=""):
|
84 |
"""执行无 RAG 链的查询"""
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
|
90 |
-
# rag_demo.py 相对 rag_demo_v0.py 的不同之处在于可以输出检索到的文档块。
|
91 |
if __name__ == "__main__":
|
92 |
# 假设文档已存在于本地
|
93 |
file_path = './documents/LightZero_README.zh.md'
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
# 加载和分割文档
|
96 |
-
chunks = load_and_split_document(file_path)
|
97 |
|
98 |
# 创建向量存储
|
99 |
-
retriever = create_vector_store(chunks)
|
100 |
|
101 |
# 设置 RAG 流程
|
102 |
-
rag_chain = setup_rag_chain()
|
103 |
|
104 |
# 提出问题并获取答案
|
105 |
-
query = "请问 LightZero 里面实现的 AlphaZero 算法支持在 Atari
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
# 使用 RAG 链获取参考的文档与答案
|
109 |
-
retrieved_documents, result_with_rag = execute_query(retriever, rag_chain, query
|
|
|
110 |
|
111 |
# 不使用 RAG 链获取答案
|
112 |
-
result_without_rag = execute_query_no_rag(query=query)
|
113 |
|
114 |
# 打印并对比两种方法的结果
|
115 |
# 使用textwrap.fill来自动分段文本,width参数可以根据你的屏幕宽度进行调整
|
|
|
2 |
参考博客:https://mp.weixin.qq.com/s/RUdZjQMSlVOfHfhErSNXnA
|
3 |
"""
|
4 |
# 导入必要的库与模块
|
5 |
+
import json
|
6 |
import os
|
7 |
import textwrap
|
8 |
|
9 |
+
import requests
|
10 |
from dotenv import load_dotenv
|
11 |
from langchain.chat_models import ChatOpenAI
|
12 |
from langchain.document_loaders import TextLoader
|
13 |
+
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings, TensorflowHubEmbeddings
|
14 |
from langchain.prompts import ChatPromptTemplate
|
15 |
from langchain.schema.output_parser import StrOutputParser
|
|
|
16 |
from langchain.text_splitter import CharacterTextSplitter
|
17 |
from langchain.vectorstores import Weaviate
|
18 |
from weaviate import Client
|
19 |
from weaviate.embedded import EmbeddedOptions
|
20 |
+
from zhipuai import ZhipuAI
|
21 |
+
from openai import AzureOpenAI
|
22 |
|
23 |
# 环境设置与文档下载
|
24 |
load_dotenv() # 加载环境变量
|
25 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # 从环境变量获取 OpenAI API 密钥
|
26 |
+
MIMIMAX_API_KEY = os.getenv("MIMIMAX_API_KEY")
|
27 |
+
MIMIMAX_GROUP_ID = os.getenv("MIMIMAX_GROUP_ID")
|
28 |
+
ZHIPUAI_API_KEY = os.getenv("ZHIPUAI_API_KEY")
|
29 |
+
KIMI_OPENAI_API_KEY = os.getenv("KIMI_OPENAI_API_KEY")
|
30 |
+
|
31 |
+
AZURE_OPENAI_KEY = os.getenv("AZURE_OPENAI_KEY")
|
32 |
+
AZURE_ENDPOINT = os.getenv("AZURE_ENDPOINT")
|
33 |
|
34 |
# 确保 OPENAI_API_KEY 被正确设置
|
35 |
if not OPENAI_API_KEY:
|
|
|
47 |
|
48 |
|
49 |
# 向量存储建立
|
50 |
+
def create_vector_store(chunks, model="OpenAI", k=4):
|
51 |
"""将文档块转换为向量并存储到 Weaviate 中"""
|
52 |
client = Client(embedded_options=EmbeddedOptions())
|
53 |
+
|
54 |
+
if model == "OpenAI":
|
55 |
+
embedding_model = OpenAIEmbeddings()
|
56 |
+
elif model == "HuggingFace":
|
57 |
+
embedding_model = HuggingFaceEmbeddings()
|
58 |
+
elif model == "TensorflowHub":
|
59 |
+
embedding_model = TensorflowHubEmbeddings()
|
60 |
+
else:
|
61 |
+
raise ValueError(f"Unsupported embedding model: {model}")
|
62 |
+
|
63 |
vectorstore = Weaviate.from_documents(
|
64 |
client=client,
|
65 |
documents=chunks,
|
66 |
embedding=embedding_model,
|
67 |
by_text=False
|
68 |
)
|
69 |
+
return vectorstore.as_retriever(search_kwargs={'k': k})
|
70 |
|
71 |
|
|
|
72 |
def setup_rag_chain(model_name="gpt-4", temperature=0):
|
73 |
"""设置检索增强生成流程"""
|
74 |
+
if model_name.startswith("gpt"):
|
75 |
+
# 如果是以gpt开头的模型,使用原来的逻辑
|
76 |
+
prompt_template = """您是一个用于问答任务的专业助手。
|
77 |
+
在处理问答任务时,请根据所提供的[上下文信息]给出回答。
|
78 |
+
如果[上下文信息]与[问题]不相关,那么请运用您的知识库为提问者提供准确的答复。
|
79 |
+
请确保回答内容的质量, 包括相关性、准确性和可读性。
|
80 |
+
[问题]: {question}
|
81 |
+
[上下文信息]: {context}
|
82 |
+
[回答]:
|
83 |
+
"""
|
84 |
+
prompt = ChatPromptTemplate.from_template(prompt_template)
|
85 |
+
llm = ChatOpenAI(model_name=model_name, temperature=temperature)
|
86 |
+
# 创建 RAG 链,参考 https://python.langchain.com/docs/expression_language/
|
87 |
+
rag_chain = (
|
88 |
+
prompt
|
89 |
+
| llm
|
90 |
+
| StrOutputParser()
|
91 |
+
)
|
92 |
+
else:
|
93 |
+
# 如果不是以gpt开头的模型,返回None
|
94 |
+
rag_chain = None
|
95 |
return rag_chain
|
96 |
|
97 |
|
98 |
# 执行查询并打印结果
|
99 |
+
def execute_query(retriever, rag_chain, query, model_name="gpt-4", temperature=0):
|
100 |
+
"""
|
101 |
+
执行查询并返回结果及检索到的文档块
|
102 |
+
|
103 |
+
参数:
|
104 |
+
retriever: 文档检索器对象
|
105 |
+
rag_chain: 检索增强生成链对象,如果为None则不使用RAG链
|
106 |
+
query: 查询问题
|
107 |
+
model_name: 使用的语言模型名称,默认为"gpt-4"
|
108 |
+
temperature: 生成温度,默认为0
|
109 |
+
|
110 |
+
返回:
|
111 |
+
retrieved_documents: 检索到的文档块列表
|
112 |
+
response_text: 生成的回答文本
|
113 |
+
"""
|
114 |
+
# 使用检索器检索相关文档块
|
115 |
retrieved_documents = retriever.invoke(query)
|
116 |
+
|
117 |
+
if rag_chain is not None:
|
118 |
+
# 如果有RAG链,则使用RAG链生成回答
|
119 |
+
rag_chain_response = rag_chain.invoke({"context": retrieved_documents, "question": query})
|
120 |
+
response_text = rag_chain_response
|
121 |
+
else:
|
122 |
+
# 如果没有RAG链,则将检索到的文档块和查询问题按照指定格式输入给语言模型
|
123 |
+
if model_name == "kimi":
|
124 |
+
# 对于有检索能力的模型,使用不同的模板
|
125 |
+
prompt_template = """您是一个用于问答任务的专业助手。
|
126 |
+
在处理问答任务时,请根据所提供的【上下文信息】和【你的知识库和检索到的相关文档】给出回答。
|
127 |
+
请确保回答内容的质量,包括相关性、准确性和可读性。
|
128 |
+
【问题】: {question}
|
129 |
+
【上下文信息】: {context}
|
130 |
+
【回答】:
|
131 |
+
"""
|
132 |
+
else:
|
133 |
+
prompt_template = """您是一个用于问答任务的专业助手。
|
134 |
+
在处理问答任务时,请根据所提供的【上下文信息】给出回答。
|
135 |
+
如果【上下文信息】与【问题】不相关,那么请运用您的知识库为提问者提供准确的答复。
|
136 |
+
请确保回答内容的质量,包括相关性、准确性和可读性。
|
137 |
+
【问题】: {question}
|
138 |
+
【上下文信息】: {context}
|
139 |
+
【回答】:
|
140 |
+
"""
|
141 |
+
|
142 |
+
context = '\n'.join(
|
143 |
+
[f'**Document {i}**: ' + retrieved_documents[i].page_content for i in range(len(retrieved_documents))])
|
144 |
+
prompt = prompt_template.format(question=query, context=context)
|
145 |
+
response_text = execute_query_no_rag(model_name=model_name, temperature=temperature, query=prompt)
|
146 |
+
return retrieved_documents, response_text
|
147 |
|
148 |
|
|
|
149 |
def execute_query_no_rag(model_name="gpt-4", temperature=0, query=""):
|
150 |
"""执行无 RAG 链的查询"""
|
151 |
+
if model_name.startswith("gpt"):
|
152 |
+
# 如果是以gpt开头的模型,使用原来的逻辑
|
153 |
+
llm = ChatOpenAI(model_name=model_name, temperature=temperature)
|
154 |
+
response = llm.invoke(query)
|
155 |
+
return response.content
|
156 |
+
elif model_name.startswith("azure_gpt"):
|
157 |
+
client = AzureOpenAI(
|
158 |
+
azure_endpoint=AZURE_ENDPOINT,
|
159 |
+
api_key=AZURE_OPENAI_KEY,
|
160 |
+
api_version="2024-02-15-preview"
|
161 |
+
)
|
162 |
+
message_text = [{"role": "user", "content": query}, ]
|
163 |
+
completion = client.chat.completions.create(
|
164 |
+
model=model_name[6:], # model_name = 'azure_gpt-4', 'azure_gpt-35-turbo-16k', 'azure_gpt-35-turbo'
|
165 |
+
messages=message_text,
|
166 |
+
temperature=temperature,
|
167 |
+
top_p=0.95,
|
168 |
+
frequency_penalty=0,
|
169 |
+
presence_penalty=0,
|
170 |
+
stop=None
|
171 |
+
)
|
172 |
+
return completion.choices[0].message.content
|
173 |
+
elif model_name == 'abab6-chat':
|
174 |
+
# 如果是'abab6-chat'模型,使用专门的API调用方式
|
175 |
+
url = "https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId=" + MIMIMAX_GROUP_ID
|
176 |
+
headers = {"Content-Type": "application/json", "Authorization": "Bearer " + MIMIMAX_API_KEY}
|
177 |
+
payload = {
|
178 |
+
"bot_setting": [
|
179 |
+
{
|
180 |
+
"bot_name": "MM智能助理",
|
181 |
+
"content": "MM智能助理是一款由MiniMax自研的,没有调用其他产品的接口的大型语言模型。MiniMax是一家中国科技公司,一直致力于进行大模型相关的研究。",
|
182 |
+
}
|
183 |
+
],
|
184 |
+
"messages": [{"sender_type": "USER", "sender_name": "小明", "text": query}],
|
185 |
+
"reply_constraints": {"sender_type": "BOT", "sender_name": "MM智能助理"},
|
186 |
+
"model": model_name,
|
187 |
+
"tokens_to_generate": 1034,
|
188 |
+
"temperature": temperature,
|
189 |
+
"top_p": 0.9,
|
190 |
+
}
|
191 |
+
|
192 |
+
response = requests.request("POST", url, headers=headers, json=payload)
|
193 |
+
# 将 JSON 字符串解析为字典
|
194 |
+
response_dict = json.loads(response.text)
|
195 |
+
# 提取 'reply' 键对应的值
|
196 |
+
return response_dict['reply']
|
197 |
+
|
198 |
+
elif model_name == 'glm-4':
|
199 |
+
# 如果是'glm-4'模型,使用专门的API调用方式
|
200 |
+
client = ZhipuAI(api_key=ZHIPUAI_API_KEY) # 填写您自己的APIKey
|
201 |
+
response = client.chat.completions.create(
|
202 |
+
model=model_name, # 填写需要调用的模型名称
|
203 |
+
messages=[{"role": "user", "content": query}]
|
204 |
+
)
|
205 |
+
return response.choices[0].message.content
|
206 |
+
elif model_name == 'kimi':
|
207 |
+
# 如果是'kimi'模型,使用专门的API调用方式
|
208 |
+
from openai import OpenAI
|
209 |
+
client = OpenAI(
|
210 |
+
api_key=KIMI_OPENAI_API_KEY,
|
211 |
+
base_url="https://api.moonshot.cn/v1",
|
212 |
+
)
|
213 |
+
messages = [
|
214 |
+
{
|
215 |
+
"role": "system",
|
216 |
+
"content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。你会为用户提供安全,有帮助,准确的回答。同时,你会拒绝一切涉及恐怖主义,种族歧视,黄色暴力等问题的回答。Moonshot AI 为专有名词,不可翻译成其他语言。",
|
217 |
+
},
|
218 |
+
{"role": "user",
|
219 |
+
"content": query},
|
220 |
+
]
|
221 |
+
completion = client.chat.completions.create(
|
222 |
+
model="moonshot-v1-128k",
|
223 |
+
messages=messages,
|
224 |
+
temperature=0.01,
|
225 |
+
top_p=1.0,
|
226 |
+
n=1, # 为每条输入消息生成多少个结果
|
227 |
+
stream=False # 流式输出
|
228 |
+
)
|
229 |
+
return completion.choices[0].message.content
|
230 |
+
else:
|
231 |
+
# 如果模型不支持,抛出异常
|
232 |
+
raise ValueError(f"Unsupported model: {model_name}")
|
233 |
|
234 |
|
|
|
235 |
if __name__ == "__main__":
|
236 |
# 假设文档已存在于本地
|
237 |
file_path = './documents/LightZero_README.zh.md'
|
238 |
+
# model_name = "glm-4" # model_name=['abab6-chat', 'glm-4', 'gpt-3.5-turbo', 'gpt-4', 'gpt-4-turbo', 'azure_gpt-4', 'azure_gpt-35-turbo-16k', 'azure_gpt-35-turbo']
|
239 |
+
model_name = 'azure_gpt-4'
|
240 |
+
temperature = 0.01
|
241 |
+
# embedding_model = 'HuggingFace' # embedding_model=['HuggingFace', 'TensorflowHub', 'OpenAI']
|
242 |
+
embedding_model = 'OpenAI' # embedding_model=['HuggingFace', 'TensorflowHub', 'OpenAI']
|
243 |
|
244 |
# 加载和分割文档
|
245 |
+
chunks = load_and_split_document(file_path, chunk_size=5000, chunk_overlap=500)
|
246 |
|
247 |
# 创建向量存储
|
248 |
+
retriever = create_vector_store(chunks, model=embedding_model, k=5)
|
249 |
|
250 |
# 设置 RAG 流程
|
251 |
+
rag_chain = setup_rag_chain(model_name=model_name, temperature=temperature)
|
252 |
|
253 |
# 提出问题并获取答案
|
254 |
+
query = ("GitHub - opendilab/LightZero: [NeurIPS 2023 Spotlight] LightZero: A Unified Benchmark for Monte Carl 请根据这个仓库回答下面的问题:(1)请简要介绍一下 LightZero (2)请详细介绍 LightZero 的框架结构。 (3)请给出安装 LightZero,运行他们的示例代码的详细步骤 (4)- 请问 LightZero 具体支持什么任务(tasks/environments)? (5)请问 LightZero 具体支持什么算法?(6)请问 LightZero 具体支持什么算法,各自支持在哪些任务上运行? (7)请问 LightZero 里面实现的 MuZero 算法支持在 Atari 任务上运行吗?(8)请问 LightZero 里面实现的 AlphaZero 算法支持在 Atari 任务上运行吗?(9)LightZero 支持哪些算法? 各自的优缺点是什么? 我应该如何根据任务特点进行选择呢?(10)请结合 LightZero 中的代码介绍他们是如何实现 MCTS 的。(11)请问对这个仓库提出详细的改进建议")
|
255 |
+
"""
|
256 |
+
(1)请简要介绍一下 LightZero
|
257 |
+
(2)请详细介绍 LightZero 的框架结构。
|
258 |
+
(3)请给出安装 LightZero,运行他们的示例代码的详细步骤
|
259 |
+
(4)请问 LightZero 具体支持什么任务(tasks/environments)?
|
260 |
+
(5)请问 LightZero 具体支持什么算法?
|
261 |
+
(6)请问 LightZero 具体支持什么算法,各自支持在哪些任务上运行?
|
262 |
+
(7)请问 LightZero 里面实现的 MuZero 算法支持在 Atari 任务上运行吗?
|
263 |
+
(8)请问 LightZero 里面实现的 AlphaZero 算法支持在 Atari 任务上运行吗?
|
264 |
+
(9)LightZero 支持哪些算法? 各自的优缺点是什么? 我应该如何根据任务特点进行选择呢?
|
265 |
+
(10)请结合 LightZero 中的代码介绍他们是如何实现 MCTS 的。
|
266 |
+
(11)请问对这个仓库提出详细的改进建议。
|
267 |
+
"""
|
268 |
|
269 |
# 使用 RAG 链获取参考的文档与答案
|
270 |
+
retrieved_documents, result_with_rag = execute_query(retriever, rag_chain, query, model_name=model_name,
|
271 |
+
temperature=temperature)
|
272 |
|
273 |
# 不使用 RAG 链获取答案
|
274 |
+
result_without_rag = execute_query_no_rag(model_name=model_name, query=query, temperature=temperature)
|
275 |
|
276 |
# 打印并对比两种方法的结果
|
277 |
# 使用textwrap.fill来自动分段文本,width参数可以根据你的屏幕宽度进行调整
|
rag_demo_v0.py
DELETED
@@ -1,136 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
参考博客:https://mp.weixin.qq.com/s/RUdZjQMSlVOfHfhErSNXnA
|
3 |
-
"""
|
4 |
-
# 导入必要的库与模块
|
5 |
-
import os
|
6 |
-
import textwrap
|
7 |
-
|
8 |
-
from dotenv import load_dotenv
|
9 |
-
from langchain.chat_models import ChatOpenAI
|
10 |
-
from langchain.document_loaders import TextLoader
|
11 |
-
from langchain.embeddings import OpenAIEmbeddings
|
12 |
-
from langchain.prompts import ChatPromptTemplate
|
13 |
-
from langchain.schema.output_parser import StrOutputParser
|
14 |
-
from langchain.schema.runnable import RunnablePassthrough
|
15 |
-
from langchain.text_splitter import CharacterTextSplitter
|
16 |
-
from langchain.vectorstores import Weaviate
|
17 |
-
from weaviate import Client
|
18 |
-
from weaviate.embedded import EmbeddedOptions
|
19 |
-
|
20 |
-
# 环境设置与文档下载
|
21 |
-
load_dotenv() # 加载环境变量
|
22 |
-
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # 从环境变量获取 OpenAI API 密钥
|
23 |
-
|
24 |
-
# 确保 OPENAI_API_KEY 被正确设置
|
25 |
-
if not OPENAI_API_KEY:
|
26 |
-
raise ValueError("OpenAI API Key not found in the environment variables.")
|
27 |
-
|
28 |
-
|
29 |
-
# 文档加载与分割
|
30 |
-
def load_and_split_document(file_path, chunk_size=500, chunk_overlap=50):
|
31 |
-
"""加载文档并分割成小块"""
|
32 |
-
loader = TextLoader(file_path)
|
33 |
-
documents = loader.load()
|
34 |
-
text_splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
35 |
-
chunks = text_splitter.split_documents(documents)
|
36 |
-
return chunks
|
37 |
-
|
38 |
-
|
39 |
-
# 向量存储建立
|
40 |
-
def create_vector_store(chunks, model="OpenAI"):
|
41 |
-
"""将文档块转换为向量并存储到 Weaviate 中"""
|
42 |
-
client = Client(embedded_options=EmbeddedOptions())
|
43 |
-
embedding_model = OpenAIEmbeddings() if model == "OpenAI" else None # 可以根据需要替换为其他嵌入模型
|
44 |
-
vectorstore = Weaviate.from_documents(
|
45 |
-
client=client,
|
46 |
-
documents=chunks,
|
47 |
-
embedding=embedding_model,
|
48 |
-
by_text=False
|
49 |
-
)
|
50 |
-
return vectorstore.as_retriever()
|
51 |
-
|
52 |
-
|
53 |
-
# 定义检索增强生成流程
|
54 |
-
def setup_rag_chain_v0(retriever, model_name="gpt-4", temperature=0):
|
55 |
-
"""设置检索增强生成流程"""
|
56 |
-
prompt_template = """You are an assistant for question-answering tasks.
|
57 |
-
Use your knowledge to answer the question if the provided context is not relevant.
|
58 |
-
Otherwise, use the context to inform your answer.
|
59 |
-
Question: {question}
|
60 |
-
Context: {context}
|
61 |
-
Answer:
|
62 |
-
"""
|
63 |
-
prompt = ChatPromptTemplate.from_template(prompt_template)
|
64 |
-
llm = ChatOpenAI(model_name=model_name, temperature=temperature)
|
65 |
-
# 创建 RAG 链,参考 https://python.langchain.com/docs/expression_language/
|
66 |
-
rag_chain = (
|
67 |
-
{"context": retriever, "question": RunnablePassthrough()}
|
68 |
-
| prompt
|
69 |
-
| llm
|
70 |
-
| StrOutputParser()
|
71 |
-
)
|
72 |
-
return rag_chain
|
73 |
-
|
74 |
-
|
75 |
-
# 执行查询并打印结果
|
76 |
-
def execute_query_v0(rag_chain, query):
|
77 |
-
"""执行查询并返回结果"""
|
78 |
-
return rag_chain.invoke(query)
|
79 |
-
|
80 |
-
|
81 |
-
# 执行无 RAG 链的查询
|
82 |
-
def execute_query_no_rag(model_name="gpt-4", temperature=0, query=""):
|
83 |
-
"""执行无 RAG 链的查询"""
|
84 |
-
llm = ChatOpenAI(model_name=model_name, temperature=temperature)
|
85 |
-
response = llm.invoke(query)
|
86 |
-
return response.content
|
87 |
-
|
88 |
-
|
89 |
-
# rag_demo.py 相对 rag_demo_v0.py 的不同之处在于可以输出检索到的文档块。
|
90 |
-
if __name__ == "__main__":
|
91 |
-
# 下载并保存文档到本地(这里被注释掉了,因为已经假设文档存在于本地)
|
92 |
-
# url = "https://raw.githubusercontent.com/langchain-ai/langchain/master/docs/docs/modules/state_of_the_union.txt"
|
93 |
-
# res = requests.get(url)
|
94 |
-
# with open("state_of_the_union.txt", "w") as f:
|
95 |
-
# f.write(res.text)
|
96 |
-
|
97 |
-
# 假设文档已存在于本地
|
98 |
-
# file_path = './documents/state_of_the_union.txt'
|
99 |
-
file_path = './documents/LightZero_README.zh.md'
|
100 |
-
|
101 |
-
# 加载和分割文档
|
102 |
-
chunks = load_and_split_document(file_path)
|
103 |
-
|
104 |
-
# 创建向量存储
|
105 |
-
retriever = create_vector_store(chunks)
|
106 |
-
|
107 |
-
# 设置 RAG 流程
|
108 |
-
rag_chain = setup_rag_chain_v0(retriever)
|
109 |
-
|
110 |
-
# 提出问题并获取答案
|
111 |
-
# query = "请你分别用中英文简介 LightZero"
|
112 |
-
# query = "请你用英文简介 LightZero"
|
113 |
-
query = "请你用中文简介 LightZero"
|
114 |
-
# query = "请问 LightZero 支持哪些环境和算法,应该如何快速上手使用?"
|
115 |
-
# query = "请问 LightZero 里面实现的 MuZero 算法支持在 Atari 环境上运行吗?"
|
116 |
-
# query = "请问 LightZero 里面实现的 AlphaZero 算法支持在 Atari 环境上运行吗?请详细解释原因"
|
117 |
-
# query = "请详细解释 MCTS 算法的原理,并给出带有详细中文注释的 Python 代码示例"
|
118 |
-
|
119 |
-
# 使用 RAG 链获取答案
|
120 |
-
result_with_rag = execute_query_v0(rag_chain, query)
|
121 |
-
|
122 |
-
# 不使用 RAG 链获取答案
|
123 |
-
result_without_rag = execute_query_no_rag(query=query)
|
124 |
-
|
125 |
-
# 打印并对比两种方法的结果
|
126 |
-
# 使用textwrap.fill来自动分段文本,width参数可以根据你的屏幕宽度进行调整
|
127 |
-
wrapped_result_with_rag = textwrap.fill(result_with_rag, width=80)
|
128 |
-
wrapped_result_without_rag = textwrap.fill(result_without_rag, width=80)
|
129 |
-
|
130 |
-
# 打印自动分段后的文本
|
131 |
-
print("="*40)
|
132 |
-
print(f"我的问题是:\n{query}")
|
133 |
-
print("="*40)
|
134 |
-
print(f"Result with RAG:\n{wrapped_result_with_rag}")
|
135 |
-
print("="*40)
|
136 |
-
print(f"Result without RAG:\n{wrapped_result_without_rag}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -5,3 +5,4 @@ weaviate-client
|
|
5 |
requests
|
6 |
python-dotenv
|
7 |
tiktoken
|
|
|
|
5 |
requests
|
6 |
python-dotenv
|
7 |
tiktoken
|
8 |
+
sentence-transformers
|