NaimaAqeel commited on
Commit
c4f7f00
·
verified ·
1 Parent(s): be68f20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -40
app.py CHANGED
@@ -1,30 +1,36 @@
1
  import os
2
- import fitz
3
  from docx import Document
4
  from sentence_transformers import SentenceTransformer
5
- from langchain_community.llms import HuggingFaceEndpoint # Might need update (optional)
6
  from langchain_community.vectorstores import FAISS
7
- from langchain_community.embeddings import HuggingFaceEmbeddings
8
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
9
- from nltk.tokenize import sent_tokenize # Import for sentence segmentation
 
 
10
 
11
- # Function to extract text from a PDF file (same as before)
12
  def extract_text_from_pdf(pdf_path):
13
- # ... (implementation)
 
 
 
 
14
 
15
- # Function to extract text from a Word document (same as before)
16
  def extract_text_from_docx(docx_path):
17
- # ... (implementation)
 
 
18
 
19
- # Initialize the embedding model (same as before)
20
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
21
 
22
- # Hugging Face API token (same as before)
23
  api_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
24
  if not api_token:
25
  raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable is not set")
26
 
27
- # Define RAG models (same as before)
28
  generator_model_name = "facebook/bart-base"
29
  retriever_model_name = "facebook/bart-base" # Can be the same as generator
30
  generator = AutoModelForSeq2SeqLM.from_pretrained(generator_model_name)
@@ -32,7 +38,7 @@ generator_tokenizer = AutoTokenizer.from_pretrained(generator_model_name)
32
  retriever = AutoModelForSeq2SeqLM.from_pretrained(retriever_model_name)
33
  retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_model_name)
34
 
35
- # Load or create FAISS index (using LangChain)
36
  index_path = "faiss_index.pkl"
37
  if os.path.exists(index_path):
38
  with open(index_path, "rb") as f:
@@ -49,7 +55,7 @@ def preprocess_text(text):
49
  sentences = sent_tokenize(text)
50
  return sentences
51
 
52
- def upload_files(state, files):
53
  global index
54
  try:
55
  for file_path in files:
@@ -60,7 +66,7 @@ def upload_files(state, files):
60
  else:
61
  return {"error": "Unsupported file format"}
62
 
63
- # Preprocess text (call the new function)
64
  sentences = preprocess_text(text)
65
 
66
  # Encode sentences and add to FAISS index
@@ -70,36 +76,20 @@ def upload_files(state, files):
70
  return {"message": "Files processed successfully"}
71
  except Exception as e:
72
  print(f"Error processing files: {e}")
73
- return {"error": "Error processing files"} # Provide informative error message
74
 
75
  def process_and_query(state, files, question):
76
- # State management for conversation history (similar to previous example)
77
- # ...
78
-
79
- # Handle file upload (using upload_files function)
80
- if files:
81
- upload_result = upload_files(state, files)
82
- if "error" in upload_result:
83
- return upload_result # Return error message from upload_files if any
84
-
85
- # Handle user question and generate response using RAG models if question and state.
86
- def process_and_query(state, files, question):
87
- # State management for conversation history (similar to previous example)
88
- # ...
89
-
90
- # Handle file upload (using upload_files function)
91
  if files:
92
- upload_result = upload_files(state, files)
93
  if "error" in upload_result:
94
- return upload_result # Return error message from upload_files if any
95
 
96
- # Handle user question and generate response using RAG models
97
- if question and state["processed_text"]:
98
  # Preprocess the question
99
  question_embedding = embedding_model.encode([question])
100
 
101
  # Use retriever model to retrieve relevant passages
102
- with torch.no_grad(): # Disable gradient calculation for efficiency
103
  retriever_outputs = retriever(**retriever_tokenizer(question, return_tensors="pt"))
104
  retriever_hidden_states = retriever_outputs.hidden_states[-1] # Last hidden state
105
 
@@ -110,7 +100,6 @@ def process_and_query(state, files, question):
110
  retrieved_passages = [state["processed_text"].split("\n")[i] for i in retrieved_ids.flatten()]
111
 
112
  # Use generator model to generate response based on question and retrieved passages
113
- # Combine question embedding with retrieved passages (consider weighting or attention mechanism)
114
  combined_input = torch.cat([question_embedding, embedding_model.encode(retrieved_passages)], dim=0)
115
  with torch.no_grad():
116
  generator_outputs = generator(**generator_tokenizer(combined_input, return_tensors="pt"))
@@ -119,23 +108,26 @@ def process_and_query(state, files, question):
119
  # Update conversation history
120
  state["conversation"].append({"question": question, "answer": generated_text})
121
 
122
- return state # Return the updated state with conversation history
 
 
123
 
124
  # Create Gradio interface
125
  with gr.Blocks() as demo:
126
  gr.Markdown("## Document Upload and Query System")
127
-
128
  with gr.Tab("Upload Files"):
