pendar02 commited on
Commit
0d0c8c3
·
verified ·
1 Parent(s): 7bd75d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -92
app.py CHANGED
@@ -117,122 +117,173 @@ def preprocess_text(text):
117
 
118
  return formatted_text
119
 
120
- def post_process_summary(summary):
121
- """Clean up and improve summary coherence"""
122
- if not summary:
123
- return summary
124
-
125
- # Split into sentences
126
- sentences = [s.strip() for s in summary.split('.')]
127
- sentences = [s for s in sentences if s] # Remove empty sentences
128
-
129
- # Fix common issues
130
- processed_sentences = []
131
- for i, sentence in enumerate(sentences):
132
- # Remove redundant words/phrases
133
- sentence = sentence.replace(" and and ", " and ")
134
- sentence = sentence.replace("appointment and appointment", "appointment")
135
-
136
- # Fix common grammatical issues
137
- sentence = sentence.replace("Cancers distress", "Cancer distress")
138
- sentence = sentence.replace(" ", " ") # Remove double spaces
139
-
140
- # Capitalize first letter of each sentence
141
- sentence = sentence.capitalize()
142
-
143
- # Add to processed sentences if not empty
144
- if sentence.strip():
145
- processed_sentences.append(sentence)
146
-
147
- # Join sentences with proper spacing and punctuation
148
- cleaned_summary = '. '.join(processed_sentences)
149
- if cleaned_summary and not cleaned_summary.endswith('.'):
150
- cleaned_summary += '.'
151
-
152
- return cleaned_summary
153
 
154
  def improve_summary_generation(text, model, tokenizer):
155
- """Generate improved summary with better prompt and validation"""
156
  if not isinstance(text, str) or not text.strip():
157
  return "No abstract available to summarize."
158
 
159
- # Add a more specific prompt
160
  formatted_text = (
161
- "Summarize this medical research paper following this structure exactly:\n"
162
- "1. Background and objectives\n"
163
- "2. Methods\n"
164
- "3. Key findings with specific numbers/percentages\n"
165
- "4. Main conclusions\n"
 
 
166
  "Original text: " + preprocess_text(text)
167
  )
168
 
169
- # Adjust generation parameters
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True)
171
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
172
 
173
- with torch.no_grad():
174
- summary_ids = model.generate(
175
- **{
176
- "input_ids": inputs["input_ids"],
177
- "attention_mask": inputs["attention_mask"],
178
- "max_length": 200,
179
- "min_length": 50,
180
- "num_beams": 5,
181
- "length_penalty": 1.5,
182
- "no_repeat_ngram_size": 3,
183
- "temperature": 0.7,
184
- "repetition_penalty": 1.5
185
- }
186
- )
187
 
188
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
189
-
190
- # Post-process the summary
191
- processed_summary = post_process_summary(summary)
192
-
193
- # Validate the summary
194
- if not validate_summary(processed_summary, text):
195
- # If validation fails, try one more time with different parameters
196
- with torch.no_grad():
197
- summary_ids = model.generate(
198
- **{
199
- "input_ids": inputs["input_ids"],
200
- "attention_mask": inputs["attention_mask"],
201
- "max_length": 200,
202
- "min_length": 50,
203
- "num_beams": 4,
204
- "length_penalty": 2.0,
205
- "no_repeat_ngram_size": 4,
206
- "temperature": 0.8,
207
- "repetition_penalty": 2.0
208
- }
209
- )
210
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
211
- processed_summary = post_process_summary(summary)
212
 
213
- return processed_summary
214
 
215
  def validate_summary(summary, original_text):
216
- """Validate summary content against original text"""
217
- # Check for age inconsistencies
218
- age_mentions = re.findall(r'(\d+\.?\d*)\s*years?', summary.lower())
219
- if len(age_mentions) > 1: # Multiple age mentions
220
  return False
 
 
 
 
221
 
222
- # Check for repetitive sentences
223
- sentences = summary.split('.')
224
- unique_sentences = set(s.strip().lower() for s in sentences if s.strip())
225
- if len(sentences) - len(unique_sentences) > 1: # More than one duplicate
226
  return False
227
 
228
- # Check summary isn't too long or too short compared to original
229
- summary_words = len(summary.split())
230
- original_words = len(original_text.split())
231
- if summary_words < 20 or summary_words > original_words * 0.8:
232
  return False
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  return True
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  def generate_focused_summary(question, abstracts, model, tokenizer):
237
  """Generate focused summary based on question"""
238
  # Preprocess each abstract
 
117
 
118
  return formatted_text
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  def improve_summary_generation(text, model, tokenizer):
122
+ """Generate improved summary with better prompt engineering and validation"""
123
  if not isinstance(text, str) or not text.strip():
124
  return "No abstract available to summarize."
125
 
126
+ # Create a more structured prompt that enforces accurate reporting
127
  formatted_text = (
128
+ "Summarize this medical research paper accurately and concisely. "
129
+ "Include only factual information from the text. "
130
+ "Structure the summary as follows:\n"
131
+ "1. OBJECTIVE: State the main purpose and study population\n"
132
+ "2. METHODS: Describe key methodological elements\n"
133
+ "3. RESULTS: Report specific findings with exact numbers/percentages\n"
134
+ "4. CONCLUSION: State main implications\n\n"
135
  "Original text: " + preprocess_text(text)
136
  )
137
 
138
+ # First attempt with conservative parameters
139
+ summary = generate_summary_attempt(formatted_text, model, tokenizer,
140
+ conservative_params=True)
141
+
142
+ # Validate the generated summary
143
+ if not validate_summary(summary, text):
144
+ # If validation fails, try again with different parameters
145
+ summary = generate_summary_attempt(formatted_text, model, tokenizer,
146
+ conservative_params=False)
147
+
148
+ return post_process_summary(summary)
149
+
150
+ def generate_summary_attempt(formatted_text, model, tokenizer, conservative_params=True):
151
+ """Generate a summary with specified parameters"""
152
  inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True)
