aagoluoglu commited on
Commit
db5d62c
·
verified ·
1 Parent(s): d927c07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -7
app.py CHANGED
@@ -25,6 +25,7 @@ numeric_cols: List[str] = df.select_dtypes(include=["float64"]).columns.tolist()
25
  species: List[str] = df["Species"].unique().tolist()
26
  species.sort()
27
 
 
28
  app_ui = ui.page_fillable(
29
  shinyswatch.theme.minty(),
30
  ui.layout_sidebar(
@@ -62,6 +63,7 @@ app_ui = ui.page_fillable(
62
  ),
63
  )
64
 
 
65
  def tif_bytes_to_pil_image(tif_bytes):
66
  # Create a BytesIO object from the TIFF bytes
67
  bytes_io = io.BytesIO(tif_bytes)
@@ -89,6 +91,7 @@ def load_model():
89
 
90
  return model, processor, device
91
 
 
92
  def server(input: Inputs, output: Outputs, session: Session):
93
 
94
  # set model, processor, device once
@@ -163,11 +166,14 @@ def server(input: Inputs, output: Outputs, session: Session):
163
  """Processes the uploaded image, loads the model, and evaluates to get predictions"""
164
 
165
  """ Get Image """
166
- # Load the uploaded image
167
- uploaded_image_bytes = input.tile_image()[0].read()
168
-
169
- # Convert the uploaded TIFF bytes to a PIL Image object
170
- uploaded_image = tif_bytes_to_pil_image(uploaded_image_bytes)
 
 
 
171
 
172
  """ Prepare Inputs """
173
  # get input points prompt (grid of points)
@@ -176,10 +182,64 @@ def server(input: Inputs, output: Outputs, session: Session):
176
  # prepare image and prompt for the model
177
  inputs = processor(image, input_points=input_points, return_tensors="pt")
178
 
179
- # remove batch dimension which the processor adds by default
180
- inputs = {k:v.squeeze(0) for k,v in inputs.items()}
181
 
 
 
 
182
  """ Get Predictions """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  # Evaluate the image with the model
184
  # Example: predictions = model.predict(image_array)
185
 
 
25
  species: List[str] = df["Species"].unique().tolist()
26
  species.sort()
27
 
28
+ ### UI ###
29
  app_ui = ui.page_fillable(
30
  shinyswatch.theme.minty(),
31
  ui.layout_sidebar(
 
63
  ),
64
  )
65
 
66
+ ### HELPER FUNCTIONS ###
67
  def tif_bytes_to_pil_image(tif_bytes):
68
  # Create a BytesIO object from the TIFF bytes
69
  bytes_io = io.BytesIO(tif_bytes)
 
91
 
92
  return model, processor, device
93
 
94
+ ### SERVER ###
95
  def server(input: Inputs, output: Outputs, session: Session):
96
 
97
  # set model, processor, device once
 
166
  """Processes the uploaded image, loads the model, and evaluates to get predictions"""
167
 
168
  """ Get Image """
169
+ img_src = uploaded_image_path()
170
+
171
+ # Read the image bytes from the file
172
+ with open(img_src, 'rb') as f:
173
+ image_bytes = f.read()
174
+
175
+ # Convert the image bytes to a PIL Image
176
+ image = tif_bytes_to_pil_image(image_bytes)
177
 
178
  """ Prepare Inputs """
179
  # get input points prompt (grid of points)
 
182
  # prepare image and prompt for the model
183
  inputs = processor(image, input_points=input_points, return_tensors="pt")
184
 
185
+ # # remove batch dimension which the processor adds by default
186
+ # inputs = {k:v.squeeze(0) for k,v in inputs.items()}
187
 
188
+ # Move the input tensor to the GPU if it's not already there
189
+ inputs = {k: v.to(device) for k, v in inputs.items()}
190
+
191
  """ Get Predictions """
192
+ # forward pass
193
+ with torch.no_grad():
194
+ outputs = model(**inputs, multimask_output=False)
195
+
196
+ # apply sigmoid
197
+ prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
198
+ # convert soft mask to hard mask
199
+ prob = prob.cpu().numpy().squeeze()
200
+ prediction = (prob > 0.5).astype(np.uint8)
201
+
202
+ # fig, axes = plt.subplots(1, 5, figsize=(15, 5))
203
+
204
+ # # Extract the image data from the batch
205
+ # image_data = batch['image'].cpu().detach().numpy()[0] # Assuming batch size is 1
206
+
207
+ # # Plot the first image on the left
208
+ # axes[0].imshow(image_data)
209
+ # axes[0].set_title("Image")
210
+
211
+ # # Plot the second image on the right
212
+ # axes[1].imshow(prob)
213
+ # axes[1].set_title("Probability Map")
214
+
215
+ # # Plot the prediction image on the right
216
+ # axes[2].imshow(prediction)
217
+ # axes[2].set_title("Prediction")
218
+
219
+ # # Plot the predicted mask on the right
220
+ # axes[3].imshow(image_data)
221
+ # show_mask(prediction, axes[3])
222
+ # axes[3].set_title("Predicted Mask")
223
+
224
+ # # Extract the ground truth mask data from the batch
225
+ # ground_truth_mask_data = inputs['ground_truth_mask'].cpu().detach().numpy()[0] # Assuming batch size is 1
226
+
227
+ # # Plot the ground truth mask on the right
228
+ # axes[4].imshow(image_data)
229
+ # axes[4].imshow(ground_truth_mask_data)
230
+ # #show_mask(inputs['ground_truth_mask'], axes[4])
231
+ # axes[4].set_title("Ground Truth Mask")
232
+
233
+ # # Hide axis ticks and labels
234
+ # for ax in axes:
235
+ # ax.set_xticks([])
236
+ # ax.set_yticks([])
237
+ # ax.set_xticklabels([])
238
+ # ax.set_yticklabels([])
239
+
240
+ # # Display the images side by side
241
+ # plt.show()
242
+
243
  # Evaluate the image with the model
244
  # Example: predictions = model.predict(image_array)
245