File size: 5,023 Bytes
dfdcd97
a3ee867
e9cd6fd
 
c95f3e0
9a34a8b
 
 
 
e0d4d2f
c95f3e0
 
 
9a34a8b
 
3cd1243
9a34a8b
 
 
 
 
 
e0d4d2f
9a34a8b
72f4c5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26c0f04
3cd1243
73989e5
 
 
 
 
26c0f04
2dd8fe8
 
26c0f04
9a34a8b
 
3cd1243
9a34a8b
 
73989e5
9a34a8b
 
73989e5
2dd8fe8
 
 
 
 
 
 
 
 
 
 
 
 
 
9a34a8b
 
73989e5
3cd1243
3ba1061
73989e5
 
e0d4d2f
e9cd6fd
 
e0d4d2f
e9cd6fd
3ba1061
3cd1243
e9cd6fd
3ba1061
9a34a8b
3ba1061
 
9a34a8b
 
e0d4d2f
 
e9cd6fd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import gradio as gr
import torch
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import io
from ultralytics import FastSAM
from ultralytics.models.fastsam import FastSAMPrompt

# Set up device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load FastSAM model
model = FastSAM("FastSAM-s.pt")  # or FastSAM-x.pt

def fig2img(fig):
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    img = Image.open(buf)
    return img

def plot(annotations, prompt_process, mask_random_color=True, better_quality=True, retina=True, with_contours=True):
    for ann in annotations:
        image = ann.orig_img[..., ::-1]  # BGR to RGB
        original_h, original_w = ann.orig_shape
        fig = plt.figure(figsize=(original_w / 100, original_h / 100))
        plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
        plt.margins(0, 0)
        plt.gca().xaxis.set_major_locator(plt.NullLocator())
        plt.gca().yaxis.set_major_locator(plt.NullLocator())
        plt.imshow(image)

        if ann.masks is not None:
            masks = ann.masks.data
            if better_quality:
                if isinstance(masks[0], torch.Tensor):
                    masks = np.array(masks.cpu())
                for i, mask in enumerate(masks):
                    mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
                    masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))

            prompt_process.fast_show_mask(
                masks,
                plt.gca(),
                random_color=mask_random_color,
                bbox=None,
                points=None,
                pointlabel=None,
                retinamask=retina,
                target_height=original_h,
                target_width=original_w,
            )

            if with_contours:
                contour_all = []
                temp = np.zeros((original_h, original_w, 1))
                for i, mask in enumerate(masks):
                    mask = mask.astype(np.uint8)
                    if not retina:
                        mask = cv2.resize(mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
                    contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
                    contour_all.extend(iter(contours))
                cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
                color = np.array([0 / 255, 0 / 255, 1.0, 0.8])
                contour_mask = temp / 255 * color.reshape(1, 1, -1)
                plt.imshow(contour_mask)

        plt.axis("off")
        plt.close()
        return fig2img(fig)

def segment_image(input_image, object_name):
    try:
        if input_image is None:
            return None, "Please upload an image before submitting."
        
        input_image = Image.fromarray(input_image).convert("RGB")
        
        # Run FastSAM model with adjusted parameters
        everything_results = model(input_image, retina_masks=True, imgsz=1024, conf=0.25, iou=0.7)
        
        # Prepare a Prompt Process object
        prompt_process = FastSAMPrompt(input_image, everything_results, device=device)
        
        # Use text prompt to segment the specified object
        results = prompt_process.text_prompt(text=object_name)
        
        if not results:
            return input_image, f"Could not find '{object_name}' in the image."
        
        # Post-process the masks
        for ann in results:
            if ann.masks is not None:
                masks = ann.masks.data
                if isinstance(masks[0], torch.Tensor):
                    masks = np.array(masks.cpu())
                for i, mask in enumerate(masks):
                    # Apply more aggressive morphological operations
                    kernel = np.ones((5,5), np.uint8)
                    mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
                    mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, kernel)
                    masks[i] = cv2.dilate(mask, kernel, iterations=2)
                ann.masks.data = masks
        
        # Plot the results
        result_image = plot(annotations=results, prompt_process=prompt_process)
        
        return result_image, f"Segmented '{object_name}' in the image."
    
    except Exception as e:
        return None, f"An error occurred: {str(e)}"

# Create Gradio interface
iface = gr.Interface(
    fn=segment_image,
    inputs=[
        gr.Image(type="numpy", label="Upload an image"),
        gr.Textbox(label="Specify object to segment (e.g., dog, cat, grass)")
    ],
    outputs=[
        gr.Image(type="pil", label="Segmented Image"),
        gr.Textbox(label="Status")
    ],
    title="FastSAM Segmentation with Object Specification",
    description="Upload an image and specify an object to segment using FastSAM."
)

# Launch the interface
iface.launch()