pendar02 commited on
Commit
0f40536
·
verified ·
1 Parent(s): 005d6b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +229 -299
app.py CHANGED
@@ -27,17 +27,112 @@ if 'processing_started' not in st.session_state:
27
  if 'focused_summary_generated' not in st.session_state:
28
  st.session_state.focused_summary_generated = False
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def load_model(model_type):
31
  """Load appropriate model based on type with proper memory management"""
32
  try:
33
- # Clear any existing cached data
34
  gc.collect()
35
  torch.cuda.empty_cache()
36
-
37
- device = "cpu" # Force CPU usage
38
 
39
  if model_type == "summarize":
40
- # Load the new fine-tuned model directly
41
  model = AutoModelForSeq2SeqLM.from_pretrained(
42
  "pendar02/bart-large-pubmedd",
43
  cache_dir="./models",
@@ -48,7 +143,7 @@ def load_model(model_type):
48
  "pendar02/bart-large-pubmedd",
49
  cache_dir="./models"
50
  )
51
- else: # question_focused
52
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
53
  "GanjinZero/biobart-base",
54
  cache_dir="./models",
@@ -73,7 +168,6 @@ def load_model(model_type):
73
  raise
74
 
75
  def cleanup_model(model, tokenizer):
76
- """Properly cleanup model resources"""
77
  try:
78
  del model
79
  del tokenizer
@@ -82,15 +176,12 @@ def cleanup_model(model, tokenizer):
82
  except Exception:
83
  pass
84
 
85
- @st.cache_data
86
  def process_excel(uploaded_file):
87
- """Process uploaded Excel file"""
88
  try:
89
  df = pd.read_excel(uploaded_file)
90
  required_columns = ['Abstract', 'Article Title', 'Authors',
91
  'Source Title', 'Publication Year', 'DOI', 'Times Cited, All Databases']
92
 
93
- # Check required columns
94
  missing_columns = [col for col in required_columns if col not in df.columns]
95
  if missing_columns:
96
  st.error(f"Missing required columns: {', '.join(missing_columns)}")
@@ -101,305 +192,144 @@ def process_excel(uploaded_file):
101
  st.error(f"Error processing file: {str(e)}")
102
  return None
103
 
