Spaces:
Paused
Paused
File size: 3,292 Bytes
b07764a ace5a98 b07764a ace5a98 b07764a ace5a98 b07764a ace5a98 b07764a ace5a98 b07764a ace5a98 b07764a ace5a98 b07764a ace5a98 b07764a ace5a98 b07764a ace5a98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import torch
from torchvision import transforms
from PIL import Image
import requests
import numpy as np
import gradio as gr
from io import BytesIO
from torchvision.models.segmentation import deeplabv3_resnet101
import cv2
# Step 1: Load the Image from URL
def load_image(url):
response = requests.get(url)
image = Image.open(BytesIO(response.content)).convert("RGB")
return image
# Step 2: Crop Image Based on Bounding Box
def crop_image(image, bounding_box):
x_min, y_min, x_max, y_max = bounding_box.values()
return image.crop((x_min, y_min, x_max, y_max))
# Step 3: Preprocessing for Segmentation Model
def preprocess_image(image, size=(512, 512)):
preprocess = transforms.Compose([
transforms.Resize(size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
return preprocess(image).unsqueeze(0) # Add batch dimension
# Step 4: Load a More Robust Pre-trained Model
def load_model():
model = deeplabv3_resnet101(pretrained=True) # Switch to ResNet101 for better feature extraction
model.eval() # Set the model to evaluation mode
if torch.cuda.is_available():
model = model.to("cuda")
return model
# Step 5: Perform Segmentation
def segment_image(model, input_tensor):
if torch.cuda.is_available():
input_tensor = input_tensor.to("cuda")
with torch.no_grad():
output = model(input_tensor)['out'] # Model output
mask = output.argmax(dim=1).squeeze().cpu().numpy() # Get segmentation mask
return mask
# Step 6: Refine Mask and Extract Object
def apply_mask(image, mask):
mask = cv2.resize(mask.astype(np.uint8), image.size, interpolation=cv2.INTER_NEAREST)
# Apply morphological operations for cleaner mask
kernel = np.ones((5, 5), np.uint8)
mask = cv2.dilate(mask, kernel, iterations=1)
mask = cv2.erode(mask, kernel, iterations=1)
# Create RGBA image
image_np = np.array(image)
rgba_image = np.zeros((image_np.shape[0], image_np.shape[1], 4), dtype=np.uint8)
rgba_image[..., :3] = image_np # Copy RGB channels
rgba_image[..., 3] = mask * 255 # Alpha channel based on refined mask
return Image.fromarray(rgba_image)
# Gradio Interface to handle input and output
def segment_object(image_url, x_min, y_min, x_max, y_max):
bounding_box = {"x_min": x_min, "y_min": y_min, "x_max": x_max, "y_max": y_max}
# Load and process the image
image = load_image(image_url)
cropped_image = crop_image(image, bounding_box)
input_tensor = preprocess_image(cropped_image)
# Load model and perform segmentation
model = load_model()
mask = segment_image(model, input_tensor)
# Apply mask to extract object
result_image = apply_mask(cropped_image, mask)
return result_image
# Set up the Gradio Interface
iface = gr.Interface(
fn=segment_object,
inputs=[
gr.Textbox(label="Image URL", placeholder="Enter image URL..."),
gr.Number(label="x_min", value=100),
gr.Number(label="y_min", value=100),
gr.Number(label="x_max", value=600),
gr.Number(label="y_max", value=400),
],
outputs=gr.Image(label="Segmented Image"),
live=True
)
# Launch the interface
iface.launch()
|