File size: 3,292 Bytes
b07764a
 
 
 
 
 
 
ace5a98
 
b07764a
 
 
 
 
 
 
 
 
 
 
 
ace5a98
 
b07764a
 
 
 
 
 
 
ace5a98
b07764a
ace5a98
b07764a
ace5a98
 
b07764a
 
 
 
ace5a98
 
b07764a
 
 
 
 
ace5a98
 
 
 
 
 
 
 
b07764a
ace5a98
 
b07764a
 
ace5a98
 
b07764a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ace5a98
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
from torchvision import transforms
from PIL import Image
import requests
import numpy as np
import gradio as gr
from io import BytesIO
from torchvision.models.segmentation import deeplabv3_resnet101
import cv2

# Step 1: Load the Image from URL
def load_image(url):
    response = requests.get(url)
    image = Image.open(BytesIO(response.content)).convert("RGB")
    return image

# Step 2: Crop Image Based on Bounding Box
def crop_image(image, bounding_box):
    x_min, y_min, x_max, y_max = bounding_box.values()
    return image.crop((x_min, y_min, x_max, y_max))

# Step 3: Preprocessing for Segmentation Model
def preprocess_image(image, size=(512, 512)):
    preprocess = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return preprocess(image).unsqueeze(0)  # Add batch dimension

# Step 4: Load a More Robust Pre-trained Model
def load_model():
    model = deeplabv3_resnet101(pretrained=True)  # Switch to ResNet101 for better feature extraction
    model.eval()  # Set the model to evaluation mode
    if torch.cuda.is_available():
        model = model.to("cuda")
    return model

# Step 5: Perform Segmentation
def segment_image(model, input_tensor):
    if torch.cuda.is_available():
        input_tensor = input_tensor.to("cuda")
    with torch.no_grad():
        output = model(input_tensor)['out']  # Model output
        mask = output.argmax(dim=1).squeeze().cpu().numpy()  # Get segmentation mask
        return mask

# Step 6: Refine Mask and Extract Object
def apply_mask(image, mask):
    mask = cv2.resize(mask.astype(np.uint8), image.size, interpolation=cv2.INTER_NEAREST)

    # Apply morphological operations for cleaner mask
    kernel = np.ones((5, 5), np.uint8)
    mask = cv2.dilate(mask, kernel, iterations=1)
    mask = cv2.erode(mask, kernel, iterations=1)

    # Create RGBA image
    image_np = np.array(image)
    rgba_image = np.zeros((image_np.shape[0], image_np.shape[1], 4), dtype=np.uint8)
    rgba_image[..., :3] = image_np  # Copy RGB channels
    rgba_image[..., 3] = mask * 255  # Alpha channel based on refined mask

    return Image.fromarray(rgba_image)

# Gradio Interface to handle input and output
def segment_object(image_url, x_min, y_min, x_max, y_max):
    bounding_box = {"x_min": x_min, "y_min": y_min, "x_max": x_max, "y_max": y_max}
    
    # Load and process the image
    image = load_image(image_url)
    cropped_image = crop_image(image, bounding_box)
    input_tensor = preprocess_image(cropped_image)

    # Load model and perform segmentation
    model = load_model()
    mask = segment_image(model, input_tensor)

    # Apply mask to extract object
    result_image = apply_mask(cropped_image, mask)
    return result_image

# Set up the Gradio Interface
iface = gr.Interface(
    fn=segment_object,
    inputs=[
        gr.Textbox(label="Image URL", placeholder="Enter image URL..."),
        gr.Number(label="x_min", value=100),
        gr.Number(label="y_min", value=100),
        gr.Number(label="x_max", value=600),
        gr.Number(label="y_max", value=400),
    ],
    outputs=gr.Image(label="Segmented Image"),
    live=True
)

# Launch the interface
iface.launch()