bg_removal / app.py
Saarthak2002's picture
Update app.py
2d2f43a verified
import torch
from torchvision import models, transforms
from PIL import Image
import requests
import numpy as np
import gradio as gr
from io import BytesIO
import cv2
# Step 1: Load the Image from URL
def load_image(url):
response = requests.get(url)
image = Image.open(BytesIO(response.content)).convert("RGB")
return image
# Step 2: Adjust Bounding Box to Add Margin
def adjust_bounding_box(bounding_box, margin=20):
return {
"x_min": max(0, bounding_box["x_min"] - margin),
"y_min": max(0, bounding_box["y_min"] - margin),
"x_max": bounding_box["x_max"] + margin,
"y_max": bounding_box["y_max"] + margin,
}
# Step 3: Crop Image Based on Bounding Box
def crop_image(image, bounding_box):
x_min, y_min, x_max, y_max = bounding_box.values()
return image.crop((x_min, y_min, x_max, y_max))
# Step 4: Preprocessing for Segmentation Model
def preprocess_image(image):
transform = transforms.Compose([
transforms.ToTensor(), # Convert to Tensor
])
return transform(image).unsqueeze(0) # Add batch dimension
# Step 5: Load Mask R-CNN Model
def load_model():
model = models.detection.maskrcnn_resnet50_fpn(pretrained=True) # Pre-trained Mask R-CNN
model.eval() # Set the model to evaluation mode
if torch.cuda.is_available():
model = model.to("cuda")
return model
# Step 6: Perform Object Segmentation
def segment_image(model, input_tensor, confidence_threshold=0.6):
if torch.cuda.is_available():
input_tensor = input_tensor.to("cuda")
with torch.no_grad():
outputs = model(input_tensor) # Perform inference
# Process results: filter by confidence and get masks
scores = outputs[0]["scores"].cpu().numpy()
masks = outputs[0]["masks"].cpu().numpy()
boxes = outputs[0]["boxes"].cpu().numpy()
# Filter masks based on confidence threshold
filtered_masks = [masks[i, 0] for i in range(len(scores)) if scores[i] > confidence_threshold]
return filtered_masks
# Step 7: Combine Masks and Extract Object
def apply_masks(image, masks):
combined_mask = np.zeros((image.height, image.width), dtype=np.uint8)
for mask in masks:
resized_mask = cv2.resize(mask, (image.width, image.height), interpolation=cv2.INTER_NEAREST)
combined_mask = np.maximum(combined_mask, (resized_mask > 0.5).astype(np.uint8)) # Combine masks
# Create RGBA image
image_np = np.array(image)
rgba_image = np.zeros((image_np.shape[0], image_np.shape[1], 4), dtype=np.uint8)
rgba_image[..., :3] = image_np # Copy RGB channels
rgba_image[..., 3] = combined_mask * 255 # Alpha channel based on combined mask
return Image.fromarray(rgba_image)
# Gradio Interface to handle input and output
def segment_object(image_url, x_min, y_min, x_max, y_max):
bounding_box = adjust_bounding_box({"x_min": x_min, "y_min": y_min, "x_max": x_max, "y_max": y_max})
# Load and process the image
image = load_image(image_url)
cropped_image = crop_image(image, bounding_box)
input_tensor = preprocess_image(cropped_image)
# Load model and perform segmentation
model = load_model()
masks = segment_image(model, input_tensor)
# Apply masks to extract objects
result_image = apply_masks(cropped_image, masks)
return result_image
# Set up the Gradio Interface
iface = gr.Interface(
fn=segment_object,
inputs=[
gr.Textbox(label="Image URL", placeholder="Enter image URL..."),
gr.Number(label="x_min", value=100),
gr.Number(label="y_min", value=100),
gr.Number(label="x_max", value=600),
gr.Number(label="y_max", value=400),
],
outputs=gr.Image(label="Segmented Image"),
live=True
)
# Launch the interface
iface.launch()