Spaces:
Build error
Build error
File size: 5,026 Bytes
1649416 c4f7f00 90bf4dc be68f20 c4f7f00 eda6735 c4f7f00 90bf4dc c4f7f00 d382509 c4f7f00 90bf4dc c4f7f00 d382509 c4f7f00 90bf4dc 2c02a9e c4f7f00 90bf4dc be68f20 eda6735 c4f7f00 be68f20 90bf4dc c4f7f00 90bf4dc be68f20 90bf4dc be68f20 90bf4dc be68f20 90bf4dc be68f20 90bf4dc c4f7f00 be68f20 90bf4dc be68f20 90bf4dc be68f20 90bf4dc c4f7f00 be68f20 90bf4dc be68f20 90bf4dc be68f20 90bf4dc c4f7f00 90bf4dc be68f20 c4f7f00 be68f20 c4f7f00 eda6735 c4f7f00 be68f20 c4f7f00 be68f20 c4f7f00 90bf4dc c4f7f00 90bf4dc c4f7f00 90bf4dc c4f7f00 eda6735 90bf4dc eda6735 c4f7f00 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import os
import fitz # PyMuPDF
from docx import Document
from sentence_transformers import SentenceTransformer
from langchain_community.vectorstores import FAISS
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from nltk.tokenize import sent_tokenize
import torch
import gradio as gr
# Function to extract text from a PDF file
def extract_text_from_pdf(pdf_path):
text = ""
doc = fitz.open(pdf_path)
for page in doc:
text += page.get_text()
return text
# Function to extract text from a Word document
def extract_text_from_docx(docx_path):
doc = Document(docx_path)
text = "\n".join([para.text for para in doc.paragraphs])
return text
# Initialize the embedding model
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Hugging Face API token
api_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
if not api_token:
raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable is not set")
# Define RAG models
generator_model_name = "facebook/bart-base"
retriever_model_name = "facebook/bart-base" # Can be the same as generator
generator = AutoModelForSeq2SeqLM.from_pretrained(generator_model_name)
generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name)
retriever = AutoModelForSeq2SeqLM.from_pretrained(retriever_model_name)
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_model_name)
# Load or create FAISS index
index_path = "faiss_index.pkl"
if os.path.exists(index_path):
with open(index_path, "rb") as f:
index = FAISS.load(f)
print("Loaded FAISS index from faiss_index.pkl")
else:
# Create a new FAISS index if it doesn't exist
index = FAISS(embedding_dimension=embedding_model.get_sentence_embedding_dimension())
with open(index_path, "wb") as f:
FAISS.save(index, f)
print("Created new FAISS index and saved to faiss_index.pkl")
def preprocess_text(text):
sentences = sent_tokenize(text)
return sentences
def upload_files(files):
global index
try:
for file_path in files:
if file_path.endswith('.pdf'):
text = extract_text_from_pdf(file_path)
elif file_path.endswith('.docx'):
text = extract_text_from_docx(file_path)
else:
return {"error": "Unsupported file format"}
# Preprocess text
sentences = preprocess_text(text)
# Encode sentences and add to FAISS index
embeddings = embedding_model.encode(sentences)
index.add(embeddings)
return {"message": "Files processed successfully"}
except Exception as e:
print(f"Error processing files: {e}")
return {"error": "Error processing files"}
def process_and_query(state, files, question):
if files:
upload_result = upload_files(files)
if "error" in upload_result:
return upload_result
if question:
# Preprocess the question
question_embedding = embedding_model.encode([question])
# Use retriever model to retrieve relevant passages
with torch.no_grad():
retriever_outputs = retriever(**retriever_tokenizer(question, return_tensors="pt"))
retriever_hidden_states = retriever_outputs.hidden_states[-1] # Last hidden state
# Search the FAISS index for similar passages based on retrieved hidden states
distances, retrieved_ids = index.search(retriever_hidden_states.cpu().numpy(), k=5) # Retrieve top 5 passages
# Get the retrieved passages from the document text
retrieved_passages = [state["processed_text"].split("\n")[i] for i in retrieved_ids.flatten()]
# Use generator model to generate response based on question and retrieved passages
combined_input = torch.cat([question_embedding, embedding_model.encode(retrieved_passages)], dim=0)
with torch.no_grad():
generator_outputs = generator(**generator_tokenizer(combined_input, return_tensors="pt"))
generated_text = generator_tokenizer.decode(generator_outputs.sequences.squeeze())
# Update conversation history
state["conversation"].append({"question": question, "answer": generated_text})
return {"message": generated_text, "conversation": state["conversation"]}
return {"error": "No question provided"}
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("## Document Upload and Query System")
with gr.Tab("Upload Files"):
upload = gr.File(file_count="multiple", label="Upload PDF or DOCX files")
upload_button = gr.Button("Upload")
upload_output = gr.Textbox()
upload_button.click(fn=upload_files, inputs=upload, outputs=upload_output)
with gr.Tab("Query"):
query = gr.Textbox(label="Enter your query")
query_button = gr.Button("Search")
query_output = gr.Textbox()
query_button.click(fn=process_and_query, inputs=[query], outputs=query_output)
demo.launch()
|