NaimaAqeel commited on
Commit
03bc240
·
verified ·
1 Parent(s): bfb0254

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -48
app.py CHANGED
@@ -1,22 +1,18 @@
1
  import os
2
- import faiss
3
- import numpy as np
4
- from docx import Document
5
  from sentence_transformers import SentenceTransformer
6
  from langchain_community.vectorstores import FAISS
7
  from langchain_community.embeddings import HuggingFaceEmbeddings
8
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
9
  from nltk.tokenize import sent_tokenize
10
  import torch
11
- import gradio as gr
12
  import pickle
13
  import nltk
14
 
15
- # Download NLTK punkt resource if not already downloaded
16
- try:
17
- nltk.data.find('tokenizers/punkt')
18
- except LookupError:
19
- nltk.download('punkt')
20
 
21
  # Function to extract text from a PDF file
22
  def extract_text_from_pdf(pdf_path):
@@ -40,63 +36,46 @@ def extract_text_from_docx(docx_path):
40
  print(f"Error extracting text from DOCX: {e}")
41
  return text
42
 
43
-
44
- # Initialize the embedding model
45
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
46
 
47
- # Hugging Face API token
48
- api_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
49
- if not api_token:
50
- raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable is not set")
51
-
52
- # Define RAG models
53
- generator_model_name = "facebook/bart-base"
54
- retriever_model_name = "facebook/bart-base" # Can be the same as generator
55
- generator = AutoModelForSeq2SeqLM.from_pretrained(generator_model_name)
56
- generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name)
57
- retriever = AutoModelForSeq2SeqLM.from_pretrained(retriever_model_name)
58
- retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_model_name)
59
-
60
- # Initialize FAISS index using LangChain
61
  hf_embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
62
 
63
- # Load or create FAISS index
64
  index_path = "faiss_index.pkl"
65
  if os.path.exists(index_path):
66
  with open(index_path, "rb") as f:
67
  faiss_index = pickle.load(f)
68
  print("Loaded FAISS index from faiss_index.pkl")
69
  else:
70
- # Initialize a new FAISS index, e.g., IndexIVFFlat
71
- d = 384 # Embedding dimension
72
- nlist = 100 # Number of clusters
73
- quantizer = faiss.IndexFlatL2(d)
74
- faiss_index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_L2)
75
- faiss_index.train(np.array([])) # Optional: Train index if needed
76
- faiss_index.add(np.array([])) # Initialize index
77
 
78
  def preprocess_text(text):
79
  sentences = sent_tokenize(text)
80
  return sentences
81
 
82
  def upload_files(files):
83
- global faiss_index
84
  try:
85
  for file in files:
86
- if file.name.endswith('.pdf'):
87
- text = extract_text_from_pdf(file.name)
88
- elif file.name.endswith('.docx'):
89
- text = extract_text_from_docx(file.name)
 
 
 
90
  else:
91
- return {"error": "Unsupported file format"}
92
 
93
  # Preprocess text
94
  sentences = preprocess_text(text)
95
 
96
  # Encode sentences and add to FAISS index
97
  embeddings = embedding_model.encode(sentences)
98
- for embedding in embeddings:
99
- faiss_index.add(np.expand_dims(embedding, axis=0))
100
 
101
  # Save the updated index
102
  with open(index_path, "wb") as f:
@@ -113,15 +92,18 @@ def process_and_query(state, files, question):
113
  if "error" in upload_result:
114
  return upload_result
115
 
116
- if question and question.strip(): # Check if question is not empty
117
  # Preprocess the question
118
  question_embedding = embedding_model.encode([question])
119
 
120
  # Search the FAISS index for similar passages
121
- D, I = faiss_index.search(np.array([question_embedding]), 5) # Retrieve top 5 passages
122
- retrieved_passages = []
123
- for i in I[0]:
124
- retrieved_passages.append(faiss_index.reconstruct(i).decode('utf-8'))
 
 
 
125
 
126
  # Use generator model to generate response based on question and retrieved passages