129
  upload = gr.File(file_count="multiple", label="Upload PDF or DOCX files")
130
  upload_button = gr.Button("Upload")
131
  upload_output = gr.Textbox()
132
  upload_button.click(fn=upload_files, inputs=upload, outputs=upload_output)
133
-
134
  with gr.Tab("Query"):
135
  query = gr.Textbox(label="Enter your query")
136
  query_button = gr.Button("Search")
137
  query_output = gr.Textbox()
138
- query_button.click(fn=query_text, inputs=query, outputs=query_output)
139
 
140
  demo.launch()
141
 
 
 
1
  import os
2
+ import fitz # PyMuPDF
3
  from docx import Document
4
  from sentence_transformers import SentenceTransformer
 
5
  from langchain_community.vectorstores import FAISS
 
6
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
7
+ from nltk.tokenize import sent_tokenize
8
+ import torch
9
+ import gradio as gr
10
 
11
+ # Function to extract text from a PDF file
12
  def extract_text_from_pdf(pdf_path):
13
+ text = ""
14
+ doc = fitz.open(pdf_path)
15
+ for page in doc:
16
+ text += page.get_text()
17
+ return text
18
 
19
+ # Function to extract text from a Word document
20
  def extract_text_from_docx(docx_path):
21
+ doc = Document(docx_path)
22
+ text = "\n".join([para.text for para in doc.paragraphs])
23
+ return text
24
 
25
+ # Initialize the embedding model
26
  embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
27
 
28
+ # Hugging Face API token
29
  api_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
30
  if not api_token:
31
  raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable is not set")
32
 
33
+ # Define RAG models
34
  generator_model_name = "facebook/bart-base"
35
  retriever_model_name = "facebook/bart-base" # Can be the same as generator
36
  generator = AutoModelForSeq2SeqLM.from_pretrained(generator_model_name)
 
38
  retriever = AutoModelForSeq2SeqLM.from_pretrained(retriever_model_name)
39
  retriever_tokenizer = AutoTokenizer.from_pretrained(retriever_model_name)
40
 
41
+ # Load or create FAISS index
42
  index_path = "faiss_index.pkl"
43
  if os.path.exists(index_path):
44
  with open(index_path, "rb") as f:
 
55
  sentences = sent_tokenize(text)
56
  return sentences
57
 
58
+ def upload_files(files):
59
  global index
60
  try:
61
  for file_path in files:
 
66
  else:
67
  return {"error": "Unsupported file format"}
68
 
69
+ # Preprocess text
70
  sentences = preprocess_text(text)
71
 
72
  # Encode sentences and add to FAISS index
 
76
  return {"message": "Files processed successfully"}
77
  except Exception as e:
78
  print(f"Error processing files: {e}")
79
+ return {"error": "Error processing files"}
80
 
81
  def process_and_query(state, files, question):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  if files:
83
+ upload_result = upload_files(files)
84
  if "error" in upload_result:
85
+ return upload_result
86
 
87
+ if question:
 
88
  # Preprocess the question
89
  question_embedding = embedding_model.encode([question])
90
 
91
  # Use retriever model to retrieve relevant passages
92
+ with torch.no_grad():
93
  retriever_outputs = retriever(**retriever_tokenizer(question, return_tensors="pt"))
94
  retriever_hidden_states = retriever_outputs.hidden_states[-1] # Last hidden state
95
 
 
100
  retrieved_passages = [state["processed_text"].split("\n")[i] for i in retrieved_ids.flatten()]
101
 
102
  # Use generator model to generate response based on question and retrieved passages
 
103
  combined_input = torch.cat([question_embedding, embedding_model.encode(retrieved_passages)], dim=0)
104
  with torch.no_grad():
105
  generator_outputs = generator(**generator_tokenizer(combined_input, return_tensors="pt"))
 
108
  # Update conversation history
109
  state["conversation"].append({"question": question, "answer": generated_text})
110
 
111
+ return {"message": generated_text, "conversation": state["conversation"]}
112
+
113
+ return {"error": "No question provided"}
114
 
115
  # Create Gradio interface
116
  with gr.Blocks() as demo:
117
  gr.Markdown("## Document Upload and Query System")
118
+
119
  with gr.Tab("Upload Files"):
120
  upload = gr.File(file_count="multiple", label="Upload PDF or DOCX files")
121
  upload_button = gr.Button("Upload")
122
  upload_output = gr.Textbox()
123
  upload_button.click(fn=upload_files, inputs=upload, outputs=upload_output)
124
+
125
  with gr.Tab("Query"):
126
  query = gr.Textbox(label="Enter your query")
127
  query_button = gr.Button("Search")
128
  query_output = gr.Textbox()
129
+ query_button.click(fn=process_and_query, inputs=[query], outputs=query_output)
130
 
131
  demo.launch()
132
 
133
+