NaimaAqeel commited on
Commit
965462a
·
verified ·
1 Parent(s): 60d9162

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -77
app.py CHANGED
@@ -6,46 +6,54 @@ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
6
  import re
7
  import torch
8
 
 
9
  # Load models
 
10
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
11
- qa_pipeline = pipeline("question-answering",
12
- model="distilbert-base-cased-distilled-squad",
13
- device=0 if torch.cuda.is_available() else -1)
14
 
15
- # Load GPT model (using GPT-2 as example - replace with GPT-3/4 if available)
 
 
 
 
 
 
16
  gpt_tokenizer = AutoTokenizer.from_pretrained("gpt2")
17
  gpt_model = AutoModelForCausalLM.from_pretrained("gpt2")
18
  gpt_model.eval()
19
 
 
 
 
20
  def extract_text(file):
 
21
  if file.name.endswith(".pdf"):
22
  text = "\n".join([page.extract_text() or "" for page in PdfReader(file).pages])
23
  elif file.name.endswith(".docx"):
24
  text = "\n".join([p.text for p in docx.Document(file).paragraphs])
25
  else:
26
  return ""
27
- # Clean up text
28
- text = re.sub(r'\s+', ' ', text) # Replace multiple whitespace with single space
29
  return text.strip()
30
 
31
  def chunk_text(text, chunk_size=500, overlap=100):
 
32
  sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
33
- chunks = []
34
- current_chunk = ""
35
-
36
  for sent in sentences:
37
  if len(current_chunk) + len(sent) < chunk_size:
38
  current_chunk += sent + " "
39
  else:
40
  chunks.append(current_chunk.strip())
41
- # Keep some overlap between chunks for context
42
  current_chunk = current_chunk[-overlap:] + sent + " "
43
-
44
  if current_chunk:
45
  chunks.append(current_chunk.strip())
46
  return chunks
47
 
48
  def generate_with_gpt(prompt, max_length=150):
 
49
  inputs = gpt_tokenizer(prompt, return_tensors="pt")
50
  with torch.no_grad():
51
  outputs = gpt_model.generate(
@@ -60,65 +68,20 @@ def generate_with_gpt(prompt, max_length=150):
60
  )
61
  return gpt_tokenizer.decode(outputs[0], skip_special_tokens=True)
62
 
63
- def ask_question(file, question, history):
64
- if not file:
65
- return "Please upload a file.", history
66
-
67
- text = extract_text(file)
68
- if not text:
69
- return "Could not extract text from the file.", history
70
-
71
- chunks = chunk_text(text)
72
- if not chunks:
73
- return "No meaningful text chunks could be created.", history
74
-
75
- # Initialize answer as None
76
- answer = None
77
-
78
- try:
79
- # Normalize question for better matching
80
- normalized_question = question.lower().strip(" ?")
81
-
82
- # First try to find direct definitions
83
- if "artificial system" in normalized_question:
84
- answer = extract_direct_definition(text, "artificial system")
85
- elif "natural system" in normalized_question:
86
- answer = extract_direct_definition(text, "natural system")
87
- elif "component" in normalized_question:
88
- answer = extract_direct_definition(text, "component")
89
-
90
- # If no direct definition found, use semantic search
91
- if not answer:
92
- emb_chunks = embedder.encode(chunks, convert_to_tensor=True)
93
- emb_question = embedder.encode(question, convert_to_tensor=True)
94
- scores = util.pytorch_cos_sim(emb_question, emb_chunks)[0]
95
- best_idx = scores.argmax().item()
96
- best_chunk = chunks[best_idx]
97
-
98
- if scores[best_idx] < 0.3: # Low confidence
99
- top_k = min(3, len(chunks))
100
- best_indices = scores.topk(top_k).indices.tolist()
101
- best_chunk = " ".join([chunks[i] for i in best_indices])
102
-
103
- result = qa_pipeline(question=question, context=best_chunk)
104
- if result["score"] > 0.1 and len(result["answer"].split()) >= 2:
105
- answer = result["answer"]
106
-
107
- # Final fallback if no answer found
108
- if not answer:
109
- answer = "Sorry, I couldn't find a clear answer in the document."
110
-
111
- except Exception as e:
112
- answer = f"An error occurred: {str(e)}"
113
-
114
- history.append((question, answer))
115
- return "", history
116
 
117
  def extract_direct_definition(text, term):
