File size: 6,431 Bytes
643bd7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6aae17d
 
643bd7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6aae17d
 
 
643bd7e
6aae17d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
643bd7e
 
6aae17d
 
 
 
 
 
 
 
 
643bd7e
 
6aae17d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os

import gradio as gr
from dotenv import load_dotenv
from langchain.document_loaders import TextLoader

from rag_demo import load_and_split_document, create_vector_store, setup_rag_chain, execute_query

# 环境设置
load_dotenv()  # 加载环境变量
QUESTION_LANG = os.getenv("QUESTION_LANG")  # 从环境变量获取 QUESTION_LANG
assert QUESTION_LANG in ['cn', 'en'], QUESTION_LANG

if QUESTION_LANG == "cn":
    title = "LightZero RAG Demo"
    title_markdown = """
    <div align="center">
        <img src="https://raw.githubusercontent.com/puyuan1996/RAG/main/assets/banner.svg" width="80%" height="20%" alt="Banner Image">
    </div>
    <h2 style="text-align: center; color: black;"><a href="https://github.com/puyuan1996/RAG"> LightZero RAG Demo</a></h2>
    <h4 align="center"> 📢说明:请您在下面的"问题(Q)"框中输入任何关于 LightZero 的问题,然后点击"提交"按钮。右侧"回答(A)"框中会显示 RAG 模型给出的回答。在 QA 栏的下方会给出参考文档(其中检索得到的相关文段会用黄色高亮显示)。</h4>
    <h4 align="center"> 如果你喜欢这个项目,请给我们在 GitHub 点个 star ✨ 。我们将会持续保持更新。  </h4>
    <strong><h5 align="center">注意:算法模型的输出可能包含一定的随机性。相关结果不代表任何开发者和相关 AI 服务的态度和意见。本项目开发者不对生成结果作任何保证,仅供参考。<h5></strong>
    """
    tos_markdown = """
    ### 使用条款
    玩家使用本服务须同意以下条款:
    该服务是一项探索性研究预览版,仅供非商业用途。它仅提供有限的安全措施,并可能生成令人反感的内容。不得将其用于任何非法、有害、暴力、种族主义等目的。
    如果您的游玩体验有不佳之处,请发送邮件至 [email protected] ! 我们将删除相关信息,并不断改进这个项目。
    为了获得最佳体验,请使用台式电脑,因为移动设备可能会影响可视化效果。
    **版权所有 2024 OpenDILab。**
    """

# 路径变量,方便之后的文件使用
file_path = './documents/LightZero_README.zh.md'

# 加载原始Markdown文档
loader = TextLoader(file_path)
orig_documents = loader.load()

# 存储对话历史
conversation_history = []


def rag_answer(question, model_name, temperature, embedding_model, k):
    """
    处理用户问题并返回答案和高亮显示的上下文

    :param question: 用户输入的问题
    :param model_name: 使用的语言模型名称
    :param temperature: 生成答案时使用的温度参数
    :param embedding_model: 使用的嵌入模型
    :param k: 检索到的文档块数量
    :return: 模型生成的答案和高亮显示上下文的Markdown文本
    """
    try:
        chunks = load_and_split_document(file_path, chunk_size=5000, chunk_overlap=500)
        retriever = create_vector_store(chunks, model=embedding_model, k=k)
        rag_chain = setup_rag_chain(model_name=model_name, temperature=temperature)

        # 将问题添加到对话历史中
        conversation_history.append(("User", question))

        # 将对话历史转换为字符串
        history_str = "\n".join([f"{role}: {text}" for role, text in conversation_history])

        retrieved_documents, answer = execute_query(retriever, rag_chain, history_str, model_name=model_name,
                                                    temperature=temperature)
        # 在文档中高亮显示上下文
        context = [retrieved_documents[i].page_content for i in range(len(retrieved_documents))]
        highlighted_document = orig_documents[0].page_content
        for i in range(len(context)):
            highlighted_document = highlighted_document.replace(context[i], f"<mark>{context[i]}</mark>")

        # 将回答添加到对话历史中
        conversation_history.append(("Assistant", answer))
    except Exception as e:
        print(f"An error occurred: {e}")
        return "处理您的问题时出现错误,请稍后再试。", ""
    return answer, highlighted_document


def clear_context():
    """
    清除对话历史
    """
    global conversation_history
    conversation_history = []
    return "", ""


if __name__ == "__main__":
    with gr.Blocks(title=title, theme='ParityError/Interstellar') as rag_demo:
        gr.Markdown(title_markdown)

        with gr.Row():
            with gr.Column():
                inputs = gr.Textbox(
                    placeholder="请您输入任何关于 LightZero 的问题。",
                    label="问题 (Q)")
                model_name = gr.Dropdown(
                    choices=['kimi', 'abab6-chat', 'glm-4', 'gpt-3.5-turbo', 'gpt-4', 'gpt-4-turbo', 'azure_gpt-4', 'azure_gpt-35-turbo-16k', 'azure_gpt-35-turbo'],
                    # value='azure_gpt-4',
                    value='kimi',
                    label="选择语言模型")
                temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.01, step=0.01, label="温度参数")
                embedding_model = gr.Dropdown(
                    choices=['HuggingFace', 'TensorflowHub', 'OpenAI'],
                    value='OpenAI',
                    label="选择嵌入模型")
                k = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="检索到的文档块数量")
                with gr.Row():
                    gr_submit = gr.Button('提交')
                    gr_clear = gr.Button('清除上下文')

            outputs_answer = gr.Textbox(placeholder="当你点击提交按钮后,这里会显示 RAG 模型给出的回答。",
                                        label="回答 (A)")
        with gr.Row():
            outputs_context = gr.Markdown(label="参考的文档,检索得到的 context 用高亮显示 (C)")

        gr.Markdown(tos_markdown)

        gr_submit.click(
            rag_answer,
            inputs=[inputs, model_name, temperature, embedding_model, k],
            outputs=[outputs_answer, outputs_context],
        )
        gr_clear.click(clear_context, outputs=[outputs_answer, outputs_context])

    concurrency = int(os.environ.get('CONCURRENCY', os.cpu_count()))
    favicon_path = os.path.join(os.path.dirname(__file__), 'assets', 'avatar.png')
    rag_demo.queue().launch(max_threads=concurrency, favicon_path=favicon_path, share=True)