import torch from torchvision import models, transforms from PIL import Image import requests import numpy as np import gradio as gr from io import BytesIO import cv2 # Step 1: Load the Image from URL def load_image(url): response = requests.get(url) image = Image.open(BytesIO(response.content)).convert("RGB") return image # Step 2: Adjust Bounding Box to Add Margin def adjust_bounding_box(bounding_box, margin=20): return { "x_min": max(0, bounding_box["x_min"] - margin), "y_min": max(0, bounding_box["y_min"] - margin), "x_max": bounding_box["x_max"] + margin, "y_max": bounding_box["y_max"] + margin, } # Step 3: Crop Image Based on Bounding Box def crop_image(image, bounding_box): x_min, y_min, x_max, y_max = bounding_box.values() return image.crop((x_min, y_min, x_max, y_max)) # Step 4: Preprocessing for Segmentation Model def preprocess_image(image): transform = transforms.Compose([ transforms.ToTensor(), # Convert to Tensor ]) return transform(image).unsqueeze(0) # Add batch dimension # Step 5: Load Mask R-CNN Model def load_model(): model = models.detection.maskrcnn_resnet50_fpn(pretrained=True) # Pre-trained Mask R-CNN model.eval() # Set the model to evaluation mode if torch.cuda.is_available(): model = model.to("cuda") return model # Step 6: Perform Object Segmentation def segment_image(model, input_tensor, confidence_threshold=0.6): if torch.cuda.is_available(): input_tensor = input_tensor.to("cuda") with torch.no_grad(): outputs = model(input_tensor) # Perform inference # Process results: filter by confidence and get masks scores = outputs[0]["scores"].cpu().numpy() masks = outputs[0]["masks"].cpu().numpy() boxes = outputs[0]["boxes"].cpu().numpy() # Filter masks based on confidence threshold filtered_masks = [masks[i, 0] for i in range(len(scores)) if scores[i] > confidence_threshold] return filtered_masks # Step 7: Combine Masks and Extract Object def apply_masks(image, masks): combined_mask = np.zeros((image.height, image.width), dtype=np.uint8) for mask in masks: resized_mask = cv2.resize(mask, (image.width, image.height), interpolation=cv2.INTER_NEAREST) combined_mask = np.maximum(combined_mask, (resized_mask > 0.5).astype(np.uint8)) # Combine masks # Create RGBA image image_np = np.array(image) rgba_image = np.zeros((image_np.shape[0], image_np.shape[1], 4), dtype=np.uint8) rgba_image[..., :3] = image_np # Copy RGB channels rgba_image[..., 3] = combined_mask * 255 # Alpha channel based on combined mask return Image.fromarray(rgba_image) # Gradio Interface to handle input and output def segment_object(image_url, x_min, y_min, x_max, y_max): bounding_box = adjust_bounding_box({"x_min": x_min, "y_min": y_min, "x_max": x_max, "y_max": y_max}) # Load and process the image image = load_image(image_url) cropped_image = crop_image(image, bounding_box) input_tensor = preprocess_image(cropped_image) # Load model and perform segmentation model = load_model() masks = segment_image(model, input_tensor) # Apply masks to extract objects result_image = apply_masks(cropped_image, masks) return result_image # Set up the Gradio Interface iface = gr.Interface( fn=segment_object, inputs=[ gr.Textbox(label="Image URL", placeholder="Enter image URL..."), gr.Number(label="x_min", value=100), gr.Number(label="y_min", value=100), gr.Number(label="x_max", value=600), gr.Number(label="y_max", value=400), ], outputs=gr.Image(label="Segmented Image"), live=True ) # Launch the interface iface.launch()