dennistrujillo commited on
Commit
e399e14
1 Parent(s): 0bdc5b8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -0
app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ import pydicom
5
+ import os
6
+ from skimage import transform
7
+ import torch
8
+ from segment_anything import sam_model_registry
9
+ import matplotlib.pyplot as plt
10
+ from PIL import Image
11
+ import torch.nn.functional as F
12
+ import io
13
+ import cv2
14
+ import nrrd
15
+ from gradio_image_prompter import ImagePrompter
16
+
17
+ class PointPromptDemo:
18
+ def __init__(self, model):
19
+ self.model = model
20
+ self.model.eval()
21
+ self.image = None
22
+ self.image_embeddings = None
23
+ self.img_size = None
24
+
25
+ @torch.no_grad()
26
+ def infer(self, x, y):
27
+ coords_1024 = np.array([[[
28
+ x * 1024 / self.img_size[1],
29
+ y * 1024 / self.img_size[0]
30
+ ]]])
31
+ coords_torch = torch.tensor(coords_1024, dtype=torch.float32).to(self.model.device)
32
+ labels_torch = torch.tensor([[1]], dtype=torch.long).to(self.model.device)
33
+ point_prompt = (coords_torch, labels_torch)
34
+
35
+ sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
36
+ points=point_prompt,
37
+ boxes=None,
38
+ masks=None,
39
+ )
40
+ low_res_logits, _ = self.model.mask_decoder(
41
+ image_embeddings=self.image_embeddings,
42
+ image_pe=self.model.prompt_encoder.get_dense_pe(),
43
+ sparse_prompt_embeddings=sparse_embeddings,
44
+ dense_prompt_embeddings=dense_embeddings,
45
+ multimask_output=False,
46
+ )
47
+
48
+ low_res_probs = torch.sigmoid(low_res_logits)
49
+ low_res_pred = F.interpolate(
50
+ low_res_probs,
51
+ size=self.img_size,
52
+ mode='bilinear',
53
+ align_corners=False
54
+ )
55
+ low_res_pred = low_res_pred.detach().cpu().numpy().squeeze()
56
+
57
+ seg = np.uint8(low_res_pred > 0.5)
58
+
59
+ return seg
60
+
61
+ def set_image(self, image):
62
+ self.img_size = image.shape[:2]
63
+ if len(image.shape) == 2:
64
+ image = np.repeat(image[:,:,None], 3, -1)
65
+ self.image = image
66
+ image_preprocess = self.preprocess_image(self.image)
67
+ with torch.no_grad():
68
+ self.image_embeddings = self.model.image_encoder(image_preprocess)
69
+
70
+ def preprocess_image(self, image):
71
+ img_resize = cv2.resize(
72
+ image,
73
+ (1024, 1024),
74
+ interpolation=cv2.INTER_CUBIC
75
+ )
76
+ img_resize = (img_resize - img_resize.min()) / np.clip(img_resize.max() - img_resize.min(), a_min=1e-8, a_max=None)
77
+ assert np.max(img_resize)<=1.0 and np.min(img_resize)>=0.0, 'image should be normalized to [0, 1]'
78
+ img_tensor = torch.tensor(img_resize).float().permute(2, 0, 1).unsqueeze(0).to(self.model.device)
79
+ return img_tensor
80
+
81
+ def load_image(file_path):
82
+ if file_path.endswith(".dcm"):
83
+ ds = pydicom.dcmread(file_path)
84
+ img = ds.pixel_array
85
+ elif file_path.endswith(".nrrd"):
86
+ img, _ = nrrd.read(file_path)
87
+ else:
88
+ img = np.array(Image.open(file_path))
89
+
90
+ if len(img.shape) == 2:
91
+ img = np.stack((img,)*3, axis=-1)
92
+
93
+ return img
94
+
95
+ def visualize(image, mask):
96
+ fig, ax = plt.subplots(1, 2, figsize=(10, 5))
97
+ ax[0].imshow(image)
98
+ ax[1].imshow(image)
99
+ ax[1].imshow(mask, alpha=0.5, cmap="jet")
100
+ plt.tight_layout()
101
+
102
+ buf = io.BytesIO()
103
+ fig.savefig(buf, format='png')
104
+ plt.close(fig)
105
+ buf.seek(0)
106
+ pil_img = Image.open(buf)
107
+
108
+ return pil_img
109
+
110
+ def process_images(img_dict):
111
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
112
+
113
+ img = img_dict['image']
114
+ points = img_dict['points'][0]
115
+ if len(points) < 2:
116
+ raise ValueError("At least one point is required for ROI selection.")
117
+
118
+ x, y = points[0], points[1]
119
+
120
+ model_checkpoint_path = "medsam_point_prompt_flare22.pth"
121
+ medsam_model = sam_model_registry['vit_b'](checkpoint=model_checkpoint_path)
122
+ medsam_model = medsam_model.to(device)
123
+ medsam_model.eval()
124
+
125
+ point_prompt_demo = PointPromptDemo(medsam_model)
126
+ point_prompt_demo.set_image(img)
127
+
128
+ mask = point_prompt_demo.infer(x, y)
129
+
130
+ visualization = visualize(img, mask)
131
+ return visualization
132
+
133
+ iface = gr.Interface(
134
+ fn=process_images,
135
+ inputs=[
136
+ ImagePrompter(label="Image")
137
+ ],
138
+ outputs=[
139
+ gr.Image(type="pil", label="Processed Image")
140
+ ],
141
+ title="ROI Selection with MEDSAM",
142
+ description="Upload an image (including NRRD files) and select a point for ROI processing."
143
+ )
144
+
145
+ iface.launch()