NaimaAqeel commited on
Commit
f812db9
·
verified ·
1 Parent(s): 70fd172

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -93
app.py CHANGED
@@ -10,135 +10,105 @@ from typing import List
10
  from langchain_community.llms import HuggingFaceEndpoint
11
  from langchain_community.vectorstores import FAISS
12
  from langchain_community.embeddings import HuggingFaceEmbeddings
 
 
13
 
14
- # Function to extract text from a PDF file
15
  def extract_text_from_pdf(pdf_path):
16
- text = ""
17
- try:
18
- doc = fitz.open(pdf_path)
19
- for page_num in range(len(doc)):
20
- page = doc.load_page(page_num)
21
- text += page.get_text()
22
- except Exception as e:
23
- print(f"Error extracting text from PDF: {e}")
24
- return text
25
 
26
- # Function to extract text from a Word document
27
  def extract_text_from_docx(docx_path):
28
- text = ""
29
- try:
30
- doc = Document(docx_path)
31
- text = "\n".join([para.text for para in doc.paragraphs])
32
- except Exception as e:
33
- print(f"Error extracting text from DOCX: {e}")
34
- return text
35
 
36
- # Initialize the embedding model
37
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
38
 
39
- # Hugging Face API token
 
40
  api_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
41
  if not api_token:
42
  raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable is not set")
43
 
44
- # Initialize the HuggingFace LLM
45
- llm = HuggingFaceEndpoint(
46
- endpoint_url="https://api-inference.huggingface.co/models/gpt2",
47
- model_kwargs={"api_key": api_token}
48
- )
49
 
50
- # Initialize the HuggingFace embeddings
51
- embedding = HuggingFaceEmbeddings()
 
52
 
53
- # Load or create FAISS index
 
 
 
 
 
 
 
54
  index_path = "faiss_index.pkl"
55
  document_texts_path = "document_texts.pkl"
56
-
57
  document_texts = []
 
 
 
 
 
58
 
59
- if os.path.exists(index_path) and os.path.exists(document_texts_path):
60
- try:
61
- with open(index_path, "rb") as f:
62
- index = pickle.load(f)
63
- print("Loaded FAISS index from faiss_index.pkl")
64
- with open(document_texts_path, "rb") as f:
65
- document_texts = pickle.load(f)
66
- print("Loaded document texts from document_texts.pkl")
67
- except Exception as e:
68
- print(f"Error loading FAISS index or document texts: {e}")
69
- else:
70
- # Create a new FAISS index if it doesn't exist
71
- index = faiss.IndexFlatL2(embedding_model.get_sentence_embedding_dimension())
72
- with open(index_path, "wb") as f:
73
- pickle.dump(index, f)
74
- print("Created new FAISS index and saved to faiss_index.pkl")
75
 
76
  def upload_files(files):
77
  global index, document_texts
78
  try:
79
  for file_path in files:
80
- if file_path.endswith('.pdf'):
81
- text = extract_text_from_pdf(file_path)
82
- elif file_path.endswith('.docx'):
83
- text = extract_text_from_docx(file_path)
84
- else:
85
- return "Unsupported file format"
86
-
87
- # Process the text and update FAISS index
88
- sentences = text.split("\n")
89
  embeddings = embedding_model.encode(sentences)
90
  index.add(np.array(embeddings))
91
- document_texts.append(text)
92
-
93
- # Save the updated index and documents
94
- with open(index_path, "wb") as f:
95
- pickle.dump(index, f)
96
- print("Saved updated FAISS index to faiss_index.pkl")
97
- with open(document_texts_path, "wb") as f:
98
- pickle.dump(document_texts, f)
99
- print("Saved updated document texts to document_texts.pkl")
100
-
101
  return "Files processed successfully"
102
  except Exception as e:
103
  print(f"Error processing files: {e}")
104
  return f"Error processing files: {e}"
105
 
 
106
  def query_text(text):
107
  try:
108
- # Encode the query text
109
- query_embedding = embedding_model.encode([text])
110
-
111
- # Search the FAISS index
112
- D, I = index.search(np.array(query_embedding), k=5)
113
-
114
- top_documents = []
115
- for idx in I[0]:
116
- if idx != -1 and idx < len(document_texts): # Ensure that a valid index is found
117
- top_documents.append(document_texts[idx])
118
- else:
119
- print(f"Invalid index found: {idx}")
120
- return top_documents
 
 
 
 
 
 
 
 
 
 
121
  except Exception as e:
122
  print(f"Error querying text: {e}")
123
  return f"Error querying text: {e}"
124
 
125
- # Create Gradio interface
 
126
  with gr.Blocks() as demo:
