Prasanna Sridhar commited on
Commit
aedd89b
·
1 Parent(s): 2f1d1a1

Refactor app.py - extract reusable functions

Browse files
Files changed (3) hide show
  1. .gitignore +2 -2
  2. app.py +116 -148
  3. requirements.txt +2 -0
.gitignore CHANGED
@@ -2,7 +2,7 @@
2
  env/
3
  __pycache__
4
  .python-version
5
-
6
 
7
  # vim
8
- *.sw[op]
 
2
  env/
3
  __pycache__
4
  .python-version
5
+ *.py[od]
6
 
7
  # vim
8
+ *.sw[op]
app.py CHANGED
@@ -14,11 +14,6 @@ import matplotlib.pyplot as plt
14
  import io
15
  from enum import Enum
16
  import os
17
- import subprocess
18
- from subprocess import call
19
- import shlex
20
- import shutil
21
- #os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.getcwd(), "tmp")
22
  cwd = os.getcwd()
23
  # Suppress warnings to avoid overflowing the log.
24
  import warnings
@@ -145,22 +140,6 @@ def build_model_and_transforms(args):
145
 
146
  return model, data_transform
147
 
148
- examples = [
149
- ["strawberry.jpg", "strawberry", {"image": "strawberry.jpg"}],
150
- ["strawberry.jpg", "blueberry", {"image": "strawberry.jpg"}],
151
- ["bird-1.JPG", "bird", {"image": "bird-2.JPG"}],
152
- ["fish.jpg", "fish", {"image": "fish.jpg"}],
153
- ["women.jpg", "girl", {"image": "women.jpg"}],
154
- ["women.jpg", "boy", {"image": "women.jpg"}],
155
- ["balloon.jpg", "hot air balloon", {"image": "balloon.jpg"}],
156
- ["deer.jpg", "deer", {"image": "deer.jpg"}],
157
- ["apple.jpg", "apple", {"image": "apple.jpg"}],
158
- ["egg.jpg", "egg", {"image": "egg.jpg"}],
159
- ["stamp.jpg", "stamp", {"image": "stamp.jpg"}],
160
- ["green-pea.jpg", "green pea", {"image": "green-pea.jpg"}],
161
- ["lego.jpg", "lego", {"image": "lego.jpg"}]
162
- ]
163
-
164
  # APP:
165
  def get_box_inputs(prompts):
166
  box_inputs = []
@@ -197,6 +176,107 @@ def get_ind_to_filter(text, word_ids, keywords):
197
 
198
  return inds_to_filter
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  if __name__ == '__main__':
201
 
202
  parser = argparse.ArgumentParser("Counting Application", parents=[get_args_parser()])
@@ -207,54 +287,15 @@ if __name__ == '__main__':
207
 
208
  @spaces.GPU(duration=120)
209
  def count(image, text, prompts, state, device):
210
-
211
- keywords = "" # do not handle this for now
212
-
213
- # Handle no prompt case.
214
  if prompts is None:
215
  prompts = {"image": image, "points": []}
216
- input_image, _ = transform(image, {"exemplars": torch.tensor([])})
217
- input_image = input_image.unsqueeze(0).to(device)
218
- exemplars = get_box_inputs(prompts["points"])
219
-
220
- input_image_exemplars, exemplars = transform(prompts["image"], {"exemplars": torch.tensor(exemplars)})
221
- input_image_exemplars = input_image_exemplars.unsqueeze(0).to(device)
222
- exemplars = [exemplars["exemplars"].to(device)]
223
-
224
- with torch.no_grad():
225
- model_output = model(
226
- nested_tensor_from_tensor_list(input_image),
227
- nested_tensor_from_tensor_list(input_image_exemplars),
228
- exemplars,
229
- [torch.tensor([0]).to(device) for _ in range(len(input_image))],
230
- captions=[text + " ."] * len(input_image),
231
- )
232
 