104
- def verify_facts(summary, original_text):
105
- """Verify that key facts in the summary match the original text"""
106
- # Extract numbers and percentages
107
- def extract_numbers(text):
108
- return set(re.findall(r'(\d+\.?\d*)%?', text))
109
-
110
- # Extract statistical significance statements
111
- def extract_significance(text):
112
- patterns = [
113
- r'[pP][\s-]value.*?(?:=|was|of)\s*([<>]?\s*\d+\.?\d*)',
114
- r'significant(?:ly)?\s+(?:difference|increase|decrease|change|association)',
115
- r'statistical(?:ly)?\s+significant',
116
- r'[pP]\s*[<>]\s*\d+\.?\d*'
117
- ]
118
- findings = []
119
- for pattern in patterns:
120
- matches = re.finditer(pattern, text, re.IGNORECASE)
121
- for match in matches:
122
- # Get surrounding context
123
- start = max(0, match.start() - 50)
124
- end = min(len(text), match.end() + 50)
125
- findings.append(text[start:end].strip())
126
- return set(findings)
127
-
128
- original_numbers = extract_numbers(original_text)
129
- summary_numbers = extract_numbers(summary)
130
- original_significance = extract_significance(original_text)
131
- summary_significance = extract_significance(summary)
132
-
133
- # Check for temporal sequence preservation
134
- def extract_temporal_markers(text):
135
- markers = [
136
- r'(?:after|following|within)\s+(\d+)\s*(?:weeks?|months?|years?)',
137
- r'at\s+(\d+)\s*(?:weeks?|months?|years?)',
138
- r'(?:baseline|initial|follow-up|final)'
139
- ]
140
- sequence = []
141
- for pattern in markers:
142
- matches = re.finditer(pattern, text, re.IGNORECASE)
143
- for match in matches:
144
- sequence.append(match.group())
145
- return sequence
146
-
147
- original_sequence = extract_temporal_markers(original_text)
148
- summary_sequence = extract_temporal_markers(summary)
149
-
150
- # Extract relationships
151
- relationship_patterns = [
152
- r'associated with',
153
- r'predicted',
154
- r'correlated with',
155
- r'relationship between',
156
- r'linked to',
157
- r'impact(ed)? on',
158
- r'effect(ed)? on',
159
- r'influenced?',
160
- r'dependent on'
161
- ]
162
-
163
- def extract_relationships(text):
164
- relationships = []
165
- for pattern in relationship_patterns:
166
- matches = re.finditer(pattern, text.lower())
167
- for match in matches:
168
- start = max(0, match.start() - 50)
169
- end = min(len(text), match.end() + 50)
170
- relationships.append(text[start:end].strip())
171
- return set(relationships)
172
-
173
- original_relationships = extract_relationships(original_text)
174
- summary_relationships = extract_relationships(summary)
175
-
176
- # Check for contradictions
177
- def find_contradictions(summary, original):
178
- contradictions = []
179
- neg_patterns = [
180
- (r'no association', r'associated with'),
181
- (r'did not predict', r'predicted'),
182
- (r'was not significant', r'was significant'),
183
- (r'decreased', r'increased'),
184
- (r'lower', r'higher'),
185
- (r'negative', r'positive'),
186
- (r'no effect', r'had effect'),
187
- (r'no difference', r'difference'),
188
- (r'no change', r'changed')
189
- ]
190
-
191
- for pos, neg in neg_patterns:
192
- if (re.search(pos, summary.lower()) and re.search(neg, original.lower())) or \
193
- (re.search(neg, summary.lower()) and re.search(pos, original.lower())):
194
- contradictions.append(f"Contradiction found: {pos} vs {neg}")
195
-
196
- return contradictions
197
-
198
- contradictions = find_contradictions(summary, original_text)
199
-
200
- # Check for internal consistency
201
- def check_internal_consistency(summary):
202
- inconsistencies = []
203
- # Check for contradicting statements within the summary
204
- for pos, neg in find_contradictions(summary, summary):
205
- inconsistencies.append(f"Internal contradiction: {pos} vs {neg}")
206
- return inconsistencies
207
-
208
- internal_inconsistencies = check_internal_consistency(summary)
209
-
210
- return {
211
- 'missing_numbers': original_numbers - summary_numbers,
212
- 'incorrect_numbers': summary_numbers - original_numbers,
213
- 'missing_significance': original_significance - summary_significance,
214
- 'missing_relationships': original_relationships - summary_relationships,
215
- 'temporal_sequence_preserved': all(marker in ' '.join(summary_sequence) for marker in original_sequence),
216
- 'contradictions': contradictions,
217
- 'internal_inconsistencies': internal_inconsistencies,
218
- 'is_valid': (len(original_numbers - summary_numbers) == 0 and
219
- len(contradictions) == 0 and
220
- len(internal_inconsistencies) == 0)
221
- }
222
-
223
- def preprocess_text(text):
224
- """Preprocess text to add appropriate formatting before summarization"""
225
- if not isinstance(text, str) or not text.strip():
226
- return text
227
-
228
- # Standardize spacing and line breaks
229
- text = re.sub(r'\s+', ' ', text)
230
- text = text.replace('. ', '.\n')
231
-
232
- # Fix common formatting issues
233
- text = re.sub(r'(?<=[.!?])\s*(?=[A-Z])', '\n', text) # Add breaks after sentences
234
- text = re.sub(r'\(\s*([Nn])\s*=\s*(\d+)\s*\)', r'(n=\2)', text) # Standardize sample size format
235
- text = re.sub(r'(\d+)\s*%', r'\1%', text) # Fix percentage format
236
- text = re.sub(r'([Pp])\s*([<>])\s*(\d)', r'\1\2\3', text) # Fix p-value format
237
-
238
- # Split into sentences and clean each
239
- sentences = [s.strip() for s in text.split('\n')]
240
- sentences = [s for s in sentences if s]
241
-
242
- return '\n'.join(sentences)
243
-
244
  def improve_summary_generation(text, model, tokenizer, max_attempts=3):
