medicalaiapp / app.py
pascal-maker's picture
Upload folder using huggingface_hub
92189dd verified
raw
history blame
22.1 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Combined Medical-VLM, **SAM-2 automatic masking**, and CheXagent demo.
β­‘ Changes β­‘
-----------
1. Fixed SAM-2 installation and import issues
2. Added proper error handling for missing dependencies
3. Made SAM-2 functionality optional with graceful fallback
4. Added installation instructions and requirements check
"""
# ---------------------------------------------------------------------
# Standard libs
# ---------------------------------------------------------------------
import os
import sys
import uuid
import tempfile
import subprocess
import warnings
from threading import Thread
# Environment setup
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
warnings.filterwarnings("ignore", message=r".*upsample_bicubic2d.*")
# ---------------------------------------------------------------------
# Third-party libs
# ---------------------------------------------------------------------
import torch
import numpy as np
from PIL import Image, ImageDraw
import gradio as gr
# =============================================================================
# Dependency checker and installer
# =============================================================================
def check_and_install_sam2():
"""Check if SAM-2 is available and attempt installation if needed."""
try:
print("[SAM-2 Debug] Attempting to import SAM-2 modules...")
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
print("[SAM-2 Debug] Successfully imported SAM-2 modules")
return True, "SAM-2 already available"
except ImportError as e:
print(f"[SAM-2 Debug] Import error: {str(e)}")
print("[SAM-2 Debug] Attempting to install SAM-2...")
try:
# Clone SAM-2 repository
if not os.path.exists("segment-anything-2"):
print("[SAM-2 Debug] Cloning SAM-2 repository...")
subprocess.run([
"git", "clone",
"https://github.com/facebookresearch/segment-anything-2.git"
], check=True)
print("[SAM-2 Debug] Repository cloned successfully")
# Install SAM-2
print("[SAM-2 Debug] Installing SAM-2...")
original_dir = os.getcwd()
os.chdir("segment-anything-2")
subprocess.run([sys.executable, "-m", "pip", "install", "-e", "."], check=True)
os.chdir(original_dir)
print("[SAM-2 Debug] Installation completed")
# Add to Python path
sam2_path = os.path.abspath("segment-anything-2")
if sam2_path not in sys.path:
sys.path.insert(0, sam2_path)
print(f"[SAM-2 Debug] Added {sam2_path} to Python path")
# Try importing again
print("[SAM-2 Debug] Attempting to import SAM-2 modules again...")
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
print("[SAM-2 Debug] Successfully imported SAM-2 modules after installation")
return True, "SAM-2 installed successfully"
except Exception as e:
print(f"[SAM-2 Debug] Installation failed: {str(e)}")
print(f"[SAM-2 Debug] Error type: {type(e).__name__}")
return False, f"SAM-2 installation failed: {e}"
# Check SAM-2 availability
SAM2_AVAILABLE, SAM2_STATUS = check_and_install_sam2()
print(f"SAM-2 Status: {SAM2_STATUS}")
# =============================================================================
# SAM-2 imports (conditional)
# =============================================================================
if SAM2_AVAILABLE:
try:
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.modeling.sam2_base import SAM2Base
except ImportError as e:
print(f"SAM-2 import error: {e}")
SAM2_AVAILABLE = False
# =============================================================================
# Qwen-VLM imports & helper
# =============================================================================
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
# =============================================================================
# CheXagent imports
# =============================================================================
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
# ---------------------------------------------------------------------
# Devices
# ---------------------------------------------------------------------
def get_device():
if torch.cuda.is_available():
return torch.device("cuda")
if torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
# =============================================================================
# Qwen-VLM model & agent
# =============================================================================
_qwen_model = None
_qwen_processor = None
_qwen_device = None
def load_qwen_model_and_processor(hf_token=None):
global _qwen_model, _qwen_processor, _qwen_device
if _qwen_model is None:
_qwen_device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"[Qwen] loading model on {_qwen_device}")
auth_kwargs = {"use_auth_token": hf_token} if hf_token else {}
_qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct",
trust_remote_code=True,
attn_implementation="eager",
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
device_map=None,
**auth_kwargs,
).to(_qwen_device)
_qwen_processor = AutoProcessor.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct",
trust_remote_code=True,
**auth_kwargs,
)
return _qwen_model, _qwen_processor, _qwen_device
class MedicalVLMAgent:
"""Light wrapper around Qwen-VLM with an optional image."""
def __init__(self, model, processor, device):
self.model = model
self.processor = processor
self.device = device
self.system_prompt = (
"You are a medical information assistant with vision capabilities.\n"
"Disclaimer: I am not a licensed medical professional. "
"The information provided is for reference only and should not be taken as medical advice."
)
def run(self, user_text: str, image: Image.Image | None = None) -> str:
messages = [
{"role": "system", "content": [{"type": "text", "text": self.system_prompt}]}
]
user_content = []
if image is not None:
tmp = f"/tmp/{uuid.uuid4()}.png"
image.save(tmp)
user_content.append({"type": "image", "image": tmp})
user_content.append({"type": "text", "text": user_text or "Please describe the image."})
messages.append({"role": "user", "content": user_content})
prompt_text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
img_inputs, vid_inputs = process_vision_info(messages)
inputs = self.processor(
text=[prompt_text],
images=img_inputs,
videos=vid_inputs,
padding=True,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
out = self.model.generate(**inputs, max_new_tokens=128)
trimmed = out[0][inputs.input_ids.shape[1] :]
return self.processor.decode(trimmed, skip_special_tokens=True).strip()
# =============================================================================
# SAM-2 model + AutomaticMaskGenerator (final minimal version)
# =============================================================================
import os
import numpy as np
from PIL import Image, ImageDraw
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
def initialize_sam2():
# These two files are already in your repo
CKPT = "checkpoints/sam2.1_hiera_large.pt" # β‰ˆ2.7 GB
CFG = "configs/sam2.1/sam2.1_hiera_l.yaml"
# One chdir so Hydra's search path starts inside sam2/sam2/
os.chdir("sam2/sam2")
device = get_device()
print(f"[SAM-2] building model on {device}")
sam2_model = build_sam2(
CFG, # relative to sam2/sam2/
CKPT, # relative after chdir
device=device,
apply_postprocessing=False,
)
mask_gen = SAM2AutomaticMaskGenerator(
model=sam2_model,
points_per_side=32,
pred_iou_thresh=0.86,
stability_score_thresh=0.92,
crop_n_layers=0,
)
return sam2_model, mask_gen
# ---------------------- build once ----------------------
try:
_sam2_model, _mask_generator = initialize_sam2()
print("[SAM-2] Successfully initialized!")
except Exception as e:
print(f"[SAM-2] Failed to initialize: {e}")
_sam2_model, _mask_generator = None, None
def automatic_mask_overlay(image_np: np.ndarray) -> np.ndarray:
"""Generate masks and alpha-blend them on top of the original image."""
if _mask_generator is None:
raise RuntimeError("SAM-2 mask generator not initialized")
anns = _mask_generator.generate(image_np)
if not anns:
return image_np
overlay = image_np.copy()
if overlay.ndim == 2: # grayscale β†’ RGB
overlay = np.stack([overlay] * 3, axis=2)
for ann in sorted(anns, key=lambda x: x["area"], reverse=True):
m = ann["segmentation"]
color = np.random.randint(0, 255, 3, dtype=np.uint8)
overlay[m] = (overlay[m] * 0.5 + color * 0.5).astype(np.uint8)
return overlay
def tumor_segmentation_interface(image: Image.Image | None):
if image is None:
return None, "Please upload an image."
if _mask_generator is None:
return None, "SAM-2 not properly initialized. Check the console for errors."
try:
img_np = np.array(image.convert("RGB"))
out_np = automatic_mask_overlay(img_np)
n_masks = len(_mask_generator.generate(img_np))
return Image.fromarray(out_np), f"{n_masks} masks found."
except Exception as e:
return None, f"SAM-2 error: {e}"
# =============================================================================
# Simple fallback segmentation (when SAM-2 is not available)
# =============================================================================
def simple_segmentation_fallback(image: Image.Image | None):
"""Simple fallback segmentation using basic image processing."""
if image is None:
return None, "Please upload an image."
try:
import cv2
from skimage import segmentation, color
# Convert to numpy array
img_np = np.array(image.convert("RGB"))
# Simple watershed segmentation
gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
# Remove noise
kernel = np.ones((3,3), np.uint8)
opening = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel, iterations=2)
# Sure background area
sure_bg = cv2.dilate(opening, kernel, iterations=3)
# Finding sure foreground area
dist_transform = cv2.distanceTransform(opening, cv2.DIST_L2, 5)
_, sure_fg = cv2.threshold(dist_transform, 0.7*dist_transform.max(), 255, 0)
# Create overlay
overlay = img_np.copy()
overlay[sure_fg > 0] = [255, 0, 0] # Red overlay
# Alpha blend
result = cv2.addWeighted(img_np, 0.7, overlay, 0.3, 0)
return Image.fromarray(result), "Simple segmentation applied (SAM-2 not available)"
except Exception as e:
return None, f"Fallback segmentation error: {e}"
# =============================================================================
# CheXagent set-up
# =============================================================================
try:
print("[CheXagent] Starting initialization...")
chex_name = "StanfordAIMI/CheXagent-2-3b"
print(f"[CheXagent] Loading tokenizer from {chex_name}")
chex_tok = AutoTokenizer.from_pretrained(chex_name, trust_remote_code=True)
print("[CheXagent] Tokenizer loaded successfully")
print("[CheXagent] Loading model...")
chex_model = AutoModelForCausalLM.from_pretrained(
chex_name,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
print("[CheXagent] Model loaded successfully")
if torch.cuda.is_available():
print("[CheXagent] Converting to half precision for GPU")
chex_model = chex_model.half()
else:
print("[CheXagent] Using full precision for CPU")
chex_model = chex_model.float()
chex_model.eval()
CHEXAGENT_AVAILABLE = True
print("[CheXagent] Initialization complete")
except Exception as e:
print(f"[CheXagent] Initialization failed: {str(e)}")
print(f"[CheXagent] Error type: {type(e).__name__}")
CHEXAGENT_AVAILABLE = False
chex_tok, chex_model = None, None
def get_model_device(model):
if model is None:
return torch.device("cpu")
for p in model.parameters():
return p.device
return torch.device("cpu")
def clean_text(text):
return text.replace("</s>", "")
@torch.no_grad()
def response_report_generation(pil_image_1, pil_image_2):
"""Structured chest-X-ray report (streaming)."""
if not CHEXAGENT_AVAILABLE:
yield "CheXagent is not available. Please check installation."
return
streamer = TextIteratorStreamer(chex_tok, skip_prompt=True, skip_special_tokens=True)
paths = []
for im in [pil_image_1, pil_image_2]:
if im is None:
continue
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tfile:
im.save(tfile.name)
paths.append(tfile.name)
if not paths:
yield "Please upload at least one image."
return
device = get_model_device(chex_model)
anatomies = [
"View",
"Airway",
"Breathing",
"Cardiac",
"Diaphragm",
"Everything else (e.g., mediastinal contours, bones, soft tissues, tubes, valves, pacemakers)",
]
prompts = [
"Determine the view of this CXR",
*[
f'Provide a detailed description of "{a}" in the chest X-ray'
for a in anatomies[1:]
],
]
findings = ""
partial = "## Generating Findings (step-by-step):\n\n"
for idx, (anat, prompt) in enumerate(zip(anatomies, prompts)):
query = chex_tok.from_list_format(
[*[{"image": p} for p in paths], {"text": prompt}]
)
conv = [
{"from": "system", "value": "You are a helpful assistant."},
{"from": "human", "value": query},
]
inp = chex_tok.apply_chat_template(
conv, add_generation_prompt=True, return_tensors="pt"
).to(device)
generate_kwargs = dict(
input_ids=inp,
max_new_tokens=512,
do_sample=False,
num_beams=1,
streamer=streamer,
)
Thread(target=chex_model.generate, kwargs=generate_kwargs).start()
partial += f"**Step {idx}: {anat}...**\n\n"
for tok in streamer:
if idx:
findings += tok
partial += tok
yield clean_text(partial)
partial += "\n\n"
findings += " "
findings = findings.strip()
# Impression
partial += "## Generating Impression\n\n"
prompt = f"Write the Impression section for the following Findings: {findings}"
conv = [
{"from": "system", "value": "You are a helpful assistant."},
{"from": "human", "value": chex_tok.from_list_format([{"text": prompt}])},
]
inp = chex_tok.apply_chat_template(
conv, add_generation_prompt=True, return_tensors="pt"
).to(device)
Thread(
target=chex_model.generate,
kwargs=dict(
input_ids=inp,
do_sample=False,
num_beams=1,
max_new_tokens=512,
streamer=streamer,
),
).start()
for tok in streamer:
partial += tok
yield clean_text(partial)
yield clean_text(partial)
@torch.no_grad()
def response_phrase_grounding(pil_image, prompt_text):
"""Very simple visual-grounding placeholder."""
if not CHEXAGENT_AVAILABLE:
return "CheXagent is not available. Please check installation.", None
if pil_image is None:
return "Please upload an image.", None
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tfile:
pil_image.save(tfile.name)
img_path = tfile.name
device = get_model_device(chex_model)
query = chex_tok.from_list_format([{"image": img_path}, {"text": prompt_text}])
conv = [
{"from": "system", "value": "You are a helpful assistant."},
{"from": "human", "value": query},
]
inp = chex_tok.apply_chat_template(
conv, add_generation_prompt=True, return_tensors="pt"
).to(device)
out = chex_model.generate(
input_ids=inp, do_sample=False, num_beams=1, max_new_tokens=512
)
resp = clean_text(chex_tok.decode(out[0][inp.shape[1] :]))
# simple center box (placeholder)
w, h = pil_image.size
cx, cy, sz = w // 2, h // 2, min(w, h) // 4
draw = ImageDraw.Draw(pil_image)
draw.rectangle([(cx - sz, cy - sz), (cx + sz, cy + sz)], outline="red", width=3)
return resp, pil_image
# =============================================================================
# Gradio UI
# =============================================================================
def create_ui():
"""Create the Gradio interface."""
# Load Qwen model
try:
qwen_model, qwen_proc, qwen_dev = load_qwen_model_and_processor()
med_agent = MedicalVLMAgent(qwen_model, qwen_proc, qwen_dev)
qwen_available = True
except Exception as e:
print(f"Qwen model not available: {e}")
qwen_available = False
med_agent = None
with gr.Blocks(title="Medical AI Assistant") as demo:
gr.Markdown("# Combined Medical Q&A Β· SAM-2 Automatic Masking Β· CheXagent")
# Status information
with gr.Row():
gr.Markdown(f"""
**System Status:**
- Qwen VLM: {'βœ… Available' if qwen_available else '❌ Not Available'}
- SAM-2: {'βœ… Available' if SAM2_AVAILABLE else '❌ Not Available'}
- CheXagent: {'βœ… Available' if CHEXAGENT_AVAILABLE else '❌ Not Available'}
""")
# Medical Q&A Tab
with gr.Tab("Medical Q&A"):
if qwen_available:
q_in = gr.Textbox(label="Question / description", lines=3)
q_img = gr.Image(label="Optional image", type="pil")
q_btn = gr.Button("Submit")
q_out = gr.Textbox(label="Answer")
q_btn.click(fn=med_agent.run, inputs=[q_in, q_img], outputs=q_out)
else:
gr.Markdown("❌ Medical Q&A is not available. Qwen model failed to load.")
# Segmentation Tab
with gr.Tab("Automatic masking"):
seg_img = gr.Image(label="Upload medical image", type="pil")
seg_btn = gr.Button("Run segmentation")
seg_out = gr.Image(label="Segmentation result", type="pil")
seg_status = gr.Textbox(label="Status", interactive=False)
if SAM2_AVAILABLE and _mask_generator is not None:
seg_btn.click(
fn=tumor_segmentation_interface,
inputs=seg_img,
outputs=[seg_out, seg_status],
)
else:
seg_btn.click(
fn=simple_segmentation_fallback,
inputs=seg_img,
outputs=[seg_out, seg_status],
)
# CheXagent Tabs
with gr.Tab("CheXagent – Structured report"):
if CHEXAGENT_AVAILABLE:
gr.Markdown("Upload one or two chest X-ray images; the report streams live.")
cx1 = gr.Image(label="Image 1", image_mode="L", type="pil")
cx2 = gr.Image(label="Image 2", image_mode="L", type="pil")
cx_report = gr.Markdown()
gr.Interface(
fn=response_report_generation,
inputs=[cx1, cx2],
outputs=cx_report,
live=True,
).render()
else:
gr.Markdown("❌ CheXagent structured report is not available.")
with gr.Tab("CheXagent – Visual grounding"):
if CHEXAGENT_AVAILABLE:
vg_img = gr.Image(image_mode="L", type="pil")
vg_prompt = gr.Textbox(value="Locate the highlighted finding:")
vg_text = gr.Markdown()
vg_out_img = gr.Image()
gr.Interface(
fn=response_phrase_grounding,
inputs=[vg_img, vg_prompt],
outputs=[vg_text, vg_out_img],
).render()
else:
gr.Markdown("❌ CheXagent visual grounding is not available.")
return demo
if __name__ == "__main__":
demo = create_ui()
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)