cvips commited on
Commit
7d2f14f
·
1 Parent(s): 4d2300f

biomed-llama_multimodal

Browse files
Files changed (1) hide show
  1. app.py +107 -28
app.py CHANGED
@@ -135,6 +135,78 @@ MODALITY_PROMPTS = {
135
  "OCT": ["edema"]
136
  }
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  def extract_modality_from_llm(llm_output):
140
  """Extract modality from LLM output and map it to BIOMEDPARSE_MODES"""
@@ -278,7 +350,7 @@ def process_image(image_path, user_prompt, modality=None):
278
  # f"Analyze this medical image considering the following context: {user_prompt}. "
279
  # "Include modality, anatomical structures, and any abnormalities."
280
  # )
281
- question = 'modality?'
282
  msgs = [{'role': 'user', 'content': [pil_image, question]}]
283
 
284
  llm_response = ""
@@ -299,32 +371,29 @@ def process_image(image_path, user_prompt, modality=None):
299
  else:
300
  llm_response = "LLM not available. Please check LLM initialization logs."
301
 
302
- detected_modality = extract_modality_from_llm(llm_response)
303
  if not detected_modality:
304
  detected_modality = "X-Ray-Chest" # Fallback modality
305
-
306
- clinical_findings = extract_clinical_findings(llm_response, detected_modality)
307
- if not clinical_findings:
308
- clinical_findings = [detected_modality.split("-")[-1].lower()]
309
 
310
  results = []
311
  analysis_results = []
312
  colors = [(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]
313
 
314
- for idx, finding in enumerate(clinical_findings):
315
  try:
316
- mask_list = interactive_infer_image(model, pil_image, [finding])
317
  if not mask_list or len(mask_list) == 0:
318
- analysis_results.append(f"No mask generated for '{finding}'")
319
  continue
320
 
321
  pred_mask = mask_list[0]
322
- if pred_mask is None or not pred_mask.any():
323
- analysis_results.append(f"Empty mask generated for '{finding}'")
324
- continue
325
 
326
- p_value = check_mask_stats(image, pred_mask.astype(np.uint8) * 255, detected_modality, finding)
327
- analysis_results.append(f"P-value for '{finding}' ({detected_modality}): {p_value:.4f}")
328
 
329
  overlay_image = image.copy()
330
  color = colors[idx % len(colors)]
@@ -333,26 +402,36 @@ def process_image(image_path, user_prompt, modality=None):
333
  overlay_image[mask_indices] = color
334
  results.append(overlay_image)
335
  except Exception as e:
336
- print(f"Error processing finding {finding}: {str(e)}")
337
- analysis_results.append(f"Failed to process '{finding}': {str(e)}")
338
 
339
  if not results:
340
  results = [image] # Return original image if no overlays were created
341
 
342
- enhanced_response = llm_response + "\n\nSegmentation Results:\n"
343
- for idx, finding in enumerate(clinical_findings):
344
- color_name = ["red", "green", "blue", "yellow", "magenta"][idx % len(colors)]
345
- enhanced_response += f"- {finding} (shown in {color_name})\n"
346
 
347
- combined_analysis = "\n\n" + "="*50 + "\n"
348
- combined_analysis += "BiomedParse Analysis:\n"
349
- combined_analysis += "\n".join(analysis_results)
350
- combined_analysis += "\n\n" + "="*50 + "\n"
351
- combined_analysis += "Enhanced LLM Analysis:\n"
352
- combined_analysis += enhanced_response
353
- combined_analysis += "\n" + "="*50
 
 
 
 
 
 
 
 
 
 
354
 
355
- return results, combined_analysis, detected_modality
356
 
357
  except Exception as e:
358
  error_msg = f"⚠️ An error occurred: {str(e)}"
 
135
  "OCT": ["edema"]
136
  }
137
 
