import gradio as gr from transformers import pipeline from transformers.image_utils import load_image checkpoints = [ 'ustc-community/dfine_n_coco', 'ustc-community/dfine_s_coco', 'ustc-community/dfine_m_coco', 'ustc-community/dfine_l_coco', 'ustc-community/dfine_x_coco', 'ustc-community/dfine_s_obj365', 'ustc-community/dfine_m_obj365', 'ustc-community/dfine_l_obj365', 'ustc-community/dfine_x_obj365', 'ustc-community/dfine_s_obj2coco', 'ustc-community/dfine_m_obj2coco', 'ustc-community/dfine_l_obj2coco_e25', 'ustc-community/dfine_x_obj2coco', ] def detect_objects(image, checkpoint, confidence_threshold=0.3, use_url=False, url=""): pipe = pipeline( "object-detection", model=checkpoint, image_processor=checkpoint, device="cpu", ) if use_url and url: input_image = load_image(url) elif image is not None: input_image = image else: return None, gr.Markdown("**Error**: Please provide an image or URL.", visible=True) # Run detection results = pipe(input_image, threshold=confidence_threshold) # Get image dimensions for validation img_width, img_height = input_image.size # Prepare annotations in the format: list of (bounding_box, label) annotations = [] for result in results: score = result["score"] if score < confidence_threshold: continue label = f"{result['label']} ({score:.2f})" box = result["box"] # Validate and convert box to (x1, y1, x2, y2) x1 = max(0, int(box["xmin"])) y1 = max(0, int(box["ymin"])) x2 = min(img_width, int(box["xmax"])) y2 = min(img_height, int(box["ymax"])) # Ensure valid box if x2 <= x1 or y2 <= y1: continue bounding_box = (x1, y1, x2, y2) annotations.append((bounding_box, label)) # Handle empty annotations if not annotations: return (input_image, []), gr.Markdown( "**Warning**: No objects detected above the confidence threshold. Try lowering the threshold.", visible=True ) # Return base image and annotations return (input_image, annotations), gr.Markdown(visible=False) # Gradio interface with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ # Real-Time Object Detection Demo Experience state-of-the-art object detection with USTC's Dfine models. Upload an image, provide a URL, or try an example below. Select a model and adjust the confidence threshold to see detections in real time! **Instructions**: - Upload an image or enter a URL. - Choose a model checkpoint from the dropdown. - Adjust the confidence threshold (0.1 to 1.0). - Click "Detect Objects" to view results, or select an example. - Use "Clear" to reset inputs and outputs. """, elem_classes="header-text" ) with gr.Row(): with gr.Column(scale=1, min_width=300): with gr.Group(): image_input = gr.Image( label="Upload Image", type="pil", sources=["upload", "webcam"], interactive=True, elem_classes="input-component", ) use_url = gr.Checkbox(label="Use Image URL Instead", value=False) url_input = gr.Textbox( label="Image URL", placeholder="https://example.com/image.jpg", visible=False, elem_classes="input-component", ) checkpoint = gr.Dropdown( choices=checkpoints, label="Select Model Checkpoint", value=checkpoints[0], elem_classes="input-component", ) confidence_threshold = gr.Slider( minimum=0.1, maximum=1.0, value=0.3, step=0.1, label="Confidence Threshold", elem_classes="input-component", ) with gr.Row(): detect_button = gr.Button( "Detect Objects", variant="primary", elem_classes="action-button", ) clear_button = gr.Button( "Clear", variant="secondary", elem_classes="action-button", ) with gr.Column(scale=2): output_annotated = gr.AnnotatedImage( label="Detection Results", show_label=True, color_map=None, # Let Gradio assign colors elem_classes="output-component", ) error_message = gr.Markdown(visible=False, elem_classes="error-text") gr.Examples( examples=[ ["./image.jpg", False, "", checkpoints[0], 0.3], [None, True, "https://live.staticflickr.com/65535/33021460783_1646d43c54_b.jpg", checkpoints[0], 0.3], ], inputs=[image_input, use_url, url_input, checkpoint, confidence_threshold], outputs=[output_annotated, error_message], fn=detect_objects, cache_examples=False, # Avoid caching due to model size label="Select an example to run the model", ) # Dynamic visibility for URL input use_url.change( fn=lambda x: gr.update(visible=x), inputs=use_url, outputs=url_input, ) # Clear button functionality clear_button.click( fn=lambda: ( None, # image_input False, # use_url "", # url_input checkpoints[0], # checkpoint 0.3, # confidence_threshold None, # output_annotated gr.Markdown(visible=False), # error_message ), outputs=[ image_input, use_url, url_input, checkpoint, confidence_threshold, output_annotated, error_message, ], ) # Detect button event detect_button.click( fn=detect_objects, inputs=[image_input, checkpoint, confidence_threshold, use_url, url_input], outputs=[output_annotated, error_message], ) if __name__ == "__main__": demo.launch()