Chatbot / app.py
NaimaAqeel's picture
Update app.py
e510bfe verified
raw
history blame
8.02 kB
import os
import gradio as gr
from docx import Document
import fitz # PyMuPDF for PDF text extraction
from sentence_transformers import SentenceTransformer
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from nltk.tokenize import sent_tokenize
import torch
import pickle
import nltk
import faiss
import numpy as np
# Ensure NLTK resources are downloaded
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt')
# Initialize the embedding model
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Hugging Face API token
api_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
if not api_token:
raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable is not set")
# Define RAG models
generator_model_name = "facebook/bart-base"
retriever_model_name = "facebook/bart-base" # Can be the same as generator
generator = AutoModelForSeq2SeqLM.from_pretrained(generator_model_name)
generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name)
retriever = AutoModelForSeq2SeqLM.from_pretrained(retriever_model_name)
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_model_name)
# Initialize FAISS index using LangChain
hf_embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
# Load or create FAISS index
index_path = "faiss_index.index"
if os.path.exists(index_path):
faiss_index = faiss.read_index(index_path)
print("Loaded FAISS index from faiss_index.index")
else:
# Create a new FAISS index
d = embedding_model.get_sentence_embedding_dimension() # Dimension of the embeddings
faiss_index = faiss.IndexFlatL2(d) # Using IndexFlatL2 for simplicity
state = {
"conversation": [],
"sentences": []
}
def extract_text_from_pdf(pdf_path):
text = ""
try:
doc = fitz.open(pdf_path)
for page_num in range(len(doc)):
page = doc.load_page(page_num)
text += page.get_text()
except Exception as e:
raise RuntimeError(f"Error extracting text from PDF '{pdf_path}': {e}")
return text
def extract_text_from_docx(docx_path):
text = ""
try:
doc = Document(docx_path)
text = "\n".join([para.text for para in doc.paragraphs])
except Exception as e:
raise RuntimeError(f"Error extracting text from DOCX '{docx_path}': {e}")
return text
def preprocess_text(text):
sentences = sent_tokenize(text)
return sentences
def upload_files(files):
global state, faiss_index
try:
for file in files:
try:
if isinstance(file, str):
file_path = file
else:
file_path = file.name
if file_path.endswith('.pdf'):
text = extract_text_from_pdf(file_path)
elif file_path.endswith('.docx'):
text = extract_text_from_docx(file_path)
else:
return {"error": f"Unsupported file format: {file_path}"}
sentences = preprocess_text(text)
embeddings = embedding_model.encode(sentences)
faiss_index.add(np.array(embeddings).astype(np.float32)) # Add embeddings
state["sentences"].extend(sentences)
except Exception as e:
print(f"Error processing file '{file}': {e}")
return {"error": str(e)}
# Save the updated index
faiss.write_index(faiss_index, index_path)
return {"message": "Files processed successfully"}
except Exception as e:
print(f"General error processing files: {e}")
return {"error": str(e)}
def process_and_query(question):
global state, faiss_index
if not question:
return {"error": "No question provided"}
try:
question_embedding = embedding_model.encode([question])
# Perform FAISS search
D, I = faiss_index.search(np.array(question_embedding).astype(np.float32), k=5)
retrieved_results = [state["sentences"][i] for i in I[0] if i != -1] # Ensure valid indices
# Generate response based on retrieved results
context = " ".join(retrieved_results)
# Enhanced prompt template
prompt_template = """
Answer the question as detailed as possible from the provided context,
make sure to provide all the details, if the answer is not in
provided context just say, "answer is not available in the context",
don't provide the wrong answer
Context:\n{context}
Question: \n{question}
Answer:
--------------------------------------------------
Prompt Suggestions:
1. Summarize the primary theme of the context.
2. Elaborate on the crucial concepts highlighted in the context.
3. Pinpoint any supporting details or examples pertinent to the question.
4. Examine any recurring themes or patterns relevant to the question within the context.
5. Contrast differing viewpoints or elements mentioned in the context.
6. Explore the potential implications or outcomes of the information provided.
7. Assess the trustworthiness and validity of the information given.
8. Propose recommendations or advice based on the presented information.
9. Forecast likely future events or results stemming from the context.
10. Expand on the context or background information pertinent to the question.
11. Define any specialized terms or technical language used within the context.
12. Analyze any visual representations like charts or graphs in the context.
13. Highlight any restrictions or important considerations when responding to the question.
14. Examine any presuppositions or biases evident within the context.
15. Present alternate interpretations or viewpoints regarding the information provided.
16. Reflect on any moral or ethical issues raised by the context.
17. Investigate any cause-and-effect relationships identified in the context.
18. Uncover any questions or areas requiring further exploration.
19. Resolve any vague or conflicting information in the context.
20. Cite case studies or examples that demonstrate the concepts discussed in the context.
--------------------------------------------------
Context:\n{context}
Question:\n{question}
Answer:
"""
combined_input = prompt_template.format(context=context, question=question)
inputs = generator_tokenizer(combined_input, return_tensors="pt", max_length=512, truncation=True)
with torch.no_grad():
generator_outputs = generator.generate(**inputs)
generated_text = generator_tokenizer.decode(generator_outputs[0], skip_special_tokens=True)
# Update conversation history
state["conversation"].append({"question": question, "answer": generated_text})
return {"message": generated_text, "conversation": state["conversation"]}
except Exception as e:
print(f"Error processing query: {e}")
return {"error": str(e)}
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("## Document Upload and Query System")
with gr.Tab("Upload Files"):
upload = gr.File(file_count="multiple", label="Upload PDF or DOCX files")
upload_button = gr.Button("Upload")
upload_output = gr.Textbox()
upload_button.click(fn=upload_files, inputs=[upload], outputs=upload_output)
with gr.Tab("Query"):
query = gr.Textbox(label="Enter your query")
query_button = gr.Button("Search")
query_output = gr.Textbox()
query_button.click(fn=process_and_query, inputs=[query], outputs=query_output)
demo.launch()