Spaces:
Sleeping
Sleeping
import gradio as gr | |
from huggingface_hub import InferenceClient | |
from langchain.chains import RetrievalQA | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import FAISS | |
from langchain.llms import HuggingFaceHub | |
from langchain.chains import ConversationalRetrievalChain | |
# Load the HuggingFace language model and embeddings | |
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta") | |
# Initialize the embeddings model for document retrieval | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") | |
# Initialize vector_store and retriever as None initially | |
vector_store = None | |
retriever = None | |
def update_documents(text_input): | |
global vector_store, retriever | |
# Split the input text into individual documents based on newlines or other delimiters | |
documents = text_input.split("\n") | |
# Update the FAISS vector store with new documents | |
vector_store = FAISS.from_texts(documents, embeddings) | |
# Set the retriever to use the new vector store | |
retriever = vector_store.as_retriever() | |
return f"{len(documents)} documents successfully added to the vector store." | |
# Set up ConversationalRetrievalChain | |
rag_chain = None | |
def respond( | |
message, | |
history: list[tuple[str, str]], | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
): | |
global rag_chain, retriever | |
if retriever is None: | |
return "Please upload or enter documents before asking a question." | |
# Create the chain if it hasn't been initialized | |
if rag_chain is None: | |
rag_chain = ConversationalRetrievalChain.from_llm( | |
HuggingFaceHub(repo_id="HuggingFaceH4/zephyr-7b-beta"), | |
retriever=retriever | |
) | |
# Combine history with the user message | |
conversation_history = [{"role": "system", "content": system_message}] | |
for val in history: | |
if val[0]: | |
conversation_history.append({"role": "user", "content": val[0]}) | |
if val[1]: | |
conversation_history.append({"role": "assistant", "content": val[1]}) | |
conversation_history.append({"role": "user", "content": message}) | |
# Retrieve documents and generate response | |
response = rag_chain({"question": message, "chat_history": history}) | |
# Return the model's response | |
return response['answer'] | |
def upload_file(filepath): | |
name = Path(filepath).name | |
return [gr.UploadButton(visible=False), gr.DownloadButton(label=f"Download {name}", value=filepath, visible=True)] | |
def download_file(): | |
return [gr.UploadButton(visible=True), gr.DownloadButton(visible=False)] | |
# Gradio interface setup | |
demo = gr.Blocks() | |
with demo: | |
with gr.Row(): | |
# Input box for user to add documents | |
doc_input = gr.Textbox( | |
lines=10, placeholder="Enter your documents here, one per line.", label="Input Documents" | |
) | |
# upload_button = gr.Button("Upload Documents") | |
with gr.Row(): | |
u = gr.UploadButton("Upload a file", file_count="single") | |
d = gr.DownloadButton("Download the file", visible=False) | |
u.upload(upload_file, u, [u, d]) | |
d.click(download_file, None, [u, d]) | |
with gr.Row(): | |
# Chat interface for the RAG system | |
chat = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox(value="You are a helpful assistant.", label="System message"), | |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"), | |
], | |
) | |
# Bind button to update the document vector store | |
upload_button.click(update_documents, inputs=[doc_input], outputs=gr.Textbox(label="Status")) | |
if __name__ == "__main__": | |
demo.launch() | |