NaimaAqeel commited on
Commit
e0d20c3
·
verified ·
1 Parent(s): 8889a56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -113
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  from PyPDF2 import PdfReader
3
  import docx
4
  from sentence_transformers import SentenceTransformer, util
5
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
6
  import re
7
  import torch
8
 
@@ -12,160 +12,222 @@ import torch
12
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
13
 
14
  qa_pipeline = pipeline(
15
- "question-answering",
16
  model="distilbert-base-cased-distilled-squad",
17
  device=0 if torch.cuda.is_available() else -1
18
  )
19
 
20
- # GPT model (using GPT-2 here – replace with better model if you have)
21
- gpt_tokenizer = AutoTokenizer.from_pretrained("gpt2")
22
- gpt_model = AutoModelForCausalLM.from_pretrained("gpt2")
23
- gpt_model.eval()
24
-
25
  # -------------------------
26
- # Helper functions
27
  # -------------------------
28
- def extract_text(file):
29
- """Extract text from PDF or DOCX"""
30
- if file.name.endswith(".pdf"):
 
 
 
 
 
 
31
  text = "\n".join([page.extract_text() or "" for page in PdfReader(file).pages])
32
- elif file.name.endswith(".docx"):
33
  text = "\n".join([p.text for p in docx.Document(file).paragraphs])
34
  else:
35
  return ""
36
- text = re.sub(r'\s+', ' ', text) # clean whitespace
37
- return text.strip()
38
 
39
- def chunk_text(text, chunk_size=500, overlap=100):
40
- """Split text into overlapping chunks"""
41
- sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
42
- chunks, current_chunk = [], ""
43
-
44
- for sent in sentences:
45
- if len(current_chunk) + len(sent) < chunk_size:
46
- current_chunk += sent + " "
47
  else:
48
- chunks.append(current_chunk.strip())
49
- current_chunk = current_chunk[-overlap:] + sent + " "
 
 
 
 
 
50
 
51
- if current_chunk:
52
- chunks.append(current_chunk.strip())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  return chunks
54
 
55
- def generate_with_gpt(prompt, max_new_tokens=100):
56
- """Generate text with GPT model"""
57
- inputs = gpt_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
58
- with torch.no_grad():
59
- outputs = gpt_model.generate(
60
- inputs.input_ids,
61
- max_new_tokens=max_new_tokens, # FIXED
62
- num_return_sequences=1,
63
- no_repeat_ngram_size=2,
64
- do_sample=True,
65
- top_k=50,
66
- top_p=0.95,
67
- temperature=0.7
68
- )
69
- return gpt_tokenizer.decode(outputs[0], skip_special_tokens=True)
70
-
71
- def refine_answer_with_gpt(context, question, answer):
72
- """Ask GPT to refine the QA model answer"""
73
- prompt = (
74
- f"Context: {context}\n\n"
75
- f"Question: {question}\n\n"
76
- f"Answer: {answer}\n\n"
77
- f"Please provide a clearer and more complete answer in simple language."
78
- )
79
- return generate_with_gpt(prompt, max_new_tokens=120)
80
-
81
- def extract_direct_definition(text, term):
82
- """Find a direct definition of a term in the text"""
83
- sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
84
- term = term.lower()
85
- candidates = []
86
- for sent in sentences:
87
- lower_sent = sent.lower()
88
- if term in lower_sent:
89
- if (" is " in lower_sent or " are " in lower_sent or
90
- " refers to " in lower_sent or " defined as " in lower_sent):
91
- candidates.append(sent)
92
- if candidates:
93
- return candidates[0]
94
- return None
95
 
96
  # -------------------------
97
- # Main QA function
98
  # -------------------------
