File size: 7,434 Bytes
606df26
 
a7c7b3c
 
129b752
606df26
 
a7c7b3c
ac2cd27
a7c7b3c
a591b90
a7c7b3c
606df26
a7c7b3c
a591b90
a9d03e7
129b752
 
a7c7b3c
 
 
 
 
 
 
 
606df26
 
 
 
 
 
 
 
a7c7b3c
 
 
 
 
 
 
 
 
 
 
a591b90
a7c7b3c
a591b90
a7c7b3c
42c41b0
 
a7c7b3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a591b90
a7c7b3c
 
 
 
 
 
 
 
 
a591b90
79c8493
83588c4
 
 
 
 
a7c7b3c
 
 
 
 
a591b90
 
a7c7b3c
1fd4334
 
a7c7b3c
606df26
79c8493
 
606df26
ac2cd27
 
 
606df26
 
 
 
 
 
 
 
 
 
79c8493
606df26
 
 
79c8493
 
 
 
 
 
a7c7b3c
606df26
a7c7b3c
 
83588c4
 
 
a7c7b3c
a591b90
1fd4334
a7c7b3c
 
83588c4
a7c7b3c
a591b90
a7c7b3c
 
 
 
83588c4
 
 
 
 
 
 
 
 
 
 
606df26
 
 
79c8493
 
 
83588c4
 
 
a7c7b3c
83588c4
a7c7b3c
 
 
 
 
606df26
 
 
 
 
 
79c8493
 
 
 
 
 
a591b90
a7c7b3c
 
 
83588c4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
from typing import List

import gradio
import gradio as gr
import spacy
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.text_splitter import SpacyTextSplitter
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langchain_core.prompts import PromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI

spacy.cli.download("en_core_web_sm")

template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. Use three sentences maximum. Keep the answer as concise as possible. Always say "thanks for asking!" at the end of the answer.
Tips: Make sure to cite your sources, and use the exact words from the context.
{context}
Question: {question}
Helpful Answer:"""
QA_CHAIN_PROMPT = PromptTemplate.from_template(template)


def convert_chat_history_to_messages(chat_history) -> List[BaseMessage]:
    result = []
    for human, ai in chat_history:
        result.append(HumanMessage(content=human))
        result.append(AIMessage(content=ai))
    return result


class RAGDemo(object):
    def __init__(self):
        self.embedding = None
        self.vector_db = None
        self.chat_model = None

    def _init_chat_model(self, model_name, api_key):
        if not api_key:
            gradio.Error("Please enter model API key.")
            return
        if 'glm' in model_name:
            gradio.Error("GLM is not supported yet.")
        elif 'gemini' in model_name:
            self.chat_model = ChatGoogleGenerativeAI(
                google_api_key=api_key,
                model='gemini-pro',
                convert_system_message_to_human=True,
            )

    def _init_embedding(self, embedding_model_name, api_key):
        if not api_key:
            gradio.Error("Please enter embedding API key.")
            return
        if 'glm' in embedding_model_name:
            gradio.Error("GLM is not supported yet.")
        else:
            self.embedding = HuggingFaceInferenceAPIEmbeddings(
                api_key=api_key, model_name=embedding_model_name
            )

    def _build_vector_db(self, file_path):
        if not file_path:
            gradio.Error("Please enter vector database file path.")
            return
        gr.Info("Building vector database...")
        loader = PyPDFLoader(file_path)
        pages = loader.load()

        text_splitter = SpacyTextSplitter(chunk_size=500, chunk_overlap=50)
        docs = text_splitter.split_documents(pages)

        self.vector_db = Chroma.from_documents(
            documents=docs, embedding=self.embedding
        )
        gr.Info("Vector database built successfully.")
        print("Vector database built successfully.")

    def _init_settings(self, model_name, api_key, embedding_model, embedding_api_key, data_file):
        self._init_chat_model(model_name, api_key)
        self._init_embedding(embedding_model, embedding_api_key)
        self._build_vector_db(data_file)

    def _retrieval_qa(self, input_text):
        basic_qa = RetrievalQA.from_chain_type(
            self.chat_model,
            retriever=self.vector_db.as_retriever(),
            chain_type_kwargs={"prompt": QA_CHAIN_PROMPT},
            verbose=True,
        )
        resp = basic_qa.invoke(input_text)
        return resp['result']

    def _chat_qa(self, message, chat_history):
        if not message:
            return "", chat_history
        memory = ConversationBufferMemory(
            chat_memory=ChatMessageHistory(
                messages=convert_chat_history_to_messages(chat_history)
            ),
            memory_key="chat_history",
            return_messages=True,
        )
        qa = ConversationalRetrievalChain.from_llm(
            self.chat_model,
            retriever=self.vector_db.as_retriever(),
            memory=memory,
            verbose=True,
        )
        resp = qa.invoke(message)
        print(f">>> {resp}")
        chat_history.append((message, resp['answer']))
        return "", chat_history

    def _retry_chat_qa(self, chat_history):
        message = ""
        if chat_history:
            message, _ = chat_history.pop()
        return self._chat_qa(message, chat_history)

    def __call__(self):
        with gr.Blocks(title="🔥 RAG Demo") as demo:
            gr.Markdown("# RAG Demo\n\nbase on the [RAG learning note](https://www.jianshu.com/p/9792f1e6c3f9) and "
                        "[rag-practice](https://github.com/hiwei93/rag-practice/tree/main)")
            with gr.Tab("Settings"):
                with gr.Row():
                    with gr.Column():
                        model_name = gr.Dropdown(
                            choices=['gemini-1.0-pro'],
                            value='gemini-1.0-pro',
                            label="model"
                        )
                        api_key = gr.Textbox(placeholder="your api key for LLM", label="api key")
                        embedding_model = gr.Dropdown(
                            choices=['sentence-transformers/all-MiniLM-L6-v2',
                                     'intfloat/multilingual-e5-large'],
                            value="sentence-transformers/all-MiniLM-L6-v2",
                            label="embedding model"
                        )
                        embedding_api_key = gr.Textbox(placeholder="your api key for embedding", label="embedding api key")
                    with gr.Column():
                        data_file = gr.File(file_count='single', label="pdf file")
                        initial_btn = gr.Button("submit")
            with gr.Tab("RAG"):
                with gr.Row():
                    with gr.Column():
                        input_text = gr.Textbox(placeholder="input your question...", label="input")
                        submit_btn = gr.Button("submit")
                    with gr.Column():
                        output = gr.TextArea(label="answer")
            with gr.Tab("Chat RAG"):
                chatbot = gr.Chatbot(label="chat with pdf")
                input_msg = gr.Textbox(placeholder="input your question...", label="input")
                with gr.Row():
                    clear_btn = gr.ClearButton([chatbot, input_msg], value="🧹 Clear")
                    retry_btn = gr.Button("♻️ Retry")
            initial_btn.click(
                self._init_settings,
                inputs=[model_name, api_key, embedding_model, embedding_api_key, data_file]
            )

            submit_btn.click(
                self._retrieval_qa,
                inputs=input_text,
                outputs=output,
            )

            input_msg.submit(
                self._chat_qa,
                inputs=[input_msg, chatbot],
                outputs=[input_msg, chatbot]
            )

            retry_btn.click(
                self._retry_chat_qa,
                inputs=chatbot,
                outputs=[input_msg, chatbot]
            )
        return demo


app = RAGDemo()
app().launch(debug=True)