Spaces:
Sleeping
Sleeping
File size: 3,574 Bytes
dfdcd97 dc23d39 a3ee867 dfdcd97 a3ee867 73a8e2b a3ee867 7a7f5c3 4db07a0 73a8e2b 4db07a0 dc23d39 73a8e2b a3ee867 dc23d39 a3ee867 7a7f5c3 54db35c a3ee867 1d4c655 73a8e2b 1d4c655 73a8e2b 1d4c655 54db35c 1d4c655 54db35c 1d4c655 54db35c dfdcd97 a3ee867 4db07a0 a3ee867 1d4c655 a3ee867 4db07a0 a3ee867 54db35c a3ee867 dfdcd97 a3ee867 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import gradio as gr
import numpy as np
from PIL import Image
import torch
from transformers import AutoProcessor, CLIPSegForImageSegmentation
# Load the CLIPSeg model and processor
processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
def segment_everything(image):
# Check if image is a list and extract the actual image data
if isinstance(image, list):
image = image[0]
# Convert numpy array to PIL Image
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
inputs = processor(text=["object"], images=[image], padding="max_length", return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
preds = outputs.logits.squeeze().sigmoid()
segmentation = (preds.numpy() * 255).astype(np.uint8)
return Image.fromarray(segmentation)
def segment_box(image, x1, y1, x2, y2):
# Check if image is a list and extract the actual image data
if isinstance(image, list):
image = image[0]
# Convert PIL Image to numpy array if necessary
if isinstance(image, Image.Image):
image = np.array(image)
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
cropped_image = image[y1:y2, x1:x2]
inputs = processor(text=["object"], images=[Image.fromarray(cropped_image)], padding="max_length", return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
preds = outputs.logits.squeeze().sigmoid()
segmentation = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
segmentation[y1:y2, x1:x2] = (preds.numpy() * 255).astype(np.uint8)
return Image.fromarray(segmentation)
def update_image(image, segmentation):
if segmentation is None:
return image
# Check if image is a list and extract the actual image data
if isinstance(image, list):
image = image[0]
# Ensure image is in the correct format (PIL Image)
if isinstance(image, np.ndarray):
image_pil = Image.fromarray(image)
else:
image_pil = image
# Convert segmentation to RGBA
seg_pil = Image.fromarray(segmentation).convert('RGBA')
# Resize segmentation to match input image if necessary
if image_pil.size != seg_pil.size:
seg_pil = seg_pil.resize(image_pil.size, Image.NEAREST)
# Blend images
blended = Image.blend(image_pil.convert('RGBA'), seg_pil, 0.5)
return np.array(blended)
with gr.Blocks() as demo:
gr.Markdown("# Segment Anything-like Demo")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(label="Input Image")
with gr.Row():
x1_input = gr.Number(label="X1")
y1_input = gr.Number(label="Y1")
x2_input = gr.Number(label="X2")
y2_input = gr.Number(label="Y2")
with gr.Row():
everything_btn = gr.Button("Everything")
box_btn = gr.Button("Box")
with gr.Column(scale=1):
output_image = gr.Image(label="Segmentation Result")
everything_btn.click(
fn=segment_everything,
inputs=[input_image],
outputs=[output_image]
)
box_btn.click(
fn=segment_box,
inputs=[input_image, x1_input, y1_input, x2_input, y2_input],
outputs=[output_image]
)
output_image.change(
fn=update_image,
inputs=[input_image, output_image],
outputs=[output_image]
)
demo.launch() |