bg_removal / app.py
Saarthak2002's picture
Update app.py
ace5a98 verified
raw
history blame
3.29 kB
import torch
from torchvision import transforms
from PIL import Image
import requests
import numpy as np
import gradio as gr
from io import BytesIO
from torchvision.models.segmentation import deeplabv3_resnet101
import cv2
# Step 1: Load the Image from URL
def load_image(url):
response = requests.get(url)
image = Image.open(BytesIO(response.content)).convert("RGB")
return image
# Step 2: Crop Image Based on Bounding Box
def crop_image(image, bounding_box):
x_min, y_min, x_max, y_max = bounding_box.values()
return image.crop((x_min, y_min, x_max, y_max))
# Step 3: Preprocessing for Segmentation Model
def preprocess_image(image, size=(512, 512)):
preprocess = transforms.Compose([
transforms.Resize(size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
return preprocess(image).unsqueeze(0) # Add batch dimension
# Step 4: Load a More Robust Pre-trained Model
def load_model():
model = deeplabv3_resnet101(pretrained=True) # Switch to ResNet101 for better feature extraction
model.eval() # Set the model to evaluation mode
if torch.cuda.is_available():
model = model.to("cuda")
return model
# Step 5: Perform Segmentation
def segment_image(model, input_tensor):
if torch.cuda.is_available():
input_tensor = input_tensor.to("cuda")
with torch.no_grad():
output = model(input_tensor)['out'] # Model output
mask = output.argmax(dim=1).squeeze().cpu().numpy() # Get segmentation mask
return mask
# Step 6: Refine Mask and Extract Object
def apply_mask(image, mask):
mask = cv2.resize(mask.astype(np.uint8), image.size, interpolation=cv2.INTER_NEAREST)
# Apply morphological operations for cleaner mask
kernel = np.ones((5, 5), np.uint8)
mask = cv2.dilate(mask, kernel, iterations=1)
mask = cv2.erode(mask, kernel, iterations=1)
# Create RGBA image
image_np = np.array(image)
rgba_image = np.zeros((image_np.shape[0], image_np.shape[1], 4), dtype=np.uint8)
rgba_image[..., :3] = image_np # Copy RGB channels
rgba_image[..., 3] = mask * 255 # Alpha channel based on refined mask
return Image.fromarray(rgba_image)
# Gradio Interface to handle input and output
def segment_object(image_url, x_min, y_min, x_max, y_max):
bounding_box = {"x_min": x_min, "y_min": y_min, "x_max": x_max, "y_max": y_max}
# Load and process the image
image = load_image(image_url)
cropped_image = crop_image(image, bounding_box)
input_tensor = preprocess_image(cropped_image)
# Load model and perform segmentation
model = load_model()
mask = segment_image(model, input_tensor)
# Apply mask to extract object
result_image = apply_mask(cropped_image, mask)
return result_image
# Set up the Gradio Interface
iface = gr.Interface(
fn=segment_object,
inputs=[
gr.Textbox(label="Image URL", placeholder="Enter image URL..."),
gr.Number(label="x_min", value=100),
gr.Number(label="y_min", value=100),
gr.Number(label="x_max", value=600),
gr.Number(label="y_max", value=400),
],
outputs=gr.Image(label="Segmented Image"),
live=True
)
# Launch the interface
iface.launch()