Spaces:
Runtime error
Runtime error
Commit
·
bd2df77
1
Parent(s):
6d6e3fa
Update app.py
Browse files
app.py
CHANGED
@@ -53,7 +53,7 @@ app_ui = ui.page_fillable(
|
|
53 |
ui.input_switch("by_species", "Show species", value=True),
|
54 |
ui.input_switch("show_margins", "Show marginal plots", value=True),
|
55 |
),
|
56 |
-
ui.output_image("uploaded_image"), # display the uploaded sidewalk tile image
|
57 |
ui.output_plot("prediction_plots", fill=True),
|
58 |
ui.output_ui("value_boxes"),
|
59 |
ui.output_plot("scatter", fill=True),
|
@@ -106,6 +106,90 @@ def show_mask(mask, ax, random_color=False):
|
|
106 |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
107 |
ax.imshow(mask_image)
|
108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
### SERVER ###
|
110 |
def server(input: Inputs, output: Outputs, session: Session):
|
111 |
|
@@ -121,100 +205,18 @@ def server(input: Inputs, output: Outputs, session: Session):
|
|
121 |
else:
|
122 |
return "" # No image uploaded
|
123 |
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
"""
|
137 |
-
input_points (torch.FloatTensor of shape (batch_size, num_points, 2)) —
|
138 |
-
Input 2D spatial points, this is used by the prompt encoder to encode the prompt.
|
139 |
-
Generally yields to much better results. The points can be obtained by passing a
|
140 |
-
list of list of list to the processor that will create corresponding torch tensors
|
141 |
-
of dimension 4. The first dimension is the image batch size, the second dimension
|
142 |
-
is the point batch size (i.e. how many segmentation masks do we want the model to
|
143 |
-
predict per input point), the third dimension is the number of points per segmentation
|
144 |
-
mask (it is possible to pass multiple points for a single mask), and the last dimension
|
145 |
-
is the x (vertical) and y (horizontal) coordinates of the point. If a different number
|
146 |
-
of points is passed either for each image, or for each mask, the processor will create
|
147 |
-
“PAD” points that will correspond to the (0, 0) coordinate, and the computation of the
|
148 |
-
embedding will be skipped for these points using the labels.
|
149 |
-
|
150 |
-
"""
|
151 |
-
|
152 |
-
# Get the dimensions of the image
|
153 |
-
array_size = max(image.width, image.height)
|
154 |
-
|
155 |
-
# Generate the grid points
|
156 |
-
x = np.linspace(0, array_size-1, grid_size)
|
157 |
-
y = np.linspace(0, array_size-1, grid_size)
|
158 |
-
|
159 |
-
# Generate a grid of coordinates
|
160 |
-
xv, yv = np.meshgrid(x, y)
|
161 |
-
|
162 |
-
# Convert the numpy arrays to lists
|
163 |
-
xv_list = xv.tolist()
|
164 |
-
yv_list = yv.tolist()
|
165 |
-
|
166 |
-
# Combine the x and y coordinates into a list of list of lists
|
167 |
-
input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv_list, yv_list)]
|
168 |
-
|
169 |
-
#We need to reshape our nxn grid to the expected shape of the input_points tensor
|
170 |
-
# (batch_size, point_batch_size, num_points_per_image, 2),
|
171 |
-
# where the last dimension of 2 represents the x and y coordinates of each point.
|
172 |
-
#batch_size: The number of images you're processing at once.
|
173 |
-
#point_batch_size: The number of point sets you have for each image.
|
174 |
-
#num_points_per_image: The number of points in each set.
|
175 |
-
input_points = torch.tensor(input_points).view(1, 1, grid_size*grid_size, 2)
|
176 |
-
|
177 |
-
return input_points
|
178 |
-
|
179 |
-
def process_image():
|
180 |
-
"""Processes the uploaded image, loads the model, and evaluates to get predictions"""
|
181 |
-
|
182 |
-
""" Get Image """
|
183 |
-
img_src = uploaded_image_path()
|
184 |
-
|
185 |
-
# Read the image bytes from the file
|
186 |
-
with open(img_src, 'rb') as f:
|
187 |
-
image_bytes = f.read()
|
188 |
-
|
189 |
-
# Convert the image bytes to a PIL Image
|
190 |
-
image = bytes_to_pil_image(image_bytes)
|
191 |
-
|
192 |
-
""" Prepare Inputs """
|
193 |
-
# get input points prompt (grid of points)
|
194 |
-
input_points = generate_input_points(image)
|
195 |
-
|
196 |
-
# prepare image and prompt for the model
|
197 |
-
inputs = processor(image, input_points=input_points, return_tensors="pt")
|
198 |
-
|
199 |
-
# # remove batch dimension which the processor adds by default
|
200 |
-
# inputs = {k:v.squeeze(0) for k,v in inputs.items()}
|
201 |
-
|
202 |
-
# Move the input tensor to the GPU if it's not already there
|
203 |
-
inputs = {k: v.to(device) for k, v in inputs.items()}
|
204 |
-
|
205 |
-
""" Get Predictions """
|
206 |
-
# forward pass
|
207 |
-
with torch.no_grad():
|
208 |
-
outputs = model(**inputs, multimask_output=False)
|
209 |
-
|
210 |
-
# apply sigmoid
|
211 |
-
prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
|
212 |
-
# convert soft mask to hard mask
|
213 |
-
prob = prob.cpu().numpy().squeeze()
|
214 |
-
prediction = (prob > 0.5).astype(np.uint8)
|
215 |
-
|
216 |
-
# Return the processed result
|
217 |
-
return image, prob, prediction
|
218 |
|
219 |
@reactive.Calc
|
220 |
def get_predictions():
|
|
|
53 |
ui.input_switch("by_species", "Show species", value=True),
|
54 |
ui.input_switch("show_margins", "Show marginal plots", value=True),
|
55 |
),
|
56 |
+
#ui.output_image("uploaded_image"), # display the uploaded sidewalk tile image, for some reason doesn't work on all accepted files
|
57 |
ui.output_plot("prediction_plots", fill=True),
|
58 |
ui.output_ui("value_boxes"),
|
59 |
ui.output_plot("scatter", fill=True),
|
|
|
106 |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
107 |
ax.imshow(mask_image)
|
108 |
|
109 |
+
def generate_input_points(image, grid_size=10):
|
110 |
+
"""
|
111 |
+
input_points (torch.FloatTensor of shape (batch_size, num_points, 2)) —
|
112 |
+
Input 2D spatial points, this is used by the prompt encoder to encode the prompt.
|
113 |
+
Generally yields to much better results. The points can be obtained by passing a
|
114 |
+
list of list of list to the processor that will create corresponding torch tensors
|
115 |
+
of dimension 4. The first dimension is the image batch size, the second dimension
|
116 |
+
is the point batch size (i.e. how many segmentation masks do we want the model to
|
117 |
+
predict per input point), the third dimension is the number of points per segmentation
|
118 |
+
mask (it is possible to pass multiple points for a single mask), and the last dimension
|
119 |
+
is the x (vertical) and y (horizontal) coordinates of the point. If a different number
|
120 |
+
of points is passed either for each image, or for each mask, the processor will create
|
121 |
+
“PAD” points that will correspond to the (0, 0) coordinate, and the computation of the
|
122 |
+
embedding will be skipped for these points using the labels.
|
123 |
+
|
124 |
+
"""
|
125 |
+
|
126 |
+
# Get the dimensions of the image
|
127 |
+
array_size = max(image.width, image.height)
|
128 |
+
|
129 |
+
# Generate the grid points
|
130 |
+
x = np.linspace(0, array_size-1, grid_size)
|
131 |
+
y = np.linspace(0, array_size-1, grid_size)
|
132 |
+
|
133 |
+
# Generate a grid of coordinates
|
134 |
+
xv, yv = np.meshgrid(x, y)
|
135 |
+
|
136 |
+
# Convert the numpy arrays to lists
|
137 |
+
xv_list = xv.tolist()
|
138 |
+
yv_list = yv.tolist()
|
139 |
+
|
140 |
+
# Combine the x and y coordinates into a list of list of lists
|
141 |
+
input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv_list, yv_list)]
|
142 |
+
|
143 |
+
#We need to reshape our nxn grid to the expected shape of the input_points tensor
|
144 |
+
# (batch_size, point_batch_size, num_points_per_image, 2),
|
145 |
+
# where the last dimension of 2 represents the x and y coordinates of each point.
|
146 |
+
#batch_size: The number of images you're processing at once.
|
147 |
+
#point_batch_size: The number of point sets you have for each image.
|
148 |
+
#num_points_per_image: The number of points in each set.
|
149 |
+
input_points = torch.tensor(input_points).view(1, 1, grid_size*grid_size, 2)
|
150 |
+
|
151 |
+
return input_points
|
152 |
+
|
153 |
+
def process_image():
|
154 |
+
"""Processes the uploaded image, loads the model, and evaluates to get predictions"""
|
155 |
+
|
156 |
+
""" Get Image """
|
157 |
+
img_src = uploaded_image_path()
|
158 |
+
|
159 |
+
# Read the image bytes from the file
|
160 |
+
with open(img_src, 'rb') as f:
|
161 |
+
image_bytes = f.read()
|
162 |
+
|
163 |
+
# Convert the image bytes to a PIL Image
|
164 |
+
image = bytes_to_pil_image(image_bytes)
|
165 |
+
|
166 |
+
""" Prepare Inputs """
|
167 |
+
# get input points prompt (grid of points)
|
168 |
+
input_points = generate_input_points(image)
|
169 |
+
|
170 |
+
# prepare image and prompt for the model
|
171 |
+
inputs = processor(image, input_points=input_points, return_tensors="pt")
|
172 |
+
|
173 |
+
# # remove batch dimension which the processor adds by default
|
174 |
+
# inputs = {k:v.squeeze(0) for k,v in inputs.items()}
|
175 |
+
|
176 |
+
# Move the input tensor to the GPU if it's not already there
|
177 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
178 |
+
|
179 |
+
""" Get Predictions """
|
180 |
+
# forward pass
|
181 |
+
with torch.no_grad():
|
182 |
+
outputs = model(**inputs, multimask_output=False)
|
183 |
+
|
184 |
+
# apply sigmoid
|
185 |
+
prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
|
186 |
+
# convert soft mask to hard mask
|
187 |
+
prob = prob.cpu().numpy().squeeze()
|
188 |
+
prediction = (prob > 0.5).astype(np.uint8)
|
189 |
+
|
190 |
+
# Return the processed result
|
191 |
+
return image, prob, prediction
|
192 |
+
|
193 |
### SERVER ###
|
194 |
def server(input: Inputs, output: Outputs, session: Session):
|
195 |
|
|
|
205 |
else:
|
206 |
return "" # No image uploaded
|
207 |
|
208 |
+
# for some reason below function does not work on all accepted files
|
209 |
+
# works on one screenshot that was converted to .tif but not another *shrug*
|
210 |
+
# @render.image
|
211 |
+
# def uploaded_image():
|
212 |
+
# """Displays the uploaded image"""
|
213 |
+
# img_src = uploaded_image_path()
|
214 |
+
# if img_src:
|
215 |
+
# img: ImgData = {"src": str(img_src), "width": "200px"}
|
216 |
+
# print("IMAGE", img)
|
217 |
+
# return img
|
218 |
+
# else:
|
219 |
+
# return None # Return an empty string if no image is uploaded
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
|
221 |
@reactive.Calc
|
222 |
def get_predictions():
|