AAAAAAyq commited on
Commit
0eb155c
1 Parent(s): bd6726a

Fix the queue problem to improve stability

Browse files
Files changed (2) hide show
  1. app.py +40 -23
  2. app_copy.py → app_debug.py +6 -8
app.py CHANGED
@@ -6,7 +6,7 @@ import cv2
6
  import torch
7
  # import queue
8
  # import threading
9
- # from PIL import Image
10
 
11
 
12
  model = YOLO('checkpoints/FastSAM.pt') # load a custom model
@@ -18,8 +18,9 @@ def fast_process(annotations, image, high_quality, device):
18
 
19
  original_h = image.height
20
  original_w = image.width
21
- fig = plt.figure(figsize=(10, 10))
22
- plt.imshow(image)
 
23
  if high_quality == True:
24
  if isinstance(annotations[0],torch.Tensor):
25
  annotations = np.array(annotations.cpu())
@@ -28,7 +29,7 @@ def fast_process(annotations, image, high_quality, device):
28
  annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
29
  if device == 'cpu':
30
  annotations = np.array(annotations)
31
- fast_show_mask(annotations,
32
  plt.gca(),
33
  bbox=None,
34
  points=None,
@@ -39,13 +40,14 @@ def fast_process(annotations, image, high_quality, device):
39
  else:
40
  if isinstance(annotations[0],np.ndarray):
41
  annotations = torch.from_numpy(annotations)
42
- fast_show_mask_gpu(annotations,
43
  plt.gca(),
44
  bbox=None,
45
  points=None,
46
  pointlabel=None)
47
  if isinstance(annotations, torch.Tensor):
48
  annotations = annotations.cpu().numpy()
 
49
  if high_quality == True:
50
  contour_all = []
51
  temp = np.zeros((original_h, original_w,1))
@@ -58,12 +60,17 @@ def fast_process(annotations, image, high_quality, device):
58
  contour_all.append(contour)
59
  cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
60
  color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
61
- contour_mask = temp / 225 * color.reshape(1, 1, -1)
62
- plt.imshow(contour_mask)
63
-
64
- plt.axis('off')
65
- plt.tight_layout()
66
- return fig
 
 
 
 
 
67
 
68
 
69
  # CPU post process
@@ -85,12 +92,12 @@ def fast_show_mask(annotation, ax, bbox=None,
85
  visual = np.concatenate([color,transparency],axis=-1)
86
  mask_image = np.expand_dims(annotation,-1) * visual
87
 
88
- show = np.zeros((height,weight,4))
89
 
90
  h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
91
  indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
92
  # 使用向量化索引更新show的值
93
- show[h_indices, w_indices, :] = mask_image[indices]
94
  if bbox is not None:
95
  x1, y1, x2, y2 = bbox
96
  ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
@@ -100,8 +107,10 @@ def fast_show_mask(annotation, ax, bbox=None,
100
  plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==0], [point[1] for i, point in enumerate(points) if pointlabel[i]==0], s=20, c='m')
101
 
102
  if retinamask==False:
103
- show = cv2.resize(show,(target_width,target_height),interpolation=cv2.INTER_NEAREST)
104
- ax.imshow(show)
 
 
105
 
106
 
107
  def fast_show_mask_gpu(annotation, ax,
@@ -120,12 +129,12 @@ def fast_show_mask_gpu(annotation, ax,
120
  visual = torch.cat([color,transparency],dim=-1)
121
  mask_image = torch.unsqueeze(annotation,-1) * visual
122
  # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
123
- show = torch.zeros((height,weight,4)).to(annotation.device)
124
  h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
125
  indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
126
  # 使用向量化索引更新show的值
127
- show[h_indices, w_indices, :] = mask_image[indices]
128
- show_cpu = show.cpu().numpy()
129
  if bbox is not None:
130
  x1, y1, x2, y2 = bbox
131
  ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
@@ -133,9 +142,15 @@ def fast_show_mask_gpu(annotation, ax,
133
  if points is not None:
134
  plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==1], [point[1] for i, point in enumerate(points) if pointlabel[i]==1], s=20, c='y')
135
  plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==0], [point[1] for i, point in enumerate(points) if pointlabel[i]==0], s=20, c='m')
136
- ax.imshow(show_cpu)
 
 
137
 
 
 
138
 
 
 
139
 
140
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
141
 
@@ -157,12 +172,14 @@ def predict(input, input_size=512, high_visual_quality=False):
157
  # results = model(input, device=device, retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
