import gradio as gr import numpy as np from PIL import Image import torch from transformers import AutoProcessor, CLIPSegForImageSegmentation # Load the CLIPSeg model and processor processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") def segment_everything(image): inputs = processor(text=["object"], images=[image], padding="max_length", return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) preds = outputs.logits.squeeze().sigmoid() segmentation = (preds.numpy() * 255).astype(np.uint8) return Image.fromarray(segmentation) def segment_box(image, x1, y1, x2, y2): x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) cropped_image = image[y1:y2, x1:x2] inputs = processor(text=["object"], images=[cropped_image], padding="max_length", return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) preds = outputs.logits.squeeze().sigmoid() segmentation = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) segmentation[y1:y2, x1:x2] = (preds.numpy() * 255).astype(np.uint8) return Image.fromarray(segmentation) def update_image(image, segmentation): if segmentation is None: return image # Ensure image is in the correct format (PIL Image) if isinstance(image, np.ndarray): image_pil = Image.fromarray((image * 255).astype(np.uint8)) else: image_pil = image # Convert segmentation to RGBA seg_pil = Image.fromarray(segmentation).convert('RGBA') # Resize segmentation to match input image if necessary if image_pil.size != seg_pil.size: seg_pil = seg_pil.resize(image_pil.size, Image.NEAREST) # Blend images blended = Image.blend(image_pil.convert('RGBA'), seg_pil, 0.5) return np.array(blended) with gr.Blocks() as demo: gr.Markdown("# Segment Anything-like Demo") with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(label="Input Image") with gr.Row(): x1_input = gr.Number(label="X1") y1_input = gr.Number(label="Y1") x2_input = gr.Number(label="X2") y2_input = gr.Number(label="Y2") with gr.Row(): everything_btn = gr.Button("Everything") box_btn = gr.Button("Box") with gr.Column(scale=1): output_image = gr.Image(label="Segmentation Result") everything_btn.click( fn=segment_everything, inputs=[input_image], outputs=[output_image] ) box_btn.click( fn=segment_box, inputs=[input_image, x1_input, y1_input, x2_input, y2_input], outputs=[output_image] ) output_image.change( fn=update_image, inputs=[input_image, output_image], outputs=[output_image] ) demo.launch()