File size: 5,875 Bytes
933c40c f3004ad adf5040 933c40c a0761cf adf5040 933c40c adf5040 933c40c 63d8dec adf5040 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
# import gradio as gr
# from gradio_image_prompter import ImagePrompter
# import os
# import torch
# def prompter(prompts):
# image = prompts["image"] # Get the image from prompts
# points = prompts["points"] # Get the points from prompts
# # Print the collected inputs for debugging or logging
# print("Image received:", image)
# print("Points received:", points)
# import torch
# from sam2.sam2_image_predictor import SAM2ImagePredictor
# device = torch.device("cpu")
# predictor = SAM2ImagePredictor.from_pretrained(
# "facebook/sam2-hiera-base-plus", device=device
# )
# with torch.inference_mode():
# predictor.set_image(image)
# # masks, _, _ = predictor.predict([[point[0], point[1]] for point in points])
# input_point = [[point[0], point[1]] for point in points]
# input_label = [1]
# masks, _, _ = predictor.predict(
# point_coords=input_point, point_labels=input_label
# )
# print("Predicted Mask:", masks)
# return image, points
# # Define the Gradio interface
# demo = gr.Interface(
# fn=prompter, # Use the custom prompter function
# inputs=ImagePrompter(
# show_label=False
# ), # ImagePrompter for image input and point selection
# outputs=[
# gr.Image(show_label=False), # Display the image
# gr.Dataframe(label="Points"), # Display the points in a DataFrame
# ],
# title="Image Point Collector",
# description="Upload an image, click on it, and get the coordinates of the clicked points.",
# )
# # Launch the Gradio app
# demo.launch()
# import gradio as gr
# from gradio_image_prompter import ImagePrompter
# import torch
# from sam2.sam2_image_predictor import SAM2ImagePredictor
# def prompter(prompts):
# image = prompts["image"] # Get the image from prompts
# points = prompts["points"] # Get the points from prompts
# # Print the collected inputs for debugging or logging
# print("Image received:", image)
# print("Points received:", points)
# device = torch.device("cpu")
# # Load the SAM2ImagePredictor model
# predictor = SAM2ImagePredictor.from_pretrained(
# "facebook/sam2-hiera-base-plus", device=device
# )
# # Perform inference
# with torch.inference_mode():
# predictor.set_image(image)
# input_point = [[point[0], point[1]] for point in points]
# input_label = [1] * len(points) # Assuming all points are foreground
# masks, _, _ = predictor.predict(
# point_coords=input_point, point_labels=input_label
# )
# # The masks are returned as a list of numpy arrays
# print("Predicted Mask:", masks)
# # Assuming there's only one mask returned, you can adjust if there are multiple
# predicted_mask = masks[0]
# print(len(image))
# print(len(predicted_mask))
# # Create annotations for AnnotatedImage
# annotations = [(predicted_mask, "Predicted Mask")]
# return image, annotations
# # Define the Gradio interface
# demo = gr.Interface(
# fn=prompter, # Use the custom prompter function
# inputs=ImagePrompter(
# show_label=False
# ), # ImagePrompter for image input and point selection
# outputs=gr.AnnotatedImage(), # Display the image with the predicted mask
# title="Image Point Collector with Mask Overlay",
# description="Upload an image, click on it, and get the predicted mask overlayed on the image.",
# )
# # Launch the Gradio app
# demo.launch()
import gradio as gr
from gradio_image_prompter import ImagePrompter
import torch
import numpy as np
from sam2.sam2_image_predictor import SAM2ImagePredictor
from PIL import Image
def prompter(prompts):
image = np.array(prompts["image"]) # Convert the image to a numpy array
points = prompts["points"] # Get the points from prompts
# Print the collected inputs for debugging or logging
print("Image received:", image)
print("Points received:", points)
device = torch.device("cpu")
# Load the SAM2ImagePredictor model
predictor = SAM2ImagePredictor.from_pretrained(
"facebook/sam2-hiera-base-plus", device=device
)
# Perform inference with multimask_output=True
with torch.inference_mode():
predictor.set_image(image)
input_point = [[point[0], point[1]] for point in points]
input_label = [1] * len(points) # Assuming all points are foreground
masks, _, _ = predictor.predict(
point_coords=input_point, point_labels=input_label, multimask_output=True
)
# Prepare individual images with separate overlays
overlay_images = []
for i, mask in enumerate(masks):
print(f"Predicted Mask {i+1}:", mask)
red_mask = np.zeros_like(image)
red_mask[:, :, 0] = mask.astype(np.uint8) * 255 # Apply the red channel
red_mask = Image.fromarray(red_mask)
# Convert the original image to a PIL image
original_image = Image.fromarray(image)
# Blend the original image with the red mask
blended_image = Image.blend(original_image, red_mask, alpha=0.5)
# Add the blended image to the list
overlay_images.append(blended_image)
return overlay_images
# Define the Gradio interface
demo = gr.Interface(
fn=prompter, # Use the custom prompter function
inputs=ImagePrompter(
show_label=False
), # ImagePrompter for image input and point selection
outputs=[
gr.Image(show_label=False) for _ in range(3)
], # Display up to 3 overlay images
title="Image Point Collector with Multiple Separate Mask Overlays",
description="Upload an image, click on it, and get each predicted mask overlaid separately in red on individual images.",
)
# Launch the Gradio app
demo.launch()
|