245
  """Generate improved summary with better prompt and validation"""
246
  if not isinstance(text, str) or not text.strip():
247
  return "No abstract available to summarize."
248
 
249
- formatted_text = (
250
- "Summarize this medical research paper, strictly following these rules:\n\n"
251
- "1. Background and objectives:\n"
252
- " - State ONLY the main purpose and study population\n"
253
- " - Include sample size if mentioned (format as n=X)\n"
254
- " - No methodology details here\n\n"
255
- "2. Methods:\n"
256
- " - List the specific procedures and measurements used\n"
257
- " - Include timeframes and follow-up periods\n"
258
- " - No results here\n\n"
259
- "3. Key findings:\n"
260
- " - Report ALL numerical results (%, numbers, p-values)\n"
261
- " - Include ALL statistical relationships\n"
262
- " - Present findings in chronological order\n\n"
263
- "4. Conclusions:\n"
264
- " - State ONLY conclusions directly supported by the results\n"
265
- " - Include practical implications if mentioned\n"
266
- " - No new information\n\n"
267
- "Important:\n"
268
- "- Keep each section separate and clearly labeled\n"
269
- "- Use exact numbers from the text\n"
270
- "- Maintain original relationships between variables\n"
271
- "- No speculation or external information\n\n"
272
- "Original text:\n" + preprocess_text(text)
273
- )
274
-
275
- inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True)
276
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
277
-
278
- parameter_combinations = [
279
- {"temperature": 0.1, "num_beams": 12, "length_penalty": 2.0, "top_k": 50},
280
- {"temperature": 0.05, "num_beams": 15, "length_penalty": 2.5, "top_k": 30},
281
- {"temperature": 0.0, "num_beams": 20, "length_penalty": 3.0, "top_k": 10}
282
- ]
283
-
284
- best_summary = None
285
- best_score = -1
286
- attempts = 0
287
-
288
- while attempts < max_attempts:
289
- for params in parameter_combinations:
290
- with torch.no_grad():
291
- summary_ids = model.generate(
292
- **{
293
- "input_ids": inputs["input_ids"],
294
- "attention_mask": inputs["attention_mask"],
295
- "max_length": 300,
296
- "min_length": 100,
297
- "num_beams": params["num_beams"],
298
- "length_penalty": params["length_penalty"],
299
- "no_repeat_ngram_size": 3,
300
- "temperature": params["temperature"],
301
- "top_k": params["top_k"],
302
- "repetition_penalty": 2.5,
303
- "do_sample": params["temperature"] > 0.0
304
- }
305
- )
306
-
307
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
308
- processed_summary = post_process_summary(summary)
309
- score = score_summary(processed_summary, text)
310
-
311
- if score > best_score:
312
- best_summary = processed_summary
313
- best_score = score
314
-
315
- if score > 0.8: # Good enough threshold
316
- return best_summary
317
 
318
- attempts += 1
319
- # Adjust parameters for next attempt if needed
320
  parameter_combinations = [
321
- {**params,
322
- "num_beams": params["num_beams"] + 5,
323
- "length_penalty": params["length_penalty"] + 0.5}
324
- for params in parameter_combinations
325
  ]