233
- ind_to_filter = get_ind_to_filter(text, model_output["token"][0].word_ids, keywords)
234
- logits = model_output["pred_logits"].sigmoid()[0][:, ind_to_filter]
235
- boxes = model_output["pred_boxes"][0]
236
- if len(keywords.strip()) > 0:
237
- box_mask = (logits > CONF_THRESH).sum(dim=-1) == len(ind_to_filter)
238
- else:
239
- box_mask = logits.max(dim=-1).values > CONF_THRESH
240
- logits = logits[box_mask, :].cpu().numpy()
241
- boxes = boxes[box_mask, :].cpu().numpy()
242
-
243
- # Plot results.
244
- (w, h) = image.size
245
- det_map = np.zeros((h, w))
246
- det_map[(h * boxes[:, 1]).astype(int), (w * boxes[:, 0]).astype(int)] = 1
247
- det_map = ndimage.gaussian_filter(
248
- det_map, sigma=(w // 200, w // 200), order=0
249
- )
250
- plt.imshow(image)
251
- plt.imshow(det_map[None, :].transpose(1, 2, 0), 'jet', interpolation='none', alpha=0.7)
252
- plt.axis('off')
253
- img_buf = io.BytesIO()
254
- plt.savefig(img_buf, format='png', bbox_inches='tight')
255
- plt.close()
256
-
257
- output_img = Image.open(img_buf)
258
 
259
  if AppSteps.TEXT_AND_EXEMPLARS not in state:
260
  exemplar_image = ImagePrompter(type='pil', label='Visual Exemplar Image', value=prompts, interactive=True, visible=True)
@@ -274,92 +315,19 @@ if __name__ == '__main__':
274
  main_instructions_comp = gr.Markdown(visible=True)
275
  step_3 = gr.Tab(visible=True)
276
 
277
- out_label = "Detected instances predicted with"
278
- if len(text.strip()) > 0:
279
- out_label += " text"
280
- if exemplars[0].size()[0] == 1:
281
- out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplar."
282
- elif exemplars[0].size()[0] > 1:
283
- out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplars."
284
- else:
285
- out_label += "."
286
- elif exemplars[0].size()[0] > 0:
287
- if exemplars[0].size()[0] == 1:
288
- out_label += " " + str(exemplars[0].size()[0]) + " visual exemplar."
289
- else:
290
- out_label += " " + str(exemplars[0].size()[0]) + " visual exemplars."
291
- else:
292
- out_label = "Nothing specified to detect."
293
-
294
- return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=boxes.shape[0]), new_submit_btn, gr.Tab(visible=True), step_3, state)
295
 
296
  @spaces.GPU
297
  def count_main(image, text, prompts, device):
298
- keywords = "" # do not handle this for now
299
- # Handle no prompt case.
300
  if prompts is None:
301
  prompts = {"image": image, "points": []}
302
- input_image, _ = transform(image, {"exemplars": torch.tensor([])})
303
- input_image = input_image.unsqueeze(0).to(device)
304
- exemplars = get_box_inputs(prompts["points"])
305
-
306
- input_image_exemplars, exemplars = transform(prompts["image"], {"exemplars": torch.tensor(exemplars)})
307
- input_image_exemplars = input_image_exemplars.unsqueeze(0).to(device)
308
- exemplars = [exemplars["exemplars"].to(device)]
309
-
310
- with torch.no_grad():
311
- model_output = model(
312
- nested_tensor_from_tensor_list(input_image),
313
- nested_tensor_from_tensor_list(input_image_exemplars),
314
- exemplars,
315
- [torch.tensor([0]).to(device) for _ in range(len(input_image))],
316
- captions=[text + " ."] * len(input_image),
317
- )
318
-
319
- ind_to_filter = get_ind_to_filter(text, model_output["token"][0].word_ids, keywords)
320
- logits = model_output["pred_logits"].sigmoid()[0][:, ind_to_filter]
321
- boxes = model_output["pred_boxes"][0]
322
- if len(keywords.strip()) > 0:
323
- box_mask = (logits > CONF_THRESH).sum(dim=-1) == len(ind_to_filter)
324
- else:
325
- box_mask = logits.max(dim=-1).values > CONF_THRESH
326
- logits = logits[box_mask, :].cpu().numpy()
327
- boxes = boxes[box_mask, :].cpu().numpy()
328
-
329
- # Plot results.
330
- (w, h) = image.size
331
- det_map = np.zeros((h, w))
332
- det_map[(h * boxes[:, 1]).astype(int), (w * boxes[:, 0]).astype(int)] = 1
333
- det_map = ndimage.gaussian_filter(
334
- det_map, sigma=(w // 200, w // 200), order=0
335
- )
336
- plt.imshow(image)
337
- plt.imshow(det_map[None, :].transpose(1, 2, 0), 'jet', interpolation='none', alpha=0.7)
338
- plt.axis('off')
339
- img_buf = io.BytesIO()
340
- plt.savefig(img_buf, format='png', bbox_inches='tight')
341
- plt.close()
342
-
343
- output_img = Image.open(img_buf)
344
-
345
- out_label = "Detected instances predicted with"
346
- if len(text.strip()) > 0:
347
- out_label += " text"
348
- if exemplars[0].size()[0] == 1:
349
- out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplar."
350
- elif exemplars[0].size()[0] > 1:
351
- out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplars."
352
- else:
353
- out_label += "."
354
- elif exemplars[0].size()[0] > 0:
355
- if exemplars[0].size()[0] == 1:
356
- out_label += " " + str(exemplars[0].size()[0]) + " visual exemplar."
357
- else:
358
- out_label += " " + str(exemplars[0].size()[0]) + " visual exemplars."
359
- else:
360
- out_label = "Nothing specified to detect."
361
 
362
- return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=boxes.shape[0]))
363
 