158
  # pil_image = fast_process(annotations=results[0].masks.data,
159
  # image=input, high_quality=high_quality_visual, device=device)
 
160
  app_interface = gr.Interface(fn=predict,
161
- inputs=[gr.components.Image(type='pil'),
162
  gr.components.Slider(minimum=512, maximum=1024, value=1024, step=64, label='input_size'),
163
- gr.components.Checkbox(value=False, label='high_visual_quality')],
164
- outputs=['plot'],
165
- examples=[["assets/sa_8776.jpg", 1024, True]],
 
166
  # # ["assets/sa_1309.jpg", 1024]],
167
  # examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
168
  # ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
 
6
  import torch
7
  # import queue
8
  # import threading
9
+ from PIL import Image
10
 
11
 
12
  model = YOLO('checkpoints/FastSAM.pt') # load a custom model
 
18
 
19
  original_h = image.height
20
  original_w = image.width
21
+ image = image.convert('RGBA')
22
+ # fig = plt.figure(figsize=(10, 10))
23
+ # plt.imshow(image)
24
  if high_quality == True:
25
  if isinstance(annotations[0],torch.Tensor):
26
  annotations = np.array(annotations.cpu())
 
29
  annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
30
  if device == 'cpu':
31
  annotations = np.array(annotations)
32
+ inner_mask = fast_show_mask(annotations,
33
  plt.gca(),
34
  bbox=None,
35
  points=None,
 
40
  else:
41
  if isinstance(annotations[0],np.ndarray):
42
  annotations = torch.from_numpy(annotations)
43
+ inner_mask = fast_show_mask_gpu(annotations,
44
  plt.gca(),
45
  bbox=None,
46
  points=None,
47
  pointlabel=None)
48
  if isinstance(annotations, torch.Tensor):
49
  annotations = annotations.cpu().numpy()
50
+
51
  if high_quality == True:
52
  contour_all = []
53
  temp = np.zeros((original_h, original_w,1))
 
60
  contour_all.append(contour)
61
  cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
62
  color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
63
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
64
+ overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), 'RGBA')
65
+ image.paste(overlay_contour, (0, 0), overlay_contour)
66
+ # plt.imshow(contour_mask)
67
+ overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA')
68
+ image.paste(overlay_inner, (0, 0), overlay_inner)
69
+
70
+ return image
71
+ # plt.axis('off')
72
+ # plt.tight_layout()
73
+ # return fig
74
 
75
 
76
  # CPU post process
 
92
  visual = np.concatenate([color,transparency],axis=-1)
93
  mask_image = np.expand_dims(annotation,-1) * visual
94
 
95
+ mask = np.zeros((height,weight,4))
96
 
97
  h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
98
  indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
99
  # 使用向量化索引更新show的值
100
+ mask[h_indices, w_indices, :] = mask_image[indices]
101
  if bbox is not None:
102
  x1, y1, x2, y2 = bbox
103
  ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
 
107
  plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==0], [point[1] for i, point in enumerate(points) if pointlabel[i]==0], s=20, c='m')
108
 
109
  if retinamask==False:
110
+ mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
111
+ # ax.imshow(mask)
112
+
113
+ return mask
114
 
115
 
116
  def fast_show_mask_gpu(annotation, ax,
 
129
  visual = torch.cat([color,transparency],dim=-1)
130
  mask_image = torch.unsqueeze(annotation,-1) * visual
131
  # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
132
+ mask = torch.zeros((height,weight,4)).to(annotation.device)
133
  h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
134
  indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
135
  # 使用向量化索引更新show的值
136
+ mask[h_indices, w_indices, :] = mask_image[indices]
137
+ mask_cpu = mask.cpu().numpy()
138
  if bbox is not None:
139
  x1, y1, x2, y2 = bbox
140
  ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
 
142
  if points is not None:
143
  plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==1], [point[1] for i, point in enumerate(points) if pointlabel[i]==1], s=20, c='y')
144
  plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==0], [point[1] for i, point in enumerate(points) if pointlabel[i]==0], s=20, c='m')
145
+ # ax.imshow(mask_cpu)
146
+ return mask_cpu
147
+
148
 
149
+ # # 预测队列
150
+ # prediction_queue = queue.Queue(maxsize=5)
151
 
152
+ # # 线程锁
153
+ # lock = threading.Lock()
154
 