326
-
327
- return best_summary
328
-
329
- def score_summary(summary, original_text):
330
- """Score summary quality based on multiple factors"""
331
- score = 1.0
332
-
333
- # Verify facts
334
- verification = verify_facts(summary, original_text)
335
- if not verification['is_valid']:
336
- score -= 0.3
337
-
338
- # Check numbers
339
- if verification['missing_numbers']:
340
- score -= 0.1 * len(verification['missing_numbers'])
341
- if verification['incorrect_numbers']:
342
- score -= 0.2 * len(verification['incorrect_numbers'])
343
-
344
- # Check statistical significance preservation
345
- if verification['missing_significance']:
346
- score -= 0.1
347
-
348
- # Check temporal sequence
349
- if not verification['temporal_sequence_preserved']:
350
- score -= 0.1
351
-
352
- # Check for contradictions and inconsistencies
353
- if verification['contradictions']:
354
- score -= 0.2 * len(verification['contradictions'])
355
- if verification['internal_inconsistencies']:
356
- score -= 0.2 * len(verification['internal_inconsistencies'])
357
-
358
- # Check section structure and content
359
- required_sections = ['Background and objectives', 'Methods', 'Key findings', 'Conclusions']
360
- section_content = {}
361
- current_section = None
362
-
363
- for line in summary.split('\n'):
364
- for section in required_sections:
365
- if section.lower() in line.lower():
366
- current_section = section
367
- section_content[section] = []
368
- break
369
- if current_section and not any(section.lower() in line.lower() for section in required_sections):
370
- section_content[current_section].append(line.strip())
371
-
372
- for section in required_sections:
373
- if section not in section_content:
374
- score -= 0.15 # Missing section
375
- elif not section_content[section]:
376
- score -= 0.1 # Empty section
377
- elif len(' '.join(section_content[section]).split()) < 10:
378
- score -= 0.05 # Too short
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
 
380
  def post_process_summary(summary):
381
  """Enhanced post-processing focused on maintaining structure and removing artifacts"""
382
  if not summary:
383
  return summary
384
-
385
  # Clean up section headers
386
- summary = re.sub(r'(?i)background and objectives:?\s*background and objectives:?',
387
- 'Background and objectives:', summary)
388
- summary = re.sub(r'(?i)methods:?\s*methods:?', 'Methods:', summary)
389
- summary = re.sub(r'(?i)(key )?findings:?\s*(key )?findings:?', 'Key findings:', summary)
390
- summary = re.sub(r'(?i)conclusions?:?\s*conclusions?:?', 'Conclusions:', summary)
391
- summary = re.sub(r'(?i)materials and methods:?', 'Methods:', summary)
392
- summary = re.sub(r'(?i)objectives?:?', '', summary)
393
- summary = re.sub(r'(?i)results:?', '', summary)
394
-
395
- # Remove instruction artifacts
396
- summary = re.sub(r'(?i)state only|include only|report all|no assumptions', '', summary)
397
-
398
- # Split into sections and clean each
 
 
399
  sections = re.split(r'(?i)(Background and objectives:|Methods:|Key findings:|Conclusions:)', summary)
400
  sections = [s.strip() for s in sections if s.strip()]
401
 
