pendar02 commited on
Commit
005d6b8
·
verified ·
1 Parent(s): 054584c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +245 -145
app.py CHANGED
@@ -107,19 +107,57 @@ def verify_facts(summary, original_text):
107
  def extract_numbers(text):
108
  return set(re.findall(r'(\d+\.?\d*)%?', text))
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  original_numbers = extract_numbers(original_text)
111
  summary_numbers = extract_numbers(summary)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- # Check if all numbers from original are in summary
114
- missing_numbers = original_numbers - summary_numbers
115
 
116
- # Extract key phrases indicating relationships
117
  relationship_patterns = [
118
  r'associated with',
119
  r'predicted',
120
  r'correlated with',
121
  r'relationship between',
122
- r'linked to'
 
 
 
 
123
  ]
124
 
125
  def extract_relationships(text):
@@ -127,7 +165,6 @@ def verify_facts(summary, original_text):
127
  for pattern in relationship_patterns:
128
  matches = re.finditer(pattern, text.lower())
129
  for match in matches:
130
- # Get surrounding context
131
  start = max(0, match.start() - 50)
132
  end = min(len(text), match.end() + 50)
133
  relationships.append(text[start:end].strip())
@@ -139,13 +176,16 @@ def verify_facts(summary, original_text):
139
  # Check for contradictions
140
  def find_contradictions(summary, original):
141
  contradictions = []
142
- # Common contradiction patterns
143
  neg_patterns = [
144
  (r'no association', r'associated with'),
145
  (r'did not predict', r'predicted'),
146
  (r'was not significant', r'was significant'),
147
  (r'decreased', r'increased'),
148
- (r'lower', r'higher')
 
 
 
 
149
  ]
150
 
151
  for pos, neg in neg_patterns:
@@ -157,176 +197,236 @@ def verify_facts(summary, original_text):
157
 
158
  contradictions = find_contradictions(summary, original_text)
159
 
 
 
 
 
 
 
 
 
 
 
160
  return {
161
- 'missing_numbers': missing_numbers,
 
 
162
  'missing_relationships': original_relationships - summary_relationships,
 
163
  'contradictions': contradictions,
164
- 'is_valid': len(missing_numbers) == 0 and len(contradictions) == 0
 
 
 
165
  }
166
 
167
  def preprocess_text(text):
168
  """Preprocess text to add appropriate formatting before summarization"""
169
  if not isinstance(text, str) or not text.strip():
170
  return text
171
-
172
- # Split text into sentences (basic implementation)
173
- sentences = [s.strip() for s in text.replace('. ', '.\n').split('\n')]
174
-
175
- # Remove empty sentences and extra whitespace
176
- sentences = [re.sub(r'\s+', ' ', s).strip() for s in sentences if s.strip()]
177
-
178
- # Join with proper line breaks
179
- formatted_text = '\n'.join(sentences)
180
-
181
- return formatted_text
182
-
183
- def post_process_summary(summary):
184
- """Enhanced post-processing for better structure and completeness"""
185
- if not summary:
186
- return summary
187
-
188
- # Split into sections
189
- sections = summary.split('\n')
190
- processed_sections = []
191
 
192
- for section in sections:
193
- if not section.strip():
194
- continue
195
-
196
- # Remove redundant section headers
197
- section = re.sub(r'^(Background and objectives|Methods|Results|Conclusions):\s*', '', section)
198
-
199
- # Split into sentences
200
- sentences = [s.strip() for s in section.split('.')]
201
- sentences = [s for s in sentences if s]
202
-
203
- processed_sentences = []
204
- for i, sentence in enumerate(sentences):
205
- # Fix common issues
206
- sentence = re.sub(r'\s+', ' ', sentence) # Fix spacing
207
- sentence = re.sub(r'(\d+)\s*%', r'\1%', sentence) # Fix percentage formatting
208
- sentence = re.sub(r'\(\s*([Nn])\s*=\s*(\d+)\s*\)', r'(n=\2)', sentence) # Fix sample size formatting
209
-
210
- # Fix common phrase issues
211
- sentence = sentence.replace(" and and ", " and ")
212
- sentence = sentence.replace("appointment and appointment", "appointment")
213
- sentence = sentence.replace("Cancers distress", "Cancer distress")
214
-
215
- # Remove redundant phrases
216
- sentence = re.sub(r'(?i)the aim of (the|this) study was to', '', sentence)
217
- sentence = re.sub(r'(?i)this study aimed to', '', sentence)
218
-
219
- # Capitalize first letter
220
- sentence = sentence.capitalize()
221
-
222
- if sentence.strip():
223
- processed_sentences.append(sentence)
224
-
225
- if processed_sentences:
226
- section = '. '.join(processed_sentences)
227
- if not section.endswith('.'):
228
- section += '.'
229
- processed_sections.append(section)
230
 
