bg_removal / app.py
Saarthak2002's picture
Update app.py
f62b518 verified
raw
history blame
3.88 kB
import torch
from torchvision import transforms
from PIL import Image
import requests
import numpy as np
import gradio as gr
from io import BytesIO
from torchvision.models.segmentation import deeplabv3_resnet101
import cv2
# Step 1: Load the Image from URL
def load_image(url):
response = requests.get(url)
image = Image.open(BytesIO(response.content)).convert("RGB")
return image
# Step 2: Adjust Bounding Box to Add Margin
def adjust_bounding_box(bounding_box, margin=20):
return {
"x_min": max(0, bounding_box["x_min"] - margin),
"y_min": max(0, bounding_box["y_min"] - margin),
"x_max": bounding_box["x_max"] + margin,
"y_max": bounding_box["y_max"] + margin,
}
# Step 3: Crop Image Based on Bounding Box
def crop_image(image, bounding_box):
x_min, y_min, x_max, y_max = bounding_box.values()
return image.crop((x_min, y_min, x_max, y_max))
# Step 4: Preprocessing for Segmentation Model
def preprocess_image(image, size=(1024, 1024)):
preprocess = transforms.Compose([
transforms.Resize(size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
return preprocess(image).unsqueeze(0) # Add batch dimension
# Step 5: Load a More Robust Pre-trained Model
def load_model():
model = deeplabv3_resnet101(pretrained=True) # Switch to ResNet101 for better feature extraction
model.eval() # Set the model to evaluation mode
if torch.cuda.is_available():
model = model.to("cuda")
return model
# Step 6: Perform Segmentation with Soft Masking
def segment_image(model, input_tensor):
if torch.cuda.is_available():
input_tensor = input_tensor.to("cuda")
with torch.no_grad():
output = model(input_tensor)['out'] # Model output
probabilities = torch.softmax(output, dim=1) # Get class probabilities
mask = probabilities[0, 1].cpu().numpy() # Assuming 1 corresponds to the object class
return mask
# Step 7: Refine Mask and Extract Object
def apply_mask(image, mask, threshold=0.75):
# Threshold the mask
mask = (mask > threshold).astype(np.uint8)
# Resize mask to the original image size
mask = cv2.resize(mask, image.size, interpolation=cv2.INTER_NEAREST)
# Apply morphological operations for a cleaner mask
kernel = np.ones((5, 5), np.uint8)
mask = cv2.dilate(mask, kernel, iterations=2)
mask = cv2.erode(mask, kernel, iterations=1)
# Create RGBA image
image_np = np.array(image)
rgba_image = np.zeros((image_np.shape[0], image_np.shape[1], 4), dtype=np.uint8)
rgba_image[..., :3] = image_np # Copy RGB channels
rgba_image[..., 3] = mask * 255 # Alpha channel based on refined mask
return Image.fromarray(rgba_image)
# Gradio Interface to handle input and output
def segment_object(image_url, x_min, y_min, x_max, y_max):
bounding_box = adjust_bounding_box({"x_min": x_min, "y_min": y_min, "x_max": x_max, "y_max": y_max})
# Load and process the image
image = load_image(image_url)
cropped_image = crop_image(image, bounding_box)
input_tensor = preprocess_image(cropped_image)
# Load model and perform segmentation
model = load_model()
mask = segment_image(model, input_tensor)
# Apply mask to extract object
result_image = apply_mask(cropped_image, mask)
return result_image
# Set up the Gradio Interface
iface = gr.Interface(
fn=segment_object,
inputs=[
gr.Textbox(label="Image URL", placeholder="Enter image URL..."),
gr.Number(label="x_min", value=100),
gr.Number(label="y_min", value=100),
gr.Number(label="x_max", value=600),
gr.Number(label="y_max", value=400),
],
outputs=gr.Image(label="Segmented Image"),
live=True
)
# Launch the interface
iface.launch()