NaimaAqeel commited on
Commit
834c71a
·
verified ·
1 Parent(s): 0b59402

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -107
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import os
2
- import io
 
3
  import PyPDF2
4
- import gradio as gr
5
  from docx import Document
6
  from sentence_transformers import SentenceTransformer
7
  from langchain_community.vectorstores import FAISS
@@ -9,16 +10,32 @@ from langchain_community.embeddings import HuggingFaceEmbeddings
9
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
10
  from nltk.tokenize import sent_tokenize
11
  import torch
 
12
  import pickle
13
  import nltk
14
- import faiss
15
- import numpy as np
16
 
17
- # Ensure NLTK resources are downloaded
18
- try:
19
- nltk.data.find('tokenizers/punkt')
20
- except LookupError:
21
- nltk.download('punkt')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  # Initialize the embedding model
24
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
@@ -40,133 +57,69 @@ retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_model_name)
40
  hf_embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
41
 
42
  # Load or create FAISS index
43
- index_path = "faiss_index.index"
44
  if os.path.exists(index_path):
45
- faiss_index = faiss.read_index(index_path)
46
- print("Loaded FAISS index from faiss_index.index")
 
47
  else:
48
- # Create a new FAISS index
49
- d = embedding_model.get_sentence_embedding_dimension() # Dimension of the embeddings
50
- faiss_index = faiss.IndexFlatL2(d) # Using IndexFlatL2 for simplicity
51
-
52
- state = {
53
- "conversation": [],
54
- "sentences": []
55
- }
56
-
57
- def extract_text_from_pdf(file):
58
- text = ""
59
- try:
60
- pdf_data = file.read()
61
- pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_data))
62
- pdf_pages = pdf_reader.pages
63
- text = "\n\n".join(page.extract_text() for page in pdf_pages)
64
- except Exception as e:
65
- raise RuntimeError(f"Error extracting text from PDF: {e}")
66
- return text
67
-
68
- def extract_text_from_docx(file):
69
- text = ""
70
- try:
71
- doc = Document(file)
72
- text = "\n".join([para.text for para in doc.paragraphs])
73
- except Exception as e:
74
- raise RuntimeError(f"Error extracting text from DOCX: {e}")
75
- return text
76
 
77
  def preprocess_text(text):
78
  sentences = sent_tokenize(text)
79
  return sentences
80
 
81
  def upload_files(files):
82
- global state, faiss_index
83
  try:
84
  for file in files:
 
85
  if file.name.endswith('.pdf'):
86
- text = extract_text_from_pdf(file)
87
  elif file.name.endswith('.docx'):
88
- text = extract_text_from_docx(file)
89
  else:
90
- return {"error": f"Unsupported file format: {file.name}"}
91
 
 
92
  sentences = preprocess_text(text)
 
 
93
  embeddings = embedding_model.encode(sentences)
94
-
95
- faiss_index.add(np.array(embeddings).astype(np.float32)) # Add embeddings
96
- state["sentences"].extend(sentences)
97
 
98
  # Save the updated index
99
- faiss.write_index(faiss_index, index_path)
 
100
 
101
  return {"message": "Files processed successfully"}
102
-
103
  except Exception as e:
104
  print(f"Error processing files: {e}")
105
- return {"error": str(e)}
106
-
107
- def process_and_query(question):
108
- global state, faiss_index
109
- if not question:
110
- return {"error": "No question provided"}
111
 
112
- try:
 
 
113
  question_embedding = embedding_model.encode([question])
114
 
115
- # Perform FAISS search
116
- D, I = faiss_index.search(np.array(question_embedding).astype(np.float32), k=5)
117
- retrieved_results = [state["sentences"][i] for i in I[0] if i != -1] # Ensure valid indices
118
 
