from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation import gradio as gr from PIL import Image import torch import numpy as np import threading processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") def process_image(image, prompt): inputs = processor( text=prompt, images=image, padding="max_length", return_tensors="pt" ) with torch.no_grad(): outputs = model(**inputs) preds = outputs.logits pred = torch.sigmoid(preds) mat = pred.cpu().numpy() mask = Image.fromarray(np.uint8(mat * 255), "L") mask = mask.convert("RGB") mask = mask.resize(image.size) mask = np.array(mask)[:, :, 0] mask_min = mask.min() mask_max = mask.max() mask = (mask - mask_min) / (mask_max - mask_min) return mask def get_masks(prompts, img, threshold): prompts = prompts.split(",") masks = [] for prompt in prompts: mask = process_image(img, prompt) mask = mask > threshold masks.append(mask) return masks def extract_image(pos_prompts, neg_prompts, img, threshold): positive_masks = get_masks(pos_prompts, img, 0.5) negative_masks = get_masks(neg_prompts, img, 0.5) pos_mask = np.any(np.stack(positive_masks), axis=0) neg_mask = np.any(np.stack(negative_masks), axis=0) final_mask = pos_mask & ~neg_mask final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L") output_image = Image.new("RGBA", img.size, (0, 0, 0, 0)) output_image.paste(img, mask=final_mask) return output_image, final_mask iface = gr.Interface( fn=extract_image, inputs=[ gr.Image(type="pil", label="Input Image"), gr.Textbox(label="Positive Prompts (comma separated)"), gr.Textbox(label="Negative Prompts (comma separated)"), gr.Slider(minimum=0, maximum=1, default=0.4, label="Threshold"), ], outputs=[ gr.Image(type="pil", label="Output Image"), gr.Image(type="pil", label="Output Mask"), ], ) # Launch Gradio UI iface.launch() # Define API interface api_interface = gr.Interface( fn=extract_image, inputs=[ gr.Image(type="pil", label="Input Image"), gr.Textbox(label="Positive Prompts (comma separated)"), gr.Textbox(label="Negative Prompts (comma separated)"), gr.Slider(minimum=0, maximum=1, default=0.4, label="Threshold"), ], outputs=[ gr.Image(type="pil", label="Output Image"), gr.Image(type="pil", label="Output Mask"), ], live=True # Setting live to True enables the API endpoint ) # Launch API api_interface.launch(share=True) # share=True allows external access to the API