pendar02 commited on
Commit
0a57b0f
·
verified ·
1 Parent(s): 0f40536

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -225
app.py CHANGED
@@ -27,103 +27,52 @@ if 'processing_started' not in st.session_state:
27
  if 'focused_summary_generated' not in st.session_state:
28
  st.session_state.focused_summary_generated = False
29
 
30
- def extract_biomedical_facts(text):
31
- """Extract biomedical-specific facts and measurements"""
32
- facts = {
33
- 'p_values': [],
34
- 'measurements': [],
35
- 'demographics': [],
36
- 'statistical_measures': [],
37
- 'timeframes': []
38
- }
39
-
40
- # P-value patterns
41
- p_value_patterns = [
42
- r'[pP][\s-]*(?:value)?[\s-]*[=<>]\s*\.?\d+\.?\d*e?-?\d*', # p = 0.001, p<.05, p < 1e-6
43
- r'[pP][\s-]*(?:value)?[\s-]*(?:was|of|is|were)\s*\.?\d+\.?\d*e?-?\d*' # p value was 0.001
44
- ]
45
-
46
- # Statistical measures patterns
47
- stat_patterns = [
48
- r'(?:CI|confidence interval)[\s:]*(?:\d+\.?\d*%?)?\s*[-–]\s*(?:\d+\.?\d*%?)', # 95% CI: 1.2-3.4
49
- r'(?:OR|odds ratio)[\s:]*(?:\d+\.?\d*)', # OR: 1.5
50
- r'(?:HR|hazard ratio)[\s:]*(?:\d+\.?\d*)', # HR: 2.1
51
- r'(?:RR|relative risk)[\s:]*(?:\d+\.?\d*)', # RR: 1.3
52
- r'(?:SD|standard deviation)[\s:]*[±]?\s*\d+\.?\d*' # SD: ±2.1
53
- ]
54
-
55
- # Measurement patterns
56
- measurement_patterns = [
57
- r'\d+\.?\d*\s*(?:mg|kg|ml|mmol|µg|ng|mm|cm|µl|g/dl|mmHg)', # Units
58
- r'\d+\.?\d*\s*(?:weeks?|months?|years?|hours?|days?)', # Time units
59
- r'\d+\.?\d*\s*(?:%|percent|percentage)' # Percentages
60
- ]
61
-
62
- # Demographic patterns
63
- demographic_patterns = [
64
- r'(?:mean|median)\s*age[\s:]*(?:was|of|=)?\s*\d+\.?\d*',
65
- r'(?:\d+\.?\d*%?\s*(?:men|women|male|female))',
66
- r'(?:\d+\.?\d*%?\s*of\s*(?:patients|participants|subjects))',
67
- r'(?:[Nn]\s*=\s*\d+)', # Sample size
68
- r'(?:aged?\s*\d+[-–]\d+\s*(?:years?|yrs?)?)' # Age range
69
- ]
70
-
71
- # Extract all patterns
72
- for pattern in p_value_patterns:
73
- matches = re.finditer(pattern, text, re.IGNORECASE)
74
- facts['p_values'].extend([m.group() for m in matches])
75
-
76
- for pattern in stat_patterns:
77
- matches = re.finditer(pattern, text, re.IGNORECASE)
78
- facts['statistical_measures'].extend([m.group() for m in matches])
79
-
80
- for pattern in measurement_patterns:
81
- matches = re.finditer(pattern, text, re.IGNORECASE)
82
- facts['measurements'].extend([m.group() for m in matches])
83
-
84
- for pattern in demographic_patterns:
85
- matches = re.finditer(pattern, text, re.IGNORECASE)
86
- facts['demographics'].extend([m.group() for m in matches])
87
 
88
- # Extract timeframes
89
- timeframe_patterns = [
90
- r'(?:followed|monitored|observed|tracked)\s*(?:for|over|during)\s*\d+\.?\d*\s*(?:weeks?|months?|years?)',
91
- r'(?:follow-up|duration)\s*(?:of|was|=)\s*\d+\.?\d*\s*(?:weeks?|months?|years?)',
92
- r'\d+[-–]\s*(?:week|month|year)\s*(?:follow-up|period|duration)'
93
- ]
94
 
95
- for pattern in timeframe_patterns:
96
- matches = re.finditer(pattern, text, re.IGNORECASE)
97
- facts['timeframes'].extend([m.group() for m in matches])
 
98
 
99
- return facts
100
 
101
- def identify_abstract_structure(text):
102
- """Identify the structure of the biomedical abstract"""
103
- # Common section headers in biomedical abstracts
104
- section_patterns = {
105
- 'background': r'(?:background|introduction|objective|purpose|aim)',
106
- 'methods': r'(?:methods|materials|design|study design|procedure)',
107
- 'results': r'(?:results|findings|outcome)',
108
- 'conclusions': r'(?:conclusion|discussion|summary|implications)'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  }
110
-
111
- # Check if abstract has clear section headers
112
- has_sections = any(
113
- re.search(f"{pattern}s?:?", text, re.IGNORECASE)
114
- for pattern in section_patterns.values()
115
- )
116
-
117
- if not has_sections:
118
- return "unstructured"
119
-
120
- # Identify present sections
121
- present_sections = []
122
- for section, pattern in section_patterns.items():
123
- if re.search(f"{pattern}s?:?", text, re.IGNORECASE):
124
- present_sections.append(section)
125
-
126
- return present_sections
127
 
128
  def load_model(model_type):
129
  """Load appropriate model based on type with proper memory management"""
@@ -192,114 +141,47 @@ def process_excel(uploaded_file):
192
  st.error(f"Error processing file: {str(e)}")
193
  return None
194
 
195
- def improve_summary_generation(text, model, tokenizer, max_attempts=3):
196
  """Generate improved summary with better prompt and validation"""
197
  if not isinstance(text, str) or not text.strip():
198
  return "No abstract available to summarize."
199
 
200
  try:
201
- # Identify abstract structure and extract facts
202
- structure = identify_abstract_structure(text)
203
- facts = extract_biomedical_facts(text)
204
-
205
- # Build prompt based on structure
206
- if structure == "unstructured":
207
- section_prompt = (
208
- "Organize this unstructured biomedical abstract into clear sections:\n"
209
- "1. Background/Objectives\n"
210
- "2. Methods\n"
211
- "3. Results\n"
212
- "4. Conclusions\n\n"
213
- )
214
- else:
215
- section_prompt = "Summarize while maintaining these sections:\n"
216
- for section in structure:
217
- section_prompt += f"- {section.capitalize()}\n"
218
-
219
  formatted_text = (
220
- f"{section_prompt}\n"
221
- "Requirements:\n"
222
- "- Include ALL statistical findings (p-values, CIs, ORs)\n"
223
- "- Preserve ALL demographic information\n"
224
- "- Maintain ALL measurements and units\n"
225
- "- Keep ALL timeframes and follow-up periods\n"
226
- "- Report numerical results with original precision\n"
227
- "- Preserve relationships between variables\n"
228
- "- Maintain chronological order of findings\n\n"
229
- "Original text:\n" + text
230
  )
231
 
232
  inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True)
233
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
234
 
235
- parameter_combinations = [
236
- {"temperature": 0.1, "num_beams": 12, "length_penalty": 2.0, "top_k": 50},
237
- {"temperature": 0.05, "num_beams": 15, "length_penalty": 2.5, "top_k": 30},
238
- {"temperature": 0.0, "num_beams": 20, "length_penalty": 3.0, "top_k": 10}
239
- ]
240
-
241
- best_summary = None
242
- best_score = -1
243
- attempts = 0
 
 
 
 
 
 
244
 
245
- while attempts < max_attempts:
246
- for params in parameter_combinations:
247
- try:
248
- with torch.no_grad():
249
- summary_ids = model.generate(
250
- **{
251
- "input_ids": inputs["input_ids"],
252
- "attention_mask": inputs["attention_mask"],
253
- "max_length": 300,
254
- "min_length": 100,
255
- "num_beams": params["num_beams"],
256
- "length_penalty": params["length_penalty"],
257
- "no_repeat_ngram_size": 3,
258
- "temperature": params["temperature"],
259
- "top_k": params["top_k"],
260
- "repetition_penalty": 2.5,
261
- "do_sample": params["temperature"] > 0.0
262
- }
263
- )
264
-
265
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
266
- if not summary:
267
- continue
268
-
269
- processed_summary = post_process_summary(summary)
270
- if not processed_summary:
271
- continue
272
-
273
- # Validate biomedical content
274
- summary_facts = extract_biomedical_facts(processed_summary)
275
- missing_facts = {k: set(v) - set(summary_facts[k]) for k, v in facts.items()}
276
-
277
- # Calculate score
278
- score = 1.0
279
- for category, missing in missing_facts.items():
280
- if missing:
281
- score -= 0.1 * len(missing)
282
-
283
- if score > best_score:
284
- best_summary = processed_summary
285
- best_score = score
286
-
287
- if score > 0.8:
288
- return best_summary
289
-
290
- except Exception as e:
291
- print(f"Error in generation attempt: {str(e)}")
292
- continue
293
 
