bg_removal /
Saarthak2002's picture
2d2f43a verified
import torch
from torchvision import models, transforms
from PIL import Image
import requests
import numpy as np
import gradio as gr
from io import BytesIO
import cv2
# Step 1: Load the Image from URL
def load_image(url):
response = requests.get(url)
image ="RGB")
return image
# Step 2: Adjust Bounding Box to Add Margin
def adjust_bounding_box(bounding_box, margin=20):
return {
"x_min": max(0, bounding_box["x_min"] - margin),
"y_min": max(0, bounding_box["y_min"] - margin),
"x_max": bounding_box["x_max"] + margin,
"y_max": bounding_box["y_max"] + margin,
# Step 3: Crop Image Based on Bounding Box
def crop_image(image, bounding_box):
x_min, y_min, x_max, y_max = bounding_box.values()
return image.crop((x_min, y_min, x_max, y_max))
# Step 4: Preprocessing for Segmentation Model
def preprocess_image(image):
transform = transforms.Compose([
transforms.ToTensor(), # Convert to Tensor
return transform(image).unsqueeze(0) # Add batch dimension
# Step 5: Load Mask R-CNN Model
def load_model():
model = models.detection.maskrcnn_resnet50_fpn(pretrained=True) # Pre-trained Mask R-CNN
model.eval() # Set the model to evaluation mode
if torch.cuda.is_available():
model ="cuda")
return model
# Step 6: Perform Object Segmentation
def segment_image(model, input_tensor, confidence_threshold=0.6):
if torch.cuda.is_available():
input_tensor ="cuda")
with torch.no_grad():
outputs = model(input_tensor) # Perform inference
# Process results: filter by confidence and get masks
scores = outputs[0]["scores"].cpu().numpy()
masks = outputs[0]["masks"].cpu().numpy()
boxes = outputs[0]["boxes"].cpu().numpy()
# Filter masks based on confidence threshold
filtered_masks = [masks[i, 0] for i in range(len(scores)) if scores[i] > confidence_threshold]
return filtered_masks
# Step 7: Combine Masks and Extract Object
def apply_masks(image, masks):
combined_mask = np.zeros((image.height, image.width), dtype=np.uint8)
for mask in masks:
resized_mask = cv2.resize(mask, (image.width, image.height), interpolation=cv2.INTER_NEAREST)
combined_mask = np.maximum(combined_mask, (resized_mask > 0.5).astype(np.uint8)) # Combine masks
# Create RGBA image
image_np = np.array(image)
rgba_image = np.zeros((image_np.shape[0], image_np.shape[1], 4), dtype=np.uint8)
rgba_image[..., :3] = image_np # Copy RGB channels
rgba_image[..., 3] = combined_mask * 255 # Alpha channel based on combined mask
return Image.fromarray(rgba_image)
# Gradio Interface to handle input and output
def segment_object(image_url, x_min, y_min, x_max, y_max):
bounding_box = adjust_bounding_box({"x_min": x_min, "y_min": y_min, "x_max": x_max, "y_max": y_max})
# Load and process the image
image = load_image(image_url)
cropped_image = crop_image(image, bounding_box)
input_tensor = preprocess_image(cropped_image)
# Load model and perform segmentation
model = load_model()
masks = segment_image(model, input_tensor)
# Apply masks to extract objects
result_image = apply_masks(cropped_image, masks)
return result_image
# Set up the Gradio Interface
iface = gr.Interface(
gr.Textbox(label="Image URL", placeholder="Enter image URL..."),
gr.Number(label="x_min", value=100),
gr.Number(label="y_min", value=100),
gr.Number(label="x_max", value=600),
gr.Number(label="y_max", value=400),
outputs=gr.Image(label="Segmented Image"),
# Launch the interface