119
- # Generate response based on retrieved results
120
- context = " ".join(retrieved_results)
121
-
122
- # Enhanced prompt template
123
  prompt_template = """
124
  Answer the question as detailed as possible from the provided context,
125
  make sure to provide all the details, if the answer is not in
126
  provided context just say, "answer is not available in the context",
127
  don't provide the wrong answer
128
 
129
- Context:
130
- {context}
131
-
132
- Question:
133
- {question}
134
-
135
- Answer:
136
- --------------------------------------------------
137
- Prompt Suggestions:
138
- 1. Summarize the primary theme of the context.
139
- 2. Elaborate on the crucial concepts highlighted in the context.
140
- 3. Pinpoint any supporting details or examples pertinent to the question.
141
- 4. Examine any recurring themes or patterns relevant to the question within the context.
142
- 5. Contrast differing viewpoints or elements mentioned in the context.
143
- 6. Explore the potential implications or outcomes of the information provided.
144
- 7. Assess the trustworthiness and validity of the information given.
145
- 8. Propose recommendations or advice based on the presented information.
146
- 9. Forecast likely future events or results stemming from the context.
147
- 10. Expand on the context or background information pertinent to the question.
148
- 11. Define any specialized terms or technical language used within the context.
149
- 12. Analyze any visual representations like charts or graphs in the context.
150
- 13. Highlight any restrictions or important considerations when responding to the question.
151
- 14. Examine any presuppositions or biases evident within the context.
152
- 15. Present alternate interpretations or viewpoints regarding the information provided.
153
- 16. Reflect on any moral or ethical issues raised by the context.
154
- 17. Investigate any cause-and-effect relationships identified in the context.
155
- 18. Uncover any questions or areas requiring further exploration.
156
- 19. Resolve any vague or conflicting information in the context.
157
- 20. Cite case studies or examples that demonstrate the concepts discussed in the context.
158
- --------------------------------------------------
159
- Context:
160
- {context}
161
-
162
- Question:
163
- {question}
164
-
165
  Answer:
166
  """
167
-
168
- combined_input = prompt_template.format(context=context, question=question)
169
- inputs = generator_tokenizer(combined_input, return_tensors="pt", max_length=512, truncation=True)
170
  with torch.no_grad():
171
  generator_outputs = generator.generate(**inputs)
172
  generated_text = generator_tokenizer.decode(generator_outputs[0], skip_special_tokens=True)
@@ -176,9 +129,7 @@ def process_and_query(question):
176
 
177
  return {"message": generated_text, "conversation": state["conversation"]}
178
 
179
- except Exception as e:
180
- print(f"Error processing query: {e}")
181
- return {"error": str(e)}
182
 
183
  # Create Gradio interface
184
  with gr.Blocks() as demo:
@@ -188,12 +139,13 @@ with gr.Blocks() as demo:
188
  upload = gr.File(file_count="multiple", label="Upload PDF or DOCX files")
189
  upload_button = gr.Button("Upload")
190
  upload_output = gr.Textbox()
191
- upload_button.click(fn=upload_files, inputs=[upload], outputs=upload_output)
192
 
193
  with gr.Tab("Query"):
 
194
  query = gr.Textbox(label="Enter your query")
195
  query_button = gr.Button("Search")
196
  query_output = gr.Textbox()
197
- query_button.click(fn=process_and_query, inputs=[query], outputs=query_output)
198
 
199
  demo.launch()
 
1
  import os
2
+ import faiss
3
+ import numpy as np
4
  import PyPDF2
5
+ import io
6
  from docx import Document
7
  from sentence_transformers import SentenceTransformer
8
  from langchain_community.vectorstores import FAISS
 
10
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
11
  from nltk.tokenize import sent_tokenize
12
  import torch
13
+ import gradio as gr
14
  import pickle
15
  import nltk
 
 
16
 
17
+ nltk.download('punkt')
18
+
19
+ # Function to extract text from a PDF file
20
+ def extract_text_from_pdf(pdf_file):
21
+ text = ""
22
+ try:
23
+ pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_file))
24
+ for page in pdf_reader.pages:
25
+ text += page.extract_text()
26
+ except Exception as e:
27
+ print(f"Error extracting text from PDF: {e}")
28
+ return text
29
+
30
+ # Function to extract text from a Word document
31
+ def extract_text_from_docx(docx_file):
32
+ text = ""
33
+ try:
34
+ doc = Document(io.BytesIO(docx_file))
35
+ text = "\n".join([para.text for para in doc.paragraphs])
36
+ except Exception as e:
37
+ print(f"Error extracting text from DOCX: {e}")
38
+ return text
39
 
