File size: 3,994 Bytes
4d26566
 
 
 
87c6f54
4d26566
 
 
87c6f54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d26566
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87c6f54
30e0f74
4d26566
 
 
 
30e0f74
 
4d26566
30e0f74
 
4d26566
30e0f74
 
4d26566
30e0f74
 
 
87c6f54
30e0f74
 
 
4d26566
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87c6f54
 
 
4d26566
 
87c6f54
 
 
 
4d26566
 
 
30e0f74
4d26566
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from ultralytics import YOLO
import numpy as np
import matplotlib.pyplot as plt
import gradio as gr
import torch

model = YOLO('checkpoints/FastSAM.pt')  # load a custom model

def format_results(result,filter = 0):
    annotations = []
    n = len(result.masks.data)
    for i in range(n):
        annotation = {}
        mask = result.masks.data[i] == 1.0

    
        if torch.sum(mask) < filter:
            continue
        annotation['id'] = i
        annotation['segmentation'] = mask.cpu().numpy()
        annotation['bbox'] = result.boxes.data[i]
        annotation['score'] = result.boxes.conf[i]
        annotation['area'] = annotation['segmentation'].sum()
        annotations.append(annotation)
    return annotations

def show_mask(annotation, ax, random_color=True, bbox=None, points=None):
    if random_color :    # random mask color
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    if type(annotation) == dict:
        annotation = annotation['segmentation']
    mask = annotation
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    # draw box
    if bbox is not None:
        x1, y1, x2, y2 = bbox
        ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
    # draw point
    if points is not None:
        ax.scatter([point[0] for point in points], [point[1] for point in points], s=10, c='g')
    ax.imshow(mask_image)
    return mask_image

def post_process(annotations, image, mask_random_color=True, bbox=None, points=None):
    fig = plt.figure(figsize=(10, 10))
    plt.imshow(image)
    for i, mask in enumerate(annotations):
        show_mask(mask, plt.gca(),random_color=mask_random_color,bbox=bbox,points=points)
    plt.axis('off')
    # # create a BytesIO object
    # buf = io.BytesIO()

    # # save plot to buf
    # plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.0)
    
    # # use PIL to open the image
    # img = Image.open(buf)
    
    # # copy the image data
    # img_copy = img.copy()
    plt.tight_layout()
    
    # # don't forget to close the buffer
    # buf.close()
    return fig


# def show_mask(annotation, ax, random_color=False):
#     if random_color :    # 掩膜颜色是否随机决定
#         color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
#     else:
#         color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
#     mask = annotation.cpu().numpy()
#     h, w = mask.shape[-2:]
#     mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
#     ax.imshow(mask_image)

# def post_process(annotations, image):
#     plt.figure(figsize=(10, 10))
#     plt.imshow(image)
#     for i, mask in enumerate(annotations):
#         show_mask(mask.data, plt.gca(),random_color=True)
#     plt.axis('off')
    
    # 获取渲染后的像素数据并转换为PIL图像
    
    return pil_image


# post_process(results[0].masks, Image.open("../data/cake.png"))

def predict(inp):
    results = model(inp, device='cpu', retina_masks=True, iou=0.7, conf=0.25, imgsz=1024)
    results = format_results(results[0], 100)
    pil_image = post_process(annotations=results, image=inp)
    return pil_image

# inp = 'assets/sa_192.jpg'
# results = model(inp, device='cpu', retina_masks=True, iou=0.7, conf=0.25, imgsz=1024)
# results = format_results(results[0], 100)
# post_process(annotations=results, image_path=inp)

demo = gr.Interface(fn=predict,
                    inputs=gr.inputs.Image(type='pil'),
                    outputs=['plot'],
                    examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
                              ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
                              ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],
                              ["assets/sa_10039.jpg"], ["assets/sa_11025.jpg"],],
                    )

demo.launch()