File size: 3,253 Bytes
74d7b02
 
 
9c2c5e4
74d7b02
041db94
74d7b02
 
 
 
 
 
 
 
 
 
 
 
 
041db94
74d7b02
 
 
 
 
 
 
 
 
 
 
 
8f78001
b48c806
 
 
 
74d7b02
 
041db94
74d7b02
 
 
 
 
 
 
041db94
74d7b02
 
 
041db94
 
74d7b02
 
 
be786c9
041db94
74d7b02
 
be786c9
041db94
9c2c5e4
041db94
 
 
d4461b5
041db94
9c2c5e4
 
 
 
be786c9
041db94
74d7b02
041db94
74d7b02
 
041db94
 
 
 
 
74d7b02
041db94
 
 
74d7b02
 
 
041db94
 
 
 
d4461b5
041db94
74d7b02
cd645f2
74d7b02
 
 
be786c9
041db94
74d7b02
 
 
 
041db94
74d7b02
041db94
 
74d7b02
be786c9
041db94
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
import sahi
import torch
from ultralytics import YOLO

# Download sample images
sahi.utils.file.download_from_url(
    "https://raw.githubusercontent.com/kadirnar/dethub/main/data/images/highway.jpg",
    "highway.jpg",
)
sahi.utils.file.download_from_url(
    "https://raw.githubusercontent.com/obss/sahi/main/tests/data/small-vehicles1.jpeg",
    "small-vehicles1.jpeg",
)
sahi.utils.file.download_from_url(
    "https://raw.githubusercontent.com/ultralytics/yolov5/master/data/images/zidane.jpg",
    "zidane.jpg",
)

# List of YOLOv8 segmentation models
model_names = [
    "yolov8n-seg.pt",
    "yolov8s-seg.pt",
    "yolov8m-seg.pt",
    "yolov8l-seg.pt",
    "yolov8x-seg.pt",
]

current_model_name = "yolov8m-seg.pt"
model = YOLO(current_model_name)

def yolov8_inference(
    image: gr.Image = None,
    model_name: gr.Dropdown = None,
    image_size: gr.Slider = 640,
    conf_threshold: gr.Slider = 0.25,
    iou_threshold: gr.Slider = 0.45,
):
    """
    YOLOv8 inference function to return masks and label names for each detected object
    Args:
        image: Input image
        model_name: Name of the model
        image_size: Image size
        conf_threshold: Confidence threshold
        iou_threshold: IOU threshold
    Returns:
        Object masks, coordinates, and label names
    """
    global model
    global current_model_name
    
    # Check if a new model is selected
    if model_name != current_model_name:
        model = YOLO(model_name)
        current_model_name = model_name
    
    # Set the confidence and IOU thresholds
    model.overrides["conf"] = conf_threshold
    model.overrides["iou"] = iou_threshold
    
    # Perform model prediction
    results = model(image)

    # Initialize an empty list to store the output
    output = []
    
    # Iterate over the results
    for i,box in enumerate(results[0].boxes):
        label = results[0].names[box.cls[0].item()]
        bbox = box.xyxy[0]
        output.append({"label": label, "bbox_coords": bbox})
    
    return output

# Define Gradio interface inputs and outputs
inputs = [
    gr.Image(type="filepath", label="Input Image"),
    gr.Dropdown(
        model_names,
        value=current_model_name,
        label="Model type",
    ),
    gr.Slider(minimum=320, maximum=1280, value=640, step=32, label="Image Size"),
    gr.Slider(
        minimum=0.0, maximum=1.0, value=0.25, step=0.05, label="Confidence Threshold"
    ),
    gr.Slider(minimum=0.0, maximum=1.0, value=0.45, step=0.05, label="IOU Threshold"),
]

# Output is a dictionary containing label names and coordinates of masks or boxes
outputs = gr.JSON(label="Output Masks and Labels")

title = "Ultralytics YOLOv8 Segmentation Demo"

# Example images for the interface
examples = [
    ["zidane.jpg", "yolov8m-seg.pt", 640, 0.6, 0.45],
    ["highway.jpg", "yolov8m-seg.pt", 640, 0.25, 0.45],
    ["small-vehicles1.jpeg", "yolov8m-seg.pt", 640, 0.25, 0.45],
]

# Build the Gradio demo app
demo_app = gr.Interface(
    fn=yolov8_inference,
    inputs=inputs,
    outputs=outputs,
    title=title,
    examples=examples,
    cache_examples=False,  # Set to False to avoid caching issues
    theme="default",
)

# Launch the app
demo_app.queue().launch(debug=True)