File size: 3,782 Bytes
b07764a
2d2f43a
b07764a
 
 
 
 
ace5a98
b07764a
 
 
 
 
 
 
975418f
 
 
 
 
 
 
 
 
 
b07764a
 
 
 
975418f
2d2f43a
 
 
b07764a
2d2f43a
b07764a
2d2f43a
b07764a
2d2f43a
b07764a
ace5a98
 
b07764a
 
2d2f43a
 
ace5a98
 
b07764a
2d2f43a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b07764a
ace5a98
 
b07764a
 
2d2f43a
ace5a98
b07764a
 
 
 
975418f
2d2f43a
b07764a
 
 
 
 
 
 
2d2f43a
b07764a
2d2f43a
 
b07764a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
from torchvision import models, transforms
from PIL import Image
import requests
import numpy as np
import gradio as gr
from io import BytesIO
import cv2

# Step 1: Load the Image from URL
def load_image(url):
    response = requests.get(url)
    image = Image.open(BytesIO(response.content)).convert("RGB")
    return image

# Step 2: Adjust Bounding Box to Add Margin
def adjust_bounding_box(bounding_box, margin=20):
    return {
        "x_min": max(0, bounding_box["x_min"] - margin),
        "y_min": max(0, bounding_box["y_min"] - margin),
        "x_max": bounding_box["x_max"] + margin,
        "y_max": bounding_box["y_max"] + margin,
    }

# Step 3: Crop Image Based on Bounding Box
def crop_image(image, bounding_box):
    x_min, y_min, x_max, y_max = bounding_box.values()
    return image.crop((x_min, y_min, x_max, y_max))

# Step 4: Preprocessing for Segmentation Model
def preprocess_image(image):
    transform = transforms.Compose([
        transforms.ToTensor(),  # Convert to Tensor
    ])
    return transform(image).unsqueeze(0)  # Add batch dimension

# Step 5: Load Mask R-CNN Model
def load_model():
    model = models.detection.maskrcnn_resnet50_fpn(pretrained=True)  # Pre-trained Mask R-CNN
    model.eval()  # Set the model to evaluation mode
    if torch.cuda.is_available():
        model = model.to("cuda")
    return model

# Step 6: Perform Object Segmentation
def segment_image(model, input_tensor, confidence_threshold=0.6):
    if torch.cuda.is_available():
        input_tensor = input_tensor.to("cuda")
    with torch.no_grad():
        outputs = model(input_tensor)  # Perform inference

    # Process results: filter by confidence and get masks
    scores = outputs[0]["scores"].cpu().numpy()
    masks = outputs[0]["masks"].cpu().numpy()
    boxes = outputs[0]["boxes"].cpu().numpy()

    # Filter masks based on confidence threshold
    filtered_masks = [masks[i, 0] for i in range(len(scores)) if scores[i] > confidence_threshold]
    return filtered_masks

# Step 7: Combine Masks and Extract Object
def apply_masks(image, masks):
    combined_mask = np.zeros((image.height, image.width), dtype=np.uint8)

    for mask in masks:
        resized_mask = cv2.resize(mask, (image.width, image.height), interpolation=cv2.INTER_NEAREST)
        combined_mask = np.maximum(combined_mask, (resized_mask > 0.5).astype(np.uint8))  # Combine masks

    # Create RGBA image
    image_np = np.array(image)
    rgba_image = np.zeros((image_np.shape[0], image_np.shape[1], 4), dtype=np.uint8)
    rgba_image[..., :3] = image_np  # Copy RGB channels
    rgba_image[..., 3] = combined_mask * 255  # Alpha channel based on combined mask

    return Image.fromarray(rgba_image)

# Gradio Interface to handle input and output
def segment_object(image_url, x_min, y_min, x_max, y_max):
    bounding_box = adjust_bounding_box({"x_min": x_min, "y_min": y_min, "x_max": x_max, "y_max": y_max})

    # Load and process the image
    image = load_image(image_url)
    cropped_image = crop_image(image, bounding_box)
    input_tensor = preprocess_image(cropped_image)

    # Load model and perform segmentation
    model = load_model()
    masks = segment_image(model, input_tensor)

    # Apply masks to extract objects
    result_image = apply_masks(cropped_image, masks)
    return result_image

# Set up the Gradio Interface
iface = gr.Interface(
    fn=segment_object,
    inputs=[
        gr.Textbox(label="Image URL", placeholder="Enter image URL..."),
        gr.Number(label="x_min", value=100),
        gr.Number(label="y_min", value=100),
        gr.Number(label="x_max", value=600),
        gr.Number(label="y_max", value=400),
    ],
    outputs=gr.Image(label="Segmented Image"),
    live=True
)

# Launch the interface
iface.launch()