Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
import pydicom | |
import os | |
from skimage import transform | |
import torch | |
from segment_anything import sam_model_registry | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import torch.nn.functional as F | |
import io | |
import cv2 | |
import nrrd | |
from gradio_image_prompter import ImagePrompter | |
class PointPromptDemo: | |
def __init__(self, model): | |
self.model = model | |
self.model.eval() | |
self.image = None | |
self.image_embeddings = None | |
self.img_size = None | |
def infer(self, x, y): | |
coords_1024 = np.array([[[ | |
x * 1024 / self.img_size[1], | |
y * 1024 / self.img_size[0] | |
]]]) | |
coords_torch = torch.tensor(coords_1024, dtype=torch.float32).to(self.model.device) | |
labels_torch = torch.tensor([[1]], dtype=torch.long).to(self.model.device) | |
point_prompt = (coords_torch, labels_torch) | |
sparse_embeddings, dense_embeddings = self.model.prompt_encoder( | |
points=point_prompt, | |
boxes=None, | |
masks=None, | |
) | |
low_res_logits, _ = self.model.mask_decoder( | |
image_embeddings=self.image_embeddings, | |
image_pe=self.model.prompt_encoder.get_dense_pe(), | |
sparse_prompt_embeddings=sparse_embeddings, | |
dense_prompt_embeddings=dense_embeddings, | |
multimask_output=False, | |
) | |
low_res_probs = torch.sigmoid(low_res_logits) | |
low_res_pred = F.interpolate( | |
low_res_probs, | |
size=self.img_size, | |
mode='bilinear', | |
align_corners=False | |
) | |
low_res_pred = low_res_pred.detach().cpu().numpy().squeeze() | |
seg = np.uint8(low_res_pred > 0.5) | |
return seg | |
def set_image(self, image): | |
self.img_size = image.shape[:2] | |
if len(image.shape) == 2: | |
image = np.repeat(image[:,:,None], 3, -1) | |
self.image = image | |
image_preprocess = self.preprocess_image(self.image) | |
with torch.no_grad(): | |
self.image_embeddings = self.model.image_encoder(image_preprocess) | |
def preprocess_image(self, image): | |
img_resize = cv2.resize( | |
image, | |
(1024, 1024), | |
interpolation=cv2.INTER_CUBIC | |
) | |
img_resize = (img_resize - img_resize.min()) / np.clip(img_resize.max() - img_resize.min(), a_min=1e-8, a_max=None) | |
assert np.max(img_resize)<=1.0 and np.min(img_resize)>=0.0, 'image should be normalized to [0, 1]' | |
img_tensor = torch.tensor(img_resize).float().permute(2, 0, 1).unsqueeze(0).to(self.model.device) | |
return img_tensor | |
def load_image(file_path): | |
if file_path.endswith(".dcm"): | |
ds = pydicom.dcmread(file_path) | |
img = ds.pixel_array | |
elif file_path.endswith(".nrrd"): | |
img, _ = nrrd.read(file_path) | |
else: | |
img = np.array(Image.open(file_path)) | |
if len(img.shape) == 2: | |
img = np.stack((img,)*3, axis=-1) | |
return img | |
def visualize(image, mask): | |
fig, ax = plt.subplots(1, 2, figsize=(10, 5)) | |
ax[0].imshow(image) | |
ax[1].imshow(image) | |
ax[1].imshow(mask, alpha=0.5, cmap="jet") | |
plt.tight_layout() | |
buf = io.BytesIO() | |
fig.savefig(buf, format='png') | |
plt.close(fig) | |
buf.seek(0) | |
pil_img = Image.open(buf) | |
return pil_img | |
def process_images(img_dict): | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
img = img_dict['image'] | |
points = img_dict['points'][0] | |
if len(points) < 2: | |
raise ValueError("At least one point is required for ROI selection.") | |
x, y = points[0], points[1] | |
model_checkpoint_path = "medsam_point_prompt_flare22.pth" | |
medsam_model = sam_model_registry['vit_b'](checkpoint=model_checkpoint_path) | |
medsam_model = medsam_model.to(device) | |
medsam_model.eval() | |
point_prompt_demo = PointPromptDemo(medsam_model) | |
point_prompt_demo.set_image(img) | |
mask = point_prompt_demo.infer(x, y) | |
visualization = visualize(img, mask) | |
return visualization | |
iface = gr.Interface( | |
fn=process_images, | |
inputs=[ | |
ImagePrompter(label="Image") | |
], | |
outputs=[ | |
gr.Image(type="pil", label="Processed Image") | |
], | |
title="ROI Selection with MEDSAM", | |
description="Upload an image (including NRRD files) and select a point for ROI processing." | |
) | |
iface.launch() | |