pendar02 commited on
Commit
883d34a
·
verified ·
1 Parent(s): d0820e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -7
app.py CHANGED
@@ -151,11 +151,18 @@ def post_process_summary(summary):
151
  return cleaned_summary
152
 
153
  def improve_summary_generation(text, model, tokenizer):
 
 
 
 
154
  # Add a more specific prompt
155
  formatted_text = (
156
- "Summarize the following medical research paper, focusing on: "
157
- "1) Study objectives 2) Methods 3) Key findings 4) Main conclusions. "
158
- "Text: " + preprocess_text(text)
 
 
 
159
  )
160
 
161
  # Adjust generation parameters
@@ -179,11 +186,76 @@ def improve_summary_generation(text, model, tokenizer):
179
 
180
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  def post_process_summary(summary):
183
  """Enhanced post-processing to catch common errors"""
184
  if not summary:
185
  return summary
186
-
187
  # Remove contradictory age statements
188
  age_statements = []
189
  lines = summary.split('.')
@@ -199,8 +271,21 @@ def post_process_summary(summary):
199
  seen_content = set()
200
  unique_lines = []
201
  for line in cleaned_lines:
202
- line_core = ' '.join(sorted(line.lower().split())) # Normalize for comparison
203
- if line_core not in seen_content:
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  seen_content.add(line_core)
205
  unique_lines.append(line)
206
 
@@ -208,7 +293,13 @@ def post_process_summary(summary):
208
  cleaned_summary = '. '.join(s.strip() for s in unique_lines if s.strip())
209
  if cleaned_summary and not cleaned_summary.endswith('.'):
210
  cleaned_summary += '.'
211
-
 
 
 
 
 
 
212
  return cleaned_summary
213
 
214
  def generate_focused_summary(question, abstracts, model, tokenizer):
 
151
  return cleaned_summary
152
 
153
  def improve_summary_generation(text, model, tokenizer):
154
+ """Generate improved summary with better prompt and validation"""
155
+ if not isinstance(text, str) or not text.strip():
156
+ return "No abstract available to summarize."
157
+
158
  # Add a more specific prompt
159
  formatted_text = (
160
+ "Summarize this medical research paper following this structure exactly:\n"
161
+ "1. Background and objectives\n"
162
+ "2. Methods\n"
163
+ "3. Key findings with specific numbers/percentages\n"
164
+ "4. Main conclusions\n"
165
+ "Original text: " + preprocess_text(text)
166
  )
167
 
168
  # Adjust generation parameters
 
186
 
187
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
188
 
189
+ # Post-process the summary
190
+ processed_summary = post_process_summary(summary)
191
+
192
+ # Validate the summary
193
+ if not validate_summary(processed_summary, text):
194
+ # If validation fails, try one more time with different parameters
195
+ with torch.no_grad():
196
+ summary_ids = model.generate(
197
+ **{
198
+ "input_ids": inputs["input_ids"],
199
+ "attention_mask": inputs["attention_mask"],
200
+ "max_length": 200,
201
+ "min_length": 50,
202
+ "num_beams": 4,
203
+ "length_penalty": 2.0,
204
+ "no_repeat_ngram_size": 4,
205
+ "temperature": 0.8,
206
+ "repetition_penalty": 2.0
207
+ }
208
+ )
209
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
210
+ processed_summary = post_process_summary(summary)
211
+
212
+ return processed_summary
213
+
214
+ def validate_summary(summary, original_text):
215
+ """Validate summary content against original text"""
216
+ import re
217
+
218
+ # Don't validate empty summaries
219
+ if not summary or not original_text:
220
+ return False
221
+
222
+ # Check for age inconsistencies
223
+ age_mentions = re.findall(r'(\d+\.?\d*)\s*years?', summary.lower())
224
+ if len(age_mentions) > 1: # Multiple age mentions
225
+ return False
226
+
227
+ # Check for repetitive sentences
228
+ sentences = summary.split('.')
229
+ unique_sentences = set(s.strip().lower() for s in sentences if s.strip())
230
+ if len(sentences) - len(unique_sentences) > 1: # More than one duplicate
231
+ return False
232
+
233
+ # Check summary isn't too long or too short compared to original
234
+ summary_words = len(summary.split())
235
+ original_words = len(original_text.split())
236
+ if summary_words < 20 or summary_words > original_words * 0.8:
237
+ return False
238
+
239
+ # Check for common error patterns
240
+ error_patterns = [
241
+ r'mean.*mean',
242
+ r'median.*median',
243
+ r'results.*results',
244
+ r'conclusion.*conclusion',
245
+ r'significance.*significance'
246
+ ]
247
+
248
+ for pattern in error_patterns:
249
+ if len(re.findall(pattern, summary.lower())) > 1:
250
+ return False
251
+
252
+ return True
253
+
254
  def post_process_summary(summary):
255
  """Enhanced post-processing to catch common errors"""
256
  if not summary:
257
  return summary
258
+
259
  # Remove contradictory age statements
260
  age_statements = []
261
  lines = summary.split('.')
 
271
  seen_content = set()
272
  unique_lines = []
273
  for line in cleaned_lines:
274
+ # Skip empty lines
275
+ if not line.strip():
276
+ continue
277
+
278
+ # Normalize for comparison
279
+ line_core = ' '.join(sorted(line.lower().split()))
280
+
281
+ # Check for near-duplicates
282
+ duplicate = False
283
+ for seen in seen_content:
284
+ if line_core in seen or seen in line_core:
285
+ duplicate = True
286
+ break
287
+
288
+ if not duplicate:
289
  seen_content.add(line_core)
290
  unique_lines.append(line)
291
 
 
293
  cleaned_summary = '. '.join(s.strip() for s in unique_lines if s.strip())
294
  if cleaned_summary and not cleaned_summary.endswith('.'):
295
  cleaned_summary += '.'
296
+
297
+ # Additional cleaning
298
+ cleaned_summary = cleaned_summary.replace(" and and ", " and ")
299
+ cleaned_summary = cleaned_summary.replace("results showed", "")
300
+ cleaned_summary = cleaned_summary.replace("results indicated", "")
301
+ cleaned_summary = cleaned_summary.replace(" ", " ")
302
+
303
  return cleaned_summary
304
 
305
  def generate_focused_summary(question, abstracts, model, tokenizer):