NaimaAqeel commited on
Commit
be68f20
·
verified ·
1 Parent(s): 9fac60a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -128
app.py CHANGED
@@ -2,168 +2,124 @@ import os
2
  import fitz
3
  from docx import Document
4
  from sentence_transformers import SentenceTransformer
5
- import faiss
6
- import numpy as np
7
- import pickle
8
- import gradio as gr
9
- from langchain_community.llms import HuggingFaceEndpoint
10
  from langchain_community.vectorstores import FAISS
11
  from langchain_community.embeddings import HuggingFaceEmbeddings
 
 
12
 
13
- # Function to extract text from a PDF file
14
  def extract_text_from_pdf(pdf_path):
15
- text = ""
16
- try:
17
- doc = fitz.open(pdf_path)
18
- for page_num in range(len(doc)):
19
- page = doc.load_page(page_num)
20
- text += page.get_text()
21
- except Exception as e:
22
- print(f"Error extracting text from PDF: {e}")
23
- return text
24
 
25
- # Function to extract text from a Word document
26
  def extract_text_from_docx(docx_path):
27
- text = ""
28
- try:
29
- doc = Document(docx_path)
30
- text = "\n".join([para.text for para in doc.paragraphs])
31
- except Exception as e:
32
- print(f"Error extracting text from DOCX: {e}")
33
- return text
34
 
35
- # Initialize the embedding model
36
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
37
 
38
- # Hugging Face API token
39
  api_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
40
  if not api_token:
41
- raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable is not set or invalid")
42
-
43
- # Initialize the HuggingFace LLM
44
- llm = HuggingFaceEndpoint(
45
- endpoint_url="https://api-inference.huggingface.co/models/gpt2", # Using gpt2 model
46
- model_kwargs={"api_key": api_token}
47
- )
48
 
49
- # Initialize the HuggingFace embeddings
50
- embedding = HuggingFaceEmbeddings()
 
 
 
 
 
51
 
52
- # Load or create FAISS index
53
  index_path = "faiss_index.pkl"
54
- document_texts_path = "document_texts.pkl"
55
-
56
- document_texts = []
57
-
58
- if os.path.exists(index_path) and os.path.exists(document_texts_path):
59
- try:
60
- with open(index_path, "rb") as f:
61
- index = pickle.load(f)
62
- print("Loaded FAISS index from faiss_index.pkl")
63
- with open(document_texts_path, "rb") as f:
64
- document_texts = pickle.load(f)
65
- print("Loaded document texts from document_texts.pkl")
66
- except Exception as e:
67
- print(f"Error loading FAISS index or document texts: {e}")
68
  else:
69
  # Create a new FAISS index if it doesn't exist
70
- index = faiss.IndexFlatL2(embedding_model.get_sentence_embedding_dimension())
71
  with open(index_path, "wb") as f:
72
- pickle.dump(index, f)
73
  print("Created new FAISS index and saved to faiss_index.pkl")
74
 
75
  def preprocess_text(text):
76
- # Add more preprocessing steps if necessary
77
- return text.strip()
78
 
79
- def upload_files(files):
80
- global index, document_texts
81
  try:
82
- for file in files:
83
- file_path = file.name # Get the file path from the NamedString object
84
  if file_path.endswith('.pdf'):
85
  text = extract_text_from_pdf(file_path)
86
  elif file_path.endswith('.docx'):
87
  text = extract_text_from_docx(file_path)
88
  else:
89
- return "Unsupported file format"
90
 
91
- print(f"Extracted text: {text[:100]}...") # Debug: Show the first 100 characters of the extracted text
 
92
 
93
- # Process the text and update FAISS index
94
- sentences = text.split("\n")
95
- sentences = [preprocess_text(sentence) for sentence in sentences if sentence.strip()]
96
  embeddings = embedding_model.encode(sentences)
