Spaces:
Sleeping
Sleeping
File size: 3,281 Bytes
dfdcd97 dc23d39 a3ee867 dfdcd97 2af5758 a3ee867 73a8e2b 49e0cdb a3ee867 2af5758 a3ee867 7a7f5c3 4db07a0 73a8e2b 4db07a0 dc23d39 49e0cdb a3ee867 2af5758 dc23d39 a3ee867 7a7f5c3 54db35c a3ee867 1d4c655 73a8e2b 1d4c655 73a8e2b 1d4c655 54db35c 1d4c655 49e0cdb 1d4c655 54db35c 1d4c655 54db35c dfdcd97 a3ee867 4db07a0 a3ee867 1d4c655 a3ee867 4db07a0 a3ee867 54db35c a3ee867 dfdcd97 49e0cdb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import gradio as gr
import numpy as np
from PIL import Image
import torch
from transformers import AutoProcessor, CLIPSegForImageSegmentation
# Load the CLIPSeg model and processor
processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
# Ensure that the model uses GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
def segment_everything(image):
if isinstance(image, list):
image = image[0]
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
inputs = processor(text=["object"], images=image, padding="max_length", return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
preds = outputs.logits.squeeze().sigmoid().cpu()
segmentation = (preds.numpy() * 255).astype(np.uint8)
return Image.fromarray(segmentation)
def segment_box(image, x1, y1, x2, y2):
if isinstance(image, list):
image = image[0]
if isinstance(image, Image.Image):
image = np.array(image)
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
cropped_image = image[y1:y2, x1:x2]
inputs = processor(text=["object"], images=Image.fromarray(cropped_image), padding="max_length", return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
preds = outputs.logits.squeeze().sigmoid().cpu()
segmentation = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
segmentation[y1:y2, x1:x2] = (preds.numpy() * 255).astype(np.uint8)
return Image.fromarray(segmentation)
def update_image(image, segmentation):
if segmentation is None:
return image
if isinstance(image, list):
image = image[0]
if isinstance(image, np.ndarray):
image_pil = Image.fromarray(image)
else:
image_pil = image
seg_pil = Image.fromarray(segmentation).convert('RGBA')
if image_pil.size!= seg_pil.size:
seg_pil = seg_pil.resize(image_pil.size, Image.NEAREST)
blended = Image.blend(image_pil.convert('RGBA'), seg_pil, 0.5)
return np.array(blended)
with gr.Blocks() as demo:
gr.Markdown("# Segment Anything-like Demo")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(label="Input Image")
with gr.Row():
x1_input = gr.Number(label="X1")
y1_input = gr.Number(label="Y1")
x2_input = gr.Number(label="X2")
y2_input = gr.Number(label="Y2")
with gr.Row():
everything_btn = gr.Button("Everything")
box_btn = gr.Button("Box")
with gr.Column(scale=1):
output_image = gr.Image(label="Segmentation Result")
everything_btn.click(
fn=segment_everything,
inputs=[input_image],
outputs=[output_image]
)
box_btn.click(
fn=segment_box,
inputs=[input_image, x1_input, y1_input, x2_input, y2_input],
outputs=[output_image]
)
output_image.change(
fn=update_image,
inputs=[input_image, output_image],
outputs=[output_image]
)
demo.launch() |