File size: 5,875 Bytes
933c40c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3004ad
adf5040
933c40c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0761cf
adf5040
 
933c40c
adf5040
 
 
 
933c40c
 
 
 
63d8dec
 
 
adf5040
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# import gradio as gr
# from gradio_image_prompter import ImagePrompter

# import os
# import torch


# def prompter(prompts):
#     image = prompts["image"]  # Get the image from prompts
#     points = prompts["points"]  # Get the points from prompts

#     # Print the collected inputs for debugging or logging
#     print("Image received:", image)
#     print("Points received:", points)

#     import torch
#     from sam2.sam2_image_predictor import SAM2ImagePredictor

#     device = torch.device("cpu")

#     predictor = SAM2ImagePredictor.from_pretrained(
#         "facebook/sam2-hiera-base-plus", device=device
#     )

#     with torch.inference_mode():
#         predictor.set_image(image)
#         # masks, _, _ = predictor.predict([[point[0], point[1]] for point in points])
#         input_point = [[point[0], point[1]] for point in points]
#         input_label = [1]
#         masks, _, _ = predictor.predict(
#             point_coords=input_point, point_labels=input_label
#         )
#     print("Predicted Mask:", masks)

#     return image, points


# # Define the Gradio interface
# demo = gr.Interface(
#     fn=prompter,  # Use the custom prompter function
#     inputs=ImagePrompter(
#         show_label=False
#     ),  # ImagePrompter for image input and point selection
#     outputs=[
#         gr.Image(show_label=False),  # Display the image
#         gr.Dataframe(label="Points"),  # Display the points in a DataFrame
#     ],
#     title="Image Point Collector",
#     description="Upload an image, click on it, and get the coordinates of the clicked points.",
# )

# # Launch the Gradio app
# demo.launch()


# import gradio as gr
# from gradio_image_prompter import ImagePrompter
# import torch
# from sam2.sam2_image_predictor import SAM2ImagePredictor


# def prompter(prompts):
#     image = prompts["image"]  # Get the image from prompts
#     points = prompts["points"]  # Get the points from prompts

#     # Print the collected inputs for debugging or logging
#     print("Image received:", image)
#     print("Points received:", points)

#     device = torch.device("cpu")

#     # Load the SAM2ImagePredictor model
#     predictor = SAM2ImagePredictor.from_pretrained(
#         "facebook/sam2-hiera-base-plus", device=device
#     )

#     # Perform inference
#     with torch.inference_mode():
#         predictor.set_image(image)
#         input_point = [[point[0], point[1]] for point in points]
#         input_label = [1] * len(points)  # Assuming all points are foreground
#         masks, _, _ = predictor.predict(
#             point_coords=input_point, point_labels=input_label
#         )

#     # The masks are returned as a list of numpy arrays
#     print("Predicted Mask:", masks)

#     # Assuming there's only one mask returned, you can adjust if there are multiple
#     predicted_mask = masks[0]

#     print(len(image))

#     print(len(predicted_mask))

#     # Create annotations for AnnotatedImage
#     annotations = [(predicted_mask, "Predicted Mask")]

#     return image, annotations


# # Define the Gradio interface
# demo = gr.Interface(
#     fn=prompter,  # Use the custom prompter function
#     inputs=ImagePrompter(
#         show_label=False
#     ),  # ImagePrompter for image input and point selection
#     outputs=gr.AnnotatedImage(),  # Display the image with the predicted mask
#     title="Image Point Collector with Mask Overlay",
#     description="Upload an image, click on it, and get the predicted mask overlayed on the image.",
# )

# # Launch the Gradio app
# demo.launch()


import gradio as gr
from gradio_image_prompter import ImagePrompter
import torch
import numpy as np
from sam2.sam2_image_predictor import SAM2ImagePredictor
from PIL import Image


def prompter(prompts):
    image = np.array(prompts["image"])  # Convert the image to a numpy array
    points = prompts["points"]  # Get the points from prompts

    # Print the collected inputs for debugging or logging
    print("Image received:", image)
    print("Points received:", points)

    device = torch.device("cpu")

    # Load the SAM2ImagePredictor model
    predictor = SAM2ImagePredictor.from_pretrained(
        "facebook/sam2-hiera-base-plus", device=device
    )

    # Perform inference with multimask_output=True
    with torch.inference_mode():
        predictor.set_image(image)
        input_point = [[point[0], point[1]] for point in points]
        input_label = [1] * len(points)  # Assuming all points are foreground
        masks, _, _ = predictor.predict(
            point_coords=input_point, point_labels=input_label, multimask_output=True
        )

    # Prepare individual images with separate overlays
    overlay_images = []
    for i, mask in enumerate(masks):
        print(f"Predicted Mask {i+1}:", mask)
        red_mask = np.zeros_like(image)
        red_mask[:, :, 0] = mask.astype(np.uint8) * 255  # Apply the red channel
        red_mask = Image.fromarray(red_mask)

        # Convert the original image to a PIL image
        original_image = Image.fromarray(image)

        # Blend the original image with the red mask
        blended_image = Image.blend(original_image, red_mask, alpha=0.5)

        # Add the blended image to the list
        overlay_images.append(blended_image)

    return overlay_images


# Define the Gradio interface
demo = gr.Interface(
    fn=prompter,  # Use the custom prompter function
    inputs=ImagePrompter(
        show_label=False
    ),  # ImagePrompter for image input and point selection
    outputs=[
        gr.Image(show_label=False) for _ in range(3)
    ],  # Display up to 3 overlay images
    title="Image Point Collector with Multiple Separate Mask Overlays",
    description="Upload an image, click on it, and get each predicted mask overlaid separately in red on individual images.",
)

# Launch the Gradio app
demo.launch()