NaimaAqeel commited on
Commit
318fbd2
·
verified ·
1 Parent(s): e0d20c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -212
app.py CHANGED
@@ -1,234 +1,56 @@
1
- import gradio as gr
2
  from PyPDF2 import PdfReader
3
  import docx
4
  from sentence_transformers import SentenceTransformer, util
5
  from transformers import pipeline
6
- import re
7
- import torch
8
 
9
- # -------------------------
10
  # Load models
11
- # -------------------------
12
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
13
-
14
- qa_pipeline = pipeline(
15
- "question-answering",
16
- model="distilbert-base-cased-distilled-squad",
17
- device=0 if torch.cuda.is_available() else -1
18
- )
19
-
20
- # -------------------------
21
- # Text utilities
22
- # -------------------------
23
- SENT_SPLIT_RE = re.compile(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s')
24
-
25
- def normalize_ws(text: str) -> str:
26
- text = re.sub(r'[ \t]+', ' ', text)
27
- text = re.sub(r'\s*\n\s*', '\n', text)
28
- return text.strip()
29
-
30
- def extract_text(file) -> str:
31
- if file.name.lower().endswith(".pdf"):
32
- text = "\n".join([page.extract_text() or "" for page in PdfReader(file).pages])
33
- elif file.name.lower().endswith(".docx"):
34
- text = "\n".join([p.text for p in docx.Document(file).paragraphs])
35
- else:
36
- return ""
37
- return normalize_ws(text)
38
-
39
- def split_sentences(text: str):
40
- parts = SENT_SPLIT_RE.split(text)
41
- # Merge very short fragments with neighbors
42
- out = []
43
- buf = ""
44
- for p in parts:
45
- if len(p.strip()) < 40:
46
- buf += (" " if buf else "") + p.strip()
47
- else:
48
- if buf:
49
- out.append(buf.strip())
50
- buf = ""
51
- out.append(p.strip())
52
- if buf:
53
- out.append(buf.strip())
54
- return [s for s in out if s]
55
-
56
- def chunk_by_chars(sentences, chunk_char_limit=900, overlap_sents=1):
57
- chunks, cur, cur_len = [], [], 0
58
- for s in sentences:
59
- if cur_len + len(s) + 1 <= chunk_char_limit:
60
- cur.append(s); cur_len += len(s) + 1
61
  else:
62
- if cur:
63
- chunks.append(" ".join(cur))
64
- # overlap for context
65
- cur = cur[-overlap_sents:] + [s]
66
- cur_len = sum(len(x) + 1 for x in cur)
67
- else:
68
- # extremely long sentence, hard cut
69
- chunks.append(s[:chunk_char_limit])
70
- cur, cur_len = [], 0
71
- if cur:
72
- chunks.append(" ".join(cur))
73
  return chunks
74
 
75
- def clean_answer(text: str) -> str:
76
- # Remove obvious footer/contact lines or emails/urls/phones
77
- text = re.sub(r'\bIf you have any questions.*', '', text, flags=re.IGNORECASE)
78
- text = re.sub(r'[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}', '', text)
79
- text = re.sub(r'https?://\S+|www\.\S+', '', text)
80
- text = re.sub(r'\b(?:Tel|Phone|Cell|Contact)\b.*', '', text, flags=re.IGNORECASE)
81
- return normalize_ws(text)
82
-
83
- # -------------------------
84
- # Definition finder (for "what is / define" questions)
85
- # -------------------------
86
- DEF_PATTERNS = [
87
- r'\b(system)\s+is\s+(?:an?|the)\s+[^.]+?\.', # "system is a/the ..."
88
- r'\b(system)\s+refers\s+to\s+[^.]+?\.', # "system refers to ..."
89
- r'\b(system)\s+can\s+be\s+defined\s+as\s+[^.]+?\.', # "system can be defined as ..."
90
- r'\b(system)\s+consists\s+of\s+[^.]+?\.', # "system consists of ..."
91
- ]
92
-
93
- KEYWORDS_BONUS = {"interrelated", "components", "objective", "objectives",
94
- "environment", "inputs", "outputs", "communication", "function"}
95
-
96
- def find_definition_sentences(text: str, term: str = "system"):
97
- sentences = split_sentences(text)
98
- cand = []
99
- term_lc = term.lower()
100
- for s in sentences:
101
- s_lc = s.lower()
102
- if term_lc not in s_lc:
103
- continue
104
- matched = any(re.search(pat.replace("system", term_lc), s_lc) for pat in DEF_PATTERNS)
105
- if matched:
106
- score = sum(1 for k in KEYWORDS_BONUS if k in s_lc)
107
- cand.append((score, s.strip()))
108
- if not cand:
109
- # fallback: sentences with term + several keywords
110
- for s in sentences:
111
- s_lc = s.lower()
112
- if term_lc in s_lc:
113
- score = sum(1 for k in KEYWORDS_BONUS if k in s_lc)
114
- if score >= 2:
115
- cand.append((score, s.strip()))
116
- if not cand:
117
- return None
118
- cand.sort(key=lambda x: (-x[0], len(x[1])))
119
- return cand[0][1]
120
-
121
- # -------------------------
122
- # Retrieval helpers
123
- # -------------------------
124
- def select_top_chunks(chunks, question, top_k=3):
125
- emb_chunks = embedder.encode(chunks, convert_to_tensor=True, normalize_embeddings=True)
126
- emb_q = embedder.encode([question], convert_to_tensor=True, normalize_embeddings=True)
127
- sims = util.cos_sim(emb_q, emb_chunks)[0] # shape [num_chunks]
128
- top_k = min(top_k, len(chunks))
129
- top_idx = torch.topk(sims, k=top_k).indices.tolist()
130
- return [chunks[i] for i in top_idx], sims.max().item()
131
-
132
- # -------------------------
133
- # Main QA logic
134
- # -------------------------
135
- def answer_from_chunks(question: str, chunks: list, strict_extractive=True):
136
- """
137
- Try QA over the best chunk(s). If strict_extractive, return the extractive span only.
138
- We'll query the best chunk first; if low score, concatenate top-3 chunks and retry.
139
- """
140
- if not chunks:
141
- return None
142
-
143
- # Best single chunk first
144
- result = qa_pipeline(question=question, context=chunks[0])
145
- best_answer, best_score = result.get("answer", ""), result.get("score", 0.0)
146
-
147
- # If weak, try merged top chunks
148
- if best_score < 0.25 and len(chunks) > 1:
149
- merged = " ".join(chunks)
150
- result2 = qa_pipeline(question=question, context=merged)
151
- if result2.get("score", 0.0) > best_score:
152
- best_answer, best_score = result2["answer"], result2["score"]
153
-
154
- if best_score < 0.15 or len(best_answer.strip()) < 2:
155
- return None
156
-
157
- ans = best_answer.strip()
158
- # keep it extractive and clean
159
- ans = clean_answer(ans)
160
- if strict_extractive:
161
- # ensure it's a concise span (avoid run-on junk)
162
- ans = re.split(r'[\n\r]', ans)[0].strip()
163
- return ans or None
164
-
165
- # -------------------------
166
- # Gradio callback
167
- # -------------------------
168
- def ask_question(file, question, history, strict_extractive=True):
169
  if not file:
170
  return "Please upload a file.", history
171
- if not question or not question.strip():
172
- return "Please type a question.", history
173
 
174
  text = extract_text(file)
175
- if not text:
176
- return "Could not extract text from the file.", history
177
-
178
- sentences = split_sentences(text)
179
- chunks = chunk_by_chars(sentences, chunk_char_limit=900, overlap_sents=1)
180
 
181
- q_norm = question.lower().strip(" ?!")
182
-
183
- try:
184
- # 1) Prefer a definition for "what is/define ..." style questions
185
- if re.search(r'\b(what\s+is|define|definition of)\b', q_norm) and "system" in q_norm:
186
- defin = find_definition_sentences(text, term="system")
187
- if defin:
188
- answer = clean_answer(defin)
189
- history.append((question, answer))
190
- return "", history
191
-
192
- # 2) Retrieval + extractive QA
193
- top_chunks, max_sim = select_top_chunks(chunks, question, top_k=3)
194
- answer = answer_from_chunks(question, top_chunks, strict_extractive=strict_extractive)
195
-
196
- # 3) If still nothing, try a simpler sentence retrieval: pick the most relevant sentence
197
- if not answer:
198
- emb_sents = embedder.encode(sentences, convert_to_tensor=True, normalize_embeddings=True)
199
- emb_q = embedder.encode([question], convert_to_tensor=True, normalize_embeddings=True)
200
- sims = util.cos_sim(emb_q, emb_sents)[0]
201
- best_i = int(torch.argmax(sims).item())
202
- if sims[best_i].item() > 0.2:
203
- answer = clean_answer(sentences[best_i])
204
-
205
- if not answer:
206
- answer = "Sorry, I couldn't find a clear, grounded answer in the document."
207
-
208
- except Exception as e:
209
- answer = f"An error occurred: {str(e)}"
210
 
211
  history.append((question, answer))
212
  return "", history
213
 
214
- # -------------------------
215
- # UI
216
- # -------------------------
217
  with gr.Blocks() as demo:
218
- gr.Markdown("## 📘 Document QA Strict Extractive (No Hallucinations)")
219
- with gr.Row():
220
- file_input = gr.File(label="Upload PDF or Word", file_types=[".pdf", ".docx"])
221
- with gr.Row():
222
- chatbot = gr.Chatbot(height=420)
223
- with gr.Row():
224
- question = gr.Textbox(label="Ask your question", placeholder="e.g., What is a system?")
225
- strict = gr.Checkbox(value=True, label="Strict extractive only (recommended)")
226
  state = gr.State([])
 
227
 
228
- question.submit(
229
- ask_question,
230
- [file_input, question, state, strict],
231
- [question, chatbot]
232
- )
233
-
234
- demo.launch()
 
1
+ import gradio as gr
2
  from PyPDF2 import PdfReader
3
  import docx
4
  from sentence_transformers import SentenceTransformer, util
5
  from transformers import pipeline
 
 
6
 
 
7
  # Load models
 
8
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
9
+ qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad")
10
+
11
+ def extract_text(file):
12
+ if file.name.endswith(".pdf"):
13
+ return "\n".join([page.extract_text() or "" for page in PdfReader(file).pages])
14
+ elif file.name.endswith(".docx"):
15
+ return "\n".join([p.text for p in docx.Document(file).paragraphs])
16
+ return ""
17
+
18
+ def chunk_text(text, chunk_size=500):
19
+ sentences = text.split(". ")
20
+ chunks, buffer = [], ""
21
+ for sent in sentences:
22
+ if len(buffer) + len(sent) < chunk_size:
23
+ buffer += sent + ". "
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  else:
25
+ chunks.append(buffer.strip())
26
+ buffer = sent + ". "
27
+ if buffer:
28
+ chunks.append(buffer.strip())
 
 
 
 
 
 
 
29
  return chunks
30
 
31
+ def ask_question(file, question, history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  if not file:
33
  return "Please upload a file.", history
 
 
34
 
35
  text = extract_text(file)
36
+ chunks = chunk_text(text)
37
+ emb_chunks = embedder.encode(chunks, convert_to_tensor=True)
38
+ emb_question = embedder.encode(question, convert_to_tensor=True)
39
+ scores = util.pytorch_cos_sim(emb_question, emb_chunks)[0]
40
+ best_chunk = chunks[scores.argmax().item()]
41
 
42
+ result = qa_pipeline(question=question, context=best_chunk)
43
+ answer = result["answer"] if result["score"] > 0.1 else "Sorry, not found."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  history.append((question, answer))
46
  return "", history
47
 
 
 
 
48
  with gr.Blocks() as demo:
49
+ gr.Markdown("## Document QA with Smart Retrieval")
50
+ file_input = gr.File(label="Upload PDF or Word", file_types=[".pdf", ".docx"])
51
+ chatbot = gr.Chatbot()
52
+ question = gr.Textbox(label="Ask your question")
 
 
 
 
53
  state = gr.State([])
54
+ question.submit(ask_question, [file_input, question, state], [question, chatbot])
55
 
56
+ demo.launch()