NaimaAqeel commited on
Commit
80e4cb4
·
verified ·
1 Parent(s): a37ef5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -86
app.py CHANGED
@@ -1,20 +1,43 @@
1
  import os
2
  import gradio as gr
 
3
  import fitz # PyMuPDF for PDF text extraction
4
- from docx import Document # python-docx for DOCX text extraction
5
  from sentence_transformers import SentenceTransformer
 
 
6
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
7
  from nltk.tokenize import sent_tokenize
8
  import torch
9
  import pickle
10
- import nltk
11
- import faiss
12
- import numpy as np
13
 
14
- # Download NLTK punkt tokenizer data if not already downloaded
15
- nltk.download('punkt', quiet=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- # Function to extract text from a PDF file
18
  def extract_text_from_pdf(pdf_path):
19
  text = ""
20
  try:
@@ -23,85 +46,54 @@ def extract_text_from_pdf(pdf_path):
23
  page = doc.load_page(page_num)
24
  text += page.get_text()
25
  except Exception as e:
26
- print(f"Error extracting text from PDF: {e}")
27
  return text
28
 
29
- # Function to extract text from a Word document
30
  def extract_text_from_docx(docx_path):
31
  text = ""
32
  try:
33
  doc = Document(docx_path)
34
  text = "\n".join([para.text for para in doc.paragraphs])
35
  except Exception as e:
36
- print(f"Error extracting text from DOCX: {e}")
37
  return text
38
 
39
- # Initialize the SentenceTransformer model for embeddings
40
- embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
41
-
42
- # Initialize the HuggingFaceEmbeddings for LangChain
43
- # Since we're not using it directly for index, initialization may be skipped here
44
-
45
- # Initialize the FAISS index
46
- class FAISSIndex:
47
- def __init__(self, dimension):
48
- self.dimension = dimension
49
- self.index = faiss.IndexFlatL2(dimension)
50
-
51
- def add_sentences(self, sentences, embeddings):
52
- # Ensure embeddings are numpy arrays
53
- embeddings = np.array(embeddings)
54
-
55
- # Check if embeddings and sentences have the same length
56
- assert len(embeddings) == len(sentences), "Number of embeddings should match number of sentences"
57
-
58
- # Add each sentence embedding to the index
59
- for emb in embeddings:
60
- self.index.add(np.expand_dims(emb, axis=0))
61
-
62
- def similarity_search(self, query_embedding, k=5):
63
- # Search for similar embeddings in the index
64
- D, I = self.index.search(query_embedding, k)
65
- return [{"text": str(i), "score": float(d)} for i, d in zip(I[0], D[0])]
66
-
67
- # Initialize the FAISS index instance
68
- index_dimension = 512 # Dimensionality of SentenceTransformer embeddings
69
- faiss_index = FAISSIndex(index_dimension)
70
-
71
  def preprocess_text(text):
72
  sentences = sent_tokenize(text)
73
  return sentences
74
 
75
  def upload_files(files):
76
  try:
 
 
77
  for file in files:
78
- file_path = file.name # Assuming `file` is a Gradio File object
79
-
80
- if file_path.endswith('.pdf'):
81
- text = extract_text_from_pdf(file_path)
82
- elif file_path.endswith('.docx'):
83
- text = extract_text_from_docx(file_path)
84
- else:
85
- return {"error": f"Unsupported file format: {file_path}"}
86
-
87
- # Preprocess text
88
- sentences = preprocess_text(text)
89
-
90
- # Encode sentences
91
- embeddings = embedding_model.encode(sentences)
92
-
93
- # Add sentences to FAISS index
94
- faiss_index.add_sentences(sentences, embeddings)
95
-
96
- # Save the updated index
97
- with open("faiss_index.pkl", "wb") as f:
98
  pickle.dump(faiss_index, f)
99
 
100
  return {"message": "Files processed successfully"}
 
101
  except Exception as e:
102
- print(f"Error processing files: {e}")
103
- return {"error": str(e)} # Provide informative error message
104
-
105
 
106
  def process_and_query(state, files, question):
107
  if files:
@@ -110,29 +102,9 @@ def process_and_query(state, files, question):
110
  return upload_result
111
 
112
  if question:
113
- # Preprocess the question
114
  question_embedding = embedding_model.encode([question])
115
 
116
- # Search the FAISS index for similar passages
117
- retrieved_results = faiss_index.similarity_search(question_embedding, k=5) # Retrieve top 5 passages
118
- retrieved_passages = [result['text'] for result in retrieved_results]
119
-
120
- # Initialize RAG generator model
121
- generator_model_name = "facebook/bart-base"
122
- generator = AutoModelForSeq2SeqLM.from_pretrained(generator_model_name)
123
- generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name)
124
-
125
- # Use generator model to generate response based on question and retrieved passages
126
- combined_input = question + " ".join(retrieved_passages)
127
- inputs = generator_tokenizer(combined_input, return_tensors="pt")
128
- with torch.no_grad():
129
- generator_outputs = generator.generate(**inputs)
130
- generated_text = generator_tokenizer.decode(generator_outputs[0], skip_special_tokens=True)
131
-
132
- # Update conversation history
133
- state["conversation"].append({"question": question, "answer": generated_text})
134
-
135
- return {"message": generated_text, "conversation": state["conversation"]}
136
 
