ZhengPeng7 commited on
Commit
b96d57e
·
verified ·
1 Parent(s): a7920cd

Change to tab_batch to take dynamic number of images as the input.

Browse files
Files changed (1) hide show
  1. app.py +39 -20
app.py CHANGED
@@ -36,8 +36,24 @@ model.eval()
36
 
37
 
38
  @spaces.GPU
39
- def pred_maps(image_1, image_2, image_3, image_4):
40
- images = [image_1, image_2, image_3, image_4]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  image_shapes = [image.shape[:2] for image in images]
42
  images = [Image.fromarray(image) for image in images]
43
 
@@ -50,30 +66,33 @@ def pred_maps(image_1, image_2, image_3, image_4):
50
  with torch.no_grad():
51
  scaled_preds_tensor = model(images_proc.to(device))[-1]
52
  preds = []
53
- for image_shape, pred_tensor in zip(image_shapes, scaled_preds_tensor):
54
  if device == 'cuda':
55
  pred_tensor = pred_tensor.cpu()
56
- preds.append(torch.nn.functional.interpolate(pred_tensor.unsqueeze(0), size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy())
57
- image_preds = []
58
- for image, pred in zip(images, preds):
59
- image_preds.append(
60
- np.hstack([np.array(image.convert('RGB')), cv2.cvtColor((pred*255).astype(np.uint8), cv2.COLOR_GRAY2RGB)])
61
- )
62
- return image_preds[:]
 
63
 
64
 
65
- N = 4
66
  # examples = [[_] for _ in glob('example_images/butterfly/*')][:N]
67
 
68
- ipt = [gr.Image(width=600, height=300) for _ in range(N)]
69
- opt = [gr.Image(width=600, height=300) for _ in range(N)]
70
- demo = gr.Interface(
71
  fn=pred_maps,
72
- inputs=ipt,
73
- outputs=opt,
74
- # examples=examples,
75
- # interpretation='default',
76
- title='Online demo for `GCoNet+: A Stronger Group Collaborative Co-Salient Object Detector (T-PAMI 2023)`',
77
- description='Upload pictures, most of which contain salient objects of the same class. Our demo will give you the binary maps of these co-salient objects :)\n**********Example images need to be dropped into each block, instead of click.**********'
 
 
 
 
78
  )
79
  demo.launch(debug=True)
 
36
 
37
 
38
  @spaces.GPU
39
+ def pred_maps(images):
40
+ assert (images is not None), 'AssertionError: images cannot be None.'
41
+ # For tab_batch
42
+ save_paths = []
43
+ save_dir = 'preds-GCoNet_plus'
44
+ if not os.path.exists(save_dir):
45
+ os.makedirs(save_dir)
46
+
47
+ image_array_lst = []
48
+ for idx_image, image_src in enumerate(images):
49
+ save_paths.append(os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0])))
50
+ if isinstance(image_src, str):
51
+ image = np.array(Image.open(image_src))
52
+ else:
53
+ image = image_src
54
+ image_array_lst.append(image)
55
+ images = image_array_lst
56
+
57
  image_shapes = [image.shape[:2] for image in images]
58
  images = [Image.fromarray(image) for image in images]
59
 
 
66
  with torch.no_grad():
67
  scaled_preds_tensor = model(images_proc.to(device))[-1]
68
  preds = []
69
+ for image_shape, pred_tensor, save_path in zip(image_shapes, scaled_preds_tensor, save_paths):
70
  if device == 'cuda':
71
  pred_tensor = pred_tensor.cpu()
72
+ pred_tensor = torch.nn.functional.interpolate(pred_tensor.unsqueeze(0), size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy()
73
+ cv2.imwrite(save_path, pred_tensor)
74
+
75
+ zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
76
+ with zipfile.ZipFile(zip_file_path, 'w') as zipf:
77
+ for file in save_paths:
78
+ zipf.write(file, os.path.basename(file))
79
+ return save_paths, zip_file_path
80
 
81
 
82
+ # N = 4
83
  # examples = [[_] for _ in glob('example_images/butterfly/*')][:N]
84
 
85
+ tab_batch = gr.Interface(
 
 
86
  fn=pred_maps,
87
+ inputs=gr.File(label="Upload multiple images in a group", type="filepath", file_count="multiple"),
88
+ outputs=[gr.Gallery(label="GCoNet+'s predictions"), gr.File(label="Download predicted maps.")],
89
+ api_name="batch",
90
+ description='Upload pictures, most of which contain salient objects of the same class. Our demo will give you the binary maps of these co-salient objects :)',
91
+ )
92
+
93
+ demo = gr.TabbedInterface(
94
+ [tab_batch],
95
+ ['batch'],
96
+ title="Online demo for `GCoNet+: A Stronger Group Collaborative Co-Salient Object Detector (T-PAMI 2023)`",
97
  )
98
  demo.launch(debug=True)