aagoluoglu commited on
Commit
bd2df77
·
1 Parent(s): 6d6e3fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -95
app.py CHANGED
@@ -53,7 +53,7 @@ app_ui = ui.page_fillable(
53
  ui.input_switch("by_species", "Show species", value=True),
54
  ui.input_switch("show_margins", "Show marginal plots", value=True),
55
  ),
56
- ui.output_image("uploaded_image"), # display the uploaded sidewalk tile image
57
  ui.output_plot("prediction_plots", fill=True),
58
  ui.output_ui("value_boxes"),
59
  ui.output_plot("scatter", fill=True),
@@ -106,6 +106,90 @@ def show_mask(mask, ax, random_color=False):
106
  mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
107
  ax.imshow(mask_image)
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  ### SERVER ###
110
  def server(input: Inputs, output: Outputs, session: Session):
111
 
@@ -121,100 +205,18 @@ def server(input: Inputs, output: Outputs, session: Session):
121
  else:
122
  return "" # No image uploaded
123
 
124
- @render.image
125
- def uploaded_image():
126
- """Displays the uploaded image"""
127
- img_src = uploaded_image_path()
128
- if img_src:
129
- img: ImgData = {"src": str(img_src), "width": "200px"}
130
- print("IMAGE", img)
131
- return img
132
- else:
133
- return None # Return an empty string if no image is uploaded
134
-
135
- def generate_input_points(image, grid_size=10):
136
- """
137
- input_points (torch.FloatTensor of shape (batch_size, num_points, 2)) —
138
- Input 2D spatial points, this is used by the prompt encoder to encode the prompt.
139
- Generally yields to much better results. The points can be obtained by passing a
140
- list of list of list to the processor that will create corresponding torch tensors
141
- of dimension 4. The first dimension is the image batch size, the second dimension
142
- is the point batch size (i.e. how many segmentation masks do we want the model to
143
- predict per input point), the third dimension is the number of points per segmentation
144
- mask (it is possible to pass multiple points for a single mask), and the last dimension
145
- is the x (vertical) and y (horizontal) coordinates of the point. If a different number
146
- of points is passed either for each image, or for each mask, the processor will create
147
- “PAD” points that will correspond to the (0, 0) coordinate, and the computation of the
148
- embedding will be skipped for these points using the labels.
149
-
150
- """
151
-
152
- # Get the dimensions of the image
153
- array_size = max(image.width, image.height)
154
-
155
- # Generate the grid points
156
- x = np.linspace(0, array_size-1, grid_size)
157
- y = np.linspace(0, array_size-1, grid_size)
158
-
159
- # Generate a grid of coordinates
160
- xv, yv = np.meshgrid(x, y)
161
-
162
- # Convert the numpy arrays to lists
163
- xv_list = xv.tolist()
164
- yv_list = yv.tolist()
165
-
166
- # Combine the x and y coordinates into a list of list of lists
167
- input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv_list, yv_list)]
168
-
169
- #We need to reshape our nxn grid to the expected shape of the input_points tensor
170
- # (batch_size, point_batch_size, num_points_per_image, 2),
171
- # where the last dimension of 2 represents the x and y coordinates of each point.
172
- #batch_size: The number of images you're processing at once.
173
- #point_batch_size: The number of point sets you have for each image.
174
- #num_points_per_image: The number of points in each set.
175
- input_points = torch.tensor(input_points).view(1, 1, grid_size*grid_size, 2)
176
-
177
- return input_points
178
-
179
- def process_image():
180
- """Processes the uploaded image, loads the model, and evaluates to get predictions"""
181
-
182
- """ Get Image """
183
- img_src = uploaded_image_path()
184
-
185
- # Read the image bytes from the file
186
- with open(img_src, 'rb') as f:
187
- image_bytes = f.read()
188
-
189
- # Convert the image bytes to a PIL Image
190
- image = bytes_to_pil_image(image_bytes)
191
-
192
- """ Prepare Inputs """
193
- # get input points prompt (grid of points)
194
- input_points = generate_input_points(image)
195
-
196
- # prepare image and prompt for the model
197
- inputs = processor(image, input_points=input_points, return_tensors="pt")
198
-
199
- # # remove batch dimension which the processor adds by default
200
- # inputs = {k:v.squeeze(0) for k,v in inputs.items()}
201
-
202
- # Move the input tensor to the GPU if it's not already there
203
- inputs = {k: v.to(device) for k, v in inputs.items()}
204
-
205
- """ Get Predictions """
206
- # forward pass
207
- with torch.no_grad():
208
- outputs = model(**inputs, multimask_output=False)
209
-
210
- # apply sigmoid
211
- prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
212
- # convert soft mask to hard mask
213
- prob = prob.cpu().numpy().squeeze()
214
- prediction = (prob > 0.5).astype(np.uint8)
215
-
216
- # Return the processed result
217
- return image, prob, prediction
218
 
219
  @reactive.Calc
220
  def get_predictions():
 