294
- attempts += 1
295
- parameter_combinations = [
296
- {**params,
297
- "num_beams": params["num_beams"] + 5,
298
- "length_penalty": params["length_penalty"] + 0.5}
299
- for params in parameter_combinations
300
- ]
301
-
302
- return best_summary if best_summary is not None else "Unable to generate a satisfactory summary."
303
 
304
  except Exception as e:
305
  print(f"Error in summary generation: {str(e)}")
@@ -356,13 +238,12 @@ def post_process_summary(summary):
356
 
357
  return '\n\n'.join(final_sections)
358
 
359
- # The rest of your app.py code (main function, UI components, etc.) remains the same...
360
  def validate_summary(summary, original_text):
361
  """Validate summary content against original text"""
362
  # Perform fact verification
363
  verification = verify_facts(summary, original_text)
364
 
365
- if not verification['is_valid']:
366
  return False
367
 
368
  # Check for age inconsistencies
@@ -386,34 +267,40 @@ def validate_summary(summary, original_text):
386
 
387
  def generate_focused_summary(question, abstracts, model, tokenizer):
388
  """Generate focused summary based on question"""
389
- # Preprocess each abstract
390
- formatted_abstracts = [preprocess_text(abstract) for abstract in abstracts]
391
- combined_input = f"Question: {question} Abstracts: " + " [SEP] ".join(formatted_abstracts)
392
-
393
- inputs = tokenizer(combined_input, return_tensors="pt", max_length=1024, truncation=True)
394
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
395
-
396
- with torch.no_grad():
397
- summary_ids = model.generate(
398
- **{
399
- "input_ids": inputs["input_ids"],
400
- "attention_mask": inputs["attention_mask"],
401
- "max_length": 200,
402
- "min_length": 50,
403
- "num_beams": 4,
404
- "length_penalty": 2.0,
405
- "early_stopping": True
406
- }
407
- )
408
-
409
- return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
 
410
 
411
  def create_filter_controls(df, sort_column):
412
  """Create appropriate filter controls based on the selected column"""
413
  filtered_df = df.copy()
414
 
415
  if sort_column == 'Publication Year':
416
- # Year range slider
417
  year_min = int(df['Publication Year'].min())
418
  year_max = int(df['Publication Year'].max())
419
  col1, col2 = st.columns(2)
@@ -433,7 +320,6 @@ def create_filter_controls(df, sort_column):
433
  ]
434
 
435
  elif sort_column == 'Authors':
