Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from transformers import pipeline | |
from transformers.image_utils import load_image | |
checkpoints = [ | |
'ustc-community/dfine_n_coco', | |
'ustc-community/dfine_s_coco', | |
'ustc-community/dfine_m_coco', | |
'ustc-community/dfine_l_coco', | |
'ustc-community/dfine_x_coco', | |
'ustc-community/dfine_s_obj365', | |
'ustc-community/dfine_m_obj365', | |
'ustc-community/dfine_l_obj365', | |
'ustc-community/dfine_x_obj365', | |
'ustc-community/dfine_s_obj2coco', | |
'ustc-community/dfine_m_obj2coco', | |
'ustc-community/dfine_l_obj2coco_e25', | |
'ustc-community/dfine_x_obj2coco', | |
] | |
def detect_objects(image, checkpoint, confidence_threshold=0.3, use_url=False, url=""): | |
pipe = pipeline( | |
"object-detection", | |
model=checkpoint, | |
image_processor=checkpoint, | |
device="cpu", | |
) | |
if use_url and url: | |
input_image = load_image(url) | |
elif image is not None: | |
input_image = image | |
else: | |
return None, gr.Markdown("**Error**: Please provide an image or URL.", visible=True) | |
# Run detection | |
results = pipe(input_image, threshold=confidence_threshold) | |
# Get image dimensions for validation | |
img_width, img_height = input_image.size | |
# Prepare annotations in the format: list of (bounding_box, label) | |
annotations = [] | |
for result in results: | |
score = result["score"] | |
if score < confidence_threshold: | |
continue | |
label = f"{result['label']} ({score:.2f})" | |
box = result["box"] | |
# Validate and convert box to (x1, y1, x2, y2) | |
x1 = max(0, int(box["xmin"])) | |
y1 = max(0, int(box["ymin"])) | |
x2 = min(img_width, int(box["xmax"])) | |
y2 = min(img_height, int(box["ymax"])) | |
# Ensure valid box | |
if x2 <= x1 or y2 <= y1: | |
continue | |
bounding_box = (x1, y1, x2, y2) | |
annotations.append((bounding_box, label)) | |
# Handle empty annotations | |
if not annotations: | |
return (input_image, []), gr.Markdown( | |
"**Warning**: No objects detected above the confidence threshold. Try lowering the threshold.", | |
visible=True | |
) | |
# Return base image and annotations | |
return (input_image, annotations), gr.Markdown(visible=False) | |
# Gradio interface | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# Real-Time Object Detection Demo | |
Experience state-of-the-art object detection with USTC's Dfine models. Upload an image, provide a URL, or try an example below. Select a model and adjust the confidence threshold to see detections in real time! | |
**Instructions**: | |
- Upload an image or enter a URL. | |
- Choose a model checkpoint from the dropdown. | |
- Adjust the confidence threshold (0.1 to 1.0). | |
- Click "Detect Objects" to view results, or select an example. | |
- Use "Clear" to reset inputs and outputs. | |
""", | |
elem_classes="header-text" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=300): | |
with gr.Group(): | |
image_input = gr.Image( | |
label="Upload Image", | |
type="pil", | |
sources=["upload", "webcam"], | |
interactive=True, | |
elem_classes="input-component", | |
) | |
use_url = gr.Checkbox(label="Use Image URL Instead", value=False) | |
url_input = gr.Textbox( | |
label="Image URL", | |
placeholder="https://example.com/image.jpg", | |
visible=False, | |
elem_classes="input-component", | |
) | |
checkpoint = gr.Dropdown( | |
choices=checkpoints, | |
label="Select Model Checkpoint", | |
value=checkpoints[0], | |
elem_classes="input-component", | |
) | |
confidence_threshold = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.3, | |
step=0.1, | |
label="Confidence Threshold", | |
elem_classes="input-component", | |
) | |
with gr.Row(): | |
detect_button = gr.Button( | |
"Detect Objects", | |
variant="primary", | |
elem_classes="action-button", | |
) | |
clear_button = gr.Button( | |
"Clear", | |
variant="secondary", | |
elem_classes="action-button", | |
) | |
with gr.Column(scale=2): | |
output_annotated = gr.AnnotatedImage( | |
label="Detection Results", | |
show_label=True, | |
color_map=None, # Let Gradio assign colors | |
elem_classes="output-component", | |
) | |
error_message = gr.Markdown(visible=False, elem_classes="error-text") | |
gr.Examples( | |
examples=[ | |
["./image.jpg", False, "", checkpoints[0], 0.3], | |
[None, True, "https://live.staticflickr.com/65535/33021460783_1646d43c54_b.jpg", checkpoints[0], 0.3], | |
], | |
inputs=[image_input, use_url, url_input, checkpoint, confidence_threshold], | |
outputs=[output_annotated, error_message], | |
fn=detect_objects, | |
cache_examples=False, # Avoid caching due to model size | |
label="Select an example to run the model", | |
) | |
# Dynamic visibility for URL input | |
use_url.change( | |
fn=lambda x: gr.update(visible=x), | |
inputs=use_url, | |
outputs=url_input, | |
) | |
# Clear button functionality | |
clear_button.click( | |
fn=lambda: ( | |
None, # image_input | |
False, # use_url | |
"", # url_input | |
checkpoints[0], # checkpoint | |
0.3, # confidence_threshold | |
None, # output_annotated | |
gr.Markdown(visible=False), # error_message | |
), | |
outputs=[ | |
image_input, | |
use_url, | |
url_input, | |
checkpoint, | |
confidence_threshold, | |
output_annotated, | |
error_message, | |
], | |
) | |
# Detect button event | |
detect_button.click( | |
fn=detect_objects, | |
inputs=[image_input, checkpoint, confidence_threshold, use_url, url_input], | |
outputs=[output_annotated, error_message], | |
) | |
if __name__ == "__main__": | |
demo.launch() |