364
  def remove_label(image):
365
  return gr.Image(show_label=False)
@@ -401,12 +369,12 @@ if __name__ == '__main__':
401
  with gr.Accordion("Open for Further Information", open=False):
402
  gr.Markdown(exemplar_img_drawing_instructions_part_2)
403
  with gr.Tab("Step 1", visible=True) as step_1:
404
- input_image = gr.Image(type='pil', label='Input Image', show_label='True', value="strawberry.jpg", interactive=False, width="30vw")
405
  gr.Markdown('# Click "Count" to count the strawberries.')
406
 
407
  with gr.Column():
408
  with gr.Tab("Output Image"):
409
- detected_instances = gr.Image(label="Detected Instances", show_label='True', interactive=False, visible=True, width="40vw")
410
 
411
  with gr.Row():
412
  input_text = gr.Textbox(label="What would you like to count?", value="strawberry", interactive=True)
 
14
  import io
15
  from enum import Enum
16
  import os
 
 
 
 
 
17
  cwd = os.getcwd()
18
  # Suppress warnings to avoid overflowing the log.
19
  import warnings
 
140
 
141
  return model, data_transform
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  # APP:
144
  def get_box_inputs(prompts):
145
  box_inputs = []
 
176
 
177
  return inds_to_filter
178
 
179
+ def generate_heatmap(image, boxes):
180
+ # Plot results.
181
+ (w, h) = image.size
182
+ det_map = np.zeros((h, w))
183
+ det_map[(h * boxes[:, 1]).astype(int), (w * boxes[:, 0]).astype(int)] = 1
184
+ det_map = ndimage.gaussian_filter(
185
+ det_map, sigma=(w // 200, w // 200), order=0
186
+ )
187
+ plt.imshow(image)
188
+ plt.imshow(det_map[None, :].transpose(1, 2, 0), 'jet', interpolation='none', alpha=0.7)
189
+ plt.axis('off')
190
+ img_buf = io.BytesIO()
191
+ plt.savefig(img_buf, format='png', bbox_inches='tight')
192
+ plt.close()
193
+
194
+ output_img = Image.open(img_buf)
195
+ return output_img
196
+
197
+ def generate_output_label(text, num_exemplars):
198
+ out_label = "Detected instances predicted with"
199
+ if len(text.strip()) > 0:
200
+ out_label += " text"
201
+ if num_exemplars == 1:
202
+ out_label += " and " + str(num_exemplars) + " visual exemplar."
203
+ elif num_exemplars > 1:
204
+ out_label += " and " + str(num_exemplars) + " visual exemplars."
205
+ else:
206
+ out_label += "."
207
+ elif num_exemplars > 0:
208
+ if num_exemplars == 1:
209
+ out_label += " " + str(num_exemplars) + " visual exemplar."
210
+ else:
211
+ out_label += " " + str(num_exemplars) + " visual exemplars."
212
+ else:
213
+ out_label = "Nothing specified to detect."
214
+
215
+ return out_label
216
+
217
+ def preprocess(image, input_prompts = None):
218
+ if input_prompts == None:
219
+ prompts = { "image": image, "points": []}
220
+ else:
221
+ prompts = input_prompts
222
+
223
+ input_image, _ = transform(image, None)
224
+ exemplar = get_box_inputs(prompts["points"])
225
+ # Wrapping exemplar in a dictionary to apply only relevant transforms
226
+ input_image_exemplar, exemplar = transform(prompts['image'], {"exemplars": torch.tensor(exemplar)})
227
+ exemplar = exemplar["exemplars"]
228
+
229
+ return input_image, input_image_exemplar, exemplar
230
+
231
+ def get_boxes_from_prediction(model_output, text, keywords = ""):
232
+ ind_to_filter = get_ind_to_filter(text, model_output["token"][0].word_ids, keywords)
233
+ logits = model_output["pred_logits"].sigmoid()[0][:, ind_to_filter]
234
+ boxes = model_output["pred_boxes"][0]
235
+ if len(keywords.strip()) > 0:
236
+ box_mask = (logits > CONF_THRESH).sum(dim=-1) == len(ind_to_filter)
237
+ else:
238
+ box_mask = logits.max(dim=-1).values > CONF_THRESH
239
+ boxes = boxes[box_mask, :].cpu().numpy()
240
+ logits = logits[box_mask, :].cpu().numpy()
241
+ return boxes, logits
242
+
243
+ def predict(model, image, text, prompts, device):
244
+ keywords = "" # do not handle this for now
245
+ input_image, input_image_exemplar, exemplar = preprocess(image, prompts)
246
+
247
+ input_images = input_image.unsqueeze(0).to(device)
248
+ input_image_exemplars = input_image_exemplar.unsqueeze(0).to(device)
249
+ exemplars = [exemplar.to(device)]
250
+
251
+ with torch.no_grad():
252
+ model_output = model(
253
+ nested_tensor_from_tensor_list(input_images),
254
+ nested_tensor_from_tensor_list(input_image_exemplars),
255
+ exemplars,
256
+ [torch.tensor([0]).to(device) for _ in range(len(input_images))],
257
+ captions=[text + " ."] * len(input_images),
258
+ )
259
+
260
+ keywords = ""
261
+ return get_boxes_from_prediction(model_output, text, keywords)
262
+
263
+ examples = [
264
+ ["strawberry.jpg", "strawberry", {"image": "strawberry.jpg"}],
265
+ ["strawberry.jpg", "blueberry", {"image": "strawberry.jpg"}],
266
+ ["bird-1.JPG", "bird", {"image": "bird-2.JPG"}],
267
+ ["fish.jpg", "fish", {"image": "fish.jpg"}],
268
+ ["women.jpg", "girl", {"image": "women.jpg"}],
269
+ ["women.jpg", "boy", {"image": "women.jpg"}],
270
+ ["balloon.jpg", "hot air balloon", {"image": "balloon.jpg"}],
271
+ ["deer.jpg", "deer", {"image": "deer.jpg"}],
272
+ ["apple.jpg", "apple", {"image": "apple.jpg"}],
273
+ ["egg.jpg", "egg", {"image": "egg.jpg"}],
274
+ ["stamp.jpg", "stamp", {"image": "stamp.jpg"}],
275
+ ["green-pea.jpg", "green pea", {"image": "green-pea.jpg"}],
276
+ ["lego.jpg", "lego", {"image": "lego.jpg"}]
277
+ ]
278
+
279
+
280
  if __name__ == '__main__':
281
 
282
  parser = argparse.ArgumentParser("Counting Application", parents=[get_args_parser()])
 
287
 
288
  @spaces.GPU(duration=120)
289
  def count(image, text, prompts, state, device):
 
 
 
 
290
  if prompts is None:
291
  prompts = {"image": image, "points": []}
292
+
293
+ boxes, _ = predict(model, image, text, prompts, device)
294
+ count = len(boxes)
295
+ output_img = generate_heatmap(image, boxes)
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
+ num_exemplars = len(get_box_inputs(prompts["points"]))
298
+ out_label = generate_output_label(text, num_exemplars)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
  if AppSteps.TEXT_AND_EXEMPLARS not in state:
301
  exemplar_image = ImagePrompter(type='pil', label='Visual Exemplar Image', value=prompts, interactive=True, visible=True)
 
315
  main_instructions_comp = gr.Markdown(visible=True)
316
  step_3 = gr.Tab(visible=True)
317
 
318
+ return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=count), new_submit_btn, gr.Tab(visible=True), step_3, state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
  @spaces.GPU