97
- print(f"Embeddings shape: {embeddings.shape}") # Debug: Show the shape of the embeddings
98
- index.add(np.array(embeddings))
99
- document_texts.extend(sentences) # Store sentences for retrieval
100
-
101
- # Save the updated index and documents
102
- with open(index_path, "wb") as f:
103
- pickle.dump(index, f)
104
- print("Saved updated FAISS index to faiss_index.pkl")
105
- with open(document_texts_path, "wb") as f:
106
- pickle.dump(document_texts, f)
107
- print("Saved updated document texts to document_texts.pkl")
108
-
109
- return "Files processed successfully"
110
  except Exception as e:
111
  print(f"Error processing files: {e}")
112
- return f"Error processing files: {e}"
113
 
114
- # Improved prompt template
115
- prompt_template = """
116
- You are a helpful assistant. Use the provided context to answer the question accurately.
117
- If the answer is not in the context, say "answer is not available in the context".
118
- Do not provide false information.
119
 
120
- Context:
121
- {context}
 
 
 
122
 
123
- Question:
124
- {question}
 
 
125
 
126
- Answer:
127
- """
 
 
 
128
 
129
- def query_text(text):
130
- try:
131
- print(f"Query text: {text}") # Debug: Show the query text
132
-
133
- # Encode the query text
134
- query_embedding = embedding_model.encode([text])
135
- print(f"Query embedding shape: {query_embedding.shape}") # Debug: Show the shape of the query embedding
136
-
137
- # Search the FAISS index
138
- D, I = index.search(np.array(query_embedding), k=5)
139
- print(f"Distances: {D}, Indices: {I}") # Debug: Show the distances and indices of the search results
140
-
141
- top_documents = []
142
- for idx in I[0]:
143
- if idx != -1 and idx < len(document_texts):
144
- # Get a passage around the retrieved sentence (e.g., paragraph)
145
- passage_start = max(0, idx - 5) # Adjust window size as needed
146
- passage_end = min(len(document_texts), idx + 5)
147
- passage = "\n".join(document_texts[passage_start:passage_end])
148
- top_documents.append(passage)
149
- else:
150
- print(f"Invalid index found: {idx}")
151
-
152
- # Remove duplicates and sort by relevance
153
- top_documents = list(dict.fromkeys(top_documents))
154
-
155
- # Join the top documents for the context
156
- context = "\n".join(top_documents)
157
-
158
- # Prepare the prompt
159
- prompt = prompt_template.format(context=context, question=text)
160
-
161
- # Query the LLM
162
- response = llm(prompt)
163
- return response
164
- except Exception as e:
165
- print(f"Error querying text: {e}")
166
- return f"Error querying text: {e}"
167
 
168
  # Create Gradio interface
169
  with gr.Blocks() as demo:
@@ -183,6 +139,3 @@ with gr.Blocks() as demo:
183
 
184
  demo.launch()
185
 
186
-
187
-
188
-
 
2
  import fitz
3
  from docx import Document
4
  from sentence_transformers import SentenceTransformer
5
+ from langchain_community.llms import HuggingFaceEndpoint # Might need update (optional)
 
 
 
 
6
  from langchain_community.vectorstores import FAISS
7
  from langchain_community.embeddings import HuggingFaceEmbeddings
8
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
9
+ from nltk.tokenize import sent_tokenize # Import for sentence segmentation
10
 
11
+ # Function to extract text from a PDF file (same as before)
12
  def extract_text_from_pdf(pdf_path):
13
+ # ... (implementation)
 
 
 
 
 
 
 
 
14
 
15
+ # Function to extract text from a Word document (same as before)
16
  def extract_text_from_docx(docx_path):
17
+ # ... (implementation)
 
 
 
 
 
 
18
 
19
+ # Initialize the embedding model (same as before)
20
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
21
 
22
+ # Hugging Face API token (same as before)
23
  api_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
24
  if not api_token:
25
+ raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable is not set")
 
 
 
 
 
 
26
 
27
+ # Define RAG models (same as before)
28
+ generator_model_name = "facebook/bart-base"
29
+ retriever_model_name = "facebook/bart-base" # Can be the same as generator
30
+ generator = AutoModelForSeq2SeqLM.from_pretrained(generator_model_name)
31
+ generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name)
32
+ retriever = AutoModelForSeq2SeqLM.from_pretrained(retriever_model_name)
33
+ retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_model_name)
34
 
