Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,234 +1,56 @@
|
|
1 |
-
import gradio as gr
|
2 |
from PyPDF2 import PdfReader
|
3 |
import docx
|
4 |
from sentence_transformers import SentenceTransformer, util
|
5 |
from transformers import pipeline
|
6 |
-
import re
|
7 |
-
import torch
|
8 |
|
9 |
-
# -------------------------
|
10 |
# Load models
|
11 |
-
# -------------------------
|
12 |
embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
)
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
return text.strip()
|
29 |
-
|
30 |
-
def extract_text(file) -> str:
|
31 |
-
if file.name.lower().endswith(".pdf"):
|
32 |
-
text = "\n".join([page.extract_text() or "" for page in PdfReader(file).pages])
|
33 |
-
elif file.name.lower().endswith(".docx"):
|
34 |
-
text = "\n".join([p.text for p in docx.Document(file).paragraphs])
|
35 |
-
else:
|
36 |
-
return ""
|
37 |
-
return normalize_ws(text)
|
38 |
-
|
39 |
-
def split_sentences(text: str):
|
40 |
-
parts = SENT_SPLIT_RE.split(text)
|
41 |
-
# Merge very short fragments with neighbors
|
42 |
-
out = []
|
43 |
-
buf = ""
|
44 |
-
for p in parts:
|
45 |
-
if len(p.strip()) < 40:
|
46 |
-
buf += (" " if buf else "") + p.strip()
|
47 |
-
else:
|
48 |
-
if buf:
|
49 |
-
out.append(buf.strip())
|
50 |
-
buf = ""
|
51 |
-
out.append(p.strip())
|
52 |
-
if buf:
|
53 |
-
out.append(buf.strip())
|
54 |
-
return [s for s in out if s]
|
55 |
-
|
56 |
-
def chunk_by_chars(sentences, chunk_char_limit=900, overlap_sents=1):
|
57 |
-
chunks, cur, cur_len = [], [], 0
|
58 |
-
for s in sentences:
|
59 |
-
if cur_len + len(s) + 1 <= chunk_char_limit:
|
60 |
-
cur.append(s); cur_len += len(s) + 1
|
61 |
else:
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
cur_len = sum(len(x) + 1 for x in cur)
|
67 |
-
else:
|
68 |
-
# extremely long sentence, hard cut
|
69 |
-
chunks.append(s[:chunk_char_limit])
|
70 |
-
cur, cur_len = [], 0
|
71 |
-
if cur:
|
72 |
-
chunks.append(" ".join(cur))
|
73 |
return chunks
|
74 |
|
75 |
-
def
|
76 |
-
# Remove obvious footer/contact lines or emails/urls/phones
|
77 |
-
text = re.sub(r'\bIf you have any questions.*', '', text, flags=re.IGNORECASE)
|
78 |
-
text = re.sub(r'[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}', '', text)
|
79 |
-
text = re.sub(r'https?://\S+|www\.\S+', '', text)
|
80 |
-
text = re.sub(r'\b(?:Tel|Phone|Cell|Contact)\b.*', '', text, flags=re.IGNORECASE)
|
81 |
-
return normalize_ws(text)
|
82 |
-
|
83 |
-
# -------------------------
|
84 |
-
# Definition finder (for "what is / define" questions)
|
85 |
-
# -------------------------
|
86 |
-
DEF_PATTERNS = [
|
87 |
-
r'\b(system)\s+is\s+(?:an?|the)\s+[^.]+?\.', # "system is a/the ..."
|
88 |
-
r'\b(system)\s+refers\s+to\s+[^.]+?\.', # "system refers to ..."
|
89 |
-
r'\b(system)\s+can\s+be\s+defined\s+as\s+[^.]+?\.', # "system can be defined as ..."
|
90 |
-
r'\b(system)\s+consists\s+of\s+[^.]+?\.', # "system consists of ..."
|
91 |
-
]
|
92 |
-
|
93 |
-
KEYWORDS_BONUS = {"interrelated", "components", "objective", "objectives",
|
94 |
-
"environment", "inputs", "outputs", "communication", "function"}
|
95 |
-
|
96 |
-
def find_definition_sentences(text: str, term: str = "system"):
|
97 |
-
sentences = split_sentences(text)
|
98 |
-
cand = []
|
99 |
-
term_lc = term.lower()
|
100 |
-
for s in sentences:
|
101 |
-
s_lc = s.lower()
|
102 |
-
if term_lc not in s_lc:
|
103 |
-
continue
|
104 |
-
matched = any(re.search(pat.replace("system", term_lc), s_lc) for pat in DEF_PATTERNS)
|
105 |
-
if matched:
|
106 |
-
score = sum(1 for k in KEYWORDS_BONUS if k in s_lc)
|
107 |
-
cand.append((score, s.strip()))
|
108 |
-
if not cand:
|
109 |
-
# fallback: sentences with term + several keywords
|
110 |
-
for s in sentences:
|
111 |
-
s_lc = s.lower()
|
112 |
-
if term_lc in s_lc:
|
113 |
-
score = sum(1 for k in KEYWORDS_BONUS if k in s_lc)
|
114 |
-
if score >= 2:
|
115 |
-
cand.append((score, s.strip()))
|
116 |
-
if not cand:
|
117 |
-
return None
|
118 |
-
cand.sort(key=lambda x: (-x[0], len(x[1])))
|
119 |
-
return cand[0][1]
|
120 |
-
|
121 |
-
# -------------------------
|
122 |
-
# Retrieval helpers
|
123 |
-
# -------------------------
|
124 |
-
def select_top_chunks(chunks, question, top_k=3):
|
125 |
-
emb_chunks = embedder.encode(chunks, convert_to_tensor=True, normalize_embeddings=True)
|
126 |
-
emb_q = embedder.encode([question], convert_to_tensor=True, normalize_embeddings=True)
|
127 |
-
sims = util.cos_sim(emb_q, emb_chunks)[0] # shape [num_chunks]
|
128 |
-
top_k = min(top_k, len(chunks))
|
129 |
-
top_idx = torch.topk(sims, k=top_k).indices.tolist()
|
130 |
-
return [chunks[i] for i in top_idx], sims.max().item()
|
131 |
-
|
132 |
-
# -------------------------
|
133 |
-
# Main QA logic
|
134 |
-
# -------------------------
|
135 |
-
def answer_from_chunks(question: str, chunks: list, strict_extractive=True):
|
136 |
-
"""
|
137 |
-
Try QA over the best chunk(s). If strict_extractive, return the extractive span only.
|
138 |
-
We'll query the best chunk first; if low score, concatenate top-3 chunks and retry.
|
139 |
-
"""
|
140 |
-
if not chunks:
|
141 |
-
return None
|
142 |
-
|
143 |
-
# Best single chunk first
|
144 |
-
result = qa_pipeline(question=question, context=chunks[0])
|
145 |
-
best_answer, best_score = result.get("answer", ""), result.get("score", 0.0)
|
146 |
-
|
147 |
-
# If weak, try merged top chunks
|
148 |
-
if best_score < 0.25 and len(chunks) > 1:
|
149 |
-
merged = " ".join(chunks)
|
150 |
-
result2 = qa_pipeline(question=question, context=merged)
|
151 |
-
if result2.get("score", 0.0) > best_score:
|
152 |
-
best_answer, best_score = result2["answer"], result2["score"]
|
153 |
-
|
154 |
-
if best_score < 0.15 or len(best_answer.strip()) < 2:
|
155 |
-
return None
|
156 |
-
|
157 |
-
ans = best_answer.strip()
|
158 |
-
# keep it extractive and clean
|
159 |
-
ans = clean_answer(ans)
|
160 |
-
if strict_extractive:
|
161 |
-
# ensure it's a concise span (avoid run-on junk)
|
162 |
-
ans = re.split(r'[\n\r]', ans)[0].strip()
|
163 |
-
return ans or None
|
164 |
-
|
165 |
-
# -------------------------
|
166 |
-
# Gradio callback
|
167 |
-
# -------------------------
|
168 |
-
def ask_question(file, question, history, strict_extractive=True):
|
169 |
if not file:
|
170 |
return "Please upload a file.", history
|
171 |
-
if not question or not question.strip():
|
172 |
-
return "Please type a question.", history
|
173 |
|
174 |
text = extract_text(file)
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
|
181 |
-
|
182 |
-
|
183 |
-
try:
|
184 |
-
# 1) Prefer a definition for "what is/define ..." style questions
|
185 |
-
if re.search(r'\b(what\s+is|define|definition of)\b', q_norm) and "system" in q_norm:
|
186 |
-
defin = find_definition_sentences(text, term="system")
|
187 |
-
if defin:
|
188 |
-
answer = clean_answer(defin)
|
189 |
-
history.append((question, answer))
|
190 |
-
return "", history
|
191 |
-
|
192 |
-
# 2) Retrieval + extractive QA
|
193 |
-
top_chunks, max_sim = select_top_chunks(chunks, question, top_k=3)
|
194 |
-
answer = answer_from_chunks(question, top_chunks, strict_extractive=strict_extractive)
|
195 |
-
|
196 |
-
# 3) If still nothing, try a simpler sentence retrieval: pick the most relevant sentence
|
197 |
-
if not answer:
|
198 |
-
emb_sents = embedder.encode(sentences, convert_to_tensor=True, normalize_embeddings=True)
|
199 |
-
emb_q = embedder.encode([question], convert_to_tensor=True, normalize_embeddings=True)
|
200 |
-
sims = util.cos_sim(emb_q, emb_sents)[0]
|
201 |
-
best_i = int(torch.argmax(sims).item())
|
202 |
-
if sims[best_i].item() > 0.2:
|
203 |
-
answer = clean_answer(sentences[best_i])
|
204 |
-
|
205 |
-
if not answer:
|
206 |
-
answer = "Sorry, I couldn't find a clear, grounded answer in the document."
|
207 |
-
|
208 |
-
except Exception as e:
|
209 |
-
answer = f"An error occurred: {str(e)}"
|
210 |
|
211 |
history.append((question, answer))
|
212 |
return "", history
|
213 |
|
214 |
-
# -------------------------
|
215 |
-
# UI
|
216 |
-
# -------------------------
|
217 |
with gr.Blocks() as demo:
|
218 |
-
gr.Markdown("##
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
chatbot = gr.Chatbot(height=420)
|
223 |
-
with gr.Row():
|
224 |
-
question = gr.Textbox(label="Ask your question", placeholder="e.g., What is a system?")
|
225 |
-
strict = gr.Checkbox(value=True, label="Strict extractive only (recommended)")
|
226 |
state = gr.State([])
|
|
|
227 |
|
228 |
-
|
229 |
-
ask_question,
|
230 |
-
[file_input, question, state, strict],
|
231 |
-
[question, chatbot]
|
232 |
-
)
|
233 |
-
|
234 |
-
demo.launch()
|
|
|
1 |
+
import gradio as gr
|
2 |
from PyPDF2 import PdfReader
|
3 |
import docx
|
4 |
from sentence_transformers import SentenceTransformer, util
|
5 |
from transformers import pipeline
|
|
|
|
|
6 |
|
|
|
7 |
# Load models
|
|
|
8 |
embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
9 |
+
qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad")
|
10 |
+
|
11 |
+
def extract_text(file):
|
12 |
+
if file.name.endswith(".pdf"):
|
13 |
+
return "\n".join([page.extract_text() or "" for page in PdfReader(file).pages])
|
14 |
+
elif file.name.endswith(".docx"):
|
15 |
+
return "\n".join([p.text for p in docx.Document(file).paragraphs])
|
16 |
+
return ""
|
17 |
+
|
18 |
+
def chunk_text(text, chunk_size=500):
|
19 |
+
sentences = text.split(". ")
|
20 |
+
chunks, buffer = [], ""
|
21 |
+
for sent in sentences:
|
22 |
+
if len(buffer) + len(sent) < chunk_size:
|
23 |
+
buffer += sent + ". "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
else:
|
25 |
+
chunks.append(buffer.strip())
|
26 |
+
buffer = sent + ". "
|
27 |
+
if buffer:
|
28 |
+
chunks.append(buffer.strip())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
return chunks
|
30 |
|
31 |
+
def ask_question(file, question, history):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
if not file:
|
33 |
return "Please upload a file.", history
|
|
|
|
|
34 |
|
35 |
text = extract_text(file)
|
36 |
+
chunks = chunk_text(text)
|
37 |
+
emb_chunks = embedder.encode(chunks, convert_to_tensor=True)
|
38 |
+
emb_question = embedder.encode(question, convert_to_tensor=True)
|
39 |
+
scores = util.pytorch_cos_sim(emb_question, emb_chunks)[0]
|
40 |
+
best_chunk = chunks[scores.argmax().item()]
|
41 |
|
42 |
+
result = qa_pipeline(question=question, context=best_chunk)
|
43 |
+
answer = result["answer"] if result["score"] > 0.1 else "Sorry, not found."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
history.append((question, answer))
|
46 |
return "", history
|
47 |
|
|
|
|
|
|
|
48 |
with gr.Blocks() as demo:
|
49 |
+
gr.Markdown("## Document QA with Smart Retrieval")
|
50 |
+
file_input = gr.File(label="Upload PDF or Word", file_types=[".pdf", ".docx"])
|
51 |
+
chatbot = gr.Chatbot()
|
52 |
+
question = gr.Textbox(label="Ask your question")
|
|
|
|
|
|
|
|
|
53 |
state = gr.State([])
|
54 |
+
question.submit(ask_question, [file_input, question, state], [question, chatbot])
|
55 |
|
56 |
+
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|