402
- # Reorganize into proper sections
403
  organized_sections = {
404
  'Background and objectives': '',
405
  'Methods': '',
@@ -412,21 +342,21 @@ def post_process_summary(summary):
412
  if item in organized_sections:
413
  current_section = item
414
  elif current_section:
415
- organized_sections[current_section] = item.strip()
 
 
 
 
416
 
417
  # Build final summary
418
  final_sections = []
419
  for section, content in organized_sections.items():
420
  if content:
421
- # Clean up the content
422
- content = re.sub(r'\s+', ' ', content) # Fix spacing
423
- content = re.sub(r'\.+', '.', content) # Fix multiple periods
424
- content = content.strip('.: ') # Remove trailing periods and spaces
425
-
426
- # Add to final sections
427
- final_sections.append(f"{section}: {content}.")
428
 
429
  return '\n\n'.join(final_sections)
 
 
430
  def validate_summary(summary, original_text):
431
  """Validate summary content against original text"""
432
  # Perform fact verification
 
27
  if 'focused_summary_generated' not in st.session_state:
28
  st.session_state.focused_summary_generated = False
29
 
30
+ def extract_biomedical_facts(text):
31
+ """Extract biomedical-specific facts and measurements"""
32
+ facts = {
33
+ 'p_values': [],
34
+ 'measurements': [],
35
+ 'demographics': [],
36
+ 'statistical_measures': [],
37
+ 'timeframes': []
38
+ }
39
+
40
+ # P-value patterns
41
+ p_value_patterns = [
42
+ r'[pP][\s-]*(?:value)?[\s-]*[=<>]\s*\.?\d+\.?\d*e?-?\d*', # p = 0.001, p<.05, p < 1e-6
43
+ r'[pP][\s-]*(?:value)?[\s-]*(?:was|of|is|were)\s*\.?\d+\.?\d*e?-?\d*' # p value was 0.001
44
+ ]
45
+
46
+ # Statistical measures patterns
47
+ stat_patterns = [
48
+ r'(?:CI|confidence interval)[\s:]*(?:\d+\.?\d*%?)?\s*[-–]\s*(?:\d+\.?\d*%?)', # 95% CI: 1.2-3.4
49
+ r'(?:OR|odds ratio)[\s:]*(?:\d+\.?\d*)', # OR: 1.5
50
+ r'(?:HR|hazard ratio)[\s:]*(?:\d+\.?\d*)', # HR: 2.1
51
+ r'(?:RR|relative risk)[\s:]*(?:\d+\.?\d*)', # RR: 1.3
52
+ r'(?:SD|standard deviation)[\s:]*[±]?\s*\d+\.?\d*' # SD: ±2.1
53
+ ]
54
+
55
+ # Measurement patterns
56
+ measurement_patterns = [
57
+ r'\d+\.?\d*\s*(?:mg|kg|ml|mmol|µg|ng|mm|cm|µl|g/dl|mmHg)', # Units
58
+ r'\d+\.?\d*\s*(?:weeks?|months?|years?|hours?|days?)', # Time units
59
+ r'\d+\.?\d*\s*(?:%|percent|percentage)' # Percentages
60
+ ]
61
+
62
+ # Demographic patterns
63
+ demographic_patterns = [
64
+ r'(?:mean|median)\s*age[\s:]*(?:was|of|=)?\s*\d+\.?\d*',
65
+ r'(?:\d+\.?\d*%?\s*(?:men|women|male|female))',
66
+ r'(?:\d+\.?\d*%?\s*of\s*(?:patients|participants|subjects))',
67
+ r'(?:[Nn]\s*=\s*\d+)', # Sample size
68
+ r'(?:aged?\s*\d+[-–]\d+\s*(?:years?|yrs?)?)' # Age range
69
+ ]
70
+
71
+ # Extract all patterns
72
+ for pattern in p_value_patterns:
73
+ matches = re.finditer(pattern, text, re.IGNORECASE)
74
+ facts['p_values'].extend([m.group() for m in matches])
75
+
76
+ for pattern in stat_patterns:
77
+ matches = re.finditer(pattern, text, re.IGNORECASE)
78
+ facts['statistical_measures'].extend([m.group() for m in matches])
79
+
80
+ for pattern in measurement_patterns:
81
+ matches = re.finditer(pattern, text, re.IGNORECASE)
82
+ facts['measurements'].extend([m.group() for m in matches])
83
+
84
+ for pattern in demographic_patterns:
85
+ matches = re.finditer(pattern, text, re.IGNORECASE)
86
+ facts['demographics'].extend([m.group() for m in matches])
87
+
88
+ # Extract timeframes
89
+ timeframe_patterns = [
90
+ r'(?:followed|monitored|observed|tracked)\s*(?:for|over|during)\s*\d+\.?\d*\s*(?:weeks?|months?|years?)',
91
+ r'(?:follow-up|duration)\s*(?:of|was|=)\s*\d+\.?\d*\s*(?:weeks?|months?|years?)',
92
+ r'\d+[-–]\s*(?:week|month|year)\s*(?:follow-up|period|duration)'
93
+ ]
94
+
95
+ for pattern in timeframe_patterns:
96
+ matches = re.finditer(pattern, text, re.IGNORECASE)
97
+ facts['timeframes'].extend([m.group() for m in matches])
98
+
99
+ return facts
100
+
101
+ def identify_abstract_structure(text):
102
+ """Identify the structure of the biomedical abstract"""
103
+ # Common section headers in biomedical abstracts
104
+ section_patterns = {
105
+ 'background': r'(?:background|introduction|objective|purpose|aim)',
106
+ 'methods': r'(?:methods|materials|design|study design|procedure)',
107
+ 'results': r'(?:results|findings|outcome)',
108
+ 'conclusions': r'(?:conclusion|discussion|summary|implications)'
109
+ }
110
+
111
+ # Check if abstract has clear section headers
112
+ has_sections = any(
113
+ re.search(f"{pattern}s?:?", text, re.IGNORECASE)
114
+ for pattern in section_patterns.values()
115
+ )
116
+
117
+ if not has_sections:
118
+ return "unstructured"
119
+
120
+ # Identify present sections
121
+ present_sections = []
122
+ for section, pattern in section_patterns.items():
123
+ if re.search(f"{pattern}s?:?", text, re.IGNORECASE):
124
+ present_sections.append(section)
125
+
126
+ return present_sections
127
+
128
  def load_model(model_type):
129
  """Load appropriate model based on type with proper memory management"""
130
  try:
 
131
  gc.collect()
132
  torch.cuda.empty_cache()
133
+ device = "cpu"
 
134
 
135
  if model_type == "summarize":
 
136
  model = AutoModelForSeq2SeqLM.from_pretrained(
137
  "pendar02/bart-large-pubmedd",
138
  cache_dir="./models",
 
143
  "pendar02/bart-large-pubmedd",
144
  cache_dir="./models"
145
  )
146
+ else:
147
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
148
  "GanjinZero/biobart-base",
149
  cache_dir="./models",
 
168
  raise
169
 
170
  def cleanup_model(model, tokenizer):
 
171
  try:
172
  del model
173
  del tokenizer
 
176
  except Exception:
177
  pass
178
 
 
179
  def process_excel(uploaded_file):
 
180
  try:
181
  df = pd.read_excel(uploaded_file)
182
  required_columns = ['Abstract', 'Article Title', 'Authors',
183
  'Source Title', 'Publication Year', 'DOI', 'Times Cited, All Databases']
184
 
 
185
  missing_columns = [col for col in required_columns if col not in df.columns]
186
  if missing_columns:
187
  st.error(f"Missing required columns: {', '.join(missing_columns)}")
 
192
  st.error(f"Error processing file: {str(e)}")
193
  return None
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  def improve_summary_generation(text, model, tokenizer, max_attempts=3):
196
  """Generate improved summary with better prompt and validation"""
197
  if not isinstance(text, str) or not text.strip():
198
  return "No abstract available to summarize."
199
 
200
+ try:
201
+ # Identify abstract structure and extract facts
202
+ structure = identify_abstract_structure(text)
203
+ facts = extract_biomedical_facts(text)
204
+
205
+ # Build prompt based on structure
206
+ if structure == "unstructured":
207
+ section_prompt = (
208
+ "Organize this unstructured biomedical abstract into clear sections:\n"
209
+ "1. Background/Objectives\n"
210
+ "2. Methods\n"
211
+ "3. Results\n"
212
+ "4. Conclusions\n\n"
213
+ )
214
+ else:
215
+ section_prompt = "Summarize while maintaining these sections:\n"
216
+ for section in structure:
217
+ section_prompt += f"- {section.capitalize()}\n"
218
+
219
+ formatted_text = (
220
+ f"{section_prompt}\n"
221
+ "Requirements:\n"
222
+ "- Include ALL statistical findings (p-values, CIs, ORs)\n"
223
+ "- Preserve ALL demographic information\n"
224
+ "- Maintain ALL measurements and units\n"
225
+ "- Keep ALL timeframes and follow-up periods\n"
226
+ "- Report numerical results with original precision\n"
227
+ "- Preserve relationships between variables\n"
228
+ "- Maintain chronological order of findings\n\n"
229
+ "Original text:\n" + text
230
+ )
231
+
232
+ inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True)
233
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
 
 
235
  parameter_combinations = [
236
+ {"temperature": 0.1, "num_beams": 12, "length_penalty": 2.0, "top_k": 50},
237
+ {"temperature": 0.05, "num_beams": 15, "length_penalty": 2.5, "top_k": 30},
238
+ {"temperature": 0.0, "num_beams": 20, "length_penalty": 3.0, "top_k": 10}
 
239
  ]
