Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
|
|
2 |
from PyPDF2 import PdfReader
|
3 |
import docx
|
4 |
from sentence_transformers import SentenceTransformer, util
|
5 |
-
from transformers import pipeline
|
6 |
import re
|
7 |
import torch
|
8 |
|
@@ -12,160 +12,222 @@ import torch
|
|
12 |
embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
13 |
|
14 |
qa_pipeline = pipeline(
|
15 |
-
"question-answering",
|
16 |
model="distilbert-base-cased-distilled-squad",
|
17 |
device=0 if torch.cuda.is_available() else -1
|
18 |
)
|
19 |
|
20 |
-
# GPT model (using GPT-2 here – replace with better model if you have)
|
21 |
-
gpt_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
22 |
-
gpt_model = AutoModelForCausalLM.from_pretrained("gpt2")
|
23 |
-
gpt_model.eval()
|
24 |
-
|
25 |
# -------------------------
|
26 |
-
#
|
27 |
# -------------------------
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
text = "\n".join([page.extract_text() or "" for page in PdfReader(file).pages])
|
32 |
-
elif file.name.endswith(".docx"):
|
33 |
text = "\n".join([p.text for p in docx.Document(file).paragraphs])
|
34 |
else:
|
35 |
return ""
|
36 |
-
|
37 |
-
return text.strip()
|
38 |
|
39 |
-
def
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
for
|
45 |
-
if len(
|
46 |
-
|
47 |
else:
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
return chunks
|
54 |
|
55 |
-
def
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
num_return_sequences=1,
|
63 |
-
no_repeat_ngram_size=2,
|
64 |
-
do_sample=True,
|
65 |
-
top_k=50,
|
66 |
-
top_p=0.95,
|
67 |
-
temperature=0.7
|
68 |
-
)
|
69 |
-
return gpt_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
70 |
-
|
71 |
-
def refine_answer_with_gpt(context, question, answer):
|
72 |
-
"""Ask GPT to refine the QA model answer"""
|
73 |
-
prompt = (
|
74 |
-
f"Context: {context}\n\n"
|
75 |
-
f"Question: {question}\n\n"
|
76 |
-
f"Answer: {answer}\n\n"
|
77 |
-
f"Please provide a clearer and more complete answer in simple language."
|
78 |
-
)
|
79 |
-
return generate_with_gpt(prompt, max_new_tokens=120)
|
80 |
-
|
81 |
-
def extract_direct_definition(text, term):
|
82 |
-
"""Find a direct definition of a term in the text"""
|
83 |
-
sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
|
84 |
-
term = term.lower()
|
85 |
-
candidates = []
|
86 |
-
for sent in sentences:
|
87 |
-
lower_sent = sent.lower()
|
88 |
-
if term in lower_sent:
|
89 |
-
if (" is " in lower_sent or " are " in lower_sent or
|
90 |
-
" refers to " in lower_sent or " defined as " in lower_sent):
|
91 |
-
candidates.append(sent)
|
92 |
-
if candidates:
|
93 |
-
return candidates[0]
|
94 |
-
return None
|
95 |
|
96 |
# -------------------------
|
97 |
-
#
|
98 |
# -------------------------
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
if not file:
|
101 |
return "Please upload a file.", history
|
|
|
|
|
102 |
|
103 |
text = extract_text(file)
|
104 |
if not text:
|
105 |
return "Could not extract text from the file.", history
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
answer = None
|
113 |
-
normalized_question = question.lower().strip(" ?")
|
114 |
-
|
115 |
try:
|
116 |
-
#
|
117 |
-
if "
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
#
|
|
|
|
|
|
|
|
|
125 |
if not answer:
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
if scores[best_idx] < 0.3:
|
134 |
-
top_k = min(3, len(chunks))
|
135 |
-
best_indices = scores.topk(top_k).indices.tolist()
|
136 |
-
best_chunk = " ".join([chunks[i] for i in best_indices])
|
137 |
-
|
138 |
-
result = qa_pipeline(question=question, context=best_chunk)
|
139 |
-
answer = result["answer"] if result["score"] > 0.1 else None
|
140 |
-
|
141 |
-
if answer and len(answer.split()) > 2:
|
142 |
-
answer = refine_answer_with_gpt(best_chunk, question, answer)
|
143 |
-
|
144 |
if not answer:
|
145 |
-
answer = "Sorry, I couldn't find a clear answer in the document."
|
146 |
-
|
147 |
except Exception as e:
|
148 |
answer = f"An error occurred: {str(e)}"
|
149 |
-
|
150 |
history.append((question, answer))
|
151 |
return "", history
|
152 |
|
153 |
# -------------------------
|
154 |
-
#
|
155 |
# -------------------------
|
156 |
with gr.Blocks() as demo:
|
157 |
-
gr.Markdown("## 📘
|
158 |
with gr.Row():
|
159 |
file_input = gr.File(label="Upload PDF or Word", file_types=[".pdf", ".docx"])
|
160 |
with gr.Row():
|
161 |
-
chatbot = gr.Chatbot(height=
|
162 |
with gr.Row():
|
163 |
-
question = gr.Textbox(label="Ask your question", placeholder="
|
|
|
164 |
state = gr.State([])
|
165 |
|
166 |
question.submit(
|
167 |
-
ask_question,
|
168 |
-
[file_input, question, state],
|
169 |
[question, chatbot]
|
170 |
)
|
171 |
|
|
|
2 |
from PyPDF2 import PdfReader
|
3 |
import docx
|
4 |
from sentence_transformers import SentenceTransformer, util
|
5 |
+
from transformers import pipeline
|
6 |
import re
|
7 |
import torch
|
8 |
|
|
|
12 |
embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
13 |
|
14 |
qa_pipeline = pipeline(
|
15 |
+
"question-answering",
|
16 |
model="distilbert-base-cased-distilled-squad",
|
17 |
device=0 if torch.cuda.is_available() else -1
|
18 |
)
|
19 |
|
|
|
|
|
|
|
|
|
|
|
20 |
# -------------------------
|
21 |
+
# Text utilities
|
22 |
# -------------------------
|
23 |
+
SENT_SPLIT_RE = re.compile(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s')
|
24 |
+
|
25 |
+
def normalize_ws(text: str) -> str:
|
26 |
+
text = re.sub(r'[ \t]+', ' ', text)
|
27 |
+
text = re.sub(r'\s*\n\s*', '\n', text)
|
28 |
+
return text.strip()
|
29 |
+
|
30 |
+
def extract_text(file) -> str:
|
31 |
+
if file.name.lower().endswith(".pdf"):
|
32 |
text = "\n".join([page.extract_text() or "" for page in PdfReader(file).pages])
|
33 |
+
elif file.name.lower().endswith(".docx"):
|
34 |
text = "\n".join([p.text for p in docx.Document(file).paragraphs])
|
35 |
else:
|
36 |
return ""
|
37 |
+
return normalize_ws(text)
|
|
|
38 |
|
39 |
+
def split_sentences(text: str):
|
40 |
+
parts = SENT_SPLIT_RE.split(text)
|
41 |
+
# Merge very short fragments with neighbors
|
42 |
+
out = []
|
43 |
+
buf = ""
|
44 |
+
for p in parts:
|
45 |
+
if len(p.strip()) < 40:
|
46 |
+
buf += (" " if buf else "") + p.strip()
|
47 |
else:
|
48 |
+
if buf:
|
49 |
+
out.append(buf.strip())
|
50 |
+
buf = ""
|
51 |
+
out.append(p.strip())
|
52 |
+
if buf:
|
53 |
+
out.append(buf.strip())
|
54 |
+
return [s for s in out if s]
|
55 |
|
56 |
+
def chunk_by_chars(sentences, chunk_char_limit=900, overlap_sents=1):
|
57 |
+
chunks, cur, cur_len = [], [], 0
|
58 |
+
for s in sentences:
|
59 |
+
if cur_len + len(s) + 1 <= chunk_char_limit:
|
60 |
+
cur.append(s); cur_len += len(s) + 1
|
61 |
+
else:
|
62 |
+
if cur:
|
63 |
+
chunks.append(" ".join(cur))
|
64 |
+
# overlap for context
|
65 |
+
cur = cur[-overlap_sents:] + [s]
|
66 |
+
cur_len = sum(len(x) + 1 for x in cur)
|
67 |
+
else:
|
68 |
+
# extremely long sentence, hard cut
|
69 |
+
chunks.append(s[:chunk_char_limit])
|
70 |
+
cur, cur_len = [], 0
|
71 |
+
if cur:
|
72 |
+
chunks.append(" ".join(cur))
|
73 |
return chunks
|
74 |
|
75 |
+
def clean_answer(text: str) -> str:
|
76 |
+
# Remove obvious footer/contact lines or emails/urls/phones
|
77 |
+
text = re.sub(r'\bIf you have any questions.*', '', text, flags=re.IGNORECASE)
|
78 |
+
text = re.sub(r'[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}', '', text)
|
79 |
+
text = re.sub(r'https?://\S+|www\.\S+', '', text)
|
80 |
+
text = re.sub(r'\b(?:Tel|Phone|Cell|Contact)\b.*', '', text, flags=re.IGNORECASE)
|
81 |
+
return normalize_ws(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
# -------------------------
|
84 |
+
# Definition finder (for "what is / define" questions)
|
85 |
# -------------------------
|
86 |
+
DEF_PATTERNS = [
|
87 |
+
r'\b(system)\s+is\s+(?:an?|the)\s+[^.]+?\.', # "system is a/the ..."
|
88 |
+
r'\b(system)\s+refers\s+to\s+[^.]+?\.', # "system refers to ..."
|
89 |
+
r'\b(system)\s+can\s+be\s+defined\s+as\s+[^.]+?\.', # "system can be defined as ..."
|
90 |
+
r'\b(system)\s+consists\s+of\s+[^.]+?\.', # "system consists of ..."
|
91 |
+
]
|
92 |
+
|
93 |
+
KEYWORDS_BONUS = {"interrelated", "components", "objective", "objectives",
|
94 |
+
"environment", "inputs", "outputs", "communication", "function"}
|
95 |
+
|
96 |
+
def find_definition_sentences(text: str, term: str = "system"):
|
97 |
+
sentences = split_sentences(text)
|
98 |
+
cand = []
|
99 |
+
term_lc = term.lower()
|
100 |
+
for s in sentences:
|
101 |
+
s_lc = s.lower()
|
102 |
+
if term_lc not in s_lc:
|
103 |
+
continue
|
104 |
+
matched = any(re.search(pat.replace("system", term_lc), s_lc) for pat in DEF_PATTERNS)
|
105 |
+
if matched:
|
106 |
+
score = sum(1 for k in KEYWORDS_BONUS if k in s_lc)
|
107 |
+
cand.append((score, s.strip()))
|
108 |
+
if not cand:
|
109 |
+
# fallback: sentences with term + several keywords
|
110 |
+
for s in sentences:
|
111 |
+
s_lc = s.lower()
|
112 |
+
if term_lc in s_lc:
|
113 |
+
score = sum(1 for k in KEYWORDS_BONUS if k in s_lc)
|
114 |
+
if score >= 2:
|
115 |
+
cand.append((score, s.strip()))
|
116 |
+
if not cand:
|
117 |
+
return None
|
118 |
+
cand.sort(key=lambda x: (-x[0], len(x[1])))
|
119 |
+
return cand[0][1]
|
120 |
+
|
121 |
+
# -------------------------
|
122 |
+
# Retrieval helpers
|
123 |
+
# -------------------------
|
124 |
+
def select_top_chunks(chunks, question, top_k=3):
|
125 |
+
emb_chunks = embedder.encode(chunks, convert_to_tensor=True, normalize_embeddings=True)
|
126 |
+
emb_q = embedder.encode([question], convert_to_tensor=True, normalize_embeddings=True)
|
127 |
+
sims = util.cos_sim(emb_q, emb_chunks)[0] # shape [num_chunks]
|
128 |
+
top_k = min(top_k, len(chunks))
|
129 |
+
top_idx = torch.topk(sims, k=top_k).indices.tolist()
|
130 |
+
return [chunks[i] for i in top_idx], sims.max().item()
|
131 |
+
|
132 |
+
# -------------------------
|
133 |
+
# Main QA logic
|
134 |
+
# -------------------------
|
135 |
+
def answer_from_chunks(question: str, chunks: list, strict_extractive=True):
|
136 |
+
"""
|
137 |
+
Try QA over the best chunk(s). If strict_extractive, return the extractive span only.
|
138 |
+
We'll query the best chunk first; if low score, concatenate top-3 chunks and retry.
|
139 |
+
"""
|
140 |
+
if not chunks:
|
141 |
+
return None
|
142 |
+
|
143 |
+
# Best single chunk first
|
144 |
+
result = qa_pipeline(question=question, context=chunks[0])
|
145 |
+
best_answer, best_score = result.get("answer", ""), result.get("score", 0.0)
|
146 |
+
|
147 |
+
# If weak, try merged top chunks
|
148 |
+
if best_score < 0.25 and len(chunks) > 1:
|
149 |
+
merged = " ".join(chunks)
|
150 |
+
result2 = qa_pipeline(question=question, context=merged)
|
151 |
+
if result2.get("score", 0.0) > best_score:
|
152 |
+
best_answer, best_score = result2["answer"], result2["score"]
|
153 |
+
|
154 |
+
if best_score < 0.15 or len(best_answer.strip()) < 2:
|
155 |
+
return None
|
156 |
+
|
157 |
+
ans = best_answer.strip()
|
158 |
+
# keep it extractive and clean
|
159 |
+
ans = clean_answer(ans)
|
160 |
+
if strict_extractive:
|
161 |
+
# ensure it's a concise span (avoid run-on junk)
|
162 |
+
ans = re.split(r'[\n\r]', ans)[0].strip()
|
163 |
+
return ans or None
|
164 |
+
|
165 |
+
# -------------------------
|
166 |
+
# Gradio callback
|
167 |
+
# -------------------------
|
168 |
+
def ask_question(file, question, history, strict_extractive=True):
|
169 |
if not file:
|
170 |
return "Please upload a file.", history
|
171 |
+
if not question or not question.strip():
|
172 |
+
return "Please type a question.", history
|
173 |
|
174 |
text = extract_text(file)
|
175 |
if not text:
|
176 |
return "Could not extract text from the file.", history
|
177 |
+
|
178 |
+
sentences = split_sentences(text)
|
179 |
+
chunks = chunk_by_chars(sentences, chunk_char_limit=900, overlap_sents=1)
|
180 |
+
|
181 |
+
q_norm = question.lower().strip(" ?!")
|
182 |
+
|
|
|
|
|
|
|
183 |
try:
|
184 |
+
# 1) Prefer a definition for "what is/define ..." style questions
|
185 |
+
if re.search(r'\b(what\s+is|define|definition of)\b', q_norm) and "system" in q_norm:
|
186 |
+
defin = find_definition_sentences(text, term="system")
|
187 |
+
if defin:
|
188 |
+
answer = clean_answer(defin)
|
189 |
+
history.append((question, answer))
|
190 |
+
return "", history
|
191 |
+
|
192 |
+
# 2) Retrieval + extractive QA
|
193 |
+
top_chunks, max_sim = select_top_chunks(chunks, question, top_k=3)
|
194 |
+
answer = answer_from_chunks(question, top_chunks, strict_extractive=strict_extractive)
|
195 |
+
|
196 |
+
# 3) If still nothing, try a simpler sentence retrieval: pick the most relevant sentence
|
197 |
if not answer:
|
198 |
+
emb_sents = embedder.encode(sentences, convert_to_tensor=True, normalize_embeddings=True)
|
199 |
+
emb_q = embedder.encode([question], convert_to_tensor=True, normalize_embeddings=True)
|
200 |
+
sims = util.cos_sim(emb_q, emb_sents)[0]
|
201 |
+
best_i = int(torch.argmax(sims).item())
|
202 |
+
if sims[best_i].item() > 0.2:
|
203 |
+
answer = clean_answer(sentences[best_i])
|
204 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
if not answer:
|
206 |
+
answer = "Sorry, I couldn't find a clear, grounded answer in the document."
|
207 |
+
|
208 |
except Exception as e:
|
209 |
answer = f"An error occurred: {str(e)}"
|
210 |
+
|
211 |
history.append((question, answer))
|
212 |
return "", history
|
213 |
|
214 |
# -------------------------
|
215 |
+
# UI
|
216 |
# -------------------------
|
217 |
with gr.Blocks() as demo:
|
218 |
+
gr.Markdown("## 📘 Document QA — Strict Extractive (No Hallucinations)")
|
219 |
with gr.Row():
|
220 |
file_input = gr.File(label="Upload PDF or Word", file_types=[".pdf", ".docx"])
|
221 |
with gr.Row():
|
222 |
+
chatbot = gr.Chatbot(height=420)
|
223 |
with gr.Row():
|
224 |
+
question = gr.Textbox(label="Ask your question", placeholder="e.g., What is a system?")
|
225 |
+
strict = gr.Checkbox(value=True, label="Strict extractive only (recommended)")
|
226 |
state = gr.State([])
|
227 |
|
228 |
question.submit(
|
229 |
+
ask_question,
|
230 |
+
[file_input, question, state, strict],
|
231 |
[question, chatbot]
|
232 |
)
|
233 |
|