137
  return {"error": "No question provided"}
138
 
 
1
  import os
2
  import gradio as gr
3
+ from docx import Document
4
  import fitz # PyMuPDF for PDF text extraction
 
5
  from sentence_transformers import SentenceTransformer
6
+ from langchain_community.vectorstores import FAISS
7
+ from langchain_community.embeddings import HuggingFaceEmbeddings
8
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
9
  from nltk.tokenize import sent_tokenize
10
  import torch
11
  import pickle
 
 
 
12
 
13
+ # Initialize the embedding model
14
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
15
+
16
+ # Hugging Face API token
17
+ api_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
18
+ if not api_token:
19
+ raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable is not set")
20
+
21
+ # Define RAG models
22
+ generator_model_name = "facebook/bart-base"
23
+ retriever_model_name = "facebook/bart-base" # Can be the same as generator
24
+ generator = AutoModelForSeq2SeqLM.from_pretrained(generator_model_name)
25
+ generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name)
26
+ retriever = AutoModelForSeq2SeqLM.from_pretrained(retriever_model_name)
27
+ retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_model_name)
28
+
29
+ # Initialize FAISS index using LangChain
30
+ hf_embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
31
+
32
+ # Load or create FAISS index
33
+ index_path = "faiss_index.pkl"
34
+ if os.path.exists(index_path):
35
+ with open(index_path, "rb") as f:
36
+ faiss_index = pickle.load(f)
37
+ print("Loaded FAISS index from faiss_index.pkl")
38
+ else:
39
+ faiss_index = FAISS()
40
 
 
41
  def extract_text_from_pdf(pdf_path):
42
  text = ""
43
  try:
 
46
  page = doc.load_page(page_num)
47
  text += page.get_text()
48
  except Exception as e:
49
+ raise RuntimeError(f"Error extracting text from PDF '{pdf_path}': {e}")
50
  return text
51
 
 
52
  def extract_text_from_docx(docx_path):
53
  text = ""
54
  try:
55
  doc = Document(docx_path)
56
  text = "\n".join([para.text for para in doc.paragraphs])
57
  except Exception as e:
58
+ raise RuntimeError(f"Error extracting text from DOCX '{docx_path}': {e}")
59
  return text
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def preprocess_text(text):
62
  sentences = sent_tokenize(text)
63
  return sentences
64
 
65
  def upload_files(files):
66
  try:
67
+ global faiss_index
68
+
69
  for file in files:
70
+ try:
71
+ file_path = file.name
72
+ if file_path.endswith('.pdf'):
73
+ text = extract_text_from_pdf(file_path)
74
+ elif file_path.endswith('.docx'):
75
+ text = extract_text_from_docx(file_path)
76
+ else:
77
+ return {"error": f"Unsupported file format: {file_path}"}
78
+
79
+ sentences = preprocess_text(text)
80
+ embeddings = embedding_model.encode(sentences)
81
+
82
+ for sentence, embedding in zip(sentences, embeddings):
83
+ faiss_index.add_sentence(sentence, embedding)
84
+
85
+ except Exception as e:
86
+ print(f"Error processing file '{file.name}': {e}")
87
+ return {"error": str(e)}
88
+
89
+ with open(index_path, "wb") as f:
90
  pickle.dump(faiss_index, f)
91
 
92
  return {"message": "Files processed successfully"}
93
+
94
  except Exception as e:
95
+ print(f"General error processing files: {e}")
96
+ return {"error": str(e)}
 
97
 
98
  def process_and_query(state, files, question):
99
  if files:
 
102
  return upload_result
103
 
104
  if question:
 
105
  question_embedding = embedding_model.encode([question])
106
 
107
+ # Perform FAISS search and generate response as before
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  return {"error": "No question provided"}
110