155
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
156
 
 
172
  # results = model(input, device=device, retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
173
  # pil_image = fast_process(annotations=results[0].masks.data,
174
  # image=input, high_quality=high_quality_visual, device=device)
175
+
176
  app_interface = gr.Interface(fn=predict,
177
+ inputs=[gr.Image(type='pil'),
178
  gr.components.Slider(minimum=512, maximum=1024, value=1024, step=64, label='input_size'),
179
+ gr.components.Checkbox(value=True, label='high_visual_quality')],
180
+ # outputs=['plot'],
181
+ outputs=gr.Image(type='pil'),
182
+ examples=[["assets/sa_8776.jpg"]],
183
  # # ["assets/sa_1309.jpg", 1024]],
184
  # examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
185
  # ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
app_copy.py → app_debug.py RENAMED
@@ -18,6 +18,7 @@ def fast_process(annotations, image, high_quality, device):
18
 
19
  original_h = image.height
20
  original_w = image.width
 
21
  # fig = plt.figure(figsize=(10, 10))
22
  # plt.imshow(image)
23
  if high_quality == True:
@@ -46,6 +47,7 @@ def fast_process(annotations, image, high_quality, device):
46
  pointlabel=None)
47
  if isinstance(annotations, torch.Tensor):
48
  annotations = annotations.cpu().numpy()
 
49
  if high_quality == True:
50
  contour_all = []
51
  temp = np.zeros((original_h, original_w,1))
@@ -59,15 +61,11 @@ def fast_process(annotations, image, high_quality, device):
59
  cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
60
  color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
61
  contour_mask = temp / 255 * color.reshape(1, 1, -1)
 
 
62
  # plt.imshow(contour_mask)
63
- image = image.convert('RGBA')
64
  overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA')
65
- overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), 'RGBA')
66
- # image = image.convert('RGBA')
67
- # image = Image.alpha_composite(image, overlay_inner)
68
- # image = Image.alpha_composite(image, overlay_contour)
69
  image.paste(overlay_inner, (0, 0), overlay_inner)
70
- image.paste(overlay_contour, (0, 0), overlay_contour)
71
 
72
  return image
73
  # plt.axis('off')
@@ -176,11 +174,11 @@ def predict(input, input_size=512, high_visual_quality=False):
176
  # image=input, high_quality=high_quality_visual, device=device)
177
 
178
  app_interface = gr.Interface(fn=predict,
179
- inputs=[gr.components.Image(type='pil'),
180
  gr.components.Slider(minimum=512, maximum=1024, value=1024, step=64, label='input_size'),
181
  gr.components.Checkbox(value=False, label='high_visual_quality')],
182
  # outputs=['plot'],
183
- outputs=gr.components.Image(type='pil'),
184
  examples=[["assets/sa_8776.jpg", 1024, True]],
185
  # # ["assets/sa_1309.jpg", 1024]],
186
  # examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
 
18
 
19
  original_h = image.height
20
  original_w = image.width
21
+ image = image.convert('RGBA')
22
  # fig = plt.figure(figsize=(10, 10))
23
  # plt.imshow(image)
24
  if high_quality == True:
 
47
  pointlabel=None)
48
  if isinstance(annotations, torch.Tensor):
49
  annotations = annotations.cpu().numpy()
50
+
51
  if high_quality == True:
52
  contour_all = []
53
  temp = np.zeros((original_h, original_w,1))
 
61
  cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
62
  color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
63
  contour_mask = temp / 255 * color.reshape(1, 1, -1)
64
+ overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), 'RGBA')
65
+ image.paste(overlay_contour, (0, 0), overlay_contour)
66
  # plt.imshow(contour_mask)
 
67
  overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA')
 
 
 
 
68
  image.paste(overlay_inner, (0, 0), overlay_inner)
 
69
 
70
  return image
71
  # plt.axis('off')
 
174
  # image=input, high_quality=high_quality_visual, device=device)
175
 
176
  app_interface = gr.Interface(fn=predict,
177
+ inputs=[gr.Image(type='pil'),
178
  gr.components.Slider(minimum=512, maximum=1024, value=1024, step=64, label='input_size'),
179
  gr.components.Checkbox(value=False, label='high_visual_quality')],
180
  # outputs=['plot'],
181
+ outputs=gr.Image(type='pil'),
182
  examples=[["assets/sa_8776.jpg", 1024, True]],
183
  # # ["assets/sa_1309.jpg", 1024]],
184
  # examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],