MedSAM_Demo / app.py
dennistrujillo's picture
Create app.py
e399e14 verified
import gradio as gr
import pandas as pd
import numpy as np
import pydicom
import os
from skimage import transform
import torch
from segment_anything import sam_model_registry
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn.functional as F
import io
import cv2
import nrrd
from gradio_image_prompter import ImagePrompter
class PointPromptDemo:
def __init__(self, model):
self.model = model
self.model.eval()
self.image = None
self.image_embeddings = None
self.img_size = None
@torch.no_grad()
def infer(self, x, y):
coords_1024 = np.array([[[
x * 1024 / self.img_size[1],
y * 1024 / self.img_size[0]
]]])
coords_torch = torch.tensor(coords_1024, dtype=torch.float32).to(self.model.device)
labels_torch = torch.tensor([[1]], dtype=torch.long).to(self.model.device)
point_prompt = (coords_torch, labels_torch)
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
points=point_prompt,
boxes=None,
masks=None,
)
low_res_logits, _ = self.model.mask_decoder(
image_embeddings=self.image_embeddings,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
)
low_res_probs = torch.sigmoid(low_res_logits)
low_res_pred = F.interpolate(
low_res_probs,
size=self.img_size,
mode='bilinear',
align_corners=False
)
low_res_pred = low_res_pred.detach().cpu().numpy().squeeze()
seg = np.uint8(low_res_pred > 0.5)
return seg
def set_image(self, image):
self.img_size = image.shape[:2]
if len(image.shape) == 2:
image = np.repeat(image[:,:,None], 3, -1)
self.image = image
image_preprocess = self.preprocess_image(self.image)
with torch.no_grad():
self.image_embeddings = self.model.image_encoder(image_preprocess)
def preprocess_image(self, image):
img_resize = cv2.resize(
image,
(1024, 1024),
interpolation=cv2.INTER_CUBIC
)
img_resize = (img_resize - img_resize.min()) / np.clip(img_resize.max() - img_resize.min(), a_min=1e-8, a_max=None)
assert np.max(img_resize)<=1.0 and np.min(img_resize)>=0.0, 'image should be normalized to [0, 1]'
img_tensor = torch.tensor(img_resize).float().permute(2, 0, 1).unsqueeze(0).to(self.model.device)
return img_tensor
def load_image(file_path):
if file_path.endswith(".dcm"):
ds = pydicom.dcmread(file_path)
img = ds.pixel_array
elif file_path.endswith(".nrrd"):
img, _ = nrrd.read(file_path)
else:
img = np.array(Image.open(file_path))
if len(img.shape) == 2:
img = np.stack((img,)*3, axis=-1)
return img
def visualize(image, mask):
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(image)
ax[1].imshow(image)
ax[1].imshow(mask, alpha=0.5, cmap="jet")
plt.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format='png')
plt.close(fig)
buf.seek(0)
pil_img = Image.open(buf)
return pil_img
def process_images(img_dict):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
img = img_dict['image']
points = img_dict['points'][0]
if len(points) < 2:
raise ValueError("At least one point is required for ROI selection.")
x, y = points[0], points[1]
model_checkpoint_path = "medsam_point_prompt_flare22.pth"
medsam_model = sam_model_registry['vit_b'](checkpoint=model_checkpoint_path)
medsam_model = medsam_model.to(device)
medsam_model.eval()
point_prompt_demo = PointPromptDemo(medsam_model)
point_prompt_demo.set_image(img)
mask = point_prompt_demo.infer(x, y)
visualization = visualize(img, mask)
return visualization
iface = gr.Interface(
fn=process_images,
inputs=[
ImagePrompter(label="Image")
],
outputs=[
gr.Image(type="pil", label="Processed Image")
],
title="ROI Selection with MEDSAM",
description="Upload an image (including NRRD files) and select a point for ROI processing."
)
iface.launch()