118
- """Try to find a sentence that directly defines the term"""
119
  sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
120
  term = term.lower()
121
-
122
  candidates = []
123
  for sent in sentences:
124
  lower_sent = sent.lower()
@@ -126,11 +89,13 @@ def extract_direct_definition(text, term):
126
  if (" is " in lower_sent or " are " in lower_sent or
127
  " refers to " in lower_sent or " defined as " in lower_sent):
128
  candidates.append(sent)
129
-
130
  if candidates:
131
  return candidates[0]
132
  return None
133
 
 
 
 
134
  def ask_question(file, question, history):
135
  if not file:
136
  return "Please upload a file.", history
@@ -143,11 +108,12 @@ def ask_question(file, question, history):
143
  if not chunks:
144
  return "No meaningful text chunks could be created.", history
145
 
146
- # Normalize question for better matching
 
147
  normalized_question = question.lower().strip(" ?")
148
 
149
  try:
150
- # First try to find direct definitions
151
  if "artificial system" in normalized_question:
152
  answer = extract_direct_definition(text, "artificial system")
153
  elif "natural system" in normalized_question:
@@ -155,7 +121,7 @@ def ask_question(file, question, history):
155
  elif "component" in normalized_question:
156
  answer = extract_direct_definition(text, "component")
157
 
158
- # If no direct definition found, use semantic search
159
  if not answer:
160
  emb_chunks = embedder.encode(chunks, convert_to_tensor=True)
161
  emb_question = embedder.encode(question, convert_to_tensor=True)
@@ -163,32 +129,32 @@ def ask_question(file, question, history):
163
  best_idx = scores.argmax().item()
164
  best_chunk = chunks[best_idx]
165
 
166
- # Combine top chunks if confidence is low
167
  if scores[best_idx] < 0.3:
168
  top_k = min(3, len(chunks))
169
  best_indices = scores.topk(top_k).indices.tolist()
170
  best_chunk = " ".join([chunks[i] for i in best_indices])
171
 
172
- # Get initial answer from QA model
173
  result = qa_pipeline(question=question, context=best_chunk)
174
  answer = result["answer"] if result["score"] > 0.1 else None
175
 
176
- # Refine answer with GPT if available
177
  if answer and len(answer.split()) > 2:
178
  answer = refine_answer_with_gpt(best_chunk, question, answer)
179
 
180
- # Final fallback
181
  if not answer:
182
  answer = "Sorry, I couldn't find a clear answer in the document."
183
-
184
  except Exception as e:
185
  answer = f"An error occurred: {str(e)}"
186
 
187
  history.append((question, answer))
188
  return "", history
189
 
 
 
 
190
  with gr.Blocks() as demo:
191
- gr.Markdown("## Enhanced Document QA with GPT Integration")
192
  with gr.Row():
193
  file_input = gr.File(label="Upload PDF or Word", file_types=[".pdf", ".docx"])
194
  with gr.Row():
@@ -196,7 +162,7 @@ with gr.Blocks() as demo:
196
  with gr.Row():
197
  question = gr.Textbox(label="Ask your question", placeholder="Type your question here...")
198
  state = gr.State([])
