NaimaAqeel commited on
Commit
944d263
·
verified ·
1 Parent(s): cf26f9f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -136
app.py CHANGED
@@ -1,22 +1,42 @@
1
  import os
2
- import io
3
- import PyPDF2
4
  from docx import Document
5
- import numpy as np
6
- from nltk.tokenize import sent_tokenize
7
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
8
  from sentence_transformers import SentenceTransformer
 
 
 
9
  import gradio as gr
10
- import torch
 
 
 
11
 
12
- # Download NLTK punkt tokenizer if not already downloaded
13
- import nltk
14
- nltk.download('punkt')
 
 
 
 
 
 
 
 
15
 
16
- # Initialize Sentence Transformer model for embeddings
 
 
 
 
 
 
 
 
 
 
17
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
18
 
19
- # Initialize Hugging Face API token
20
  api_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
21
  if not api_token:
22
  raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable is not set")
@@ -27,144 +47,111 @@ retriever_model_name = "facebook/bart-base"
27
  generator = AutoModelForSeq2SeqLM.from_pretrained(generator_model_name)
28
  generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name)
29
  retriever = AutoModelForSeq2SeqLM.from_pretrained(retriever_model_name)
30
- retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_model_name)
31
-
32
- # Initialize FAISS index using LangChain
33
- from langchain_community.vectorstores import FAISS
34
- from langchain_community.embeddings import HuggingFaceEmbeddings
35
 
36
- # Initialize Hugging Face embeddings
37
- hf_embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
38
-
39
- # Dummy implementations for index, docstore, and index_to_docstore_id
40
- # Replace with actual implementations or configurations as per LangChain documentation
41
- index = None
42
- docstore = None
43
- index_to_docstore_id = None
44
-
45
- # Initialize FAISS index with required parameters
46
- faiss_index = FAISS(
47
- embedding_function=hf_embeddings,
48
- index=index,
49
- docstore=docstore,
50
- index_to_docstore_id=index_to_docstore_id
51
  )
52
 
53
- # Function to extract text from a PDF file
54
- def extract_text_from_pdf(pdf_data):
55
- text = ""
56
- try:
57
- pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_data))
58
- for page in pdf_reader.pages:
59
- text += page.extract_text()
60
- except Exception as e:
61
- print(f"Error extracting text from PDF: {e}")
62
- return text
63
 
64
- # Function to extract text from a Word document
65
- def extract_text_from_docx(docx_data):
66
- text = ""
 
 
 
 
67
  try:
68
- doc = Document(io.BytesIO(docx_data))
69
- text = "\n".join([para.text for para in doc.paragraphs])
 
 
 
 
70
  except Exception as e:
71
- print(f"Error extracting text from DOCX: {e}")
72
- return text
73
-
74
- # Function to preprocess text into sentences
75
- def preprocess_text(text):
76
- sentences = sent_tokenize(text)
77
- return sentences
78
 
79
- # Function to handle file uploads and update FAISS index
80
  def upload_files(files):
81
- global faiss_index
82
  try:
83
  for file in files:
84
- file_name = file.name
85
- file_content = file.read() # Get the file content as bytes
86
-
87
- if file_name.endswith('.pdf'):
88
- text = extract_text_from_pdf(file_content)
89
- elif file_name.endswith('.docx'):
90
- text = extract_text_from_docx(file_content)
91
  else:
92
- return {"error": "Unsupported file format"}
93
 
94
- # Preprocess text
95
- sentences = preprocess_text(text)
96
 
97
- # Encode sentences and add to FAISS index
 
98
  embeddings = embedding_model.encode(sentences)
99
- if faiss_index is not None:
100
- for embedding in embeddings:
101
- faiss_index.add(np.expand_dims(embedding, axis=0))
102
-
103
- # Save the updated index (if needed)
104
- # Add your logic here to save the FAISS index if you're using persistence
105
-
106
- return {"message": "Files processed successfully"}
 
 
 
 
 
107
  except Exception as e:
108
  print(f"Error processing files: {e}")