153
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
154
 
155
+ params = {
156
+ "input_ids": inputs["input_ids"],
157
+ "attention_mask": inputs["attention_mask"],
158
+ "max_length": 250, # Increased for better coverage
159
+ "min_length": 100, # Increased to ensure comprehensive summary
160
+ "early_stopping": True,
161
+ "no_repeat_ngram_size": 3,
162
+ }
 
 
 
 
 
 
163
 
164
+ if conservative_params:
165
+ params.update({
166
+ "num_beams": 5,
167
+ "length_penalty": 1.5,
168
+ "temperature": 0.7,
169
+ "top_p": 0.9,
170
+ "repetition_penalty": 1.5
171
+ })
172
+ else:
173
+ params.update({
174
+ "num_beams": 4,
175
+ "length_penalty": 2.0,
176
+ "temperature": 0.8,
177
+ "top_p": 0.95,
178
+ "repetition_penalty": 2.0
179
+ })
180
+
181
+ with torch.no_grad():
182
+ summary_ids = model.generate(**params)
 
 
 
 
 
183
 
184
+ return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
185
 
186
  def validate_summary(summary, original_text):
187
+ """Enhanced validation of summary content"""
188
+ if not summary or not original_text:
 
 
189
  return False
190
+
191
+ # Extract numerical values from both texts
192
+ original_numbers = set(re.findall(r'(\d+(?:\.\d+)?)\s*%', original_text))
193
+ summary_numbers = set(re.findall(r'(\d+(?:\.\d+)?)\s*%', summary))
194
 
195
+ # Check if key percentages are preserved
196
+ if not summary_numbers.issubset(original_numbers):
 
 
197
  return False
198
 
199
+ # Check for contradictions in methodology statements
200
+ methods_original = extract_methods(original_text)
201
+ methods_summary = extract_methods(summary)
202
+ if methods_summary and not any(m in original_text.lower() for m in methods_summary):
203
  return False
204
 
205
+ # Verify no hallucinated content
206
+ sentences = summary.split('.')
207
+ for sentence in sentences:
208
+ # Check if key claims in summary are supported by original
209
+ if sentence.strip() and not is_supported_by_original(sentence, original_text):
210
+ return False
211
+
212
+ return True
213
+
214
+ def extract_methods(text):
215
+ """Extract methodology-related terms"""
216
+ method_keywords = ['study', 'survey', 'analysis', 'trial', 'experiment']
217
+ methods = []
218
+ for keyword in method_keywords:
219
+ pattern = fr'{keyword}\s+\w+'
220
+ matches = re.findall(pattern, text.lower())
221
+ methods.extend(matches)
222
+ return methods
223
+
224
+ def is_supported_by_original(claim, original):
225
+ """Check if a claim from summary is supported by original text"""
226
+ # Remove common filler phrases
227
+ claim = re.sub(r'(this study|the study|results show|we found that)', '', claim.lower()).strip()
228
+
229
+ # Split into key phrases
230
+ key_phrases = [p.strip() for p in claim.split(' and ')]
231
+
232
+ # Check if each key phrase has supporting evidence
233
+ for phrase in key_phrases:
234
+ if phrase and not has_supporting_evidence(phrase, original.lower()):
235
+ return False
236
  return True
237
 
238
+ def has_supporting_evidence(phrase, original):
239
+ """Check if there's supporting evidence for a phrase"""
240
+ # Convert to word sets for flexible matching
241
+ phrase_words = set(phrase.split())
242
+ original_sentences = [set(s.split()) for s in original.split('.')]
243
+
244
+ # Check if any sentence contains most of the phrase words
245
+ return any(len(phrase_words.intersection(sent)) >= len(phrase_words) * 0.7
246
+ for sent in original_sentences)
247
+
248
+ def post_process_summary(summary):
249
+ """Enhanced post-processing of generated summary"""
250
+ if not summary:
251
+ return summary
252
+
253
+ # Split into sections based on the structured format
254
+ sections = []
255
+ current_section = []
256
+
257
+ for line in summary.split('\n'):
258
+ line = line.strip()
259
+ if any(marker in line.upper() for marker in ['OBJECTIVE:', 'METHODS:', 'RESULTS:', 'CONCLUSION:']):
260
+ if current_section:
261
+ sections.append(' '.join(current_section))
262
+ current_section = [line]
263
+ elif line:
264
+ current_section.append(line)
265
+
266
+ if current_section:
267
+ sections.append(' '.join(current_section))
268
+
269
+ # Clean up each section
270
+ cleaned_sections = []
271
+ for section in sections:
272
+ # Fix common issues
273
+ section = re.sub(r'\s+', ' ', section) # Remove multiple spaces
274
+ section = re.sub(r'(\d+)\s*%', r'\1%', section) # Fix percentage formatting
275
+ section = re.sub(r'(\.|,)\s*(\d)', r'\1 \2', section) # Fix number spacing
276
+ cleaned_sections.append(section)
277
+
278
+ # Join sections with proper spacing
279
+ final_summary = '\n'.join(cleaned_sections)
280
+
281
+ # Ensure proper ending
282
+ if final_summary and not final_summary.endswith('.'):
283
+ final_summary += '.'
284
+
285
+ return final_summary
286
+
287
  def generate_focused_summary(question, abstracts, model, tokenizer):
288
  """Generate focused summary based on question"""
289
  # Preprocess each abstract