File size: 7,434 Bytes
606df26 a7c7b3c 129b752 606df26 a7c7b3c ac2cd27 a7c7b3c a591b90 a7c7b3c 606df26 a7c7b3c a591b90 a9d03e7 129b752 a7c7b3c 606df26 a7c7b3c a591b90 a7c7b3c a591b90 a7c7b3c 42c41b0 a7c7b3c a591b90 a7c7b3c a591b90 79c8493 83588c4 a7c7b3c a591b90 a7c7b3c 1fd4334 a7c7b3c 606df26 79c8493 606df26 ac2cd27 606df26 79c8493 606df26 79c8493 a7c7b3c 606df26 a7c7b3c 83588c4 a7c7b3c a591b90 1fd4334 a7c7b3c 83588c4 a7c7b3c a591b90 a7c7b3c 83588c4 606df26 79c8493 83588c4 a7c7b3c 83588c4 a7c7b3c 606df26 79c8493 a591b90 a7c7b3c 83588c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
from typing import List
import gradio
import gradio as gr
import spacy
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.text_splitter import SpacyTextSplitter
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langchain_core.prompts import PromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI
spacy.cli.download("en_core_web_sm")
template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. Use three sentences maximum. Keep the answer as concise as possible. Always say "thanks for asking!" at the end of the answer.
Tips: Make sure to cite your sources, and use the exact words from the context.
{context}
Question: {question}
Helpful Answer:"""
QA_CHAIN_PROMPT = PromptTemplate.from_template(template)
def convert_chat_history_to_messages(chat_history) -> List[BaseMessage]:
result = []
for human, ai in chat_history:
result.append(HumanMessage(content=human))
result.append(AIMessage(content=ai))
return result
class RAGDemo(object):
def __init__(self):
self.embedding = None
self.vector_db = None
self.chat_model = None
def _init_chat_model(self, model_name, api_key):
if not api_key:
gradio.Error("Please enter model API key.")
return
if 'glm' in model_name:
gradio.Error("GLM is not supported yet.")
elif 'gemini' in model_name:
self.chat_model = ChatGoogleGenerativeAI(
google_api_key=api_key,
model='gemini-pro',
convert_system_message_to_human=True,
)
def _init_embedding(self, embedding_model_name, api_key):
if not api_key:
gradio.Error("Please enter embedding API key.")
return
if 'glm' in embedding_model_name:
gradio.Error("GLM is not supported yet.")
else:
self.embedding = HuggingFaceInferenceAPIEmbeddings(
api_key=api_key, model_name=embedding_model_name
)
def _build_vector_db(self, file_path):
if not file_path:
gradio.Error("Please enter vector database file path.")
return
gr.Info("Building vector database...")
loader = PyPDFLoader(file_path)
pages = loader.load()
text_splitter = SpacyTextSplitter(chunk_size=500, chunk_overlap=50)
docs = text_splitter.split_documents(pages)
self.vector_db = Chroma.from_documents(
documents=docs, embedding=self.embedding
)
gr.Info("Vector database built successfully.")
print("Vector database built successfully.")
def _init_settings(self, model_name, api_key, embedding_model, embedding_api_key, data_file):
self._init_chat_model(model_name, api_key)
self._init_embedding(embedding_model, embedding_api_key)
self._build_vector_db(data_file)
def _retrieval_qa(self, input_text):
basic_qa = RetrievalQA.from_chain_type(
self.chat_model,
retriever=self.vector_db.as_retriever(),
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT},
verbose=True,
)
resp = basic_qa.invoke(input_text)
return resp['result']
def _chat_qa(self, message, chat_history):
if not message:
return "", chat_history
memory = ConversationBufferMemory(
chat_memory=ChatMessageHistory(
messages=convert_chat_history_to_messages(chat_history)
),
memory_key="chat_history",
return_messages=True,
)
qa = ConversationalRetrievalChain.from_llm(
self.chat_model,
retriever=self.vector_db.as_retriever(),
memory=memory,
verbose=True,
)
resp = qa.invoke(message)
print(f">>> {resp}")
chat_history.append((message, resp['answer']))
return "", chat_history
def _retry_chat_qa(self, chat_history):
message = ""
if chat_history:
message, _ = chat_history.pop()
return self._chat_qa(message, chat_history)
def __call__(self):
with gr.Blocks(title="🔥 RAG Demo") as demo:
gr.Markdown("# RAG Demo\n\nbase on the [RAG learning note](https://www.jianshu.com/p/9792f1e6c3f9) and "
"[rag-practice](https://github.com/hiwei93/rag-practice/tree/main)")
with gr.Tab("Settings"):
with gr.Row():
with gr.Column():
model_name = gr.Dropdown(
choices=['gemini-1.0-pro'],
value='gemini-1.0-pro',
label="model"
)
api_key = gr.Textbox(placeholder="your api key for LLM", label="api key")
embedding_model = gr.Dropdown(
choices=['sentence-transformers/all-MiniLM-L6-v2',
'intfloat/multilingual-e5-large'],
value="sentence-transformers/all-MiniLM-L6-v2",
label="embedding model"
)
embedding_api_key = gr.Textbox(placeholder="your api key for embedding", label="embedding api key")
with gr.Column():
data_file = gr.File(file_count='single', label="pdf file")
initial_btn = gr.Button("submit")
with gr.Tab("RAG"):
with gr.Row():
with gr.Column():
input_text = gr.Textbox(placeholder="input your question...", label="input")
submit_btn = gr.Button("submit")
with gr.Column():
output = gr.TextArea(label="answer")
with gr.Tab("Chat RAG"):
chatbot = gr.Chatbot(label="chat with pdf")
input_msg = gr.Textbox(placeholder="input your question...", label="input")
with gr.Row():
clear_btn = gr.ClearButton([chatbot, input_msg], value="🧹 Clear")
retry_btn = gr.Button("♻️ Retry")
initial_btn.click(
self._init_settings,
inputs=[model_name, api_key, embedding_model, embedding_api_key, data_file]
)
submit_btn.click(
self._retrieval_qa,
inputs=input_text,
outputs=output,
)
input_msg.submit(
self._chat_qa,
inputs=[input_msg, chatbot],
outputs=[input_msg, chatbot]
)
retry_btn.click(
self._retry_chat_qa,
inputs=chatbot,
outputs=[input_msg, chatbot]
)
return demo
app = RAGDemo()
app().launch(debug=True)
|