109
- return {"error": str(e)} # Provide informative error message
110
-
111
- # Function to process queries using RAG model
112
- def process_and_query(state, question):
113
- if question:
114
- try:
115
- # Search the FAISS index for similar passages
116
- question_embedding = embedding_model.encode([question])
117
- D, I = faiss_index.search(np.array(question_embedding), k=5)
118
- retrieved_passages = [faiss_index.index_to_text(i) for i in I[0]]
119
-
120
- # Use generator model to generate response based on question and retrieved passages
121
- prompt_template = """
122
- Answer the question as detailed as possible from the provided context,
123
- make sure to provide all the details, if the answer is not in
124
- provided context just say, "answer is not available in the context",
125
- don't provide the wrong answer
126
- Context:\n{context}\n
127
- Question:\n{question}\n
128
- Answer:
129
- """
130
- combined_input = prompt_template.format(context=' '.join(retrieved_passages), question=question)
131
- inputs = generator_tokenizer(combined_input, return_tensors="pt")
132
- with torch.no_grad():
133
- generator_outputs = generator.generate(**inputs)
134
- generated_text = generator_tokenizer.decode(generator_outputs[0], skip_special_tokens=True)
135
-
136
- # Update conversation history
137
- state.append({"question": question, "answer": generated_text})
138
-
139
- return {"message": generated_text, "conversation": state}
140
- except Exception as e:
141
- print(f"Error processing query: {e}")
142
- return {"error": str(e)}
143
- else:
144
- return {"error": "No question provided"}
145
-
146
- # Define the Gradio interface
147
- def main():
148
- upload_interface = gr.Interface(
149
- fn=upload_files,
150
- inputs=gr.inputs.File(label="Upload PDF or DOCX files", multiple=True),
151
- outputs=gr.outputs.Textbox(label="Upload Status")
152
- )
153
-
154
- query_interface = gr.Interface(
155
- fn=process_and_query,
156
- inputs=[gr.inputs.Textbox(label="Conversation State"), gr.inputs.Textbox(label="Enter your query")],
157
- outputs=[gr.outputs.Textbox(label="Query Response"), gr.outputs.Textbox(label="Conversation State")]
158
- )
159
-
160
- gr.Interface(
161
- fn=None,
162
- inputs=[
163
- gr.Interface.Tab("Upload Files", upload_interface),
164
- gr.Interface.Tab("Query", query_interface)
165
- ],
166
- outputs=gr.outputs.Textbox(label="Output", default="Output will be shown here")
167
- ).launch()
168
-
169
- if __name__ == "__main__":
170
- main()
 
1
  import os
2
+ import fitz
 
3
  from docx import Document
 
 
 
4
  from sentence_transformers import SentenceTransformer
5
+ import faiss
6
+ import numpy as np
7
+ import pickle
8
  import gradio as gr
9
+ from typing import List
10
+ from langchain_community.llms import HuggingFaceEndpoint
11
+ from langchain_community.vectorstores import FAISS
12
+ from langchain_community.embeddings import HuggingFaceEmbeddings
13
 
14
+ # Function to extract text from a PDF file
15
+ def extract_text_from_pdf(pdf_path):
16
+ text = ""
17
+ try:
18
+ doc = fitz.open(pdf_path)
19
+ for page_num in range(len(doc)):
20
+ page = doc.load_page(page_num)
21
+ text += page.get_text()
22
+ except Exception as e:
23
+ print(f"Error extracting text from PDF: {e}")
24
+ return text
25
 
26
+ # Function to extract text from a Word document
27
+ def extract_text_from_docx(docx_path):
28
+ text = ""
29
+ try:
30
+ doc = Document(docx_path)
31
+ text = "\n".join([para.text for para in doc.paragraphs])
32
+ except Exception as e:
33
+ print(f"Error extracting text from DOCX: {e}")
34
+ return text
35
+
36
+ # Initialize the embedding model
37
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
38
 
39
+ # Hugging Face API token
40
  api_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
41
  if not api_token:
42
  raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable is not set")
 
47
  generator = AutoModelForSeq2SeqLM.from_pretrained(generator_model_name)
48
  generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name)
49
  retriever = AutoModelForSeq2SeqLM.from_pretrained(retriever_model_name)
50
+ retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_model_name)
 
 
 
 
51
 
52
+ # Initialize the HuggingFace LLM
53
+ llm = HuggingFaceEndpoint(
54
+ endpoint_url="https://api-inference.huggingface.co/models/gpt2",
55
+ model_kwargs={"api_key": api_token}
 
 
 
 
 
 
 
 
 
 
 
56
  )
57
 
58
+ # Initialize the HuggingFace embeddings
59
+ embedding = HuggingFaceEmbeddings()
 
 
 
 
 
 
 
 
60
 
