File size: 6,561 Bytes
82925a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import gradio as gr
from transformers import pipeline
from transformers.image_utils import load_image

checkpoints = [
    'ustc-community/dfine_n_coco',
    'ustc-community/dfine_s_coco',
    'ustc-community/dfine_m_coco',
    'ustc-community/dfine_l_coco',
    'ustc-community/dfine_x_coco',
    'ustc-community/dfine_s_obj365',
    'ustc-community/dfine_m_obj365',
    'ustc-community/dfine_l_obj365',
    'ustc-community/dfine_x_obj365',
    'ustc-community/dfine_s_obj2coco',
    'ustc-community/dfine_m_obj2coco',
    'ustc-community/dfine_l_obj2coco_e25',
    'ustc-community/dfine_x_obj2coco',
]

def detect_objects(image, checkpoint, confidence_threshold=0.3, use_url=False, url=""):
    pipe = pipeline(
        "object-detection",
        model=checkpoint,
        image_processor=checkpoint,
        device="cpu",
    )

    if use_url and url:
        input_image = load_image(url)
    elif image is not None:
        input_image = image
    else:
        return None, gr.Markdown("**Error**: Please provide an image or URL.", visible=True)

    # Run detection
    results = pipe(input_image, threshold=confidence_threshold)

    # Get image dimensions for validation
    img_width, img_height = input_image.size

    # Prepare annotations in the format: list of (bounding_box, label)
    annotations = []
    for result in results:
        score = result["score"]
        if score < confidence_threshold:
            continue
        label = f"{result['label']} ({score:.2f})"
        box = result["box"]
        # Validate and convert box to (x1, y1, x2, y2)
        x1 = max(0, int(box["xmin"]))
        y1 = max(0, int(box["ymin"]))
        x2 = min(img_width, int(box["xmax"]))
        y2 = min(img_height, int(box["ymax"]))
        # Ensure valid box
        if x2 <= x1 or y2 <= y1:
            continue
        bounding_box = (x1, y1, x2, y2)
        annotations.append((bounding_box, label))

    # Handle empty annotations
    if not annotations:
        return (input_image, []), gr.Markdown(
            "**Warning**: No objects detected above the confidence threshold. Try lowering the threshold.",
            visible=True
        )

    # Return base image and annotations
    return (input_image, annotations), gr.Markdown(visible=False)

# Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # Real-Time Object Detection Demo
        Experience state-of-the-art object detection with USTC's Dfine models. Upload an image, provide a URL, or try an example below. Select a model and adjust the confidence threshold to see detections in real time!

        **Instructions**:
        - Upload an image or enter a URL.
        - Choose a model checkpoint from the dropdown.
        - Adjust the confidence threshold (0.1 to 1.0).
        - Click "Detect Objects" to view results, or select an example.
        - Use "Clear" to reset inputs and outputs.
        """,
        elem_classes="header-text"
    )
    
    with gr.Row():
        with gr.Column(scale=1, min_width=300):
            with gr.Group():
                image_input = gr.Image(
                    label="Upload Image",
                    type="pil",
                    sources=["upload", "webcam"],
                    interactive=True,
                    elem_classes="input-component",
                )
                use_url = gr.Checkbox(label="Use Image URL Instead", value=False)
                url_input = gr.Textbox(
                    label="Image URL",
                    placeholder="https://example.com/image.jpg",
                    visible=False,
                    elem_classes="input-component",
                )
                checkpoint = gr.Dropdown(
                    choices=checkpoints,
                    label="Select Model Checkpoint",
                    value=checkpoints[0],
                    elem_classes="input-component",
                )
                confidence_threshold = gr.Slider(
                    minimum=0.1,
                    maximum=1.0,
                    value=0.3,
                    step=0.1,
                    label="Confidence Threshold",
                    elem_classes="input-component",
                )
                with gr.Row():
                    detect_button = gr.Button(
                        "Detect Objects",
                        variant="primary",
                        elem_classes="action-button",
                    )
                    clear_button = gr.Button(
                        "Clear",
                        variant="secondary",
                        elem_classes="action-button",
                    )
        
        with gr.Column(scale=2):
            output_annotated = gr.AnnotatedImage(
                label="Detection Results",
                show_label=True,
                color_map=None,  # Let Gradio assign colors
                elem_classes="output-component",
            )
            error_message = gr.Markdown(visible=False, elem_classes="error-text")
    
    gr.Examples(
        examples=[
            ["./image.jpg", False, "", checkpoints[0], 0.3],
            [None, True, "https://live.staticflickr.com/65535/33021460783_1646d43c54_b.jpg", checkpoints[0], 0.3],
        ],
        inputs=[image_input, use_url, url_input, checkpoint, confidence_threshold],
        outputs=[output_annotated, error_message],
        fn=detect_objects,
        cache_examples=False,  # Avoid caching due to model size
        label="Select an example to run the model",
    )

    # Dynamic visibility for URL input
    use_url.change(
        fn=lambda x: gr.update(visible=x),
        inputs=use_url,
        outputs=url_input,
    )

    # Clear button functionality
    clear_button.click(
        fn=lambda: (
            None,  # image_input
            False,  # use_url
            "",  # url_input
            checkpoints[0],  # checkpoint
            0.3,  # confidence_threshold
            None,  # output_annotated
            gr.Markdown(visible=False),  # error_message
        ),
        outputs=[
            image_input,
            use_url,
            url_input,
            checkpoint,
            confidence_threshold,
            output_annotated,
            error_message,
        ],
    )

    # Detect button event
    detect_button.click(
        fn=detect_objects,
        inputs=[image_input, checkpoint, confidence_threshold, use_url, url_input],
        outputs=[output_annotated, error_message],
    )

if __name__ == "__main__":
    demo.launch()