231
- # Ensure key sections are present
232
- required_sections = ['Background and objectives', 'Methods', 'Key findings', 'Conclusions']
233
- final_sections = []
 
 
234
 
235
- for i, section in enumerate(processed_sections):
236
- if i < len(required_sections):
237
- final_sections.append(f"{required_sections[i]}: {section}")
238
- else:
239
- final_sections.append(section)
240
 
241
- return '\n\n'.join(final_sections)
242
 
243
- def improve_summary_generation(text, model, tokenizer):
244
  """Generate improved summary with better prompt and validation"""
245
  if not isinstance(text, str) or not text.strip():
246
  return "No abstract available to summarize."
247
 
248
- # Add a more specific prompt with strict guidelines
249
  formatted_text = (
250
- "Generate a precise summary of this medical research paper following these strict guidelines:\n"
251
- "1. Background and objectives: State ONLY the actual study purpose and population - no assumptions\n"
252
- "2. Methods: Include ONLY methods explicitly mentioned in the text\n"
253
- "3. Key findings: Report ALL numerical results and statistical relationships\n"
254
- "4. Conclusions: State ONLY conclusions directly supported by the reported results\n\n"
255
- "Requirements:\n"
256
- "- Include ALL percentages and numbers from the original text\n"
257
- "- Do not repeat section headers\n"
258
- "- Do not make claims beyond what's explicitly stated\n"
259
- "- Maintain the original meaning without contradiction\n"
260
- "- Do not introduce new information\n\n"
261
- "Original text: " + preprocess_text(text)
 
 
 
 
 
 
 
 
 
 
 
262
  )
263
 
264
- # Tokenize input
265
  inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True)
266
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
267
 
268
- def generate_attempt(temperature, num_beams, length_penalty):
269
- with torch.no_grad():
270
- return model.generate(
271
- **{
272
- "input_ids": inputs["input_ids"],
273
- "attention_mask": inputs["attention_mask"],
274
- "max_length": 300, # Increased to ensure all facts are included
275
- "min_length": 100, # Increased to encourage more complete summaries
276
- "num_beams": num_beams,
277
- "length_penalty": length_penalty,
278
- "no_repeat_ngram_size": 3,
279
- "temperature": temperature,
280
- "repetition_penalty": 2.0, # Increased to reduce repetition
281
- "do_sample": True # Enable sampling for more diverse outputs
282
- }
283
- )
284
-
285
- # Try different parameter combinations until we get a valid summary
286
  parameter_combinations = [
287
- {"temperature": 0.7, "num_beams": 5, "length_penalty": 1.5},
288
- {"temperature": 0.5, "num_beams": 8, "length_penalty": 2.0},
289
- {"temperature": 0.3, "num_beams": 10, "length_penalty": 2.5}
290
  ]
291
 
292
  best_summary = None
293
- best_verification = None
294
-
295
- for params in parameter_combinations:
296
- summary_ids = generate_attempt(**params)
297
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
298
- processed_summary = post_process_summary(summary)
299
-
300
- # Verify facts in the summary
301
- verification = verify_facts(processed_summary, text)
302
-
303
- if verification['is_valid']:
304
- return processed_summary
305
-
306
- # Keep track of best attempt
307
- if best_verification is None or \
308
- len(verification['missing_numbers']) < len(best_verification['missing_numbers']):
309
- best_summary = processed_summary
310
- best_verification = verification
311
-
312
- # If no perfect summary was generated, use the best attempt
313
- # Add missing information if necessary
314
- if best_verification and best_verification['missing_numbers']:
315
- # Attempt to add missing numerical information
316
- additional_info = []
317
- original_sentences = text.split('.')
318
- for num in best_verification['missing_numbers']:
319
- # Find sentences containing the missing number
320
- for sentence in original_sentences:
321
- if str(num) in sentence:
322
- additional_info.append(sentence.strip())
323
- break
 
324
 
325
- if additional_info:
326
- best_summary += "\n\nAdditional key findings: " + ". ".join(additional_info) + "."
 
 
 
 
 
 
327
 
328
  return best_summary
329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  def validate_summary(summary, original_text):
331
  """Validate summary content against original text"""
