Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
""" | |
Combined Medical-VLM, **SAM-2 automatic masking**, and CheXagent demo. | |
β Changes β | |
----------- | |
1. Fixed SAM-2 installation and import issues | |
2. Added proper error handling for missing dependencies | |
3. Made SAM-2 functionality optional with graceful fallback | |
4. Added installation instructions and requirements check | |
""" | |
# --------------------------------------------------------------------- | |
# Standard libs | |
# --------------------------------------------------------------------- | |
import os | |
import sys | |
import uuid | |
import tempfile | |
import subprocess | |
import warnings | |
from threading import Thread | |
# Environment setup | |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" | |
warnings.filterwarnings("ignore", message=r".*upsample_bicubic2d.*") | |
# --------------------------------------------------------------------- | |
# Third-party libs | |
# --------------------------------------------------------------------- | |
import torch | |
import numpy as np | |
from PIL import Image, ImageDraw | |
import gradio as gr | |
# ============================================================================= | |
# Dependency checker and installer | |
# ============================================================================= | |
def check_and_install_sam2(): | |
"""Check if SAM-2 is available and attempt installation if needed.""" | |
try: | |
print("[SAM-2 Debug] Attempting to import SAM-2 modules...") | |
from sam2.build_sam import build_sam2 | |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator | |
print("[SAM-2 Debug] Successfully imported SAM-2 modules") | |
return True, "SAM-2 already available" | |
except ImportError as e: | |
print(f"[SAM-2 Debug] Import error: {str(e)}") | |
print("[SAM-2 Debug] Attempting to install SAM-2...") | |
try: | |
# Clone SAM-2 repository | |
if not os.path.exists("segment-anything-2"): | |
print("[SAM-2 Debug] Cloning SAM-2 repository...") | |
subprocess.run([ | |
"git", "clone", | |
"https://github.com/facebookresearch/segment-anything-2.git" | |
], check=True) | |
print("[SAM-2 Debug] Repository cloned successfully") | |
# Install SAM-2 | |
print("[SAM-2 Debug] Installing SAM-2...") | |
original_dir = os.getcwd() | |
os.chdir("segment-anything-2") | |
subprocess.run([sys.executable, "-m", "pip", "install", "-e", "."], check=True) | |
os.chdir(original_dir) | |
print("[SAM-2 Debug] Installation completed") | |
# Add to Python path | |
sam2_path = os.path.abspath("segment-anything-2") | |
if sam2_path not in sys.path: | |
sys.path.insert(0, sam2_path) | |
print(f"[SAM-2 Debug] Added {sam2_path} to Python path") | |
# Try importing again | |
print("[SAM-2 Debug] Attempting to import SAM-2 modules again...") | |
from sam2.build_sam import build_sam2 | |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator | |
print("[SAM-2 Debug] Successfully imported SAM-2 modules after installation") | |
return True, "SAM-2 installed successfully" | |
except Exception as e: | |
print(f"[SAM-2 Debug] Installation failed: {str(e)}") | |
print(f"[SAM-2 Debug] Error type: {type(e).__name__}") | |
return False, f"SAM-2 installation failed: {e}" | |
# Check SAM-2 availability | |
SAM2_AVAILABLE, SAM2_STATUS = check_and_install_sam2() | |
print(f"SAM-2 Status: {SAM2_STATUS}") | |
# ============================================================================= | |
# SAM-2 imports (conditional) | |
# ============================================================================= | |
if SAM2_AVAILABLE: | |
try: | |
from sam2.build_sam import build_sam2 | |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator | |
from sam2.modeling.sam2_base import SAM2Base | |
except ImportError as e: | |
print(f"SAM-2 import error: {e}") | |
SAM2_AVAILABLE = False | |
# ============================================================================= | |
# Qwen-VLM imports & helper | |
# ============================================================================= | |
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor | |
from qwen_vl_utils import process_vision_info | |
# ============================================================================= | |
# CheXagent imports | |
# ============================================================================= | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
# --------------------------------------------------------------------- | |
# Devices | |
# --------------------------------------------------------------------- | |
def get_device(): | |
if torch.cuda.is_available(): | |
return torch.device("cuda") | |
if torch.backends.mps.is_available(): | |
return torch.device("mps") | |
return torch.device("cpu") | |
# ============================================================================= | |
# Qwen-VLM model & agent | |
# ============================================================================= | |
_qwen_model = None | |
_qwen_processor = None | |
_qwen_device = None | |
def load_qwen_model_and_processor(hf_token=None): | |
global _qwen_model, _qwen_processor, _qwen_device | |
if _qwen_model is None: | |
_qwen_device = "mps" if torch.backends.mps.is_available() else "cpu" | |
print(f"[Qwen] loading model on {_qwen_device}") | |
auth_kwargs = {"use_auth_token": hf_token} if hf_token else {} | |
_qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
"Qwen/Qwen2.5-VL-3B-Instruct", | |
trust_remote_code=True, | |
attn_implementation="eager", | |
torch_dtype=torch.float32, | |
low_cpu_mem_usage=True, | |
device_map=None, | |
**auth_kwargs, | |
).to(_qwen_device) | |
_qwen_processor = AutoProcessor.from_pretrained( | |
"Qwen/Qwen2.5-VL-3B-Instruct", | |
trust_remote_code=True, | |
**auth_kwargs, | |
) | |
return _qwen_model, _qwen_processor, _qwen_device | |
class MedicalVLMAgent: | |
"""Light wrapper around Qwen-VLM with an optional image.""" | |
def __init__(self, model, processor, device): | |
self.model = model | |
self.processor = processor | |
self.device = device | |
self.system_prompt = ( | |
"You are a medical information assistant with vision capabilities.\n" | |
"Disclaimer: I am not a licensed medical professional. " | |
"The information provided is for reference only and should not be taken as medical advice." | |
) | |
def run(self, user_text: str, image: Image.Image | None = None) -> str: | |
messages = [ | |
{"role": "system", "content": [{"type": "text", "text": self.system_prompt}]} | |
] | |
user_content = [] | |
if image is not None: | |
tmp = f"/tmp/{uuid.uuid4()}.png" | |
image.save(tmp) | |
user_content.append({"type": "image", "image": tmp}) | |
user_content.append({"type": "text", "text": user_text or "Please describe the image."}) | |
messages.append({"role": "user", "content": user_content}) | |
prompt_text = self.processor.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
img_inputs, vid_inputs = process_vision_info(messages) | |
inputs = self.processor( | |
text=[prompt_text], | |
images=img_inputs, | |
videos=vid_inputs, | |
padding=True, | |
return_tensors="pt", | |
).to(self.device) | |
with torch.no_grad(): | |
out = self.model.generate(**inputs, max_new_tokens=128) | |
trimmed = out[0][inputs.input_ids.shape[1] :] | |
return self.processor.decode(trimmed, skip_special_tokens=True).strip() | |
# ============================================================================= | |
# SAM-2 model + AutomaticMaskGenerator (final minimal version) | |
# ============================================================================= | |
import os | |
import numpy as np | |
from PIL import Image, ImageDraw | |
from sam2.build_sam import build_sam2 | |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator | |
def initialize_sam2(): | |
# These two files are already in your repo | |
CKPT = "checkpoints/sam2.1_hiera_large.pt" # β2.7 GB | |
CFG = "configs/sam2.1/sam2.1_hiera_l.yaml" | |
# One chdir so Hydra's search path starts inside sam2/sam2/ | |
os.chdir("sam2/sam2") | |
device = get_device() | |
print(f"[SAM-2] building model on {device}") | |
sam2_model = build_sam2( | |
CFG, # relative to sam2/sam2/ | |
CKPT, # relative after chdir | |
device=device, | |
apply_postprocessing=False, | |
) | |
mask_gen = SAM2AutomaticMaskGenerator( | |
model=sam2_model, | |
points_per_side=32, | |
pred_iou_thresh=0.86, | |
stability_score_thresh=0.92, | |
crop_n_layers=0, | |
) | |
return sam2_model, mask_gen | |
# ---------------------- build once ---------------------- | |
try: | |
_sam2_model, _mask_generator = initialize_sam2() | |
print("[SAM-2] Successfully initialized!") | |
except Exception as e: | |
print(f"[SAM-2] Failed to initialize: {e}") | |
_sam2_model, _mask_generator = None, None | |
def automatic_mask_overlay(image_np: np.ndarray) -> np.ndarray: | |
"""Generate masks and alpha-blend them on top of the original image.""" | |
if _mask_generator is None: | |
raise RuntimeError("SAM-2 mask generator not initialized") | |
anns = _mask_generator.generate(image_np) | |
if not anns: | |
return image_np | |
overlay = image_np.copy() | |
if overlay.ndim == 2: # grayscale β RGB | |
overlay = np.stack([overlay] * 3, axis=2) | |
for ann in sorted(anns, key=lambda x: x["area"], reverse=True): | |
m = ann["segmentation"] | |
color = np.random.randint(0, 255, 3, dtype=np.uint8) | |
overlay[m] = (overlay[m] * 0.5 + color * 0.5).astype(np.uint8) | |
return overlay | |
def tumor_segmentation_interface(image: Image.Image | None): | |
if image is None: | |
return None, "Please upload an image." | |
if _mask_generator is None: | |
return None, "SAM-2 not properly initialized. Check the console for errors." | |
try: | |
img_np = np.array(image.convert("RGB")) | |
out_np = automatic_mask_overlay(img_np) | |
n_masks = len(_mask_generator.generate(img_np)) | |
return Image.fromarray(out_np), f"{n_masks} masks found." | |
except Exception as e: | |
return None, f"SAM-2 error: {e}" | |
# ============================================================================= | |
# Simple fallback segmentation (when SAM-2 is not available) | |
# ============================================================================= | |
def simple_segmentation_fallback(image: Image.Image | None): | |
"""Simple fallback segmentation using basic image processing.""" | |
if image is None: | |
return None, "Please upload an image." | |
try: | |
import cv2 | |
from skimage import segmentation, color | |
# Convert to numpy array | |
img_np = np.array(image.convert("RGB")) | |
# Simple watershed segmentation | |
gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY) | |
_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) | |
# Remove noise | |
kernel = np.ones((3,3), np.uint8) | |
opening = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel, iterations=2) | |
# Sure background area | |
sure_bg = cv2.dilate(opening, kernel, iterations=3) | |
# Finding sure foreground area | |
dist_transform = cv2.distanceTransform(opening, cv2.DIST_L2, 5) | |
_, sure_fg = cv2.threshold(dist_transform, 0.7*dist_transform.max(), 255, 0) | |
# Create overlay | |
overlay = img_np.copy() | |
overlay[sure_fg > 0] = [255, 0, 0] # Red overlay | |
# Alpha blend | |
result = cv2.addWeighted(img_np, 0.7, overlay, 0.3, 0) | |
return Image.fromarray(result), "Simple segmentation applied (SAM-2 not available)" | |
except Exception as e: | |
return None, f"Fallback segmentation error: {e}" | |
# ============================================================================= | |
# CheXagent set-up | |
# ============================================================================= | |
try: | |
print("[CheXagent] Starting initialization...") | |
chex_name = "StanfordAIMI/CheXagent-2-3b" | |
print(f"[CheXagent] Loading tokenizer from {chex_name}") | |
chex_tok = AutoTokenizer.from_pretrained(chex_name, trust_remote_code=True) | |
print("[CheXagent] Tokenizer loaded successfully") | |
print("[CheXagent] Loading model...") | |
chex_model = AutoModelForCausalLM.from_pretrained( | |
chex_name, | |
device_map="auto", | |
trust_remote_code=True, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
) | |
print("[CheXagent] Model loaded successfully") | |
if torch.cuda.is_available(): | |
print("[CheXagent] Converting to half precision for GPU") | |
chex_model = chex_model.half() | |
else: | |
print("[CheXagent] Using full precision for CPU") | |
chex_model = chex_model.float() | |
chex_model.eval() | |
CHEXAGENT_AVAILABLE = True | |
print("[CheXagent] Initialization complete") | |
except Exception as e: | |
print(f"[CheXagent] Initialization failed: {str(e)}") | |
print(f"[CheXagent] Error type: {type(e).__name__}") | |
CHEXAGENT_AVAILABLE = False | |
chex_tok, chex_model = None, None | |
def get_model_device(model): | |
if model is None: | |
return torch.device("cpu") | |
for p in model.parameters(): | |
return p.device | |
return torch.device("cpu") | |
def clean_text(text): | |
return text.replace("</s>", "") | |
def response_report_generation(pil_image_1, pil_image_2): | |
"""Structured chest-X-ray report (streaming).""" | |
if not CHEXAGENT_AVAILABLE: | |
yield "CheXagent is not available. Please check installation." | |
return | |
streamer = TextIteratorStreamer(chex_tok, skip_prompt=True, skip_special_tokens=True) | |
paths = [] | |
for im in [pil_image_1, pil_image_2]: | |
if im is None: | |
continue | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tfile: | |
im.save(tfile.name) | |
paths.append(tfile.name) | |
if not paths: | |
yield "Please upload at least one image." | |
return | |
device = get_model_device(chex_model) | |
anatomies = [ | |
"View", | |
"Airway", | |
"Breathing", | |
"Cardiac", | |
"Diaphragm", | |
"Everything else (e.g., mediastinal contours, bones, soft tissues, tubes, valves, pacemakers)", | |
] | |
prompts = [ | |
"Determine the view of this CXR", | |
*[ | |
f'Provide a detailed description of "{a}" in the chest X-ray' | |
for a in anatomies[1:] | |
], | |
] | |
findings = "" | |
partial = "## Generating Findings (step-by-step):\n\n" | |
for idx, (anat, prompt) in enumerate(zip(anatomies, prompts)): | |
query = chex_tok.from_list_format( | |
[*[{"image": p} for p in paths], {"text": prompt}] | |
) | |
conv = [ | |
{"from": "system", "value": "You are a helpful assistant."}, | |
{"from": "human", "value": query}, | |
] | |
inp = chex_tok.apply_chat_template( | |
conv, add_generation_prompt=True, return_tensors="pt" | |
).to(device) | |
generate_kwargs = dict( | |
input_ids=inp, | |
max_new_tokens=512, | |
do_sample=False, | |
num_beams=1, | |
streamer=streamer, | |
) | |
Thread(target=chex_model.generate, kwargs=generate_kwargs).start() | |
partial += f"**Step {idx}: {anat}...**\n\n" | |
for tok in streamer: | |
if idx: | |
findings += tok | |
partial += tok | |
yield clean_text(partial) | |
partial += "\n\n" | |
findings += " " | |
findings = findings.strip() | |
# Impression | |
partial += "## Generating Impression\n\n" | |
prompt = f"Write the Impression section for the following Findings: {findings}" | |
conv = [ | |
{"from": "system", "value": "You are a helpful assistant."}, | |
{"from": "human", "value": chex_tok.from_list_format([{"text": prompt}])}, | |
] | |
inp = chex_tok.apply_chat_template( | |
conv, add_generation_prompt=True, return_tensors="pt" | |
).to(device) | |
Thread( | |
target=chex_model.generate, | |
kwargs=dict( | |
input_ids=inp, | |
do_sample=False, | |
num_beams=1, | |
max_new_tokens=512, | |
streamer=streamer, | |
), | |
).start() | |
for tok in streamer: | |
partial += tok | |
yield clean_text(partial) | |
yield clean_text(partial) | |
def response_phrase_grounding(pil_image, prompt_text): | |
"""Very simple visual-grounding placeholder.""" | |
if not CHEXAGENT_AVAILABLE: | |
return "CheXagent is not available. Please check installation.", None | |
if pil_image is None: | |
return "Please upload an image.", None | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tfile: | |
pil_image.save(tfile.name) | |
img_path = tfile.name | |
device = get_model_device(chex_model) | |
query = chex_tok.from_list_format([{"image": img_path}, {"text": prompt_text}]) | |
conv = [ | |
{"from": "system", "value": "You are a helpful assistant."}, | |
{"from": "human", "value": query}, | |
] | |
inp = chex_tok.apply_chat_template( | |
conv, add_generation_prompt=True, return_tensors="pt" | |
).to(device) | |
out = chex_model.generate( | |
input_ids=inp, do_sample=False, num_beams=1, max_new_tokens=512 | |
) | |
resp = clean_text(chex_tok.decode(out[0][inp.shape[1] :])) | |
# simple center box (placeholder) | |
w, h = pil_image.size | |
cx, cy, sz = w // 2, h // 2, min(w, h) // 4 | |
draw = ImageDraw.Draw(pil_image) | |
draw.rectangle([(cx - sz, cy - sz), (cx + sz, cy + sz)], outline="red", width=3) | |
return resp, pil_image | |
# ============================================================================= | |
# Gradio UI | |
# ============================================================================= | |
def create_ui(): | |
"""Create the Gradio interface.""" | |
# Load Qwen model | |
try: | |
qwen_model, qwen_proc, qwen_dev = load_qwen_model_and_processor() | |
med_agent = MedicalVLMAgent(qwen_model, qwen_proc, qwen_dev) | |
qwen_available = True | |
except Exception as e: | |
print(f"Qwen model not available: {e}") | |
qwen_available = False | |
med_agent = None | |
with gr.Blocks(title="Medical AI Assistant") as demo: | |
gr.Markdown("# Combined Medical Q&A Β· SAM-2 Automatic Masking Β· CheXagent") | |
# Status information | |
with gr.Row(): | |
gr.Markdown(f""" | |
**System Status:** | |
- Qwen VLM: {'β Available' if qwen_available else 'β Not Available'} | |
- SAM-2: {'β Available' if SAM2_AVAILABLE else 'β Not Available'} | |
- CheXagent: {'β Available' if CHEXAGENT_AVAILABLE else 'β Not Available'} | |
""") | |
# Medical Q&A Tab | |
with gr.Tab("Medical Q&A"): | |
if qwen_available: | |
q_in = gr.Textbox(label="Question / description", lines=3) | |
q_img = gr.Image(label="Optional image", type="pil") | |
q_btn = gr.Button("Submit") | |
q_out = gr.Textbox(label="Answer") | |
q_btn.click(fn=med_agent.run, inputs=[q_in, q_img], outputs=q_out) | |
else: | |
gr.Markdown("β Medical Q&A is not available. Qwen model failed to load.") | |
# Segmentation Tab | |
with gr.Tab("Automatic masking"): | |
seg_img = gr.Image(label="Upload medical image", type="pil") | |
seg_btn = gr.Button("Run segmentation") | |
seg_out = gr.Image(label="Segmentation result", type="pil") | |
seg_status = gr.Textbox(label="Status", interactive=False) | |
if SAM2_AVAILABLE and _mask_generator is not None: | |
seg_btn.click( | |
fn=tumor_segmentation_interface, | |
inputs=seg_img, | |
outputs=[seg_out, seg_status], | |
) | |
else: | |
seg_btn.click( | |
fn=simple_segmentation_fallback, | |
inputs=seg_img, | |
outputs=[seg_out, seg_status], | |
) | |
# CheXagent Tabs | |
with gr.Tab("CheXagent β Structured report"): | |
if CHEXAGENT_AVAILABLE: | |
gr.Markdown("Upload one or two chest X-ray images; the report streams live.") | |
cx1 = gr.Image(label="Image 1", image_mode="L", type="pil") | |
cx2 = gr.Image(label="Image 2", image_mode="L", type="pil") | |
cx_report = gr.Markdown() | |
gr.Interface( | |
fn=response_report_generation, | |
inputs=[cx1, cx2], | |
outputs=cx_report, | |
live=True, | |
).render() | |
else: | |
gr.Markdown("β CheXagent structured report is not available.") | |
with gr.Tab("CheXagent β Visual grounding"): | |
if CHEXAGENT_AVAILABLE: | |
vg_img = gr.Image(image_mode="L", type="pil") | |
vg_prompt = gr.Textbox(value="Locate the highlighted finding:") | |
vg_text = gr.Markdown() | |
vg_out_img = gr.Image() | |
gr.Interface( | |
fn=response_phrase_grounding, | |
inputs=[vg_img, vg_prompt], | |
outputs=[vg_text, vg_out_img], | |
).render() | |
else: | |
gr.Markdown("β CheXagent visual grounding is not available.") | |
return demo | |
if __name__ == "__main__": | |
demo = create_ui() | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |