aagoluoglu commited on
Commit
27540d1
·
1 Parent(s): bd2df77

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -40
app.py CHANGED
@@ -149,46 +149,6 @@ def generate_input_points(image, grid_size=10):
149
  input_points = torch.tensor(input_points).view(1, 1, grid_size*grid_size, 2)
150
 
151
  return input_points
152
-
153
- def process_image():
154
- """Processes the uploaded image, loads the model, and evaluates to get predictions"""
155
-
156
- """ Get Image """
157
- img_src = uploaded_image_path()
158
-
159
- # Read the image bytes from the file
160
- with open(img_src, 'rb') as f:
161
- image_bytes = f.read()
162
-
163
- # Convert the image bytes to a PIL Image
164
- image = bytes_to_pil_image(image_bytes)
165
-
166
- """ Prepare Inputs """
167
- # get input points prompt (grid of points)
168
- input_points = generate_input_points(image)
169
-
170
- # prepare image and prompt for the model
171
- inputs = processor(image, input_points=input_points, return_tensors="pt")
172
-
173
- # # remove batch dimension which the processor adds by default
174
- # inputs = {k:v.squeeze(0) for k,v in inputs.items()}
175
-
176
- # Move the input tensor to the GPU if it's not already there
177
- inputs = {k: v.to(device) for k, v in inputs.items()}
178
-
179
- """ Get Predictions """
180
- # forward pass
181
- with torch.no_grad():
182
- outputs = model(**inputs, multimask_output=False)
183
-
184
- # apply sigmoid
185
- prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
186
- # convert soft mask to hard mask
187
- prob = prob.cpu().numpy().squeeze()
188
- prediction = (prob > 0.5).astype(np.uint8)
189
-
190
- # Return the processed result
191
- return image, prob, prediction
192
 
193
  ### SERVER ###
194
  def server(input: Inputs, output: Outputs, session: Session):
@@ -218,6 +178,46 @@ def server(input: Inputs, output: Outputs, session: Session):
218
  # else:
219
  # return None # Return an empty string if no image is uploaded
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  @reactive.Calc
222
  def get_predictions():
223
  """Processes the image when uploaded to get predictions"""
 
149
  input_points = torch.tensor(input_points).view(1, 1, grid_size*grid_size, 2)
150
 
151
  return input_points
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  ### SERVER ###
154
  def server(input: Inputs, output: Outputs, session: Session):
 
178
  # else:
179
  # return None # Return an empty string if no image is uploaded
180
 
181
+ def process_image():
182
+ """Processes the uploaded image, loads the model, and evaluates to get predictions"""
183
+
184
+ """ Get Image """
185
+ img_src = uploaded_image_path()
186
+
187
+ # Read the image bytes from the file
188
+ with open(img_src, 'rb') as f:
189
+ image_bytes = f.read()
190
+
191
+ # Convert the image bytes to a PIL Image
192
+ image = bytes_to_pil_image(image_bytes)
193
+
194
+ """ Prepare Inputs """
195
+ # get input points prompt (grid of points)
196
+ input_points = generate_input_points(image)
197
+
198
+ # prepare image and prompt for the model
199
+ inputs = processor(image, input_points=input_points, return_tensors="pt")
200
+
201
+ # # remove batch dimension which the processor adds by default
202
+ # inputs = {k:v.squeeze(0) for k,v in inputs.items()}
203
+
204
+ # Move the input tensor to the GPU if it's not already there
205
+ inputs = {k: v.to(device) for k, v in inputs.items()}
206
+
207
+ """ Get Predictions """
208
+ # forward pass
209
+ with torch.no_grad():
210
+ outputs = model(**inputs, multimask_output=False)
211
+
212
+ # apply sigmoid
213
+ prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
214
+ # convert soft mask to hard mask
215
+ prob = prob.cpu().numpy().squeeze()
216
+ prediction = (prob > 0.5).astype(np.uint8)
217
+
218
+ # Return the processed result
219
+ return image, prob, prediction
220
+
221
  @reactive.Calc
222
  def get_predictions():
223
  """Processes the image when uploaded to get predictions"""