332
  # Perform fact verification
 
107
  def extract_numbers(text):
108
  return set(re.findall(r'(\d+\.?\d*)%?', text))
109
 
110
+ # Extract statistical significance statements
111
+ def extract_significance(text):
112
+ patterns = [
113
+ r'[pP][\s-]value.*?(?:=|was|of)\s*([<>]?\s*\d+\.?\d*)',
114
+ r'significant(?:ly)?\s+(?:difference|increase|decrease|change|association)',
115
+ r'statistical(?:ly)?\s+significant',
116
+ r'[pP]\s*[<>]\s*\d+\.?\d*'
117
+ ]
118
+ findings = []
119
+ for pattern in patterns:
120
+ matches = re.finditer(pattern, text, re.IGNORECASE)
121
+ for match in matches:
122
+ # Get surrounding context
123
+ start = max(0, match.start() - 50)
124
+ end = min(len(text), match.end() + 50)
125
+ findings.append(text[start:end].strip())
126
+ return set(findings)
127
+
128
  original_numbers = extract_numbers(original_text)
129
  summary_numbers = extract_numbers(summary)
130
+ original_significance = extract_significance(original_text)
131
+ summary_significance = extract_significance(summary)
132
+
133
+ # Check for temporal sequence preservation
134
+ def extract_temporal_markers(text):
135
+ markers = [
136
+ r'(?:after|following|within)\s+(\d+)\s*(?:weeks?|months?|years?)',
137
+ r'at\s+(\d+)\s*(?:weeks?|months?|years?)',
138
+ r'(?:baseline|initial|follow-up|final)'
139
+ ]
140
+ sequence = []
141
+ for pattern in markers:
142
+ matches = re.finditer(pattern, text, re.IGNORECASE)
143
+ for match in matches:
144
+ sequence.append(match.group())
145
+ return sequence
146
 
147
+ original_sequence = extract_temporal_markers(original_text)
148
+ summary_sequence = extract_temporal_markers(summary)
149
 
150
+ # Extract relationships
151
  relationship_patterns = [
152
  r'associated with',
153
  r'predicted',
154
  r'correlated with',
155
  r'relationship between',
156
+ r'linked to',
157
+ r'impact(ed)? on',
158
+ r'effect(ed)? on',
159
+ r'influenced?',
160
+ r'dependent on'
161
  ]
162
 
163
  def extract_relationships(text):
 
165
  for pattern in relationship_patterns:
166
  matches = re.finditer(pattern, text.lower())
167
  for match in matches:
 
168
  start = max(0, match.start() - 50)
169
  end = min(len(text), match.end() + 50)
170
  relationships.append(text[start:end].strip())
 
176
  # Check for contradictions
177
  def find_contradictions(summary, original):
178
  contradictions = []
 
179
  neg_patterns = [
180
  (r'no association', r'associated with'),
181
  (r'did not predict', r'predicted'),
182
  (r'was not significant', r'was significant'),
183
  (r'decreased', r'increased'),
184
+ (r'lower', r'higher'),
185
+ (r'negative', r'positive'),
186
+ (r'no effect', r'had effect'),
187
+ (r'no difference', r'difference'),
188
+ (r'no change', r'changed')
189
  ]
190
 
191
  for pos, neg in neg_patterns:
 
197
 
198
  contradictions = find_contradictions(summary, original_text)
199
 
200
+ # Check for internal consistency
201
+ def check_internal_consistency(summary):
202
+ inconsistencies = []
203
+ # Check for contradicting statements within the summary
204
+ for pos, neg in find_contradictions(summary, summary):
205
+ inconsistencies.append(f"Internal contradiction: {pos} vs {neg}")
206
+ return inconsistencies
207
+
208
+ internal_inconsistencies = check_internal_consistency(summary)
209
+
210
  return {
211
+ 'missing_numbers': original_numbers - summary_numbers,
212
+ 'incorrect_numbers': summary_numbers - original_numbers,
213
+ 'missing_significance': original_significance - summary_significance,
214
  'missing_relationships': original_relationships - summary_relationships,
215
+ 'temporal_sequence_preserved': all(marker in ' '.join(summary_sequence) for marker in original_sequence),
216
  'contradictions': contradictions,
217
+ 'internal_inconsistencies': internal_inconsistencies,
218
+ 'is_valid': (len(original_numbers - summary_numbers) == 0 and
219
+ len(contradictions) == 0 and
220
+ len(internal_inconsistencies) == 0)
221
  }
