Spaces:
Build error
Build error
File size: 8,017 Bytes
1649416 03bc240 80e4cb4 03bc240 90bf4dc 80e4cb4 be68f20 c4f7f00 377f3f1 56ec544 124e62a 8c85ad8 56ec544 2b77a1d 80e4cb4 ec0cc7d 80e4cb4 ec0cc7d 80e4cb4 124e62a ec0cc7d 44b52da d1c1284 90bf4dc c4f7f00 b28c6a7 80e4cb4 c4f7f00 d382509 90bf4dc b28c6a7 80e4cb4 c4f7f00 d382509 90bf4dc be68f20 90bf4dc c4f7f00 d1c1284 90bf4dc 377f3f1 80e4cb4 b3c97dd 80e4cb4 ec0cc7d d1c1284 80e4cb4 b3c97dd 80e4cb4 124e62a be68f20 80e4cb4 90bf4dc 80e4cb4 a37ef5b e510bfe 4d187c1 eda6735 4d187c1 be68f20 ec0cc7d 7000461 ec0cc7d 80cb300 e510bfe 80cb300 ec0cc7d c4f7f00 4d187c1 90bf4dc c4f7f00 90bf4dc d1c1284 c4f7f00 90bf4dc b3c97dd 03bc240 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
import os
import gradio as gr
from docx import Document
import fitz # PyMuPDF for PDF text extraction
from sentence_transformers import SentenceTransformer
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from nltk.tokenize import sent_tokenize
import torch
import pickle
import nltk
import faiss
import numpy as np
# Ensure NLTK resources are downloaded
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt')
# Initialize the embedding model
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Hugging Face API token
api_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
if not api_token:
raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable is not set")
# Define RAG models
generator_model_name = "facebook/bart-base"
retriever_model_name = "facebook/bart-base" # Can be the same as generator
generator = AutoModelForSeq2SeqLM.from_pretrained(generator_model_name)
generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name)
retriever = AutoModelForSeq2SeqLM.from_pretrained(retriever_model_name)
retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_model_name)
# Initialize FAISS index using LangChain
hf_embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
# Load or create FAISS index
index_path = "faiss_index.index"
if os.path.exists(index_path):
faiss_index = faiss.read_index(index_path)
print("Loaded FAISS index from faiss_index.index")
else:
# Create a new FAISS index
d = embedding_model.get_sentence_embedding_dimension() # Dimension of the embeddings
faiss_index = faiss.IndexFlatL2(d) # Using IndexFlatL2 for simplicity
state = {
"conversation": [],
"sentences": []
}
def extract_text_from_pdf(pdf_path):
text = ""
try:
doc = fitz.open(pdf_path)
for page_num in range(len(doc)):
page = doc.load_page(page_num)
text += page.get_text()
except Exception as e:
raise RuntimeError(f"Error extracting text from PDF '{pdf_path}': {e}")
return text
def extract_text_from_docx(docx_path):
text = ""
try:
doc = Document(docx_path)
text = "\n".join([para.text for para in doc.paragraphs])
except Exception as e:
raise RuntimeError(f"Error extracting text from DOCX '{docx_path}': {e}")
return text
def preprocess_text(text):
sentences = sent_tokenize(text)
return sentences
def upload_files(files):
global state, faiss_index
try:
for file in files:
try:
if isinstance(file, str):
file_path = file
else:
file_path = file.name
if file_path.endswith('.pdf'):
text = extract_text_from_pdf(file_path)
elif file_path.endswith('.docx'):
text = extract_text_from_docx(file_path)
else:
return {"error": f"Unsupported file format: {file_path}"}
sentences = preprocess_text(text)
embeddings = embedding_model.encode(sentences)
faiss_index.add(np.array(embeddings).astype(np.float32)) # Add embeddings
state["sentences"].extend(sentences)
except Exception as e:
print(f"Error processing file '{file}': {e}")
return {"error": str(e)}
# Save the updated index
faiss.write_index(faiss_index, index_path)
return {"message": "Files processed successfully"}
except Exception as e:
print(f"General error processing files: {e}")
return {"error": str(e)}
def process_and_query(question):
global state, faiss_index
if not question:
return {"error": "No question provided"}
try:
question_embedding = embedding_model.encode([question])
# Perform FAISS search
D, I = faiss_index.search(np.array(question_embedding).astype(np.float32), k=5)
retrieved_results = [state["sentences"][i] for i in I[0] if i != -1] # Ensure valid indices
# Generate response based on retrieved results
context = " ".join(retrieved_results)
# Enhanced prompt template
prompt_template = """
Answer the question as detailed as possible from the provided context,
make sure to provide all the details, if the answer is not in
provided context just say, "answer is not available in the context",
don't provide the wrong answer
Context:\n{context}
Question: \n{question}
Answer:
--------------------------------------------------
Prompt Suggestions:
1. Summarize the primary theme of the context.
2. Elaborate on the crucial concepts highlighted in the context.
3. Pinpoint any supporting details or examples pertinent to the question.
4. Examine any recurring themes or patterns relevant to the question within the context.
5. Contrast differing viewpoints or elements mentioned in the context.
6. Explore the potential implications or outcomes of the information provided.
7. Assess the trustworthiness and validity of the information given.
8. Propose recommendations or advice based on the presented information.
9. Forecast likely future events or results stemming from the context.
10. Expand on the context or background information pertinent to the question.
11. Define any specialized terms or technical language used within the context.
12. Analyze any visual representations like charts or graphs in the context.
13. Highlight any restrictions or important considerations when responding to the question.
14. Examine any presuppositions or biases evident within the context.
15. Present alternate interpretations or viewpoints regarding the information provided.
16. Reflect on any moral or ethical issues raised by the context.
17. Investigate any cause-and-effect relationships identified in the context.
18. Uncover any questions or areas requiring further exploration.
19. Resolve any vague or conflicting information in the context.
20. Cite case studies or examples that demonstrate the concepts discussed in the context.
--------------------------------------------------
Context:\n{context}
Question:\n{question}
Answer:
"""
combined_input = prompt_template.format(context=context, question=question)
inputs = generator_tokenizer(combined_input, return_tensors="pt", max_length=512, truncation=True)
with torch.no_grad():
generator_outputs = generator.generate(**inputs)
generated_text = generator_tokenizer.decode(generator_outputs[0], skip_special_tokens=True)
# Update conversation history
state["conversation"].append({"question": question, "answer": generated_text})
return {"message": generated_text, "conversation": state["conversation"]}
except Exception as e:
print(f"Error processing query: {e}")
return {"error": str(e)}
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("## Document Upload and Query System")
with gr.Tab("Upload Files"):
upload = gr.File(file_count="multiple", label="Upload PDF or DOCX files")
upload_button = gr.Button("Upload")
upload_output = gr.Textbox()
upload_button.click(fn=upload_files, inputs=[upload], outputs=upload_output)
with gr.Tab("Query"):
query = gr.Textbox(label="Enter your query")
query_button = gr.Button("Search")
query_output = gr.Textbox()
query_button.click(fn=process_and_query, inputs=[query], outputs=query_output)
demo.launch()
|