99
- def ask_question(file, question, history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  if not file:
101
  return "Please upload a file.", history
 
 
102
 
103
  text = extract_text(file)
104
  if not text:
105
  return "Could not extract text from the file.", history
106
-
107
- chunks = chunk_text(text)
108
- if not chunks:
109
- return "No meaningful text chunks could be created.", history
110
-
111
- # Initialize answer
112
- answer = None
113
- normalized_question = question.lower().strip(" ?")
114
-
115
  try:
116
- # Try direct definition
117
- if "artificial system" in normalized_question:
118
- answer = extract_direct_definition(text, "artificial system")
119
- elif "natural system" in normalized_question:
120
- answer = extract_direct_definition(text, "natural system")
121
- elif "component" in normalized_question:
122
- answer = extract_direct_definition(text, "component")
123
-
124
- # If no direct definition, do semantic search + QA
 
 
 
 
125
  if not answer:
126
- emb_chunks = embedder.encode(chunks, convert_to_tensor=True)
127
- emb_question = embedder.encode(question, convert_to_tensor=True)
128
- scores = util.pytorch_cos_sim(emb_question, emb_chunks)[0]
129
- best_idx = scores.argmax().item()
130
- best_chunk = chunks[best_idx]
131
-
132
- # Low confidence → merge top chunks
133
- if scores[best_idx] < 0.3:
134
- top_k = min(3, len(chunks))
135
- best_indices = scores.topk(top_k).indices.tolist()
136
- best_chunk = " ".join([chunks[i] for i in best_indices])
137
-
138
- result = qa_pipeline(question=question, context=best_chunk)
139
- answer = result["answer"] if result["score"] > 0.1 else None
140
-
141
- if answer and len(answer.split()) > 2:
142
- answer = refine_answer_with_gpt(best_chunk, question, answer)
143
-
144
  if not answer:
145
- answer = "Sorry, I couldn't find a clear answer in the document."
146
-
147
  except Exception as e:
148
  answer = f"An error occurred: {str(e)}"
149
-
150
  history.append((question, answer))
151
  return "", history
152
 
153
  # -------------------------
154
- # Gradio Interface
155
  # -------------------------
156
  with gr.Blocks() as demo:
157
- gr.Markdown("## 📘 Enhanced Document QA with GPT Integration")
158
  with gr.Row():
159
  file_input = gr.File(label="Upload PDF or Word", file_types=[".pdf", ".docx"])
160
  with gr.Row():
161
- chatbot = gr.Chatbot(height=400)
162
  with gr.Row():
163
- question = gr.Textbox(label="Ask your question", placeholder="Type your question here...")
 
164
  state = gr.State([])
165
 
166
  question.submit(
167
- ask_question,
168
- [file_input, question, state],
169
  [question, chatbot]
170
  )
171
 
 
2
  from PyPDF2 import PdfReader
3
  import docx
4
  from sentence_transformers import SentenceTransformer, util
5
+ from transformers import pipeline
6
  import re
7
  import torch
8
 
 
12
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
13
 
14
  qa_pipeline = pipeline(
15
+ "question-answering",
16
  model="distilbert-base-cased-distilled-squad",
17
  device=0 if torch.cuda.is_available() else -1
18
  )
19
 
 
 
 
 
 
20
  # -------------------------
21
+ # Text utilities
22
  # -------------------------
23
+ SENT_SPLIT_RE = re.compile(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s')
24
+
25
+ def normalize_ws(text: str) -> str:
26
+ text = re.sub(r'[ \t]+', ' ', text)
27
+ text = re.sub(r'\s*\n\s*', '\n', text)
28
+ return text.strip()
29
+
30
+ def extract_text(file) -> str:
31
+ if file.name.lower().endswith(".pdf"):
32
  text = "\n".join([page.extract_text() or "" for page in PdfReader(file).pages])
33
+ elif file.name.lower().endswith(".docx"):
34
  text = "\n".join([p.text for p in docx.Document(file).paragraphs])
35
  else:
36
  return ""
37
+ return normalize_ws(text)
 
38
 
39
+ def split_sentences(text: str):
40
+ parts = SENT_SPLIT_RE.split(text)
41
+ # Merge very short fragments with neighbors
42
+ out = []
43
+ buf = ""
44
+ for p in parts:
45
+ if len(p.strip()) < 40:
46
+ buf += (" " if buf else "") + p.strip()
47
  else:
48
+ if buf:
49
+ out.append(buf.strip())
50
+ buf = ""
51
+ out.append(p.strip())
52
+ if buf:
53
+ out.append(buf.strip())
54
+ return [s for s in out if s]
55
 
56
+ def chunk_by_chars(sentences, chunk_char_limit=900, overlap_sents=1):
57
+ chunks, cur, cur_len = [], [], 0
58
+ for s in sentences:
59
+ if cur_len + len(s) + 1 <= chunk_char_limit:
60
+ cur.append(s); cur_len += len(s) + 1
61
+ else:
62
+ if cur:
63
+ chunks.append(" ".join(cur))
64
+ # overlap for context
65
+ cur = cur[-overlap_sents:] + [s]
66
+ cur_len = sum(len(x) + 1 for x in cur)
67
+ else:
68
+ # extremely long sentence, hard cut
69
+ chunks.append(s[:chunk_char_limit])
70
+ cur, cur_len = [], 0
71
+ if cur:
72
+ chunks.append(" ".join(cur))
73
  return chunks
74
 
75
+ def clean_answer(text: str) -> str:
76
+ # Remove obvious footer/contact lines or emails/urls/phones
77
+ text = re.sub(r'\bIf you have any questions.*', '', text, flags=re.IGNORECASE)
78
+ text = re.sub(r'[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}', '', text)
79
+ text = re.sub(r'https?://\S+|www\.\S+', '', text)
80
+ text = re.sub(r'\b(?:Tel|Phone|Cell|Contact)\b.*', '', text, flags=re.IGNORECASE)
81
+ return normalize_ws(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  # -------------------------
84
+ # Definition finder (for "what is / define" questions)
85
  # -------------------------
86
+ DEF_PATTERNS = [
87
+ r'\b(system)\s+is\s+(?:an?|the)\s+[^.]+?\.', # "system is a/the ..."
88
+ r'\b(system)\s+refers\s+to\s+[^.]+?\.', # "system refers to ..."
89
+ r'\b(system)\s+can\s+be\s+defined\s+as\s+[^.]+?\.', # "system can be defined as ..."
90
+ r'\b(system)\s+consists\s+of\s+[^.]+?\.', # "system consists of ..."
91
+ ]
92
+
93
+ KEYWORDS_BONUS = {"interrelated", "components", "objective", "objectives",
94
+ "environment", "inputs", "outputs", "communication", "function"}
95
+
96
+ def find_definition_sentences(text: str, term: str = "system"):
97
+ sentences = split_sentences(text)
98
+ cand = []
99
+ term_lc = term.lower()
100
+ for s in sentences:
101
+ s_lc = s.lower()
102
+ if term_lc not in s_lc:
103
+ continue
104
+ matched = any(re.search(pat.replace("system", term_lc), s_lc) for pat in DEF_PATTERNS)
105
+ if matched:
106
+ score = sum(1 for k in KEYWORDS_BONUS if k in s_lc)
107
+ cand.append((score, s.strip()))
108
+ if not cand:
109
+ # fallback: sentences with term + several keywords
110
+ for s in sentences:
111
+ s_lc = s.lower()
112
+ if term_lc in s_lc:
113
+ score = sum(1 for k in KEYWORDS_BONUS if k in s_lc)
114
+ if score >= 2:
115
+ cand.append((score, s.strip()))
116
+ if not cand:
117
+ return None
118
+ cand.sort(key=lambda x: (-x[0], len(x[1])))
119
+ return cand[0][1]
120
+
121
+ # -------------------------
122
+ # Retrieval helpers
123
+ # -------------------------
124
+ def select_top_chunks(chunks, question, top_k=3):
125
+ emb_chunks = embedder.encode(chunks, convert_to_tensor=True, normalize_embeddings=True)
126
+ emb_q = embedder.encode([question], convert_to_tensor=True, normalize_embeddings=True)
127
+ sims = util.cos_sim(emb_q, emb_chunks)[0] # shape [num_chunks]
128
+ top_k = min(top_k, len(chunks))
129
+ top_idx = torch.topk(sims, k=top_k).indices.tolist()
130
+ return [chunks[i] for i in top_idx], sims.max().item()
131
+
132
+ # -------------------------
133
+ # Main QA logic
134
+ # -------------------------
135
+ def answer_from_chunks(question: str, chunks: list, strict_extractive=True):
136
+ """
137
+ Try QA over the best chunk(s). If strict_extractive, return the extractive span only.
138
+ We'll query the best chunk first; if low score, concatenate top-3 chunks and retry.
139
+ """
140
+ if not chunks:
141
+ return None
142
+
143
+ # Best single chunk first
144
+ result = qa_pipeline(question=question, context=chunks[0])
145
+ best_answer, best_score = result.get("answer", ""), result.get("score", 0.0)
146
+
147
+ # If weak, try merged top chunks
148
+ if best_score < 0.25 and len(chunks) > 1:
149
+ merged = " ".join(chunks)
150
+ result2 = qa_pipeline(question=question, context=merged)
151
+ if result2.get("score", 0.0) > best_score:
152
+ best_answer, best_score = result2["answer"], result2["score"]
153
+
154
+ if best_score < 0.15 or len(best_answer.strip()) < 2:
155
+ return None
156
+
157
+ ans = best_answer.strip()
158
+ # keep it extractive and clean
159
+ ans = clean_answer(ans)
160
+ if strict_extractive:
161
+ # ensure it's a concise span (avoid run-on junk)
162
+ ans = re.split(r'[\n\r]', ans)[0].strip()
163
+ return ans or None
164
+
165
+ # -------------------------
166
+ # Gradio callback
167
+ # -------------------------
168
+ def ask_question(file, question, history, strict_extractive=True):
169
  if not file:
170
  return "Please upload a file.", history
171
+ if not question or not question.strip():
172
+ return "Please type a question.", history
173
 
174
  text = extract_text(file)
175
  if not text:
176
  return "Could not extract text from the file.", history
177
+
178
+ sentences = split_sentences(text)
179
+ chunks = chunk_by_chars(sentences, chunk_char_limit=900, overlap_sents=1)
180
+
181
+ q_norm = question.lower().strip(" ?!")
182
+
 
 
 
183
  try:
184
+ # 1) Prefer a definition for "what is/define ..." style questions
185
+ if re.search(r'\b(what\s+is|define|definition of)\b', q_norm) and "system" in q_norm:
186
+ defin = find_definition_sentences(text, term="system")
187
+ if defin:
188
+ answer = clean_answer(defin)
189
+ history.append((question, answer))
190
+ return "", history
191
+
192
+ # 2) Retrieval + extractive QA
193
+ top_chunks, max_sim = select_top_chunks(chunks, question, top_k=3)
194
+ answer = answer_from_chunks(question, top_chunks, strict_extractive=strict_extractive)
195
+
196
+ # 3) If still nothing, try a simpler sentence retrieval: pick the most relevant sentence
197
  if not answer:
198
+ emb_sents = embedder.encode(sentences, convert_to_tensor=True, normalize_embeddings=True)
199
+ emb_q = embedder.encode([question], convert_to_tensor=True, normalize_embeddings=True)
200
+ sims = util.cos_sim(emb_q, emb_sents)[0]
201
+ best_i = int(torch.argmax(sims).item())
202
+ if sims[best_i].item() > 0.2:
203
+ answer = clean_answer(sentences[best_i])
204
+
 
 
 
 
 
 
 
 
 
 
 
205
  if not answer:
206
+ answer = "Sorry, I couldn't find a clear, grounded answer in the document."
207
+
208
  except Exception as e:
209
  answer = f"An error occurred: {str(e)}"
210
+
211
  history.append((question, answer))
212
  return "", history
213
 
214
  # -------------------------
215
+ # UI
216
  # -------------------------
217
  with gr.Blocks() as demo:
218
+ gr.Markdown("## 📘 Document QA Strict Extractive (No Hallucinations)")
219
  with gr.Row():
220
  file_input = gr.File(label="Upload PDF or Word", file_types=[".pdf", ".docx"])
221
  with gr.Row():
222
+ chatbot = gr.Chatbot(height=420)
223
  with gr.Row():
224
+ question = gr.Textbox(label="Ask your question", placeholder="e.g., What is a system?")
225
+ strict = gr.Checkbox(value=True, label="Strict extractive only (recommended)")
226
  state = gr.State([])
227
 
228
  question.submit(
229
+ ask_question,
230
+ [file_input, question, state, strict],
231
  [question, chatbot]
232
  )
233