Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -151,11 +151,18 @@ def post_process_summary(summary):
|
|
151 |
return cleaned_summary
|
152 |
|
153 |
def improve_summary_generation(text, model, tokenizer):
|
|
|
|
|
|
|
|
|
154 |
# Add a more specific prompt
|
155 |
formatted_text = (
|
156 |
-
"Summarize
|
157 |
-
"1
|
158 |
-
"
|
|
|
|
|
|
|
159 |
)
|
160 |
|
161 |
# Adjust generation parameters
|
@@ -179,11 +186,76 @@ def improve_summary_generation(text, model, tokenizer):
|
|
179 |
|
180 |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
def post_process_summary(summary):
|
183 |
"""Enhanced post-processing to catch common errors"""
|
184 |
if not summary:
|
185 |
return summary
|
186 |
-
|
187 |
# Remove contradictory age statements
|
188 |
age_statements = []
|
189 |
lines = summary.split('.')
|
@@ -199,8 +271,21 @@ def post_process_summary(summary):
|
|
199 |
seen_content = set()
|
200 |
unique_lines = []
|
201 |
for line in cleaned_lines:
|
202 |
-
|
203 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
seen_content.add(line_core)
|
205 |
unique_lines.append(line)
|
206 |
|
@@ -208,7 +293,13 @@ def post_process_summary(summary):
|
|
208 |
cleaned_summary = '. '.join(s.strip() for s in unique_lines if s.strip())
|
209 |
if cleaned_summary and not cleaned_summary.endswith('.'):
|
210 |
cleaned_summary += '.'
|
211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
return cleaned_summary
|
213 |
|
214 |
def generate_focused_summary(question, abstracts, model, tokenizer):
|
|
|
151 |
return cleaned_summary
|
152 |
|
153 |
def improve_summary_generation(text, model, tokenizer):
|
154 |
+
"""Generate improved summary with better prompt and validation"""
|
155 |
+
if not isinstance(text, str) or not text.strip():
|
156 |
+
return "No abstract available to summarize."
|
157 |
+
|
158 |
# Add a more specific prompt
|
159 |
formatted_text = (
|
160 |
+
"Summarize this medical research paper following this structure exactly:\n"
|
161 |
+
"1. Background and objectives\n"
|
162 |
+
"2. Methods\n"
|
163 |
+
"3. Key findings with specific numbers/percentages\n"
|
164 |
+
"4. Main conclusions\n"
|
165 |
+
"Original text: " + preprocess_text(text)
|
166 |
)
|
167 |
|
168 |
# Adjust generation parameters
|
|
|
186 |
|
187 |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
188 |
|
189 |
+
# Post-process the summary
|
190 |
+
processed_summary = post_process_summary(summary)
|
191 |
+
|
192 |
+
# Validate the summary
|
193 |
+
if not validate_summary(processed_summary, text):
|
194 |
+
# If validation fails, try one more time with different parameters
|
195 |
+
with torch.no_grad():
|
196 |
+
summary_ids = model.generate(
|
197 |
+
**{
|
198 |
+
"input_ids": inputs["input_ids"],
|
199 |
+
"attention_mask": inputs["attention_mask"],
|
200 |
+
"max_length": 200,
|
201 |
+
"min_length": 50,
|
202 |
+
"num_beams": 4,
|
203 |
+
"length_penalty": 2.0,
|
204 |
+
"no_repeat_ngram_size": 4,
|
205 |
+
"temperature": 0.8,
|
206 |
+
"repetition_penalty": 2.0
|
207 |
+
}
|
208 |
+
)
|
209 |
+
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
210 |
+
processed_summary = post_process_summary(summary)
|
211 |
+
|
212 |
+
return processed_summary
|
213 |
+
|
214 |
+
def validate_summary(summary, original_text):
|
215 |
+
"""Validate summary content against original text"""
|
216 |
+
import re
|
217 |
+
|
218 |
+
# Don't validate empty summaries
|
219 |
+
if not summary or not original_text:
|
220 |
+
return False
|
221 |
+
|
222 |
+
# Check for age inconsistencies
|
223 |
+
age_mentions = re.findall(r'(\d+\.?\d*)\s*years?', summary.lower())
|
224 |
+
if len(age_mentions) > 1: # Multiple age mentions
|
225 |
+
return False
|
226 |
+
|
227 |
+
# Check for repetitive sentences
|
228 |
+
sentences = summary.split('.')
|
229 |
+
unique_sentences = set(s.strip().lower() for s in sentences if s.strip())
|
230 |
+
if len(sentences) - len(unique_sentences) > 1: # More than one duplicate
|
231 |
+
return False
|
232 |
+
|
233 |
+
# Check summary isn't too long or too short compared to original
|
234 |
+
summary_words = len(summary.split())
|
235 |
+
original_words = len(original_text.split())
|
236 |
+
if summary_words < 20 or summary_words > original_words * 0.8:
|
237 |
+
return False
|
238 |
+
|
239 |
+
# Check for common error patterns
|
240 |
+
error_patterns = [
|
241 |
+
r'mean.*mean',
|
242 |
+
r'median.*median',
|
243 |
+
r'results.*results',
|
244 |
+
r'conclusion.*conclusion',
|
245 |
+
r'significance.*significance'
|
246 |
+
]
|
247 |
+
|
248 |
+
for pattern in error_patterns:
|
249 |
+
if len(re.findall(pattern, summary.lower())) > 1:
|
250 |
+
return False
|
251 |
+
|
252 |
+
return True
|
253 |
+
|
254 |
def post_process_summary(summary):
|
255 |
"""Enhanced post-processing to catch common errors"""
|
256 |
if not summary:
|
257 |
return summary
|
258 |
+
|
259 |
# Remove contradictory age statements
|
260 |
age_statements = []
|
261 |
lines = summary.split('.')
|
|
|
271 |
seen_content = set()
|
272 |
unique_lines = []
|
273 |
for line in cleaned_lines:
|
274 |
+
# Skip empty lines
|
275 |
+
if not line.strip():
|
276 |
+
continue
|
277 |
+
|
278 |
+
# Normalize for comparison
|
279 |
+
line_core = ' '.join(sorted(line.lower().split()))
|
280 |
+
|
281 |
+
# Check for near-duplicates
|
282 |
+
duplicate = False
|
283 |
+
for seen in seen_content:
|
284 |
+
if line_core in seen or seen in line_core:
|
285 |
+
duplicate = True
|
286 |
+
break
|
287 |
+
|
288 |
+
if not duplicate:
|
289 |
seen_content.add(line_core)
|
290 |
unique_lines.append(line)
|
291 |
|
|
|
293 |
cleaned_summary = '. '.join(s.strip() for s in unique_lines if s.strip())
|
294 |
if cleaned_summary and not cleaned_summary.endswith('.'):
|
295 |
cleaned_summary += '.'
|
296 |
+
|
297 |
+
# Additional cleaning
|
298 |
+
cleaned_summary = cleaned_summary.replace(" and and ", " and ")
|
299 |
+
cleaned_summary = cleaned_summary.replace("results showed", "")
|
300 |
+
cleaned_summary = cleaned_summary.replace("results indicated", "")
|
301 |
+
cleaned_summary = cleaned_summary.replace(" ", " ")
|
302 |
+
|
303 |
return cleaned_summary
|
304 |
|
305 |
def generate_focused_summary(question, abstracts, model, tokenizer):
|