436
- # Multi-select for authors
437
  unique_authors = sorted(set(
438
  author.strip()
439
  for authors in df['Authors'].dropna()
@@ -451,7 +337,6 @@ def create_filter_controls(df, sort_column):
451
  ]
452
 
453
  elif sort_column == 'Source Title':
454
- # Multi-select for source titles
455
  unique_sources = sorted(df['Source Title'].unique())
456
  selected_sources = st.multiselect(
457
  'Select Sources',
@@ -460,13 +345,7 @@ def create_filter_controls(df, sort_column):
460
  if selected_sources:
461
  filtered_df = filtered_df[filtered_df['Source Title'].isin(selected_sources)]
462
 
463
- elif sort_column == 'Article Title':
464
- # Only alphabetical sorting, no filtering
465
- pass
466
-
467
-
468
  elif sort_column == 'Times Cited':
469
- # Cited count range slider
470
  cited_min = int(df['Times Cited'].min())
471
  cited_max = int(df['Times Cited'].max())
472
  col1, col2 = st.columns(2)
@@ -490,19 +369,16 @@ def create_filter_controls(df, sort_column):
490
  def main():
491
  st.title("🔬 Biomedical Papers Analysis")
492
 
493
- # File upload section
494
  uploaded_file = st.file_uploader(
495
  "Upload Excel file containing papers",
496
  type=['xlsx', 'xls'],
497
  help="File must contain: Abstract, Article Title, Authors, Source Title, Publication Year, DOI"
498
  )
499
 
500
- # Question input - moved up but hidden initially
501
  question_container = st.empty()
502
  question = ""
503
 
504
  if uploaded_file is not None:
505
- # Process Excel file
506
  if st.session_state.processed_data is None:
507
  with st.spinner("Processing file..."):
508
  df = process_excel(uploaded_file)
@@ -513,15 +389,14 @@ def main():
513
  df = st.session_state.processed_data
514
  st.write(f"📊 Loaded {len(df)} papers with abstracts")
515
 
516
- # Get question before processing
517
  with question_container:
518
  question = st.text_input(
519
  "Enter your research question (optional):",
520
- help="If provided, a question-focused summary will be generated after individual summaries"
521
  )
522
 
523
  # Single button for both processes
524
- if not st.session_state.get('processing_started', False):
525
  if st.button("Start Analysis"):
526
  st.session_state.processing_started = True
527
 
 
27
  if 'focused_summary_generated' not in st.session_state:
28
  st.session_state.focused_summary_generated = False
29
 
30
+ def preprocess_text(text):
31
+ """Preprocess text for summarization"""
32
+ if not isinstance(text, str) or not text.strip():
33
+ return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ # Clean up whitespace
36
+ text = re.sub(r'\s+', ' ', text)
37
+ text = text.strip()
 
 
 
38
 
39
+ # Fix common formatting issues
40
+ text = re.sub(r'(\d+)\s*%', r'\1%', text) # Fix percentage format
41
+ text = re.sub(r'\(\s*([Nn])\s*=\s*(\d+)\s*\)', r'(n=\2)', text) # Fix sample size format
42
+ text = re.sub(r'([Pp])\s*([<>])\s*(\d)', r'\1\2\3', text) # Fix p-value format
43
 
44
+ return text
45
 
46
+ def verify_facts(summary, original_text):
47
+ """Verify key facts between summary and original text"""
48
+ # Extract numbers and percentages
49
+ def extract_numbers(text):
50
+ return set(re.findall(r'(\d+\.?\d*)%?', text))
51
+
52
+ # Extract relationships
53
+ def extract_relationships(text):
54
+ patterns = [
55
+ r'associated with', r'predicted', r'correlated',
56
+ r'increased', r'decreased', r'significant'
57
+ ]
58
+ found = []
59
+ for pattern in patterns:
60
+ if re.search(pattern, text.lower()):
61
+ found.append(pattern)
62
+ return set(found)
63
+
64
+ # Get facts from both texts
65
+ original_numbers = extract_numbers(original_text)
66
+ summary_numbers = extract_numbers(summary)
67
+ original_relations = extract_relationships(original_text)
68
+ summary_relations = extract_relationships(summary)
69
+
70
+ return {
71
+ 'is_valid': summary_numbers.issubset(original_numbers) and
72
+ summary_relations.issubset(original_relations),
73
+ 'missing_numbers': original_numbers - summary_numbers,
74
+ 'missing_relations': original_relations - summary_relations
75
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  def load_model(model_type):
78
  """Load appropriate model based on type with proper memory management"""
 
141
  st.error(f"Error processing file: {str(e)}")
142
  return None
143
 
144
+ def improve_summary_generation(text, model, tokenizer):
145
  """Generate improved summary with better prompt and validation"""
146
  if not isinstance(text, str) or not text.strip():
147
  return "No abstract available to summarize."
148
 
149
  try:
150
+ # Simplified prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  formatted_text = (
152
+ "Summarize this biomedical abstract into four sections:\n"
153
+ "1. Background/Objectives: State the main purpose and population\n"
154
+ "2. Methods: Describe what was done\n"
155
+ "3. Key findings: Include ALL numerical results and statistical relationships\n"
156
+ "4. Conclusions: State main implications\n\n"
157
+ "Important: Preserve all numbers, measurements, and statistical findings.\n\n"
158
+ "Text: " + preprocess_text(text)
 
 
 
159
  )
160
 
161
  inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True)
162
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
163
 
164
+ # Single generation attempt with optimized parameters
165
+ with torch.no_grad():
166
+ summary_ids = model.generate(
167
+ **{
168
+ "input_ids": inputs["input_ids"],
169
+ "attention_mask": inputs["attention_mask"],
170
+ "max_length": 300,
171
+ "min_length": 100,
172
+ "num_beams": 5,
173
+ "length_penalty": 2.0,
174
+ "no_repeat_ngram_size": 3,
175
+ "temperature": 0.3,
176
+ "repetition_penalty": 2.5
177
+ }
178
+ )
179
 
180
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
181
+ if not summary:
182
+ return "Error: Could not generate summary."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
+ return post_process_summary(summary)
 
 
 
 
 
 
 
 
185
 
186
  except Exception as e:
187
  print(f"Error in summary generation: {str(e)}")
 
238
 
239
  return '\n\n'.join(final_sections)
240
 
 
241
  def validate_summary(summary, original_text):
242
  """Validate summary content against original text"""
243
  # Perform fact verification
244
  verification = verify_facts(summary, original_text)
245
 
246
+ if not verification.get('is_valid', False):
247
  return False
248
 
249
  # Check for age inconsistencies
 
267
 
268
  def generate_focused_summary(question, abstracts, model, tokenizer):
269
  """Generate focused summary based on question"""
270
+ try:
271
+ # Preprocess each abstract
272
+ formatted_abstracts = [preprocess_text(abstract) for abstract in abstracts]
273
+ combined_input = f"Question: {question}\nSummarize these abstracts to answer the question:\n" + \
274
+ "\n---\n".join(formatted_abstracts)
275
+
276
+ inputs = tokenizer(combined_input, return_tensors="pt", max_length=1024, truncation=True)
277
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
278
+
279
+ with torch.no_grad():
280
+ summary_ids = model.generate(
281
+ **{
282
+ "input_ids": inputs["input_ids"],
283
+ "attention_mask": inputs["attention_mask"],
284
+ "max_length": 300,
285
+ "min_length": 100,
286
+ "num_beams": 5,
287
+ "length_penalty": 2.0,
288
+ "temperature": 0.3,
289
+ "repetition_penalty": 2.5
290
+ }
291
+ )
292
+
293
+ return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
294
+
295
+ except Exception as e:
296
+ print(f"Error in focused summary generation: {str(e)}")
297
+ return "Error generating focused summary."
298
 
299
  def create_filter_controls(df, sort_column):
300
  """Create appropriate filter controls based on the selected column"""
301
  filtered_df = df.copy()
302
 
303
  if sort_column == 'Publication Year':
 
304
  year_min = int(df['Publication Year'].min())
305
  year_max = int(df['Publication Year'].max())
306
  col1, col2 = st.columns(2)
 
320
  ]