222
 
223
  def preprocess_text(text):
224
  """Preprocess text to add appropriate formatting before summarization"""
225
  if not isinstance(text, str) or not text.strip():
226
  return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
+ # Standardize spacing and line breaks
229
+ text = re.sub(r'\s+', ' ', text)
230
+ text = text.replace('. ', '.\n')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
+ # Fix common formatting issues
233
+ text = re.sub(r'(?<=[.!?])\s*(?=[A-Z])', '\n', text) # Add breaks after sentences
234
+ text = re.sub(r'\(\s*([Nn])\s*=\s*(\d+)\s*\)', r'(n=\2)', text) # Standardize sample size format
235
+ text = re.sub(r'(\d+)\s*%', r'\1%', text) # Fix percentage format
236
+ text = re.sub(r'([Pp])\s*([<>])\s*(\d)', r'\1\2\3', text) # Fix p-value format
237
 
238
+ # Split into sentences and clean each
239
+ sentences = [s.strip() for s in text.split('\n')]
240
+ sentences = [s for s in sentences if s]
 
 
241
 
242
+ return '\n'.join(sentences)
243
 
244
+ def improve_summary_generation(text, model, tokenizer, max_attempts=3):
245
  """Generate improved summary with better prompt and validation"""
246
  if not isinstance(text, str) or not text.strip():
247
  return "No abstract available to summarize."
248
 
 
249
  formatted_text = (
250
+ "Summarize this medical research paper, strictly following these rules:\n\n"
251
+ "1. Background and objectives:\n"
252
+ " - State ONLY the main purpose and study population\n"
253
+ " - Include sample size if mentioned (format as n=X)\n"
254
+ " - No methodology details here\n\n"
255
+ "2. Methods:\n"
256
+ " - List the specific procedures and measurements used\n"
257
+ " - Include timeframes and follow-up periods\n"
258
+ " - No results here\n\n"
259
+ "3. Key findings:\n"
260
+ " - Report ALL numerical results (%, numbers, p-values)\n"
261
+ " - Include ALL statistical relationships\n"
262
+ " - Present findings in chronological order\n\n"
263
+ "4. Conclusions:\n"
264
+ " - State ONLY conclusions directly supported by the results\n"
265
+ " - Include practical implications if mentioned\n"
266
+ " - No new information\n\n"
267
+ "Important:\n"
268
+ "- Keep each section separate and clearly labeled\n"
269
+ "- Use exact numbers from the text\n"
270
+ "- Maintain original relationships between variables\n"
271
+ "- No speculation or external information\n\n"
272
+ "Original text:\n" + preprocess_text(text)
273
  )
274
 
 
275
  inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True)
276
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
277
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  parameter_combinations = [
279
+ {"temperature": 0.1, "num_beams": 12, "length_penalty": 2.0, "top_k": 50},
280
+ {"temperature": 0.05, "num_beams": 15, "length_penalty": 2.5, "top_k": 30},
281
+ {"temperature": 0.0, "num_beams": 20, "length_penalty": 3.0, "top_k": 10}
282
  ]
283
 
284
  best_summary = None
285
+ best_score = -1
286
+ attempts = 0
287
+
288
+ while attempts < max_attempts:
289
+ for params in parameter_combinations:
290
+ with torch.no_grad():
291
+ summary_ids = model.generate(
292
+ **{
293
+ "input_ids": inputs["input_ids"],
294
+ "attention_mask": inputs["attention_mask"],
295
+ "max_length": 300,
296
+ "min_length": 100,
297
+ "num_beams": params["num_beams"],
298
+ "length_penalty": params["length_penalty"],
299
+ "no_repeat_ngram_size": 3,
300
+ "temperature": params["temperature"],
301
+ "top_k": params["top_k"],
302
+ "repetition_penalty": 2.5,
303
+ "do_sample": params["temperature"] > 0.0
304
+ }
305
+ )
306
+
307
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
308
+ processed_summary = post_process_summary(summary)
309
+ score = score_summary(processed_summary, text)
310
+
311
+ if score > best_score:
312
+ best_summary = processed_summary
313
+ best_score = score
314
+
315
+ if score > 0.8: # Good enough threshold
316
+ return best_summary
317
 
318
+ attempts += 1
319
+ # Adjust parameters for next attempt if needed
320
+ parameter_combinations = [
321
+ {**params,
322
+ "num_beams": params["num_beams"] + 5,
323
+ "length_penalty": params["length_penalty"] + 0.5}
324
+ for params in parameter_combinations
325
+ ]
326
 
