pendar02 commited on
Commit
7ab41f7
·
verified ·
1 Parent(s): 234816f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -76
app.py CHANGED
@@ -118,121 +118,115 @@ def preprocess_text(text):
118
  return formatted_text
119
 
120
  def post_process_summary(summary):
121
- """Clean up and improve summary coherence"""
122
  if not summary:
123
  return summary
124
-
125
  # Split into sentences
126
  sentences = [s.strip() for s in summary.split('.')]
127
  sentences = [s for s in sentences if s] # Remove empty sentences
128
 
129
- # Fix common issues
130
  processed_sentences = []
131
- for i, sentence in enumerate(sentences):
132
- # Remove redundant words/phrases
133
- sentence = sentence.replace(" and and ", " and ")
134
- sentence = sentence.replace("appointment and appointment", "appointment")
135
-
136
- # Fix common grammatical issues
137
- sentence = sentence.replace("Cancers distress", "Cancer distress")
138
- sentence = sentence.replace(" ", " ") # Remove double spaces
139
 
140
- # Capitalize first letter of each sentence
141
  sentence = sentence.capitalize()
142
 
143
- # Add to processed sentences if not empty
144
- if sentence.strip():
145
  processed_sentences.append(sentence)
146
 
147
- # Join sentences with proper spacing and punctuation
148
  cleaned_summary = '. '.join(processed_sentences)
149
- if cleaned_summary and not cleaned_summary.endswith('.'):
150
- cleaned_summary += '.'
151
-
152
- return cleaned_summary
153
 
154
  def improve_summary_generation(text, model, tokenizer):
155
- """Generate improved summary with better prompt and validation"""
156
  if not isinstance(text, str) or not text.strip():
157
  return "No abstract available to summarize."
158
 
159
- # Add a more specific prompt
160
  formatted_text = (
161
- "Summarize this medical research paper following this structure exactly:\n"
162
- "1. Background and objectives\n"
163
  "2. Methods\n"
164
- "3. Key findings with specific numbers/percentages\n"
165
- "4. Main conclusions\n"
166
- "Original text: " + preprocess_text(text)
167
  )
168
 
169
- # Adjust generation parameters
170
  inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True)
171
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
172
 
173
- with torch.no_grad():
174
- summary_ids = model.generate(
175
- **{
176
- "input_ids": inputs["input_ids"],
177
- "attention_mask": inputs["attention_mask"],
178
- "max_length": 200,
179
- "min_length": 50,
180
- "num_beams": 5,
181
- "length_penalty": 1.5,
182
- "no_repeat_ngram_size": 3,
183
- "temperature": 0.7,
184
- "repetition_penalty": 1.5
185
- }
186
- )
187
-
188
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
 
189
 
190
  # Post-process the summary
191
- processed_summary = post_process_summary(summary)
 
192
 
193
  # Validate the summary
194
  if not validate_summary(processed_summary, text):
195
- # If validation fails, try one more time with different parameters
196
- with torch.no_grad():
197
- summary_ids = model.generate(
198
- **{
199
- "input_ids": inputs["input_ids"],
200
- "attention_mask": inputs["attention_mask"],
201
- "max_length": 200,
202
- "min_length": 50,
203
- "num_beams": 4,
204
- "length_penalty": 2.0,
205
- "no_repeat_ngram_size": 4,
206
- "temperature": 0.8,
207
- "repetition_penalty": 2.0
208
- }
209
- )
210
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
211
- processed_summary = post_process_summary(summary)
212
 
213
  return processed_summary
214
 
215
  def validate_summary(summary, original_text):
216
- """Validate summary content against original text"""
217
- # Check for age inconsistencies
218
- age_mentions = re.findall(r'(\d+\.?\d*)\s*years?', summary.lower())
219
- if len(age_mentions) > 1: # Multiple age mentions
220
- return False
 
221
 
222
- # Check for repetitive sentences
223
- sentences = summary.split('.')
224
- unique_sentences = set(s.strip().lower() for s in sentences if s.strip())
225
- if len(sentences) - len(unique_sentences) > 1: # More than one duplicate
226
  return False
227
-
228
- # Check summary isn't too long or too short compared to original
229
- summary_words = len(summary.split())
230
- original_words = len(original_text.split())
231
- if summary_words < 20 or summary_words > original_words * 0.8:
232
  return False
233
-
234
  return True
