import os import faiss import numpy as np import PyPDF2 import io from docx import Document from nltk.tokenize import sent_tokenize from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from sentence_transformers import SentenceTransformer from langchain_community.vectorstores import FAISS from langchain_community.embeddings import HuggingFaceEmbeddings import gradio as gr import pickle # Download NLTK punkt tokenizer if not already downloaded import nltk nltk.download('punkt') # Function to extract text from a PDF file def extract_text_from_pdf(pdf_data): text = "" try: pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_data)) for page in pdf_reader.pages: text += page.extract_text() except Exception as e: print(f"Error extracting text from PDF: {e}") return text # Function to extract text from a Word document def extract_text_from_docx(docx_data): text = "" try: doc = Document(io.BytesIO(docx_data)) text = "\n".join([para.text for para in doc.paragraphs]) except Exception as e: print(f"Error extracting text from DOCX: {e}") return text # Initialize Sentence Transformer model for embeddings embedding_model = SentenceTransformer('all-MiniLM-L6-v2') # Initialize Hugging Face API token api_token = os.getenv('HUGGINGFACEHUB_API_TOKEN') if not api_token: raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable is not set") # Initialize RAG models from Hugging Face generator_model_name = "facebook/bart-base" retriever_model_name = "facebook/bart-base" generator = AutoModelForSeq2SeqLM.from_pretrained(generator_model_name) generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name) retriever = AutoModelForSeq2SeqLM.from_pretrained(retriever_model_name) retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_model_name) # Initialize FAISS index using LangChain hf_embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2') # Load or create FAISS index index_path = "faiss_index.pkl" if os.path.exists(index_path): with open(index_path, "rb") as f: faiss_index = pickle.load(f) print("Loaded FAISS index from faiss_index.pkl") else: faiss_index = FAISS(embedding_function=hf_embeddings) def preprocess_text(text): sentences = sent_tokenize(text) return sentences def upload_files(files): global faiss_index try: for file in files: if file.name.endswith('.pdf'): text = extract_text_from_pdf(file.read()) elif file.name.endswith('.docx'): text = extract_text_from_docx(file.read()) else: return {"error": "Unsupported file format"} # Preprocess text sentences = preprocess_text(text) # Encode sentences and add to FAISS index embeddings = embedding_model.encode(sentences) for embedding in embeddings: faiss_index.add(np.expand_dims(embedding, axis=0)) # Save the updated index with open(index_path, "wb") as f: pickle.dump(faiss_index, f) return {"message": "Files processed successfully"} except Exception as e: print(f"Error processing files: {e}") return {"error": str(e)} # Provide informative error message def process_and_query(state, question): if question: # Preprocess the question question_embedding = embedding_model.encode([question]) # Search the FAISS index for similar passages D, I = faiss_index.search(np.array(question_embedding), k=5) retrieved_passages = [faiss_index.index_to_text(i) for i in I[0]] # Use generator model to generate response based on question and retrieved passages prompt_template = """ Answer the question as detailed as possible from the provided context, make sure to provide all the details, if the answer is not in provided context just say, "answer is not available in the context", don't provide the wrong answer Context:\n{context}\n Question:\n{question}\n Answer: """ combined_input = prompt_template.format(context=' '.join(retrieved_passages), question=question) inputs = generator_tokenizer(combined_input, return_tensors="pt") with torch.no_grad(): generator_outputs = generator.generate(**inputs) generated_text = generator_tokenizer.decode(generator_outputs[0], skip_special_tokens=True) # Update conversation history state.append({"question": question, "answer": generated_text}) return {"message": generated_text, "conversation": state} return {"error": "No question provided"} # Initialize an empty state variable to store conversation history state = [] # Create Gradio interface with gr.Blocks() as demo: gr.Markdown("## Document Upload and Query System") with gr.Tab("Upload Files"): upload = gr.File(file_count="multiple", label="Upload PDF or DOCX files") upload_button = gr.Button("Upload") upload_output = gr.Textbox() upload_button.click(fn=upload_files, inputs=upload, outputs=upload_output) with gr.Tab("Query"): query = gr.Textbox(label="Enter your query") query_button = gr.Button("Search") query_output = gr.Textbox() # Setup the click event with correct inputs and outputs query_button.click(fn=process_and_query, inputs=[query], outputs=query_output) demo.launch()