Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -25,6 +25,7 @@ numeric_cols: List[str] = df.select_dtypes(include=["float64"]).columns.tolist()
|
|
25 |
species: List[str] = df["Species"].unique().tolist()
|
26 |
species.sort()
|
27 |
|
|
|
28 |
app_ui = ui.page_fillable(
|
29 |
shinyswatch.theme.minty(),
|
30 |
ui.layout_sidebar(
|
@@ -62,6 +63,7 @@ app_ui = ui.page_fillable(
|
|
62 |
),
|
63 |
)
|
64 |
|
|
|
65 |
def tif_bytes_to_pil_image(tif_bytes):
|
66 |
# Create a BytesIO object from the TIFF bytes
|
67 |
bytes_io = io.BytesIO(tif_bytes)
|
@@ -89,6 +91,7 @@ def load_model():
|
|
89 |
|
90 |
return model, processor, device
|
91 |
|
|
|
92 |
def server(input: Inputs, output: Outputs, session: Session):
|
93 |
|
94 |
# set model, processor, device once
|
@@ -163,11 +166,14 @@ def server(input: Inputs, output: Outputs, session: Session):
|
|
163 |
"""Processes the uploaded image, loads the model, and evaluates to get predictions"""
|
164 |
|
165 |
""" Get Image """
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
|
|
|
|
|
|
171 |
|
172 |
""" Prepare Inputs """
|
173 |
# get input points prompt (grid of points)
|
@@ -176,10 +182,64 @@ def server(input: Inputs, output: Outputs, session: Session):
|
|
176 |
# prepare image and prompt for the model
|
177 |
inputs = processor(image, input_points=input_points, return_tensors="pt")
|
178 |
|
179 |
-
# remove batch dimension which the processor adds by default
|
180 |
-
inputs = {k:v.squeeze(0) for k,v in inputs.items()}
|
181 |
|
|
|
|
|
|
|
182 |
""" Get Predictions """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
# Evaluate the image with the model
|
184 |
# Example: predictions = model.predict(image_array)
|
185 |
|
|
|
25 |
species: List[str] = df["Species"].unique().tolist()
|
26 |
species.sort()
|
27 |
|
28 |
+
### UI ###
|
29 |
app_ui = ui.page_fillable(
|
30 |
shinyswatch.theme.minty(),
|
31 |
ui.layout_sidebar(
|
|
|
63 |
),
|
64 |
)
|
65 |
|
66 |
+
### HELPER FUNCTIONS ###
|
67 |
def tif_bytes_to_pil_image(tif_bytes):
|
68 |
# Create a BytesIO object from the TIFF bytes
|
69 |
bytes_io = io.BytesIO(tif_bytes)
|
|
|
91 |
|
92 |
return model, processor, device
|
93 |
|
94 |
+
### SERVER ###
|
95 |
def server(input: Inputs, output: Outputs, session: Session):
|
96 |
|
97 |
# set model, processor, device once
|
|
|
166 |
"""Processes the uploaded image, loads the model, and evaluates to get predictions"""
|
167 |
|
168 |
""" Get Image """
|
169 |
+
img_src = uploaded_image_path()
|
170 |
+
|
171 |
+
# Read the image bytes from the file
|
172 |
+
with open(img_src, 'rb') as f:
|
173 |
+
image_bytes = f.read()
|
174 |
+
|
175 |
+
# Convert the image bytes to a PIL Image
|
176 |
+
image = tif_bytes_to_pil_image(image_bytes)
|
177 |
|
178 |
""" Prepare Inputs """
|
179 |
# get input points prompt (grid of points)
|
|
|
182 |
# prepare image and prompt for the model
|
183 |
inputs = processor(image, input_points=input_points, return_tensors="pt")
|
184 |
|
185 |
+
# # remove batch dimension which the processor adds by default
|
186 |
+
# inputs = {k:v.squeeze(0) for k,v in inputs.items()}
|
187 |
|
188 |
+
# Move the input tensor to the GPU if it's not already there
|
189 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
190 |
+
|
191 |
""" Get Predictions """
|
192 |
+
# forward pass
|
193 |
+
with torch.no_grad():
|
194 |
+
outputs = model(**inputs, multimask_output=False)
|
195 |
+
|
196 |
+
# apply sigmoid
|
197 |
+
prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
|
198 |
+
# convert soft mask to hard mask
|
199 |
+
prob = prob.cpu().numpy().squeeze()
|
200 |
+
prediction = (prob > 0.5).astype(np.uint8)
|
201 |
+
|
202 |
+
# fig, axes = plt.subplots(1, 5, figsize=(15, 5))
|
203 |
+
|
204 |
+
# # Extract the image data from the batch
|
205 |
+
# image_data = batch['image'].cpu().detach().numpy()[0] # Assuming batch size is 1
|
206 |
+
|
207 |
+
# # Plot the first image on the left
|
208 |
+
# axes[0].imshow(image_data)
|
209 |
+
# axes[0].set_title("Image")
|
210 |
+
|
211 |
+
# # Plot the second image on the right
|
212 |
+
# axes[1].imshow(prob)
|
213 |
+
# axes[1].set_title("Probability Map")
|
214 |
+
|
215 |
+
# # Plot the prediction image on the right
|
216 |
+
# axes[2].imshow(prediction)
|
217 |
+
# axes[2].set_title("Prediction")
|
218 |
+
|
219 |
+
# # Plot the predicted mask on the right
|
220 |
+
# axes[3].imshow(image_data)
|
221 |
+
# show_mask(prediction, axes[3])
|
222 |
+
# axes[3].set_title("Predicted Mask")
|
223 |
+
|
224 |
+
# # Extract the ground truth mask data from the batch
|
225 |
+
# ground_truth_mask_data = inputs['ground_truth_mask'].cpu().detach().numpy()[0] # Assuming batch size is 1
|
226 |
+
|
227 |
+
# # Plot the ground truth mask on the right
|
228 |
+
# axes[4].imshow(image_data)
|
229 |
+
# axes[4].imshow(ground_truth_mask_data)
|
230 |
+
# #show_mask(inputs['ground_truth_mask'], axes[4])
|
231 |
+
# axes[4].set_title("Ground Truth Mask")
|
232 |
+
|
233 |
+
# # Hide axis ticks and labels
|
234 |
+
# for ax in axes:
|
235 |
+
# ax.set_xticks([])
|
236 |
+
# ax.set_yticks([])
|
237 |
+
# ax.set_xticklabels([])
|
238 |
+
# ax.set_yticklabels([])
|
239 |
+
|
240 |
+
# # Display the images side by side
|
241 |
+
# plt.show()
|
242 |
+
|
243 |
# Evaluate the image with the model
|
244 |
# Example: predictions = model.predict(image_array)
|
245 |
|