35
+ # Load or create FAISS index (using LangChain)
36
  index_path = "faiss_index.pkl"
37
+ if os.path.exists(index_path):
38
+ with open(index_path, "rb") as f:
39
+ index = FAISS.load(f)
40
+ print("Loaded FAISS index from faiss_index.pkl")
 
 
 
 
 
 
 
 
 
 
41
  else:
42
  # Create a new FAISS index if it doesn't exist
43
+ index = FAISS(embedding_dimension=embedding_model.get_sentence_embedding_dimension())
44
  with open(index_path, "wb") as f:
45
+ FAISS.save(index, f)
46
  print("Created new FAISS index and saved to faiss_index.pkl")
47
 
48
  def preprocess_text(text):
49
+ sentences = sent_tokenize(text)
50
+ return sentences
51
 
52
+ def upload_files(state, files):
53
+ global index
54
  try:
55
+ for file_path in files:
 
56
  if file_path.endswith('.pdf'):
57
  text = extract_text_from_pdf(file_path)
58
  elif file_path.endswith('.docx'):
59
  text = extract_text_from_docx(file_path)
60
  else:
61
+ return {"error": "Unsupported file format"}
62
 
63
+ # Preprocess text (call the new function)
64
+ sentences = preprocess_text(text)
65
 
66
+ # Encode sentences and add to FAISS index
 
 
67
  embeddings = embedding_model.encode(sentences)
68
+ index.add(embeddings)
69
+
70
+ return {"message": "Files processed successfully"}
 
 
 
 
 
 
 
 
 
 
71
  except Exception as e:
72
  print(f"Error processing files: {e}")
73
+ return {"error": "Error processing files"} # Provide informative error message
74
 
75
+ def process_and_query(state, files, question):
76
+ # State management for conversation history (similar to previous example)
77
+ # ...
 
 
78
 
79
+ # Handle file upload (using upload_files function)
80
+ if files:
81
+ upload_result = upload_files(state, files)
82
+ if "error" in upload_result:
83
+ return upload_result # Return error message from upload_files if any
84
 
85
+ # Handle user question and generate response using RAG models if question and state.
86
+ def process_and_query(state, files, question):
87
+ # State management for conversation history (similar to previous example)
88
+ # ...
89
 
90
+ # Handle file upload (using upload_files function)
91
+ if files:
92
+ upload_result = upload_files(state, files)
93
+ if "error" in upload_result:
94
+ return upload_result # Return error message from upload_files if any
95
 
96
+ # Handle user question and generate response using RAG models
97
+ if question and state["processed_text"]:
98
+ # Preprocess the question
99
+ question_embedding = embedding_model.encode([question])
100
+
101
+ # Use retriever model to retrieve relevant passages
102
+ with torch.no_grad(): # Disable gradient calculation for efficiency
103
+ retriever_outputs = retriever(**retriever_tokenizer(question, return_tensors="pt"))
104
+ retriever_hidden_states = retriever_outputs.hidden_states[-1] # Last hidden state
105
+
106
+ # Search the FAISS index for similar passages based on retrieved hidden states
107
+ distances, retrieved_ids = index.search(retriever_hidden_states.cpu().numpy(), k=5) # Retrieve top 5 passages
108
+
109
+ # Get the retrieved passages from the document text
110
+ retrieved_passages = [state["processed_text"].split("\n")[i] for i in retrieved_ids.flatten()]
111
+
112
+ # Use generator model to generate response based on question and retrieved passages
113
+ # Combine question embedding with retrieved passages (consider weighting or attention mechanism)
114
+ combined_input = torch.cat([question_embedding, embedding_model.encode(retrieved_passages)], dim=0)
115
+ with torch.no_grad():
116
+ generator_outputs = generator(**generator_tokenizer(combined_input, return_tensors="pt"))
117
+ generated_text = generator_tokenizer.decode(generator_outputs.sequences.squeeze())
118
+
119
+ # Update conversation history
120
+ state["conversation"].append({"question": question, "answer": generated_text})
121
+
122
+ return state # Return the updated state with conversation history
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  # Create Gradio interface
125
  with gr.Blocks() as demo:
 
139
 
140
  demo.launch()
141