Chatbot / app.py
NaimaAqeel's picture
Update app.py
9ce0b96 verified
raw
history blame
6.22 kB
import os
import io
import pickle
import PyPDF2
from docx import Document
import numpy as np
from nltk.tokenize import sent_tokenize
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
import gradio as gr
import torch
# Download NLTK punkt tokenizer if not already downloaded
import nltk
nltk.download('punkt')
# Initialize Sentence Transformer model for embeddings
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Initialize Hugging Face API token
api_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
if not api_token:
raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable is not set")
# Initialize RAG models from Hugging Face
generator_model_name = "facebook/bart-base"
retriever_model_name = "facebook/bart-base"
generator = AutoModelForSeq2SeqLM.from_pretrained(generator_model_name)
generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name)
retriever = AutoModelForSeq2SeqLM.from_pretrained(retriever_model_name)
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_model_name)
# Initialize FAISS index using LangChain
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
hf_embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
faiss_index = FAISS(embedding_function=hf_embeddings)
# Function to extract text from a PDF file
def extract_text_from_pdf(pdf_data):
text = ""
try:
pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_data))
for page in pdf_reader.pages:
text += page.extract_text()
except Exception as e:
print(f"Error extracting text from PDF: {e}")
return text
# Function to extract text from a Word document
def extract_text_from_docx(docx_data):
text = ""
try:
doc = Document(io.BytesIO(docx_data))
text = "\n".join([para.text for para in doc.paragraphs])
except Exception as e:
print(f"Error extracting text from DOCX: {e}")
return text
# Function to preprocess text into sentences
def preprocess_text(text):
sentences = sent_tokenize(text)
return sentences
# Function to handle file uploads and update FAISS index
def upload_files(files):
global faiss_index
try:
for file in files:
file_name = file.name
# Extract file content
if isinstance(file, str):
file_content = file
else:
file_content = file.read().decode("utf-8")
if file_name.endswith('.pdf'):
text = extract_text_from_pdf(file_content.encode())
elif file_name.endswith('.docx'):
text = extract_text_from_docx(file_content.encode())
else:
return {"error": "Unsupported file format"}
# Preprocess text
sentences = preprocess_text(text)
# Encode sentences and add to FAISS index
embeddings = embedding_model.encode(sentences)
if faiss_index is not None:
for embedding in embeddings:
faiss_index.add(np.expand_dims(embedding, axis=0))
# Save the updated index (if needed)
# Add your logic here to save the FAISS index if you're using persistence
return {"message": "Files processed successfully"}
except Exception as e:
print(f"Error processing files: {e}")
return {"error": str(e)} # Provide informative error message
# Function to process queries using RAG model
def process_and_query(state, question):
if question:
try:
# Search the FAISS index for similar passages
question_embedding = embedding_model.encode([question])
D, I = faiss_index.search(np.array(question_embedding), k=5)
retrieved_passages = [faiss_index.index_to_text(i) for i in I[0]]
# Use generator model to generate response based on question and retrieved passages
prompt_template = """
Answer the question as detailed as possible from the provided context,
make sure to provide all the details, if the answer is not in
provided context just say, "answer is not available in the context",
don't provide the wrong answer
Context:\n{context}\n
Question:\n{question}\n
Answer:
"""
combined_input = prompt_template.format(context=' '.join(retrieved_passages), question=question)
inputs = generator_tokenizer(combined_input, return_tensors="pt")
with torch.no_grad():
generator_outputs = generator.generate(**inputs)
generated_text = generator_tokenizer.decode(generator_outputs[0], skip_special_tokens=True)
# Update conversation history
state.append({"question": question, "answer": generated_text})
return {"message": generated_text, "conversation": state}
except Exception as e:
print(f"Error processing query: {e}")
return {"error": str(e)}
else:
return {"error": "No question provided"}
# Define the Gradio interface
def main():
upload_tab = gr.Interface(
fn=upload_files,
inputs=gr.inputs.File(label="Upload PDF or DOCX files", multiple=True),
outputs=gr.outputs.Text(label="Upload Status", default="No file uploaded yet", type="textbox"),
live=True,
capture_session=True
)
query_tab = gr.Interface(
fn=process_and_query,
inputs=gr.inputs.Textbox(label="Enter your query"),
outputs=gr.outputs.Textbox(label="Query Response", default="No query processed yet", type="textbox"),
live=True,
capture_session=True
)
gr.Interface(
fn=None,
inputs=[
gr.Interface.Tab("Upload Files", upload_tab),
gr.Interface.Tab("Query", query_tab)
],
outputs=gr.outputs.Textbox(label="Output", default="Output will be shown here", type="textbox"),
live=True,
capture_session=True
).launch()
if __name__ == "__main__":
main()