327
  return best_summary
328
 
329
+ def score_summary(summary, original_text):
330
+ """Score summary quality based on multiple factors"""
331
+ score = 1.0
332
+
333
+ # Verify facts
334
+ verification = verify_facts(summary, original_text)
335
+ if not verification['is_valid']:
336
+ score -= 0.3
337
+
338
+ # Check numbers
339
+ if verification['missing_numbers']:
340
+ score -= 0.1 * len(verification['missing_numbers'])
341
+ if verification['incorrect_numbers']:
342
+ score -= 0.2 * len(verification['incorrect_numbers'])
343
+
344
+ # Check statistical significance preservation
345
+ if verification['missing_significance']:
346
+ score -= 0.1
347
+
348
+ # Check temporal sequence
349
+ if not verification['temporal_sequence_preserved']:
350
+ score -= 0.1
351
+
352
+ # Check for contradictions and inconsistencies
353
+ if verification['contradictions']:
354
+ score -= 0.2 * len(verification['contradictions'])
355
+ if verification['internal_inconsistencies']:
356
+ score -= 0.2 * len(verification['internal_inconsistencies'])
357
+
358
+ # Check section structure and content
359
+ required_sections = ['Background and objectives', 'Methods', 'Key findings', 'Conclusions']
360
+ section_content = {}
361
+ current_section = None
362
+
363
+ for line in summary.split('\n'):
364
+ for section in required_sections:
365
+ if section.lower() in line.lower():
366
+ current_section = section
367
+ section_content[section] = []
368
+ break
369
+ if current_section and not any(section.lower() in line.lower() for section in required_sections):
370
+ section_content[current_section].append(line.strip())
371
+
372
+ for section in required_sections:
373
+ if section not in section_content:
374
+ score -= 0.15 # Missing section
375
+ elif not section_content[section]:
376
+ score -= 0.1 # Empty section
377
+ elif len(' '.join(section_content[section]).split()) < 10:
378
+ score -= 0.05 # Too short
379
+
380
+ def post_process_summary(summary):
381
+ """Enhanced post-processing focused on maintaining structure and removing artifacts"""
382
+ if not summary:
383
+ return summary
384
+
385
+ # Clean up section headers
386
+ summary = re.sub(r'(?i)background and objectives:?\s*background and objectives:?',
387
+ 'Background and objectives:', summary)
388
+ summary = re.sub(r'(?i)methods:?\s*methods:?', 'Methods:', summary)
389
+ summary = re.sub(r'(?i)(key )?findings:?\s*(key )?findings:?', 'Key findings:', summary)
390
+ summary = re.sub(r'(?i)conclusions?:?\s*conclusions?:?', 'Conclusions:', summary)
391
+ summary = re.sub(r'(?i)materials and methods:?', 'Methods:', summary)
392
+ summary = re.sub(r'(?i)objectives?:?', '', summary)
393
+ summary = re.sub(r'(?i)results:?', '', summary)
394
+
395
+ # Remove instruction artifacts
396
+ summary = re.sub(r'(?i)state only|include only|report all|no assumptions', '', summary)
397
+
398
+ # Split into sections and clean each
399
+ sections = re.split(r'(?i)(Background and objectives:|Methods:|Key findings:|Conclusions:)', summary)
400
+ sections = [s.strip() for s in sections if s.strip()]
401
+
402
+ # Reorganize into proper sections
403
+ organized_sections = {
404
+ 'Background and objectives': '',
405
+ 'Methods': '',
406
+ 'Key findings': '',
407
+ 'Conclusions': ''
408
+ }
409
+
410
+ current_section = None
411
+ for item in sections:
412
+ if item in organized_sections:
413
+ current_section = item
414
+ elif current_section:
415
+ organized_sections[current_section] = item.strip()
416
+
417
+ # Build final summary
418
+ final_sections = []
419
+ for section, content in organized_sections.items():
420
+ if content:
421
+ # Clean up the content
422
+ content = re.sub(r'\s+', ' ', content) # Fix spacing
423
+ content = re.sub(r'\.+', '.', content) # Fix multiple periods
424
+ content = content.strip('.: ') # Remove trailing periods and spaces
425
+
426
+ # Add to final sections
427
+ final_sections.append(f"{section}: {content}.")
428
+
429
+ return '\n\n'.join(final_sections)
430
  def validate_summary(summary, original_text):
431
  """Validate summary content against original text"""
432
  # Perform fact verification