ariG23498's picture
ariG23498 HF Staff
chore: adding demo
82925a6
raw
history blame
6.56 kB
import gradio as gr
from transformers import pipeline
from transformers.image_utils import load_image
checkpoints = [
'ustc-community/dfine_n_coco',
'ustc-community/dfine_s_coco',
'ustc-community/dfine_m_coco',
'ustc-community/dfine_l_coco',
'ustc-community/dfine_x_coco',
'ustc-community/dfine_s_obj365',
'ustc-community/dfine_m_obj365',
'ustc-community/dfine_l_obj365',
'ustc-community/dfine_x_obj365',
'ustc-community/dfine_s_obj2coco',
'ustc-community/dfine_m_obj2coco',
'ustc-community/dfine_l_obj2coco_e25',
'ustc-community/dfine_x_obj2coco',
]
def detect_objects(image, checkpoint, confidence_threshold=0.3, use_url=False, url=""):
pipe = pipeline(
"object-detection",
model=checkpoint,
image_processor=checkpoint,
device="cpu",
)
if use_url and url:
input_image = load_image(url)
elif image is not None:
input_image = image
else:
return None, gr.Markdown("**Error**: Please provide an image or URL.", visible=True)
# Run detection
results = pipe(input_image, threshold=confidence_threshold)
# Get image dimensions for validation
img_width, img_height = input_image.size
# Prepare annotations in the format: list of (bounding_box, label)
annotations = []
for result in results:
score = result["score"]
if score < confidence_threshold:
continue
label = f"{result['label']} ({score:.2f})"
box = result["box"]
# Validate and convert box to (x1, y1, x2, y2)
x1 = max(0, int(box["xmin"]))
y1 = max(0, int(box["ymin"]))
x2 = min(img_width, int(box["xmax"]))
y2 = min(img_height, int(box["ymax"]))
# Ensure valid box
if x2 <= x1 or y2 <= y1:
continue
bounding_box = (x1, y1, x2, y2)
annotations.append((bounding_box, label))
# Handle empty annotations
if not annotations:
return (input_image, []), gr.Markdown(
"**Warning**: No objects detected above the confidence threshold. Try lowering the threshold.",
visible=True
)
# Return base image and annotations
return (input_image, annotations), gr.Markdown(visible=False)
# Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# Real-Time Object Detection Demo
Experience state-of-the-art object detection with USTC's Dfine models. Upload an image, provide a URL, or try an example below. Select a model and adjust the confidence threshold to see detections in real time!
**Instructions**:
- Upload an image or enter a URL.
- Choose a model checkpoint from the dropdown.
- Adjust the confidence threshold (0.1 to 1.0).
- Click "Detect Objects" to view results, or select an example.
- Use "Clear" to reset inputs and outputs.
""",
elem_classes="header-text"
)
with gr.Row():
with gr.Column(scale=1, min_width=300):
with gr.Group():
image_input = gr.Image(
label="Upload Image",
type="pil",
sources=["upload", "webcam"],
interactive=True,
elem_classes="input-component",
)
use_url = gr.Checkbox(label="Use Image URL Instead", value=False)
url_input = gr.Textbox(
label="Image URL",
placeholder="https://example.com/image.jpg",
visible=False,
elem_classes="input-component",
)
checkpoint = gr.Dropdown(
choices=checkpoints,
label="Select Model Checkpoint",
value=checkpoints[0],
elem_classes="input-component",
)
confidence_threshold = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.3,
step=0.1,
label="Confidence Threshold",
elem_classes="input-component",
)
with gr.Row():
detect_button = gr.Button(
"Detect Objects",
variant="primary",
elem_classes="action-button",
)
clear_button = gr.Button(
"Clear",
variant="secondary",
elem_classes="action-button",
)
with gr.Column(scale=2):
output_annotated = gr.AnnotatedImage(
label="Detection Results",
show_label=True,
color_map=None, # Let Gradio assign colors
elem_classes="output-component",
)
error_message = gr.Markdown(visible=False, elem_classes="error-text")
gr.Examples(
examples=[
["./image.jpg", False, "", checkpoints[0], 0.3],
[None, True, "https://live.staticflickr.com/65535/33021460783_1646d43c54_b.jpg", checkpoints[0], 0.3],
],
inputs=[image_input, use_url, url_input, checkpoint, confidence_threshold],
outputs=[output_annotated, error_message],
fn=detect_objects,
cache_examples=False, # Avoid caching due to model size
label="Select an example to run the model",
)
# Dynamic visibility for URL input
use_url.change(
fn=lambda x: gr.update(visible=x),
inputs=use_url,
outputs=url_input,
)
# Clear button functionality
clear_button.click(
fn=lambda: (
None, # image_input
False, # use_url
"", # url_input
checkpoints[0], # checkpoint
0.3, # confidence_threshold
None, # output_annotated
gr.Markdown(visible=False), # error_message
),
outputs=[
image_input,
use_url,
url_input,
checkpoint,
confidence_threshold,
output_annotated,
error_message,
],
)
# Detect button event
detect_button.click(
fn=detect_objects,
inputs=[image_input, checkpoint, confidence_threshold, use_url, url_input],
outputs=[output_annotated, error_message],
)
if __name__ == "__main__":
demo.launch()