Spaces:
Runtime error
Runtime error
Commit
·
27540d1
1
Parent(s):
bd2df77
Update app.py
Browse files
app.py
CHANGED
@@ -149,46 +149,6 @@ def generate_input_points(image, grid_size=10):
|
|
149 |
input_points = torch.tensor(input_points).view(1, 1, grid_size*grid_size, 2)
|
150 |
|
151 |
return input_points
|
152 |
-
|
153 |
-
def process_image():
|
154 |
-
"""Processes the uploaded image, loads the model, and evaluates to get predictions"""
|
155 |
-
|
156 |
-
""" Get Image """
|
157 |
-
img_src = uploaded_image_path()
|
158 |
-
|
159 |
-
# Read the image bytes from the file
|
160 |
-
with open(img_src, 'rb') as f:
|
161 |
-
image_bytes = f.read()
|
162 |
-
|
163 |
-
# Convert the image bytes to a PIL Image
|
164 |
-
image = bytes_to_pil_image(image_bytes)
|
165 |
-
|
166 |
-
""" Prepare Inputs """
|
167 |
-
# get input points prompt (grid of points)
|
168 |
-
input_points = generate_input_points(image)
|
169 |
-
|
170 |
-
# prepare image and prompt for the model
|
171 |
-
inputs = processor(image, input_points=input_points, return_tensors="pt")
|
172 |
-
|
173 |
-
# # remove batch dimension which the processor adds by default
|
174 |
-
# inputs = {k:v.squeeze(0) for k,v in inputs.items()}
|
175 |
-
|
176 |
-
# Move the input tensor to the GPU if it's not already there
|
177 |
-
inputs = {k: v.to(device) for k, v in inputs.items()}
|
178 |
-
|
179 |
-
""" Get Predictions """
|
180 |
-
# forward pass
|
181 |
-
with torch.no_grad():
|
182 |
-
outputs = model(**inputs, multimask_output=False)
|
183 |
-
|
184 |
-
# apply sigmoid
|
185 |
-
prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
|
186 |
-
# convert soft mask to hard mask
|
187 |
-
prob = prob.cpu().numpy().squeeze()
|
188 |
-
prediction = (prob > 0.5).astype(np.uint8)
|
189 |
-
|
190 |
-
# Return the processed result
|
191 |
-
return image, prob, prediction
|
192 |
|
193 |
### SERVER ###
|
194 |
def server(input: Inputs, output: Outputs, session: Session):
|
@@ -218,6 +178,46 @@ def server(input: Inputs, output: Outputs, session: Session):
|
|
218 |
# else:
|
219 |
# return None # Return an empty string if no image is uploaded
|
220 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
@reactive.Calc
|
222 |
def get_predictions():
|
223 |
"""Processes the image when uploaded to get predictions"""
|
|
|
149 |
input_points = torch.tensor(input_points).view(1, 1, grid_size*grid_size, 2)
|
150 |
|
151 |
return input_points
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
### SERVER ###
|
154 |
def server(input: Inputs, output: Outputs, session: Session):
|
|
|
178 |
# else:
|
179 |
# return None # Return an empty string if no image is uploaded
|
180 |
|
181 |
+
def process_image():
|
182 |
+
"""Processes the uploaded image, loads the model, and evaluates to get predictions"""
|
183 |
+
|
184 |
+
""" Get Image """
|
185 |
+
img_src = uploaded_image_path()
|
186 |
+
|
187 |
+
# Read the image bytes from the file
|
188 |
+
with open(img_src, 'rb') as f:
|
189 |
+
image_bytes = f.read()
|
190 |
+
|
191 |
+
# Convert the image bytes to a PIL Image
|
192 |
+
image = bytes_to_pil_image(image_bytes)
|
193 |
+
|
194 |
+
""" Prepare Inputs """
|
195 |
+
# get input points prompt (grid of points)
|
196 |
+
input_points = generate_input_points(image)
|
197 |
+
|
198 |
+
# prepare image and prompt for the model
|
199 |
+
inputs = processor(image, input_points=input_points, return_tensors="pt")
|
200 |
+
|
201 |
+
# # remove batch dimension which the processor adds by default
|
202 |
+
# inputs = {k:v.squeeze(0) for k,v in inputs.items()}
|
203 |
+
|
204 |
+
# Move the input tensor to the GPU if it's not already there
|
205 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
206 |
+
|
207 |
+
""" Get Predictions """
|
208 |
+
# forward pass
|
209 |
+
with torch.no_grad():
|
210 |
+
outputs = model(**inputs, multimask_output=False)
|
211 |
+
|
212 |
+
# apply sigmoid
|
213 |
+
prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
|
214 |
+
# convert soft mask to hard mask
|
215 |
+
prob = prob.cpu().numpy().squeeze()
|
216 |
+
prediction = (prob > 0.5).astype(np.uint8)
|
217 |
+
|
218 |
+
# Return the processed result
|
219 |
+
return image, prob, prediction
|
220 |
+
|
221 |
@reactive.Calc
|
222 |
def get_predictions():
|
223 |
"""Processes the image when uploaded to get predictions"""
|