199
-
200
  question.submit(
201
  ask_question,
202
  [file_input, question, state],
 
6
  import re
7
  import torch
8
 
9
+ # -------------------------
10
  # Load models
11
+ # -------------------------
12
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
 
 
 
13
 
14
+ qa_pipeline = pipeline(
15
+ "question-answering",
16
+ model="distilbert-base-cased-distilled-squad",
17
+ device=0 if torch.cuda.is_available() else -1
18
+ )
19
+
20
+ # GPT model (using GPT-2 here – replace with better model if you have)
21
  gpt_tokenizer = AutoTokenizer.from_pretrained("gpt2")
22
  gpt_model = AutoModelForCausalLM.from_pretrained("gpt2")
23
  gpt_model.eval()
24
 
25
+ # -------------------------
26
+ # Helper functions
27
+ # -------------------------
28
  def extract_text(file):
29
+ """Extract text from PDF or DOCX"""
30
  if file.name.endswith(".pdf"):
31
  text = "\n".join([page.extract_text() or "" for page in PdfReader(file).pages])
32
  elif file.name.endswith(".docx"):
33
  text = "\n".join([p.text for p in docx.Document(file).paragraphs])
34
  else:
35
  return ""
36
+ text = re.sub(r'\s+', ' ', text) # clean whitespace
 
37
  return text.strip()
38
 
39
  def chunk_text(text, chunk_size=500, overlap=100):
40
+ """Split text into overlapping chunks"""
41
  sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
42
+ chunks, current_chunk = [], ""
43
+
 
44
  for sent in sentences:
45
  if len(current_chunk) + len(sent) < chunk_size:
46
  current_chunk += sent + " "
47
  else:
48
  chunks.append(current_chunk.strip())
 
49
  current_chunk = current_chunk[-overlap:] + sent + " "
50
+
51
  if current_chunk:
52
  chunks.append(current_chunk.strip())
53
  return chunks
54
 
55
  def generate_with_gpt(prompt, max_length=150):
56
+ """Generate text with GPT model"""
57
  inputs = gpt_tokenizer(prompt, return_tensors="pt")
58
  with torch.no_grad():
59
  outputs = gpt_model.generate(
 
68
  )
69
  return gpt_tokenizer.decode(outputs[0], skip_special_tokens=True)
70
 
71
+ def refine_answer_with_gpt(context, question, answer):
72
+ """Ask GPT to refine the QA model answer"""
73
+ prompt = (
74
+ f"Context: {context}\n\n"
75
+ f"Question: {question}\n\n"
76
+ f"Answer: {answer}\n\n"
77
+ f"Please provide a clearer and more complete answer in simple language."
78
+ )
79
+ return generate_with_gpt(prompt, max_length=120)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  def extract_direct_definition(text, term):
82
+ """Find a direct definition of a term in the text"""
83
  sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
84
  term = term.lower()
 
85
  candidates = []
86
  for sent in sentences:
87
  lower_sent = sent.lower()
 
89
  if (" is " in lower_sent or " are " in lower_sent or
90
  " refers to " in lower_sent or " defined as " in lower_sent):
91
  candidates.append(sent)
 
92
  if candidates:
93
  return candidates[0]
94
  return None
95
 
96
+ # -------------------------
97
+ # Main QA function
98
+ # -------------------------
99
  def ask_question(file, question, history):
100
  if not file:
101
  return "Please upload a file.", history
 
108
  if not chunks:
109
  return "No meaningful text chunks could be created.", history
110
 
111
+ # Initialize answer
112
+ answer = None
113
  normalized_question = question.lower().strip(" ?")
114
 
115
  try:
116
+ # Try direct definition
117
  if "artificial system" in normalized_question:
118
  answer = extract_direct_definition(text, "artificial system")
119
  elif "natural system" in normalized_question:
 
121
  elif "component" in normalized_question:
122
  answer = extract_direct_definition(text, "component")
123
 
124
+ # If no direct definition, do semantic search + QA
125
  if not answer:
126
  emb_chunks = embedder.encode(chunks, convert_to_tensor=True)
127
  emb_question = embedder.encode(question, convert_to_tensor=True)
 
129
  best_idx = scores.argmax().item()
130
  best_chunk = chunks[best_idx]
131
 
132
+ # Low confidence merge top chunks
133
  if scores[best_idx] < 0.3:
134
  top_k = min(3, len(chunks))
135
  best_indices = scores.topk(top_k).indices.tolist()
136
  best_chunk = " ".join([chunks[i] for i in best_indices])
137
 
 
138
  result = qa_pipeline(question=question, context=best_chunk)
139
  answer = result["answer"] if result["score"] > 0.1 else None
140
 
 
141
  if answer and len(answer.split()) > 2:
142
  answer = refine_answer_with_gpt(best_chunk, question, answer)
143
 
 
144
  if not answer:
145
  answer = "Sorry, I couldn't find a clear answer in the document."
146
+
147
  except Exception as e:
148
  answer = f"An error occurred: {str(e)}"
149
 
150
  history.append((question, answer))
151
  return "", history
152
 
153
+ # -------------------------
154
+ # Gradio Interface
155
+ # -------------------------
156
  with gr.Blocks() as demo:
157
+ gr.Markdown("## 📘 Enhanced Document QA with GPT Integration")
158
  with gr.Row():
159
  file_input = gr.File(label="Upload PDF or Word", file_types=[".pdf", ".docx"])
160
  with gr.Row():
 
162
  with gr.Row():
163
  question = gr.Textbox(label="Ask your question", placeholder="Type your question here...")
164
  state = gr.State([])
165
+
166
  question.submit(
167
  ask_question,
168
  [file_input, question, state],