61
+ # Load or create FAISS index
62
+ index_path = "faiss_index.pkl"
63
+ document_texts_path = "document_texts.pkl"
64
+
65
+ document_texts = []
66
+
67
+ if os.path.exists(index_path) and os.path.exists(document_texts_path):
68
  try:
69
+ with open(index_path, "rb") as f:
70
+ index = pickle.load(f)
71
+ print("Loaded FAISS index from faiss_index.pkl")
72
+ with open(document_texts_path, "rb") as f:
73
+ document_texts = pickle.load(f)
74
+ print("Loaded document texts from document_texts.pkl")
75
  except Exception as e:
76
+ print(f"Error loading FAISS index or document texts: {e}")
77
+ else:
78
+ # Create a new FAISS index if it doesn't exist
79
+ index = faiss.IndexFlatL2(embedding_model.get_sentence_embedding_dimension())
80
+ with open(index_path, "wb") as f:
81
+ pickle.dump(index, f)
82
+ print("Created new FAISS index and saved to faiss_index.pkl")
83
 
 
84
  def upload_files(files):
85
+ global index, document_texts
86
  try:
87
  for file in files:
88
+ file_path = file.name # Get the file path from the NamedString object
89
+ if file_path.endswith('.pdf'):
90
+ text = extract_text_from_pdf(file_path)
91
+ elif file_path.endswith('.docx'):
92
+ text = extract_text_from_docx(file_path)
 
 
93
  else:
94
+ return "Unsupported file format"
95
 
96
+ print(f"Extracted text: {text[:100]}...") # Debug: Show the first 100 characters of the extracted text
 
97
 
98
+ # Process the text and update FAISS index
99
+ sentences = text.split("\n")
100
  embeddings = embedding_model.encode(sentences)
101
+ print(f"Embeddings shape: {embeddings.shape}") # Debug: Show the shape of the embeddings
102
+ index.add(np.array(embeddings))
103
+ document_texts.extend(sentences) # Store sentences for retrieval
104
+
105
+ # Save the updated index and documents
106
+ with open(index_path, "wb") as f:
107
+ pickle.dump(index, f)
108
+ print("Saved updated FAISS index to faiss_index.pkl")
109
+ with open(document_texts_path, "wb") as f:
110
+ pickle.dump(document_texts, f)
111
+ print("Saved updated document texts to document_texts.pkl")
112
+
113
+ return "Files processed successfully"
114
  except Exception as e:
115
  print(f"Error processing files: {e}")
116
+ return f"Error processing files: {e}"
117
+
118
+ def query_text(text):
119
+ try:
120
+ print(f"Query text: {text}") # Debug: Show the query text
121
+
122
+ # Encode the query text
123
+ query_embedding = embedding_model.encode([text])
124
+ print(f"Query embedding shape: {query_embedding.shape}") # Debug: Show the shape of the query embedding
125
+
126
+ # Search the FAISS index
127
+ D, I = index.search(np.array(query_embedding), k=5)
128
+ print(f"Distances: {D}, Indices: {I}") # Debug: Show the distances and indices of the search results
129
+
130
+ top_documents = []
131
+ for idx in I[0]:
132
+ if idx != -1 and idx < len(document_texts): # Ensure that a valid index is found
133
+ top_documents.append(document_texts[idx]) # Append the actual sentences for the response
134
+ else:
135
+ print(f"Invalid index found: {idx}")
136
+ return top_documents
137
+ except Exception as e:
138
+ print(f"Error querying text: {e}")
139
+ return f"Error querying text: {e}"
140
+
141
+ # Create Gradio interface
142
+ with gr.Blocks() as demo:
143
+ gr.Markdown("## Document Upload and Query System")
144
+
145
+ with gr.Tab("Upload Files"):
146
+ upload = gr.File(file_count="multiple", label="Upload PDF or DOCX files")
147
+ upload_button = gr.Button("Upload")
148
+ upload_output = gr.Textbox()
149
+ upload_button.click(fn=upload_files, inputs=upload, outputs=upload_output)
150
+
151
+ with gr.Tab("Query"):
152
+ query = gr.Textbox(label="Enter your query")
153
+ query_button = gr.Button("Search")
154
+ query_output = gr.Textbox()
155
+ query_button.click(fn=query_text, inputs=query, outputs=query_output)
156
+
157
+ demo.launch()