240
+
241
+ best_summary = None
242
+ best_score = -1
243
+ attempts = 0
244
+
245
+ while attempts < max_attempts:
246
+ for params in parameter_combinations:
247
+ try:
248
+ with torch.no_grad():
249
+ summary_ids = model.generate(
250
+ **{
251
+ "input_ids": inputs["input_ids"],
252
+ "attention_mask": inputs["attention_mask"],
253
+ "max_length": 300,
254
+ "min_length": 100,
255
+ "num_beams": params["num_beams"],
256
+ "length_penalty": params["length_penalty"],
257
+ "no_repeat_ngram_size": 3,
258
+ "temperature": params["temperature"],
259
+ "top_k": params["top_k"],
260
+ "repetition_penalty": 2.5,
261
+ "do_sample": params["temperature"] > 0.0
262
+ }
263
+ )
264
+
265
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
266
+ if not summary:
267
+ continue
268
+
269
+ processed_summary = post_process_summary(summary)
270
+ if not processed_summary:
271
+ continue
272
+
273
+ # Validate biomedical content
274
+ summary_facts = extract_biomedical_facts(processed_summary)
275
+ missing_facts = {k: set(v) - set(summary_facts[k]) for k, v in facts.items()}
276
+
277
+ # Calculate score
278
+ score = 1.0
279
+ for category, missing in missing_facts.items():
280
+ if missing:
281
+ score -= 0.1 * len(missing)
282
+
283
+ if score > best_score:
284
+ best_summary = processed_summary
285
+ best_score = score
286
+
287
+ if score > 0.8:
288
+ return best_summary
289
+
290
+ except Exception as e:
291
+ print(f"Error in generation attempt: {str(e)}")
292
+ continue
293
+
294
+ attempts += 1
295
+ parameter_combinations = [
296
+ {**params,
297
+ "num_beams": params["num_beams"] + 5,
298
+ "length_penalty": params["length_penalty"] + 0.5}
299
+ for params in parameter_combinations
300
+ ]
301
+
302
+ return best_summary if best_summary is not None else "Unable to generate a satisfactory summary."
303
+
304
+ except Exception as e:
305
+ print(f"Error in summary generation: {str(e)}")
306
+ return "Error generating summary."
307
 
