AAAAAAyq
commited on
Commit
·
f2149ba
1
Parent(s):
8406393
Update requirements
Browse files
app.py
CHANGED
|
@@ -13,7 +13,6 @@ def format_results(result,filter = 0):
|
|
| 13 |
annotation = {}
|
| 14 |
mask = result.masks.data[i] == 1.0
|
| 15 |
|
| 16 |
-
|
| 17 |
if torch.sum(mask) < filter:
|
| 18 |
continue
|
| 19 |
annotation['id'] = i
|
|
@@ -50,51 +49,16 @@ def post_process(annotations, image, mask_random_color=True, bbox=None, points=N
|
|
| 50 |
for i, mask in enumerate(annotations):
|
| 51 |
show_mask(mask, plt.gca(),random_color=mask_random_color,bbox=bbox,points=points)
|
| 52 |
plt.axis('off')
|
| 53 |
-
# # create a BytesIO object
|
| 54 |
-
# buf = io.BytesIO()
|
| 55 |
-
|
| 56 |
-
# # save plot to buf
|
| 57 |
-
# plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.0)
|
| 58 |
-
|
| 59 |
-
# # use PIL to open the image
|
| 60 |
-
# img = Image.open(buf)
|
| 61 |
|
| 62 |
-
# # copy the image data
|
| 63 |
-
# img_copy = img.copy()
|
| 64 |
plt.tight_layout()
|
| 65 |
-
|
| 66 |
-
# # don't forget to close the buffer
|
| 67 |
-
# buf.close()
|
| 68 |
return fig
|
| 69 |
|
| 70 |
|
| 71 |
-
# def show_mask(annotation, ax, random_color=False):
|
| 72 |
-
# if random_color : # 掩膜颜色是否随机决定
|
| 73 |
-
# color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
| 74 |
-
# else:
|
| 75 |
-
# color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
|
| 76 |
-
# mask = annotation.cpu().numpy()
|
| 77 |
-
# h, w = mask.shape[-2:]
|
| 78 |
-
# mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
| 79 |
-
# ax.imshow(mask_image)
|
| 80 |
-
|
| 81 |
-
# def post_process(annotations, image):
|
| 82 |
-
# plt.figure(figsize=(10, 10))
|
| 83 |
-
# plt.imshow(image)
|
| 84 |
-
# for i, mask in enumerate(annotations):
|
| 85 |
-
# show_mask(mask.data, plt.gca(),random_color=True)
|
| 86 |
-
# plt.axis('off')
|
| 87 |
-
|
| 88 |
-
# 获取渲染后的像素数据并转换为PIL图像
|
| 89 |
-
|
| 90 |
-
return pil_image
|
| 91 |
-
|
| 92 |
-
|
| 93 |
# post_process(results[0].masks, Image.open("../data/cake.png"))
|
| 94 |
|
| 95 |
-
def predict(inp,
|
| 96 |
-
|
| 97 |
-
results = model(inp, device='cpu', retina_masks=True, iou=0.7, conf=0.25, imgsz=
|
| 98 |
results = format_results(results[0], 100)
|
| 99 |
results.sort(key=lambda x: x['area'], reverse=True)
|
| 100 |
pil_image = post_process(annotations=results, image=inp)
|
|
@@ -106,10 +70,10 @@ def predict(inp, imgsz):
|
|
| 106 |
# post_process(annotations=results, image_path=inp)
|
| 107 |
|
| 108 |
demo = gr.Interface(fn=predict,
|
| 109 |
-
inputs=[gr.inputs.Image(type='pil'), gr.inputs.Dropdown(choices=[
|
| 110 |
outputs=['plot'],
|
| 111 |
-
examples=[["assets/sa_8776.jpg", 1024],
|
| 112 |
-
|
| 113 |
# examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
|
| 114 |
# ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
|
| 115 |
# ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],
|
|
|
|
| 13 |
annotation = {}
|
| 14 |
mask = result.masks.data[i] == 1.0
|
| 15 |
|
|
|
|
| 16 |
if torch.sum(mask) < filter:
|
| 17 |
continue
|
| 18 |
annotation['id'] = i
|
|
|
|
| 49 |
for i, mask in enumerate(annotations):
|
| 50 |
show_mask(mask, plt.gca(),random_color=mask_random_color,bbox=bbox,points=points)
|
| 51 |
plt.axis('off')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
|
|
|
|
|
|
| 53 |
plt.tight_layout()
|
|
|
|
|
|
|
|
|
|
| 54 |
return fig
|
| 55 |
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
# post_process(results[0].masks, Image.open("../data/cake.png"))
|
| 58 |
|
| 59 |
+
def predict(inp, input_size):
|
| 60 |
+
input_size = int(input_size) # 确保 imgsz 是整数
|
| 61 |
+
results = model(inp, device='cpu', retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
|
| 62 |
results = format_results(results[0], 100)
|
| 63 |
results.sort(key=lambda x: x['area'], reverse=True)
|
| 64 |
pil_image = post_process(annotations=results, image=inp)
|
|
|
|
| 70 |
# post_process(annotations=results, image_path=inp)
|
| 71 |
|
| 72 |
demo = gr.Interface(fn=predict,
|
| 73 |
+
inputs=[gr.inputs.Image(type='pil'), gr.inputs.Dropdown(choices=[512, 800, 1024])],
|
| 74 |
outputs=['plot'],
|
| 75 |
+
examples=[["assets/sa_8776.jpg", 1024]],
|
| 76 |
+
# ["assets/sa_1309.jpg", 1024]],
|
| 77 |
# examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
|
| 78 |
# ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
|
| 79 |
# ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],
|