235
 
 
236
  def generate_focused_summary(question, abstracts, model, tokenizer):
237
  """Generate focused summary based on question"""
238
  # Preprocess each abstract
 
118
  return formatted_text
119
 
120
  def post_process_summary(summary):
121
+ """Clean up and improve summary coherence."""
122
  if not summary:
123
  return summary
124
+
125
  # Split into sentences
126
  sentences = [s.strip() for s in summary.split('.')]
127
  sentences = [s for s in sentences if s] # Remove empty sentences
128
 
129
+ # Correct common issues
130
  processed_sentences = []
131
+ for sentence in sentences:
132
+ # Remove redundant phrases
133
+ sentence = re.sub(r"\b(and and|appointment and appointment)\b", "and", sentence)
 
 
 
 
 
134
 
135
+ # Ensure first letter capitalization
136
  sentence = sentence.capitalize()
137
 
138
+ # Avoid duplicates
139
+ if sentence not in processed_sentences:
140
  processed_sentences.append(sentence)
141
 
142
+ # Join sentences with proper punctuation
143
  cleaned_summary = '. '.join(processed_sentences)
144
+ return cleaned_summary if cleaned_summary.endswith('.') else cleaned_summary + '.'
145
+
 
 
146
 
147
  def improve_summary_generation(text, model, tokenizer):
148
+ """Generate improved summary with better prompt and validation."""
149
  if not isinstance(text, str) or not text.strip():
150
  return "No abstract available to summarize."
151
 
152
+ # Add a structured prompt for summarization
153
  formatted_text = (
154
+ "Summarize this biomedical research abstract into the following structure:\n"
155
+ "1. Background and Objectives\n"
156
  "2. Methods\n"
157
+ "3. Key Findings (include any percentages or numbers)\n"
158
+ "4. Conclusions\n"
159
+ f"Abstract:\n{text.strip()}"
160
  )
161
 
162
+ # Prepare input tokens
163
  inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True)
164
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
165
 
166
+ # Generate summary with adjusted parameters
167
+ try:
168
+ with torch.no_grad():
169
+ summary_ids = model.generate(
170
+ input_ids=inputs["input_ids"],
171
+ attention_mask=inputs["attention_mask"],
172
+ max_length=300, # Increased for more detailed summaries
173
+ min_length=100, # Ensure summaries are not too short
174
+ num_beams=5,
175
+ length_penalty=1.5,
176
+ no_repeat_ngram_size=3,
177
+ temperature=0.7,
178
+ repetition_penalty=1.3,
179
+ )
180
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
181
+ except Exception as e:
182
+ return f"Error in generation: {str(e)}"
183
 
184
  # Post-process the summary
185
+ return post_process_summary(summary)
186
+
187
 
188
  # Validate the summary
189
  if not validate_summary(processed_summary, text):
190
+ # Retry with alternate generation parameters
191
+ with torch.no_grad():
192
+ summary_ids = model.generate(
193
+ input_ids=inputs["input_ids"],
194
+ attention_mask=inputs["attention_mask"],
195
+ max_length=250,
196
+ min_length=50,
197
+ num_beams=4,
198
+ length_penalty=2.0,
199
+ no_repeat_ngram_size=4,
200
+ temperature=0.8,
201
+ repetition_penalty=1.5,
202
+ )
203
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
204
+ processed_summary = post_process_summary(summary)
205
+
 
206
 
207
  return processed_summary
208
 
209
  def validate_summary(summary, original_text):
210
+ """Validate summary content against original text."""
211
+ # Check for common validation points
212
+ if not summary or len(summary.split()) < 20:
213
+ return False # Too short
214
+ if len(summary.split()) > len(original_text.split()) * 0.8:
215
+ return False # Too long
216
 
217
+ # Ensure structure is maintained (e.g., headings are present)
218
+ required_sections = ["background and objectives", "methods", "key findings", "conclusions"]
219
+ if not all(section.lower() in summary.lower() for section in required_sections):
 
220
  return False
221
+
222
+ # Ensure no repetitive sentences
223
+ sentences = summary.split('.')
224
+ if len(sentences) != len(set(sentences)):
 
225
  return False
226
+
227
  return True
228
 
229
+
230
  def generate_focused_summary(question, abstracts, model, tokenizer):
231
  """Generate focused summary based on question"""
232
  # Preprocess each abstract