rag_for_all / app.py
raj999's picture
Update app.py
85ac5f6 verified
import gradio as gr
from huggingface_hub import InferenceClient
from langchain.chains import RetrievalQA
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.llms import HuggingFaceHub
from langchain.chains import ConversationalRetrievalChain
from langchain_unstructured import UnstructuredLoader
import camelot
from pathlib import Path
# Load the HuggingFace language model and embeddings
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
vector_store = None
retriever = None
def parse_page_input(page_input):
pages = set()
for part in page_input.split(","):
part = part.strip()
if '-' in part: # Handle ranges
start, end = part.split('-')
try:
pages.update(range(int(start), int(end) + 1))
except ValueError:
continue # Skip invalid ranges
else: # Handle individual pages
try:
pages.add(int(part))
except ValueError:
continue # Skip invalid page numbers
return sorted(pages) # Return a sorted list of pages
def extract_text_from_pdf(filepath, pages):
chunk_size = 1000 # Example chunk size
overlap = 100 # Example overlap
loader = UnstructuredLoader([filepath], chunk_size=chunk_size, overlap=overlap)
pages_to_load = parse_page_input(pages) # Parse the input for page numbers
# Filter pages according to user input
pages_data = []
for doc in loader.lazy_load():
if doc.metadata.page_number in pages_to_load: # Assuming doc.page_number exists
pages_data.append(doc.page_content)
return "\n".join(pages_data)
def extract_tables_from_pdf(filepath, pages):
if pages:
tables = camelot.read_pdf(filepath, pages=pages)
else:
tables = camelot.read_pdf(filepath, pages='1-end')
return [table.df.to_string(index=False) for table in tables]
def update_documents(text_input):
global vector_store, retriever
documents = text_input.split("\n")
vector_store = FAISS.from_texts(documents, embeddings)
retriever = vector_store.as_retriever()
return f"{len(documents)} documents successfully added to the vector store."
rag_chain = None
def respond(message, history, system_message, max_tokens, temperature, top_p):
global rag_chain, retriever
if retriever is None:
return "Please upload or enter documents before asking a question."
if rag_chain is None:
rag_chain = ConversationalRetrievalChain.from_llm(
HuggingFaceHub(repo_id="HuggingFaceH4/zephyr-7b-beta"),
retriever=retriever
)
conversation_history = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
conversation_history.append({"role": "user", "content": val[0]})
if val[1]:
conversation_history.append({"role": "assistant", "content": val[1]})
conversation_history.append({"role": "user", "content": message})
response = rag_chain({"question": message, "chat_history": history})
return response['answer']
def upload_file(filepath, pages):
text = extract_text_from_pdf(filepath, pages)
tables = extract_tables_from_pdf(filepath, pages)
# Update documents in the vector store
update_documents(text)
return [gr.UploadButton(visible=False),
gr.DownloadButton(label=f"Download {Path(filepath).name}", value=filepath, visible=True),
f"{len(tables)} tables extracted."] # Change to a Textbox below
# Gradio interface setup
demo = gr.Blocks()
with demo:
with gr.Row():
u = gr.UploadButton("Upload a file", file_count="single")
d = gr.DownloadButton("Download the file", visible=False)
page_input = gr.Textbox(label="Pages to Parse (e.g., 1, 2, 5-7)", placeholder="Enter page numbers or ranges")
# Create a Textbox for the status message
status_output = gr.Textbox(label="Status", visible=True)
# Use the proper output components in the upload method
u.upload(upload_file, [u, page_input], [u, d, status_output])
with gr.Row():
chat = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a helpful assistant.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
)
if __name__ == "__main__":
demo.launch()