53
  ui.input_switch("by_species", "Show species", value=True),
54
  ui.input_switch("show_margins", "Show marginal plots", value=True),
55
  ),
56
+ #ui.output_image("uploaded_image"), # display the uploaded sidewalk tile image, for some reason doesn't work on all accepted files
57
  ui.output_plot("prediction_plots", fill=True),
58
  ui.output_ui("value_boxes"),
59
  ui.output_plot("scatter", fill=True),
 
106
  mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
107
  ax.imshow(mask_image)
108
 
109
+ def generate_input_points(image, grid_size=10):
110
+ """
111
+ input_points (torch.FloatTensor of shape (batch_size, num_points, 2)) —
112
+ Input 2D spatial points, this is used by the prompt encoder to encode the prompt.
113
+ Generally yields to much better results. The points can be obtained by passing a
114
+ list of list of list to the processor that will create corresponding torch tensors
115
+ of dimension 4. The first dimension is the image batch size, the second dimension
116
+ is the point batch size (i.e. how many segmentation masks do we want the model to
117
+ predict per input point), the third dimension is the number of points per segmentation
118
+ mask (it is possible to pass multiple points for a single mask), and the last dimension
119
+ is the x (vertical) and y (horizontal) coordinates of the point. If a different number
120
+ of points is passed either for each image, or for each mask, the processor will create
121
+ “PAD” points that will correspond to the (0, 0) coordinate, and the computation of the
122
+ embedding will be skipped for these points using the labels.
123
+
124
+ """
125
+
126
+ # Get the dimensions of the image
127
+ array_size = max(image.width, image.height)
128
+
129
+ # Generate the grid points
130
+ x = np.linspace(0, array_size-1, grid_size)
131
+ y = np.linspace(0, array_size-1, grid_size)
132
+
133
+ # Generate a grid of coordinates
134
+ xv, yv = np.meshgrid(x, y)
135
+
136
+ # Convert the numpy arrays to lists
137
+ xv_list = xv.tolist()
138
+ yv_list = yv.tolist()
139
+
140
+ # Combine the x and y coordinates into a list of list of lists
141
+ input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv_list, yv_list)]
142
+
143
+ #We need to reshape our nxn grid to the expected shape of the input_points tensor
144
+ # (batch_size, point_batch_size, num_points_per_image, 2),
145
+ # where the last dimension of 2 represents the x and y coordinates of each point.
146
+ #batch_size: The number of images you're processing at once.
147
+ #point_batch_size: The number of point sets you have for each image.
148
+ #num_points_per_image: The number of points in each set.
149
+ input_points = torch.tensor(input_points).view(1, 1, grid_size*grid_size, 2)
150
+
151
+ return input_points
152
+
153
+ def process_image():
154
+ """Processes the uploaded image, loads the model, and evaluates to get predictions"""
155
+
156
+ """ Get Image """
157
+ img_src = uploaded_image_path()
158
+
159
+ # Read the image bytes from the file
160
+ with open(img_src, 'rb') as f:
161
+ image_bytes = f.read()
162
+
163
+ # Convert the image bytes to a PIL Image
164
+ image = bytes_to_pil_image(image_bytes)
165
+
166
+ """ Prepare Inputs """
167
+ # get input points prompt (grid of points)
168
+ input_points = generate_input_points(image)
169
+
170
+ # prepare image and prompt for the model
171
+ inputs = processor(image, input_points=input_points, return_tensors="pt")
172
+
173
+ # # remove batch dimension which the processor adds by default
174
+ # inputs = {k:v.squeeze(0) for k,v in inputs.items()}
175
+
176
+ # Move the input tensor to the GPU if it's not already there
177
+ inputs = {k: v.to(device) for k, v in inputs.items()}
178
+
179
+ """ Get Predictions """
180
+ # forward pass
181
+ with torch.no_grad():
182
+ outputs = model(**inputs, multimask_output=False)
183
+
184
+ # apply sigmoid
185
+ prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
186
+ # convert soft mask to hard mask
187
+ prob = prob.cpu().numpy().squeeze()
188
+ prediction = (prob > 0.5).astype(np.uint8)
189
+
190
+ # Return the processed result
191
+ return image, prob, prediction
192
+
193
  ### SERVER ###
194
  def server(input: Inputs, output: Outputs, session: Session):
195
 
 
205
  else:
206
  return "" # No image uploaded
207
 
208
+ # for some reason below function does not work on all accepted files
209
+ # works on one screenshot that was converted to .tif but not another *shrug*
210
+ # @render.image
211
+ # def uploaded_image():
212
+ # """Displays the uploaded image"""
213
+ # img_src = uploaded_image_path()
214
+ # if img_src:
215
+ # img: ImgData = {"src": str(img_src), "width": "200px"}
216
+ # print("IMAGE", img)
217
+ # return img
218
+ # else:
219
+ # return None # Return an empty string if no image is uploaded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
  @reactive.Calc
222
  def get_predictions():