Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -6,46 +6,54 @@ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
|
|
6 |
import re
|
7 |
import torch
|
8 |
|
|
|
9 |
# Load models
|
|
|
10 |
embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
11 |
-
qa_pipeline = pipeline("question-answering",
|
12 |
-
model="distilbert-base-cased-distilled-squad",
|
13 |
-
device=0 if torch.cuda.is_available() else -1)
|
14 |
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
gpt_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
17 |
gpt_model = AutoModelForCausalLM.from_pretrained("gpt2")
|
18 |
gpt_model.eval()
|
19 |
|
|
|
|
|
|
|
20 |
def extract_text(file):
|
|
|
21 |
if file.name.endswith(".pdf"):
|
22 |
text = "\n".join([page.extract_text() or "" for page in PdfReader(file).pages])
|
23 |
elif file.name.endswith(".docx"):
|
24 |
text = "\n".join([p.text for p in docx.Document(file).paragraphs])
|
25 |
else:
|
26 |
return ""
|
27 |
-
|
28 |
-
text = re.sub(r'\s+', ' ', text) # Replace multiple whitespace with single space
|
29 |
return text.strip()
|
30 |
|
31 |
def chunk_text(text, chunk_size=500, overlap=100):
|
|
|
32 |
sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
|
33 |
-
chunks = []
|
34 |
-
|
35 |
-
|
36 |
for sent in sentences:
|
37 |
if len(current_chunk) + len(sent) < chunk_size:
|
38 |
current_chunk += sent + " "
|
39 |
else:
|
40 |
chunks.append(current_chunk.strip())
|
41 |
-
# Keep some overlap between chunks for context
|
42 |
current_chunk = current_chunk[-overlap:] + sent + " "
|
43 |
-
|
44 |
if current_chunk:
|
45 |
chunks.append(current_chunk.strip())
|
46 |
return chunks
|
47 |
|
48 |
def generate_with_gpt(prompt, max_length=150):
|
|
|
49 |
inputs = gpt_tokenizer(prompt, return_tensors="pt")
|
50 |
with torch.no_grad():
|
51 |
outputs = gpt_model.generate(
|
@@ -60,65 +68,20 @@ def generate_with_gpt(prompt, max_length=150):
|
|
60 |
)
|
61 |
return gpt_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
62 |
|
63 |
-
def
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
if not chunks:
|
73 |
-
return "No meaningful text chunks could be created.", history
|
74 |
-
|
75 |
-
# Initialize answer as None
|
76 |
-
answer = None
|
77 |
-
|
78 |
-
try:
|
79 |
-
# Normalize question for better matching
|
80 |
-
normalized_question = question.lower().strip(" ?")
|
81 |
-
|
82 |
-
# First try to find direct definitions
|
83 |
-
if "artificial system" in normalized_question:
|
84 |
-
answer = extract_direct_definition(text, "artificial system")
|
85 |
-
elif "natural system" in normalized_question:
|
86 |
-
answer = extract_direct_definition(text, "natural system")
|
87 |
-
elif "component" in normalized_question:
|
88 |
-
answer = extract_direct_definition(text, "component")
|
89 |
-
|
90 |
-
# If no direct definition found, use semantic search
|
91 |
-
if not answer:
|
92 |
-
emb_chunks = embedder.encode(chunks, convert_to_tensor=True)
|
93 |
-
emb_question = embedder.encode(question, convert_to_tensor=True)
|
94 |
-
scores = util.pytorch_cos_sim(emb_question, emb_chunks)[0]
|
95 |
-
best_idx = scores.argmax().item()
|
96 |
-
best_chunk = chunks[best_idx]
|
97 |
-
|
98 |
-
if scores[best_idx] < 0.3: # Low confidence
|
99 |
-
top_k = min(3, len(chunks))
|
100 |
-
best_indices = scores.topk(top_k).indices.tolist()
|
101 |
-
best_chunk = " ".join([chunks[i] for i in best_indices])
|
102 |
-
|
103 |
-
result = qa_pipeline(question=question, context=best_chunk)
|
104 |
-
if result["score"] > 0.1 and len(result["answer"].split()) >= 2:
|
105 |
-
answer = result["answer"]
|
106 |
-
|
107 |
-
# Final fallback if no answer found
|
108 |
-
if not answer:
|
109 |
-
answer = "Sorry, I couldn't find a clear answer in the document."
|
110 |
-
|
111 |
-
except Exception as e:
|
112 |
-
answer = f"An error occurred: {str(e)}"
|
113 |
-
|
114 |
-
history.append((question, answer))
|
115 |
-
return "", history
|
116 |
|
117 |
def extract_direct_definition(text, term):
|
118 |
-
"""
|
119 |
sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
|
120 |
term = term.lower()
|
121 |
-
|
122 |
candidates = []
|
123 |
for sent in sentences:
|
124 |
lower_sent = sent.lower()
|
@@ -126,11 +89,13 @@ def extract_direct_definition(text, term):
|
|
126 |
if (" is " in lower_sent or " are " in lower_sent or
|
127 |
" refers to " in lower_sent or " defined as " in lower_sent):
|
128 |
candidates.append(sent)
|
129 |
-
|
130 |
if candidates:
|
131 |
return candidates[0]
|
132 |
return None
|
133 |
|
|
|
|
|
|
|
134 |
def ask_question(file, question, history):
|
135 |
if not file:
|
136 |
return "Please upload a file.", history
|
@@ -143,11 +108,12 @@ def ask_question(file, question, history):
|
|
143 |
if not chunks:
|
144 |
return "No meaningful text chunks could be created.", history
|
145 |
|
146 |
-
#
|
|
|
147 |
normalized_question = question.lower().strip(" ?")
|
148 |
|
149 |
try:
|
150 |
-
#
|
151 |
if "artificial system" in normalized_question:
|
152 |
answer = extract_direct_definition(text, "artificial system")
|
153 |
elif "natural system" in normalized_question:
|
@@ -155,7 +121,7 @@ def ask_question(file, question, history):
|
|
155 |
elif "component" in normalized_question:
|
156 |
answer = extract_direct_definition(text, "component")
|
157 |
|
158 |
-
# If no direct definition
|
159 |
if not answer:
|
160 |
emb_chunks = embedder.encode(chunks, convert_to_tensor=True)
|
161 |
emb_question = embedder.encode(question, convert_to_tensor=True)
|
@@ -163,32 +129,32 @@ def ask_question(file, question, history):
|
|
163 |
best_idx = scores.argmax().item()
|
164 |
best_chunk = chunks[best_idx]
|
165 |
|
166 |
-
#
|
167 |
if scores[best_idx] < 0.3:
|
168 |
top_k = min(3, len(chunks))
|
169 |
best_indices = scores.topk(top_k).indices.tolist()
|
170 |
best_chunk = " ".join([chunks[i] for i in best_indices])
|
171 |
|
172 |
-
# Get initial answer from QA model
|
173 |
result = qa_pipeline(question=question, context=best_chunk)
|
174 |
answer = result["answer"] if result["score"] > 0.1 else None
|
175 |
|
176 |
-
# Refine answer with GPT if available
|
177 |
if answer and len(answer.split()) > 2:
|
178 |
answer = refine_answer_with_gpt(best_chunk, question, answer)
|
179 |
|
180 |
-
# Final fallback
|
181 |
if not answer:
|
182 |
answer = "Sorry, I couldn't find a clear answer in the document."
|
183 |
-
|
184 |
except Exception as e:
|
185 |
answer = f"An error occurred: {str(e)}"
|
186 |
|
187 |
history.append((question, answer))
|
188 |
return "", history
|
189 |
|
|
|
|
|
|
|
190 |
with gr.Blocks() as demo:
|
191 |
-
gr.Markdown("## Enhanced Document QA with GPT Integration")
|
192 |
with gr.Row():
|
193 |
file_input = gr.File(label="Upload PDF or Word", file_types=[".pdf", ".docx"])
|
194 |
with gr.Row():
|
@@ -196,7 +162,7 @@ with gr.Blocks() as demo:
|
|
196 |
with gr.Row():
|
197 |
question = gr.Textbox(label="Ask your question", placeholder="Type your question here...")
|
198 |
state = gr.State([])
|
199 |
-
|
200 |
question.submit(
|
201 |
ask_question,
|
202 |
[file_input, question, state],
|
|
|
6 |
import re
|
7 |
import torch
|
8 |
|
9 |
+
# -------------------------
|
10 |
# Load models
|
11 |
+
# -------------------------
|
12 |
embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
|
|
|
|
|
|
13 |
|
14 |
+
qa_pipeline = pipeline(
|
15 |
+
"question-answering",
|
16 |
+
model="distilbert-base-cased-distilled-squad",
|
17 |
+
device=0 if torch.cuda.is_available() else -1
|
18 |
+
)
|
19 |
+
|
20 |
+
# GPT model (using GPT-2 here – replace with better model if you have)
|
21 |
gpt_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
22 |
gpt_model = AutoModelForCausalLM.from_pretrained("gpt2")
|
23 |
gpt_model.eval()
|
24 |
|
25 |
+
# -------------------------
|
26 |
+
# Helper functions
|
27 |
+
# -------------------------
|
28 |
def extract_text(file):
|
29 |
+
"""Extract text from PDF or DOCX"""
|
30 |
if file.name.endswith(".pdf"):
|
31 |
text = "\n".join([page.extract_text() or "" for page in PdfReader(file).pages])
|
32 |
elif file.name.endswith(".docx"):
|
33 |
text = "\n".join([p.text for p in docx.Document(file).paragraphs])
|
34 |
else:
|
35 |
return ""
|
36 |
+
text = re.sub(r'\s+', ' ', text) # clean whitespace
|
|
|
37 |
return text.strip()
|
38 |
|
39 |
def chunk_text(text, chunk_size=500, overlap=100):
|
40 |
+
"""Split text into overlapping chunks"""
|
41 |
sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
|
42 |
+
chunks, current_chunk = [], ""
|
43 |
+
|
|
|
44 |
for sent in sentences:
|
45 |
if len(current_chunk) + len(sent) < chunk_size:
|
46 |
current_chunk += sent + " "
|
47 |
else:
|
48 |
chunks.append(current_chunk.strip())
|
|
|
49 |
current_chunk = current_chunk[-overlap:] + sent + " "
|
50 |
+
|
51 |
if current_chunk:
|
52 |
chunks.append(current_chunk.strip())
|
53 |
return chunks
|
54 |
|
55 |
def generate_with_gpt(prompt, max_length=150):
|
56 |
+
"""Generate text with GPT model"""
|
57 |
inputs = gpt_tokenizer(prompt, return_tensors="pt")
|
58 |
with torch.no_grad():
|
59 |
outputs = gpt_model.generate(
|
|
|
68 |
)
|
69 |
return gpt_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
70 |
|
71 |
+
def refine_answer_with_gpt(context, question, answer):
|
72 |
+
"""Ask GPT to refine the QA model answer"""
|
73 |
+
prompt = (
|
74 |
+
f"Context: {context}\n\n"
|
75 |
+
f"Question: {question}\n\n"
|
76 |
+
f"Answer: {answer}\n\n"
|
77 |
+
f"Please provide a clearer and more complete answer in simple language."
|
78 |
+
)
|
79 |
+
return generate_with_gpt(prompt, max_length=120)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
def extract_direct_definition(text, term):
|
82 |
+
"""Find a direct definition of a term in the text"""
|
83 |
sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
|
84 |
term = term.lower()
|
|
|
85 |
candidates = []
|
86 |
for sent in sentences:
|
87 |
lower_sent = sent.lower()
|
|
|
89 |
if (" is " in lower_sent or " are " in lower_sent or
|
90 |
" refers to " in lower_sent or " defined as " in lower_sent):
|
91 |
candidates.append(sent)
|
|
|
92 |
if candidates:
|
93 |
return candidates[0]
|
94 |
return None
|
95 |
|
96 |
+
# -------------------------
|
97 |
+
# Main QA function
|
98 |
+
# -------------------------
|
99 |
def ask_question(file, question, history):
|
100 |
if not file:
|
101 |
return "Please upload a file.", history
|
|
|
108 |
if not chunks:
|
109 |
return "No meaningful text chunks could be created.", history
|
110 |
|
111 |
+
# Initialize answer
|
112 |
+
answer = None
|
113 |
normalized_question = question.lower().strip(" ?")
|
114 |
|
115 |
try:
|
116 |
+
# Try direct definition
|
117 |
if "artificial system" in normalized_question:
|
118 |
answer = extract_direct_definition(text, "artificial system")
|
119 |
elif "natural system" in normalized_question:
|
|
|
121 |
elif "component" in normalized_question:
|
122 |
answer = extract_direct_definition(text, "component")
|
123 |
|
124 |
+
# If no direct definition, do semantic search + QA
|
125 |
if not answer:
|
126 |
emb_chunks = embedder.encode(chunks, convert_to_tensor=True)
|
127 |
emb_question = embedder.encode(question, convert_to_tensor=True)
|
|
|
129 |
best_idx = scores.argmax().item()
|
130 |
best_chunk = chunks[best_idx]
|
131 |
|
132 |
+
# Low confidence → merge top chunks
|
133 |
if scores[best_idx] < 0.3:
|
134 |
top_k = min(3, len(chunks))
|
135 |
best_indices = scores.topk(top_k).indices.tolist()
|
136 |
best_chunk = " ".join([chunks[i] for i in best_indices])
|
137 |
|
|
|
138 |
result = qa_pipeline(question=question, context=best_chunk)
|
139 |
answer = result["answer"] if result["score"] > 0.1 else None
|
140 |
|
|
|
141 |
if answer and len(answer.split()) > 2:
|
142 |
answer = refine_answer_with_gpt(best_chunk, question, answer)
|
143 |
|
|
|
144 |
if not answer:
|
145 |
answer = "Sorry, I couldn't find a clear answer in the document."
|
146 |
+
|
147 |
except Exception as e:
|
148 |
answer = f"An error occurred: {str(e)}"
|
149 |
|
150 |
history.append((question, answer))
|
151 |
return "", history
|
152 |
|
153 |
+
# -------------------------
|
154 |
+
# Gradio Interface
|
155 |
+
# -------------------------
|
156 |
with gr.Blocks() as demo:
|
157 |
+
gr.Markdown("## 📘 Enhanced Document QA with GPT Integration")
|
158 |
with gr.Row():
|
159 |
file_input = gr.File(label="Upload PDF or Word", file_types=[".pdf", ".docx"])
|
160 |
with gr.Row():
|
|
|
162 |
with gr.Row():
|
163 |
question = gr.Textbox(label="Ask your question", placeholder="Type your question here...")
|
164 |
state = gr.State([])
|
165 |
+
|
166 |
question.submit(
|
167 |
ask_question,
|
168 |
[file_input, question, state],
|