40
  # Initialize the embedding model
41
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
 
57
  hf_embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
58
 
59
  # Load or create FAISS index
60
+ index_path = "faiss_index.pkl"
61
  if os.path.exists(index_path):
62
+ with open(index_path, "rb") as f:
63
+ faiss_index = pickle.load(f)
64
+ print("Loaded FAISS index from faiss_index.pkl")
65
  else:
66
+ faiss_index = FAISS(embedding_function=hf_embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  def preprocess_text(text):
69
  sentences = sent_tokenize(text)
70
  return sentences
71
 
72
  def upload_files(files):
73
+ global faiss_index
74
  try:
75
  for file in files:
76
+ file_data = file.read()
77
  if file.name.endswith('.pdf'):
78
+ text = extract_text_from_pdf(file_data)
79
  elif file.name.endswith('.docx'):
80
+ text = extract_text_from_docx(file_data)
81
  else:
82
+ return {"error": "Unsupported file format"}
83
 
84
+ # Preprocess text
85
  sentences = preprocess_text(text)
86
+
87
+ # Encode sentences and add to FAISS index
88
  embeddings = embedding_model.encode(sentences)
89
+ for embedding in embeddings:
90
+ faiss_index.add(np.expand_dims(embedding, axis=0))
 
91
 
92
  # Save the updated index
93
+ with open(index_path, "wb") as f:
94
+ pickle.dump(faiss_index, f)
95
 
96
  return {"message": "Files processed successfully"}
 
97
  except Exception as e:
98
  print(f"Error processing files: {e}")
99
+ return {"error": str(e)} # Provide informative error message
 
 
 
 
 
100
 
101
+ def process_and_query(state, question):
102
+ if question:
103
+ # Preprocess the question
104
  question_embedding = embedding_model.encode([question])
105
 
106
+ # Search the FAISS index for similar passages
107
+ D, I = faiss_index.search(np.array(question_embedding), k=5)
108
+ retrieved_passages = [faiss_index.index_to_text(i) for i in I[0]]
109
 
110
+ # Use generator model to generate response based on question and retrieved passages
 
 
 
111
  prompt_template = """
112
  Answer the question as detailed as possible from the provided context,
113
  make sure to provide all the details, if the answer is not in
114
  provided context just say, "answer is not available in the context",
115
  don't provide the wrong answer
116
 
117
+ Context:\n{context}\n
118
+ Question:\n{question}\n
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  Answer:
120
  """
121
+ combined_input = prompt_template.format(context=' '.join(retrieved_passages), question=question)
122
+ inputs = generator_tokenizer(combined_input, return_tensors="pt")
 
123
  with torch.no_grad():
124
  generator_outputs = generator.generate(**inputs)
125
  generated_text = generator_tokenizer.decode(generator_outputs[0], skip_special_tokens=True)
 
129
 
130
  return {"message": generated_text, "conversation": state["conversation"]}
131
 
132
+ return {"error": "No question provided"}
 
 
133
 
134
  # Create Gradio interface
135
  with gr.Blocks() as demo:
 
139
  upload = gr.File(file_count="multiple", label="Upload PDF or DOCX files")
140
  upload_button = gr.Button("Upload")
141
  upload_output = gr.Textbox()
142
+ upload_button.click(fn=upload_files, inputs=upload, outputs=upload_output)
143
 
144
  with gr.Tab("Query"):
145
+ state = gr.State(initial_value={"conversation": []})
146
  query = gr.Textbox(label="Enter your query")
147
  query_button = gr.Button("Search")
148
  query_output = gr.Textbox()
149
+ query_button.click(fn=process_and_query, inputs=[state, query], outputs=query_output)
150
 
151
  demo.launch()