cvips
commited on
Commit
·
7d2f14f
1
Parent(s):
4d2300f
biomed-llama_multimodal
Browse files
app.py
CHANGED
@@ -135,6 +135,78 @@ MODALITY_PROMPTS = {
|
|
135 |
"OCT": ["edema"]
|
136 |
}
|
137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
|
139 |
def extract_modality_from_llm(llm_output):
|
140 |
"""Extract modality from LLM output and map it to BIOMEDPARSE_MODES"""
|
@@ -278,7 +350,7 @@ def process_image(image_path, user_prompt, modality=None):
|
|
278 |
# f"Analyze this medical image considering the following context: {user_prompt}. "
|
279 |
# "Include modality, anatomical structures, and any abnormalities."
|
280 |
# )
|
281 |
-
question = 'modality?'
|
282 |
msgs = [{'role': 'user', 'content': [pil_image, question]}]
|
283 |
|
284 |
llm_response = ""
|
@@ -299,32 +371,29 @@ def process_image(image_path, user_prompt, modality=None):
|
|
299 |
else:
|
300 |
llm_response = "LLM not available. Please check LLM initialization logs."
|
301 |
|
302 |
-
detected_modality =
|
303 |
if not detected_modality:
|
304 |
detected_modality = "X-Ray-Chest" # Fallback modality
|
305 |
-
|
306 |
-
clinical_findings = extract_clinical_findings(llm_response, detected_modality)
|
307 |
-
if not clinical_findings:
|
308 |
-
clinical_findings = [detected_modality.split("-")[-1].lower()]
|
309 |
|
310 |
results = []
|
311 |
analysis_results = []
|
312 |
colors = [(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]
|
313 |
|
314 |
-
for idx,
|
315 |
try:
|
316 |
-
mask_list = interactive_infer_image(model, pil_image, [
|
317 |
if not mask_list or len(mask_list) == 0:
|
318 |
-
analysis_results.append(f"No mask generated for '{
|
319 |
continue
|
320 |
|
321 |
pred_mask = mask_list[0]
|
322 |
-
if pred_mask is None or not pred_mask.any():
|
323 |
-
|
324 |
-
|
325 |
|
326 |
-
p_value = check_mask_stats(image, pred_mask.astype(np.uint8) * 255, detected_modality, finding)
|
327 |
-
analysis_results.append(f"P-value for '{finding}' ({detected_modality}): {p_value:.4f}")
|
328 |
|
329 |
overlay_image = image.copy()
|
330 |
color = colors[idx % len(colors)]
|
@@ -333,26 +402,36 @@ def process_image(image_path, user_prompt, modality=None):
|
|
333 |
overlay_image[mask_indices] = color
|
334 |
results.append(overlay_image)
|
335 |
except Exception as e:
|
336 |
-
print(f"Error processing finding {
|
337 |
-
analysis_results.append(f"Failed to process '{
|
338 |
|
339 |
if not results:
|
340 |
results = [image] # Return original image if no overlays were created
|
341 |
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
354 |
|
355 |
-
return results,
|
356 |
|
357 |
except Exception as e:
|
358 |
error_msg = f"⚠️ An error occurred: {str(e)}"
|
|
|
135 |
"OCT": ["edema"]
|
136 |
}
|
137 |
|
138 |
+
def extract_modality_and_prompts(llm_output):
|
139 |
+
"""
|
140 |
+
Extract modality and relevant prompts from LLM output
|
141 |
+
Returns: (modality_type, list_of_prompts)
|
142 |
+
"""
|
143 |
+
llm_output = llm_output.lower()
|
144 |
+
|
145 |
+
# Dictionary mapping keywords to modalities
|
146 |
+
modality_indicators = {
|
147 |
+
'dermatoscop': 'Dermoscopy',
|
148 |
+
'skin lesion': 'Dermoscopy',
|
149 |
+
'oct': 'OCT',
|
150 |
+
'optical coherence': 'OCT',
|
151 |
+
'fundus': 'Fundus',
|
152 |
+
'retina': 'Fundus',
|
153 |
+
'endoscop': 'Endoscopy',
|
154 |
+
'colon': 'Endoscopy',
|
155 |
+
'patholog': 'Pathology',
|
156 |
+
'tissue': 'Pathology',
|
157 |
+
'histolog': 'Pathology',
|
158 |
+
'x-ray': 'X-Ray-Chest',
|
159 |
+
'xray': 'X-Ray-Chest',
|
160 |
+
'chest radiograph': 'X-Ray-Chest',
|
161 |
+
'mri': None, # Will be refined below
|
162 |
+
'magnetic resonance': None, # Will be refined below
|
163 |
+
'ct': None, # Will be refined below
|
164 |
+
'computed tomography': None, # Will be refined below
|
165 |
+
'ultrasound': 'Ultrasound-Cardiac',
|
166 |
+
'sonograph': 'Ultrasound-Cardiac'
|
167 |
+
}
|
168 |
+
|
169 |
+
# First pass: Detect base modality
|
170 |
+
detected_modality = None
|
171 |
+
for keyword, modality in modality_indicators.items():
|
172 |
+
if keyword in llm_output:
|
173 |
+
detected_modality = modality
|
174 |
+
break
|
175 |
+
|
176 |
+
# Second pass: Refine MRI and CT if detected
|
177 |
+
if detected_modality is None and ('mri' in llm_output or 'magnetic resonance' in llm_output):
|
178 |
+
if 'brain' in llm_output or 'flair' in llm_output:
|
179 |
+
detected_modality = 'MRI-FLAIR-Brain'
|
180 |
+
elif 'cardiac' in llm_output or 'heart' in llm_output:
|
181 |
+
detected_modality = 'MRI-Cardiac'
|
182 |
+
elif 'abdomen' in llm_output:
|
183 |
+
detected_modality = 'MRI-Abdomen'
|
184 |
+
elif 't1' in llm_output or 'contrast' in llm_output:
|
185 |
+
detected_modality = 'MRI-T1-Gd-Brain'
|
186 |
+
else:
|
187 |
+
detected_modality = 'MRI'
|
188 |
+
|
189 |
+
if detected_modality is None and ('ct' in llm_output or 'computed tomography' in llm_output):
|
190 |
+
if 'chest' in llm_output or 'lung' in llm_output:
|
191 |
+
detected_modality = 'CT-Chest'
|
192 |
+
elif 'liver' in llm_output:
|
193 |
+
detected_modality = 'CT-Liver'
|
194 |
+
elif 'abdomen' in llm_output:
|
195 |
+
detected_modality = 'CT-Abdomen'
|
196 |
+
else:
|
197 |
+
detected_modality = 'CT'
|
198 |
+
|
199 |
+
# If still no modality detected, return None
|
200 |
+
if not detected_modality:
|
201 |
+
return "", []
|
202 |
+
|
203 |
+
# Get relevant prompts for the detected modality
|
204 |
+
if detected_modality in MODALITY_PROMPTS:
|
205 |
+
relevant_prompts = MODALITY_PROMPTS[detected_modality]
|
206 |
+
else:
|
207 |
+
relevant_prompts = []
|
208 |
+
|
209 |
+
return detected_modality, relevant_prompts
|
210 |
|
211 |
def extract_modality_from_llm(llm_output):
|
212 |
"""Extract modality from LLM output and map it to BIOMEDPARSE_MODES"""
|
|
|
350 |
# f"Analyze this medical image considering the following context: {user_prompt}. "
|
351 |
# "Include modality, anatomical structures, and any abnormalities."
|
352 |
# )
|
353 |
+
question = 'What type of medical imaging modality is this? Please be specific.'
|
354 |
msgs = [{'role': 'user', 'content': [pil_image, question]}]
|
355 |
|
356 |
llm_response = ""
|
|
|
371 |
else:
|
372 |
llm_response = "LLM not available. Please check LLM initialization logs."
|
373 |
|
374 |
+
detected_modality, relevant_prompts = extract_modality_and_prompts(llm_response)
|
375 |
if not detected_modality:
|
376 |
detected_modality = "X-Ray-Chest" # Fallback modality
|
377 |
+
relevant_prompts = MODALITY_PROMPTS["X-Ray-Chest"]
|
|
|
|
|
|
|
378 |
|
379 |
results = []
|
380 |
analysis_results = []
|
381 |
colors = [(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]
|
382 |
|
383 |
+
for idx, prompt in enumerate(relevant_prompts):
|
384 |
try:
|
385 |
+
mask_list = interactive_infer_image(model, pil_image, [prompt])
|
386 |
if not mask_list or len(mask_list) == 0:
|
387 |
+
analysis_results.append(f"No mask generated for '{prompt}'")
|
388 |
continue
|
389 |
|
390 |
pred_mask = mask_list[0]
|
391 |
+
# if pred_mask is None or not pred_mask.any():
|
392 |
+
# analysis_results.append(f"Empty mask generated for '{finding}'")
|
393 |
+
# continue
|
394 |
|
395 |
+
# p_value = check_mask_stats(image, pred_mask.astype(np.uint8) * 255, detected_modality, finding)
|
396 |
+
# analysis_results.append(f"P-value for '{finding}' ({detected_modality}): {p_value:.4f}")
|
397 |
|
398 |
overlay_image = image.copy()
|
399 |
color = colors[idx % len(colors)]
|
|
|
402 |
overlay_image[mask_indices] = color
|
403 |
results.append(overlay_image)
|
404 |
except Exception as e:
|
405 |
+
print(f"Error processing finding {prompt}: {str(e)}")
|
406 |
+
analysis_results.append(f"Failed to process '{prompt}': {str(e)}")
|
407 |
|
408 |
if not results:
|
409 |
results = [image] # Return original image if no overlays were created
|
410 |
|
411 |
+
detailed_analysis = ""
|
412 |
+
# try:
|
413 |
+
analysis_prompt = f"Give the modality, organ, analysis, abnormalities (if any), treatment (if abnormalities are present) for this image. Focus more on the user question. which is: {user_prompt}"
|
414 |
+
msgs = [{'role': 'user', 'content': [pil_image, analysis_prompt]}]
|
415 |
|
416 |
+
# llm_response = ""
|
417 |
+
if llm_model and llm_tokenizer:
|
418 |
+
try:
|
419 |
+
for new_text in llm_model.chat(
|
420 |
+
image=pil_image,
|
421 |
+
msgs=msgs,
|
422 |
+
tokenizer=llm_tokenizer,
|
423 |
+
sampling=True,
|
424 |
+
temperature=0.95,
|
425 |
+
stream=True
|
426 |
+
):
|
427 |
+
detailed_analysis += new_text
|
428 |
+
except Exception as e:
|
429 |
+
print(f"LLM chat error: {str(e)}")
|
430 |
+
detailed_analysis = "LLM analysis failed. Proceeding with basic analysis."
|
431 |
+
else:
|
432 |
+
detailed_analysis = "LLM not available. Please check LLM initialization logs."
|
433 |
|
434 |
+
return results, detailed_analysis, detected_modality
|
435 |
|
436 |
except Exception as e:
|
437 |
error_msg = f"⚠️ An error occurred: {str(e)}"
|