{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: rt-detr-object-detection"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio safetensors==0.4.3 opencv-python torch transformers>=4.43.0 Pillow "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/rt-detr-object-detection/draw_boxes.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import spaces\n", "import gradio as gr\n", "import cv2\n", "from PIL import Image\n", "import torch\n", "import time\n", "import numpy as np\n", "import uuid\n", "\n", "from transformers import RTDetrForObjectDetection, RTDetrImageProcessor  # type: ignore\n", "\n", "from draw_boxes import draw_bounding_boxes\n", "\n", "image_processor = RTDetrImageProcessor.from_pretrained(\"PekingU/rtdetr_r50vd\")\n", "model = RTDetrForObjectDetection.from_pretrained(\"PekingU/rtdetr_r50vd\").to(\"cuda\")\n", "\n", "\n", "SUBSAMPLE = 2\n", "\n", "\n", "@spaces.GPU\n", "def stream_object_detection(video, conf_threshold):\n", "    cap = cv2.VideoCapture(video)\n", "\n", "    video_codec = cv2.VideoWriter_fourcc(*\"mp4v\")  # type: ignore\n", "    fps = int(cap.get(cv2.CAP_PROP_FPS))\n", "\n", "    desired_fps = fps // SUBSAMPLE\n", "    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) // 2\n", "    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) // 2\n", "\n", "    iterating, frame = cap.read()\n", "\n", "    n_frames = 0\n", "\n", "    name = f\"output_{uuid.uuid4()}.mp4\"\n", "    segment_file = cv2.VideoWriter(name, video_codec, desired_fps, (width, height))  # type: ignore\n", "    batch = []\n", "\n", "    while iterating:\n", "        frame = cv2.resize(frame, (0, 0), fx=0.5, fy=0.5)\n", "        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n", "        if n_frames % SUBSAMPLE == 0:\n", "            batch.append(frame)\n", "        if len(batch) == 2 * desired_fps:\n", "            inputs = image_processor(images=batch, return_tensors=\"pt\").to(\"cuda\")\n", "\n", "            print(f\"starting batch of size {len(batch)}\")\n", "            start = time.time()\n", "            with torch.no_grad():\n", "                outputs = model(**inputs)\n", "            end = time.time()\n", "            print(\"time taken for inference\", end - start)\n", "\n", "            start = time.time()\n", "            boxes = image_processor.post_process_object_detection(\n", "                outputs,\n", "                target_sizes=torch.tensor([(height, width)] * len(batch)),\n", "                threshold=conf_threshold,\n", "            )\n", "\n", "            for _, (array, box) in enumerate(zip(batch, boxes)):\n", "                pil_image = draw_bounding_boxes(\n", "                    Image.fromarray(array), box, model, conf_threshold\n", "                )\n", "                frame = np.array(pil_image)\n", "                # Convert RGB to BGR\n", "                frame = frame[:, :, ::-1].copy()\n", "                segment_file.write(frame)\n", "\n", "            batch = []\n", "            segment_file.release()\n", "            yield name\n", "            end = time.time()\n", "            print(\"time taken for processing boxes\", end - start)\n", "            name = f\"output_{uuid.uuid4()}.mp4\"\n", "            segment_file = cv2.VideoWriter(\n", "                name, video_codec, desired_fps, (width, height)\n", "            )  # type: ignore\n", "\n", "        iterating, frame = cap.read()\n", "        n_frames += 1\n", "\n", "\n", "with gr.Blocks() as demo:\n", "    gr.HTML(\n", "        \"\"\"\n", "    <h1 style='text-align: center'>\n", "    Video Object Detection with <a href='https://huggingface.co/PekingU/rtdetr_r101vd_coco_o365' target='_blank'>RT-DETR</a>\n", "    </h1>\n", "    \"\"\"\n", "    )\n", "    with gr.Row():\n", "        with gr.Column():\n", "            video = gr.Video(label=\"Video Source\")\n", "            conf_threshold = gr.Slider(\n", "                label=\"Confidence Threshold\",\n", "                minimum=0.0,\n", "                maximum=1.0,\n", "                step=0.05,\n", "                value=0.30,\n", "            )\n", "        with gr.Column():\n", "            output_video = gr.Video(\n", "                label=\"Processed Video\", streaming=True, autoplay=True\n", "            )\n", "\n", "    video.upload(\n", "        fn=stream_object_detection,\n", "        inputs=[video, conf_threshold],\n", "        outputs=[output_video],\n", "    )\n", "\n", "    gr.Examples(\n", "        examples=[\"3285790-hd_1920_1080_30fps.mp4\"],\n", "        inputs=[video],\n", "    )\n", "\n", "if __name__ == \"__main__\":\n", "    demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}