321
 
322
  elif sort_column == 'Authors':
 
323
  unique_authors = sorted(set(
324
  author.strip()
325
  for authors in df['Authors'].dropna()
 
337
  ]
338
 
339
  elif sort_column == 'Source Title':
 
340
  unique_sources = sorted(df['Source Title'].unique())
341
  selected_sources = st.multiselect(
342
  'Select Sources',
 
345
  if selected_sources:
346
  filtered_df = filtered_df[filtered_df['Source Title'].isin(selected_sources)]
347
 
 
 
 
 
 
348
  elif sort_column == 'Times Cited':
 
349
  cited_min = int(df['Times Cited'].min())
350
  cited_max = int(df['Times Cited'].max())
351
  col1, col2 = st.columns(2)
 
369
  def main():
370
  st.title("🔬 Biomedical Papers Analysis")
371
 
 
372
  uploaded_file = st.file_uploader(
373
  "Upload Excel file containing papers",
374
  type=['xlsx', 'xls'],
375
  help="File must contain: Abstract, Article Title, Authors, Source Title, Publication Year, DOI"
376
  )
377
 
 
378
  question_container = st.empty()
379
  question = ""
380
 
381
  if uploaded_file is not None:
 
382
  if st.session_state.processed_data is None:
383
  with st.spinner("Processing file..."):
384
  df = process_excel(uploaded_file)
 
389
  df = st.session_state.processed_data
390
  st.write(f"📊 Loaded {len(df)} papers with abstracts")
391
 
 
392
  with question_container:
393
  question = st.text_input(
394
  "Enter your research question (optional):",
395
+ help="If provided, a focused summary will be generated after individual summaries"
396
  )
397
 
398
  # Single button for both processes
399
+ if not st.session_state.get('processing_started', False):
400
  if st.button("Start Analysis"):
401
  st.session_state.processing_started = True
402