138
+ def extract_modality_and_prompts(llm_output):
139
+ """
140
+ Extract modality and relevant prompts from LLM output
141
+ Returns: (modality_type, list_of_prompts)
142
+ """
143
+ llm_output = llm_output.lower()
144
+
145
+ # Dictionary mapping keywords to modalities
146
+ modality_indicators = {
147
+ 'dermatoscop': 'Dermoscopy',
148
+ 'skin lesion': 'Dermoscopy',
149
+ 'oct': 'OCT',
150
+ 'optical coherence': 'OCT',
151
+ 'fundus': 'Fundus',
152
+ 'retina': 'Fundus',
153
+ 'endoscop': 'Endoscopy',
154
+ 'colon': 'Endoscopy',
155
+ 'patholog': 'Pathology',
156
+ 'tissue': 'Pathology',
157
+ 'histolog': 'Pathology',
158
+ 'x-ray': 'X-Ray-Chest',
159
+ 'xray': 'X-Ray-Chest',
160
+ 'chest radiograph': 'X-Ray-Chest',
161
+ 'mri': None, # Will be refined below
162
+ 'magnetic resonance': None, # Will be refined below
163
+ 'ct': None, # Will be refined below
164
+ 'computed tomography': None, # Will be refined below
165
+ 'ultrasound': 'Ultrasound-Cardiac',
166
+ 'sonograph': 'Ultrasound-Cardiac'
167
+ }
168
+
169
+ # First pass: Detect base modality
170
+ detected_modality = None
171
+ for keyword, modality in modality_indicators.items():
172
+ if keyword in llm_output:
173
+ detected_modality = modality
174
+ break
175
+
176
+ # Second pass: Refine MRI and CT if detected
177
+ if detected_modality is None and ('mri' in llm_output or 'magnetic resonance' in llm_output):
178
+ if 'brain' in llm_output or 'flair' in llm_output:
179
+ detected_modality = 'MRI-FLAIR-Brain'
180
+ elif 'cardiac' in llm_output or 'heart' in llm_output:
181
+ detected_modality = 'MRI-Cardiac'
182
+ elif 'abdomen' in llm_output:
183
+ detected_modality = 'MRI-Abdomen'
184
+ elif 't1' in llm_output or 'contrast' in llm_output:
185
+ detected_modality = 'MRI-T1-Gd-Brain'
186
+ else:
187
+ detected_modality = 'MRI'
188
+
189
+ if detected_modality is None and ('ct' in llm_output or 'computed tomography' in llm_output):
190
+ if 'chest' in llm_output or 'lung' in llm_output:
191
+ detected_modality = 'CT-Chest'
192
+ elif 'liver' in llm_output:
193
+ detected_modality = 'CT-Liver'
194
+ elif 'abdomen' in llm_output:
195
+ detected_modality = 'CT-Abdomen'
196
+ else:
197
+ detected_modality = 'CT'
198
+
199
+ # If still no modality detected, return None
200
+ if not detected_modality:
201
+ return "", []
202
+
203
+ # Get relevant prompts for the detected modality
204
+ if detected_modality in MODALITY_PROMPTS:
205
+ relevant_prompts = MODALITY_PROMPTS[detected_modality]
206
+ else:
207
+ relevant_prompts = []
208
+
209
+ return detected_modality, relevant_prompts
210
 
211
  def extract_modality_from_llm(llm_output):
212
  """Extract modality from LLM output and map it to BIOMEDPARSE_MODES"""
 
350
  # f"Analyze this medical image considering the following context: {user_prompt}. "
351
  # "Include modality, anatomical structures, and any abnormalities."
352
  # )
353
+ question = 'What type of medical imaging modality is this? Please be specific.'
354
  msgs = [{'role': 'user', 'content': [pil_image, question]}]
355
 
356
  llm_response = ""
 
371
  else:
372
  llm_response = "LLM not available. Please check LLM initialization logs."
373
 
374
+ detected_modality, relevant_prompts = extract_modality_and_prompts(llm_response)
375
  if not detected_modality:
376
  detected_modality = "X-Ray-Chest" # Fallback modality
377
+ relevant_prompts = MODALITY_PROMPTS["X-Ray-Chest"]
 
 
 
378
 
379
  results = []
380
  analysis_results = []
381
  colors = [(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]
382
 
383
+ for idx, prompt in enumerate(relevant_prompts):
384
  try:
385
+ mask_list = interactive_infer_image(model, pil_image, [prompt])
386
  if not mask_list or len(mask_list) == 0:
387
+ analysis_results.append(f"No mask generated for '{prompt}'")
388
  continue
389
 
390
  pred_mask = mask_list[0]
391
+ # if pred_mask is None or not pred_mask.any():
392
+ # analysis_results.append(f"Empty mask generated for '{finding}'")
393
+ # continue
394
 
395
+ # p_value = check_mask_stats(image, pred_mask.astype(np.uint8) * 255, detected_modality, finding)
396
+ # analysis_results.append(f"P-value for '{finding}' ({detected_modality}): {p_value:.4f}")
397
 
398
  overlay_image = image.copy()
399
  color = colors[idx % len(colors)]
 
402
  overlay_image[mask_indices] = color
403
  results.append(overlay_image)
404
  except Exception as e:
405
+ print(f"Error processing finding {prompt}: {str(e)}")
406
+ analysis_results.append(f"Failed to process '{prompt}': {str(e)}")
407
 
408
  if not results:
409
  results = [image] # Return original image if no overlays were created
410
 
411
+ detailed_analysis = ""
412
+ # try:
413
+ analysis_prompt = f"Give the modality, organ, analysis, abnormalities (if any), treatment (if abnormalities are present) for this image. Focus more on the user question. which is: {user_prompt}"
414
+ msgs = [{'role': 'user', 'content': [pil_image, analysis_prompt]}]
415
 
416
+ # llm_response = ""
417
+ if llm_model and llm_tokenizer:
418
+ try:
419
+ for new_text in llm_model.chat(
420
+ image=pil_image,
421
+ msgs=msgs,
422
+ tokenizer=llm_tokenizer,
423
+ sampling=True,
424
+ temperature=0.95,
425
+ stream=True
426
+ ):
427
+ detailed_analysis += new_text
428
+ except Exception as e:
429
+ print(f"LLM chat error: {str(e)}")
430
+ detailed_analysis = "LLM analysis failed. Proceeding with basic analysis."
431
+ else:
432
+ detailed_analysis = "LLM not available. Please check LLM initialization logs."
433
 
434
+ return results, detailed_analysis, detected_modality
435
 
436
  except Exception as e:
437
  error_msg = f"⚠️ An error occurred: {str(e)}"