127
  combined_input = question + " ".join(retrieved_passages)
@@ -151,6 +133,7 @@ with gr.Blocks() as demo:
151
  query = gr.Textbox(label="Enter your query")
152
  query_button = gr.Button("Search")
153
  query_output = gr.Textbox()
154
- query_button.click(fn=process_and_query, inputs=[upload_output, query], outputs=query_output)
 
 
155
 
156
- demo.launch()
 
1
  import os
2
+ import gradio as gr
3
+ import fitz # PyMuPDF for PDF text extraction
4
+ from docx import Document # python-docx for DOCX text extraction
5
  from sentence_transformers import SentenceTransformer
6
  from langchain_community.vectorstores import FAISS
7
  from langchain_community.embeddings import HuggingFaceEmbeddings
8
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
9
  from nltk.tokenize import sent_tokenize
10
  import torch
 
11
  import pickle
12
  import nltk
13
 
14
+ # Download NLTK punkt tokenizer data if not already downloaded
15
+ nltk.download('punkt', quiet=True)
 
 
 
16
 
17
  # Function to extract text from a PDF file
18
  def extract_text_from_pdf(pdf_path):
 
36
  print(f"Error extracting text from DOCX: {e}")
37
  return text
38
 
39
+ # Initialize the SentenceTransformer model for embeddings
 
40
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
41
 
42
+ # Initialize the HuggingFaceEmbeddings for LangChain
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  hf_embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
44
 
45
+ # Initialize the FAISS index
46
  index_path = "faiss_index.pkl"
47
  if os.path.exists(index_path):
48
  with open(index_path, "rb") as f:
49
  faiss_index = pickle.load(f)
50
  print("Loaded FAISS index from faiss_index.pkl")
51
  else:
52
+ # Initialize FAISS index using LangChain
53
+ faiss_index = FAISS(embedding_function=hf_embeddings)
 
 
 
 
 
54
 
55
  def preprocess_text(text):
56
  sentences = sent_tokenize(text)
57
  return sentences
58
 
59
  def upload_files(files):
 
60
  try:
61
  for file in files:
62
+ if isinstance(file, str): # Assuming `file` is a string (file path)
63
+ if file.endswith('.pdf'):
64
+ text = extract_text_from_pdf(file)
65
+ elif file.endswith('.docx'):
66
+ text = extract_text_from_docx(file)
67
+ else:
68
+ return {"error": "Unsupported file format"}
69
  else:
70
+ return {"error": "Invalid file format: expected a string"}
71
 
72
  # Preprocess text
73
  sentences = preprocess_text(text)
74
 
75
  # Encode sentences and add to FAISS index
76
  embeddings = embedding_model.encode(sentences)
77
+ for sentence, embedding in zip(sentences, embeddings):
78
+ faiss_index.add_sentence(sentence, embedding)
79
 
80
  # Save the updated index
81
  with open(index_path, "wb") as f:
 
92
  if "error" in upload_result:
93
  return upload_result
94
 
95
+ if question:
96
  # Preprocess the question
97
  question_embedding = embedding_model.encode([question])
98
 
99
  # Search the FAISS index for similar passages
100
+ retrieved_results = faiss_index.similarity_search(question, k=5) # Retrieve top 5 passages
101
+ retrieved_passages = [result['text'] for result in retrieved_results]
102
+
103
+ # Initialize RAG generator model
104
+ generator_model_name = "facebook/bart-base"
105
+ generator = AutoModelForSeq2SeqLM.from_pretrained(generator_model_name)
106
+ generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name)
107
 
108
  # Use generator model to generate response based on question and retrieved passages
109
  combined_input = question + " ".join(retrieved_passages)
 
133
  query = gr.Textbox(label="Enter your query")
134
  query_button = gr.Button("Search")
135
  query_output = gr.Textbox()
136
+ query_button.click(fn=process_and_query, inputs=[query], outputs=query_output)
137
+
138
+ demo.launch()
139