Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from langchain.chains import RetrievalQA | |
from langchain_community.document_loaders import TextLoader | |
from langchain_community.document_loaders import UnstructuredExcelLoader | |
from langchain.indexes import VectorstoreIndexCreator | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain_community.llms import HuggingFaceEndpoint | |
from langchain.memory import ConversationBufferMemory | |
from langchain.chains import ConversationalRetrievalChain | |
from dotenv import find_dotenv, load_dotenv | |
from langchain.chains import create_retrieval_chain, RetrievalQA | |
from langchain_community.vectorstores import FAISS | |
_=load_dotenv(find_dotenv()) | |
hf_api = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
def indexdocs (file_path, progress=gr.Progress()): | |
progress(0.1, desc="Loading documents...") | |
loaders = [UnstructuredExcelLoader(file, mode="elements") for file in file_path] | |
documents=[] | |
for loader in loaders: | |
documents.extend(loader.load()) | |
progress(0.3, desc="Splitting documents...") | |
text_splitter = RecursiveCharacterTextSplitter (chunk_size=1500, chunk_overlap=300) | |
pages=text_splitter.split_documents(documents) | |
embedding = HuggingFaceEmbeddings() | |
progress(0.5, desc="Creating vectorstore...") | |
vector=FAISS.from_documents (documents=pages,embedding=embedding) | |
retriever = vector.as_retriever() | |
progress(0.8, desc="Setting up language model...") | |
memory = ConversationBufferMemory( | |
memory_key="chat_history", | |
output_key='answer', | |
return_messages=True | |
) | |
llm = HuggingFaceEndpoint( | |
repo_id="Mistralai/Mistral-7B-Instruct-v0.2", | |
temperature = 0.1, | |
max_new_tokens = 200, | |
top_k = 1 #top_k, | |
) | |
qa_chain = ConversationalRetrievalChain.from_llm( | |
llm, | |
retriever=retriever, | |
chain_type="stuff", | |
memory=memory, | |
return_source_documents=True, | |
verbose=False, | |
) | |
return qa_chain, None | |
def format_chat_history(chat_history): | |
formatted_chat_history = [] | |
for user_message, bot_message in chat_history: | |
formatted_chat_history.append(f"User: {user_message}") | |
formatted_chat_history.append(f"Assistant: {bot_message}") | |
return formatted_chat_history | |
def chat(qa_chain,msg,history): | |
formatted_chat_history = format_chat_history(history) | |
response = qa_chain.invoke({"question": msg, "chat_history": formatted_chat_history}) | |
response_answer = response["answer"] | |
response_sources=response["source_documents"] | |
response_source1= response_sources[0].metadata["filename"] | |
response_source_sheet= response_sources[0].metadata["page_name"] | |
new_history = history + [(msg, response_answer)] | |
return qa_chain, gr.update(value=""), new_history, response_source1, response_source_sheet | |
with gr.Blocks() as demo: | |
qa_chain=gr.State() | |
gr.Markdown( | |
""" | |
# MS Excel Knowledge Base QA using RAG | |
""" | |
) | |
with gr.Column(): | |
file_list = gr.File(label='Upload your MS Excel files...', file_count='multiple', file_types=['.xls,.xlsx']) | |
fileuploadbtn= gr.Button ("Index Documents and Start Chatting") | |
with gr.Row(): | |
chatbot=gr.Chatbot(height=300) | |
with gr.Row(): | |
source=gr.Textbox(info="Source",container=False,scale=4) | |
source_page=gr.Textbox(info="Sheet",container=False,scale=1) | |
with gr.Row(): | |
prompt=gr.Textbox(placeholder="Please enter your prompt...",container=False, scale=4, visible=True, interactive=False) | |
promptsubmit=gr.Button("Submit", scale=1, visible=True, interactive=False) | |
gr.Markdown( | |
""" | |
# Responsible AI Usage | |
Your documents uploaded to the system or interactions with the chatbot are not saved. | |
""" | |
) | |
fileuploadbtn.click(fn=indexdocs, inputs = [file_list], outputs=[qa_chain,chatbot]).then(lambda:[gr.Textbox(interactive=True), gr.Button (interactive=True)], \ | |
inputs=None, outputs=[prompt,promptsubmit], queue=False) | |
promptsubmit.click(fn=chat, inputs=[qa_chain,prompt,chatbot], outputs=[qa_chain,prompt,chatbot,source,source_page],queue=False) | |
prompt.submit(fn=chat, inputs=[qa_chain,prompt,chatbot], outputs=[qa_chain,prompt,chatbot,source,source_page],queue=False) | |
if __name__ == "__main__": | |
demo.launch() |