308
  def post_process_summary(summary):
309
  """Enhanced post-processing focused on maintaining structure and removing artifacts"""
310
  if not summary:
311
  return summary
312
+
313
  # Clean up section headers
314
+ header_mappings = {
315
+ r'(?i)background.*objectives?:?': 'Background and objectives:',
316
+ r'(?i)(materials?\s*and\s*)?methods?:?': 'Methods:',
317
+ r'(?i)(key\s*)?findings?:?|results?:?': 'Key findings:',
318
+ r'(?i)conclusions?:?': 'Conclusions:',
319
+ r'(?i)(study\s*)?aims?:?|goals?:?|purpose:?': '',
320
+ r'(?i)objectives?:?': '',
321
+ r'(?i)outcomes?:?': '',
322
+ r'(?i)discussion:?': ''
323
+ }
324
+
325
+ for pattern, replacement in header_mappings.items():
326
+ summary = re.sub(pattern, replacement, summary)
327
+
328
+ # Split into sections and clean
329
  sections = re.split(r'(?i)(Background and objectives:|Methods:|Key findings:|Conclusions:)', summary)
330
  sections = [s.strip() for s in sections if s.strip()]
331
 
332
+ # Reorganize sections
333
  organized_sections = {
334
  'Background and objectives': '',
335
  'Methods': '',
 
342
  if item in organized_sections:
343
  current_section = item
344
  elif current_section:
345
+ # Clean up content
346
+ content = re.sub(r'\s+', ' ', item) # Fix spacing
347
+ content = re.sub(r'\.+', '.', content) # Fix multiple periods
348
+ content = content.strip('.: ') # Remove trailing periods and spaces
349
+ organized_sections[current_section] = content
350
 
351
  # Build final summary
352
  final_sections = []
353
  for section, content in organized_sections.items():
354
  if content:
355
+ final_sections.append(f"{section} {content}.")
 
 
 
 
 
 
356
 
357
  return '\n\n'.join(final_sections)
358
+
359
+ # The rest of your app.py code (main function, UI components, etc.) remains the same...
360
  def validate_summary(summary, original_text):
361
  """Validate summary content against original text"""
362
  # Perform fact verification