127
- gr.Markdown("## Document Upload and Query System")
128
-
129
- with gr.Tab("Upload Files"):
130
- upload = gr.File(file_count="multiple", label="Upload PDF or DOCX files")
131
- upload_button = gr.Button("Upload")
132
- upload_output = gr.Textbox()
133
- upload_button.click(fn=upload_files, inputs=upload, outputs=upload_output)
134
-
135
- with gr.Tab("Query"):
136
- query = gr.Textbox(label="Enter your query")
137
- query_button = gr.Button("Search")
138
- query_output = gr.Textbox()
139
- query_button.click(fn=query_text, inputs=query, outputs=query_output)
140
-
141
- demo.launch()
142
 
143
 
144
 
 
10
  from langchain_community.llms import HuggingFaceEndpoint
11
  from langchain_community.vectorstores import FAISS
12
  from langchain_community.embeddings import HuggingFaceEmbeddings
13
+ from nltk.tokenize import sent_tokenize # Import for sentence segmentation
14
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
15
 
16
+ # Function to extract text from a PDF file (same as before)
17
  def extract_text_from_pdf(pdf_path):
18
+ # ...
 
 
 
 
 
 
 
 
19
 
20
+ # Function to extract text from a Word document (same as before)
21
  def extract_text_from_docx(docx_path):
22
+ # ...
 
 
 
 
 
 
23
 
24
+ # Initialize the embedding model (same as before)
25
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
26
 
27
+
28
+ # Hugging Face API token (same as before)
29
  api_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
30
  if not api_token:
31
  raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable is not set")
32
 
 
 
 
 
 
33
 
34
+ # Define RAG models (replace with your chosen models)
35
+ generator_model_name = "facebook/bart-base"
36
+ retriever_model_name = "facebook/bart-base" # Can be the same as generator
37
 
38
+ generator = AutoModelForSeq2SeqLM.from_pretrained(generator_model_name)
39
+ generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name)
40
+
41
+ retriever = AutoModelForSeq2SeqLM.from_pretrained(retriever_model_name)
42
+ retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_model_name)
43
+
44
+
45
+ # Load or create FAISS index (same as before)
46
  index_path = "faiss_index.pkl"
47
  document_texts_path = "document_texts.pkl"
 
48
  document_texts = []
49
+ # ... (rest of the FAISS index loading logic)
50
+
51
+
52
+ def preprocess_text(text):
53
+ # ... (text preprocessing logic, same as before)
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  def upload_files(files):
57
  global index, document_texts
58
  try:
59
  for file_path in files:
60
+ # ... (file processing logic, same as before)
61
+
62
+ # Preprocess text (call the new function)
63
+ sentences = preprocess_text(text)
64
+
65
+ # Encode sentences and add to FAISS index
 
 
 
66
  embeddings = embedding_model.encode(sentences)
67
  index.add(np.array(embeddings))
68
+
69
+ # Save the updated index and documents (same as before)
70
+ # ...
 
 
 
 
 
 
 
71
  return "Files processed successfully"
72
  except Exception as e:
73
  print(f"Error processing files: {e}")
74
  return f"Error processing files: {e}"
75
 
76
+
77
  def query_text(text):
78
  try:
79
+ # Preprocess query text
80
+ query_sentences = preprocess_text(text)
81
+ query_embeddings = embedding_model.encode(query_sentences)
82
+
83
+ # Retrieve relevant documents using FAISS
84
+ D, I = index.search(np.array(query_embeddings), k=5)
85
+ retrieved_docs = [document_texts[idx] for idx in I[0] if idx != -1]
86
+
87
+ # Retriever-Augmented Generation (RAG)
88
+ retriever_inputs = retriever_tokenizer(
89
+ text=retrieved_docs, return_tensors="pt", padding=True
90
+ )
91
+ retriever_outputs = retriever(**retriever_inputs)
92
+ retrieved_texts = retriever_tokenizer.batch_decode(retriever_outputs.logits)
93
+
94
+ # Generate response using retrieved information (as prompts/context)
95
+ generator_inputs = generator_tokenizer(
96
+ text=[text] + retrieved_texts, return_tensors="pt", padding=True
97
+ )
98
+ generator_outputs = generator(**generator_inputs)
99
+ response = generator_tokenizer.decode(generator_outputs.sequences[0], skip_special_tokens=True)
100
+
101
+ return response
102
  except Exception as e:
103
  print(f"Error querying text: {e}")
104
  return f"Error querying text: {e}"
105
 
106
+
107
+ # Create Gradio interface
108
  with gr.Blocks() as demo:
109
+ # ... (rest of the Gradio interface definition)
110
+ query_button.click(fn=query_text, inputs
111
+
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
 
114