321
  def count_main(image, text, prompts, device):
 
 
322
  if prompts is None:
323
  prompts = {"image": image, "points": []}
324
+ boxes, _ = predict(model, image, text, prompts, device)
325
+ count = len(boxes)
326
+ output_img = generate_heatmap(image, boxes)
327
+ num_exemplars = len(get_box_inputs(prompts["points"]))
328
+ out_label = generate_output_label(text, num_exemplars)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
+ return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=count))
331
 
332
  def remove_label(image):
333
  return gr.Image(show_label=False)
 
369
  with gr.Accordion("Open for Further Information", open=False):
370
  gr.Markdown(exemplar_img_drawing_instructions_part_2)
371
  with gr.Tab("Step 1", visible=True) as step_1:
372
+ input_image = gr.Image(type='pil', label='Input Image', show_label='True', value="strawberry.jpg", interactive=False)
373
  gr.Markdown('# Click "Count" to count the strawberries.')
374
 
375
  with gr.Column():
376
  with gr.Tab("Output Image"):
377
+ detected_instances = gr.Image(label="Detected Instances", show_label='True', interactive=False, visible=True)
378
 
379
  with gr.Row():
380
  input_text = gr.Textbox(label="What would you like to count?", value="strawberry", interactive=True)
requirements.txt CHANGED
@@ -12,6 +12,8 @@ ushlex
12
  gradio>=4.0.0,<5
13
  gradio_image_prompter-0.1.0-py3-none-any.whl
14
  spaces
 
 
15
  --extra-index-url https://download.pytorch.org/whl/cu121
16
  torch<2.6
17
  torchvision
 
12
  gradio>=4.0.0,<5
13
  gradio_image_prompter-0.1.0-py3-none-any.whl
14
  spaces
15
+ filetype
16
+ tqdm
17
  --extra-index-url https://download.pytorch.org/whl/cu121
18
  torch<2.6
19
  torchvision