import gradio as gr import os import tempfile import cv2 import numpy as np from mmdet.apis import DetInferencer # Helper to load model inferencer = None def load_model(config_path, checkpoint_path): global inferencer inferencer = DetInferencer(model=config_path, weights=checkpoint_path) return "Model loaded." def infer_image(image): if inferencer is None: return "Please load a model first.", None result = inferencer(image) vis = result["visualization"] if isinstance(vis, list): vis = vis[0] return "", vis def infer_video(video): if inferencer is None: return "Please load a model first.", None temp_dir = tempfile.mkdtemp() cap = cv2.VideoCapture(video) fps = cap.get(cv2.CAP_PROP_FPS) w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) out_path = os.path.join(temp_dir, "result.mp4") fourcc = cv2.VideoWriter_fourcc(*"mp4v") out = cv2.VideoWriter(out_path, fourcc, fps, (w, h)) while True: ret, frame = cap.read() if not ret: break result = inferencer(frame) vis = result["visualization"] if isinstance(vis, list): vis = vis[0] out.write(vis[:,:,::-1]) cap.release() out.release() return "", out_path def ui(): with gr.Blocks() as demo: gr.Markdown("# SpecDETR Demo: Image and Video Detection\nUpload your config (.py) and checkpoint (.pth) to start.") with gr.Row(): config = gr.File(label="Config File (.py)") checkpoint = gr.File(label="Checkpoint (.pth)") load_btn = gr.Button("Load Model") load_status = gr.Textbox(label="Status", interactive=False) load_btn.click(load_model, inputs=[config, checkpoint], outputs=load_status) with gr.Tab("Image"): img_input = gr.Image(type="numpy") img_output = gr.Image() img_btn = gr.Button("Detect on Image") img_status = gr.Textbox(label="Status", interactive=False) img_btn.click(infer_image, inputs=img_input, outputs=[img_status, img_output]) with gr.Tab("Video"): vid_input = gr.Video() vid_output = gr.Video() vid_btn = gr.Button("Detect on Video") vid_status = gr.Textbox(label="Status", interactive=False) vid_btn.click(infer_video, inputs=vid_input, outputs=[vid_status, vid_output]) return demo demo = ui() def main(): demo.launch() if __name__ == "__main__": main()