Harshithtd commited on
Commit
777385b
·
verified ·
1 Parent(s): 86945bf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import os
3
+ import numpy as np
4
+ import torch
5
+ import gradio as gr
6
+ from PIL import Image
7
+ from transformers import AutoImageProcessor, AutoModelForObjectDetection
8
+ import supervision as sv
9
+ import spaces
10
+
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
+ processor = AutoImageProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
14
+ model = AutoModelForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365").to(device)
15
+
16
+ BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator()
17
+ MASK_ANNOTATOR = sv.MaskAnnotator()
18
+ LABEL_ANNOTATOR = sv.LabelAnnotator()
19
+ TRACKER = sv.ByteTrack()
20
+
21
+ def annotate_image(
22
+ input_image,
23
+ detections,
24
+ labels
25
+ ) -> np.ndarray:
26
+ output_image = MASK_ANNOTATOR.annotate(input_image, detections)
27
+ output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections)
28
+ output_image = LABEL_ANNOTATOR.annotate(output_image, detections, labels=labels)
29
+ return output_image
30
+
31
+ @spaces.GPU
32
+ def process_image(
33
+ input_image,
34
+ confidence_threshold,
35
+ ):
36
+ results = query(input_image, confidence_threshold)
37
+ detections = sv.Detections.from_transformers(results[0])
38
+ detections = TRACKER.update_with_detections(detections)
39
+ final_labels = [model.config.id2label[label] for label in detections.class_id.tolist()]
40
+ output_image = annotate_image(input_image, detections, final_labels)
41
+ return output_image
42
+
43
+ def query(image, confidence_threshold):
44
+ inputs = processor(images=image, return_tensors="pt").to(device)
45
+ with torch.no_grad():
46
+ outputs = model(**inputs)
47
+ target_sizes = torch.tensor([image.size[::-1]])
48
+ results = processor.post_process_object_detection(outputs=outputs, threshold=confidence_threshold, target_sizes=target_sizes)
49
+ return results
50
+
51
+ def run_demo():
52
+ input_image = gr.inputs.Image(label="Input Image")
53
+ conf = gr.inputs.Slider(label="Confidence Threshold", minimum=0.1, maximum=1.0, value=0.6, step=0.05)
54
+ output_image = gr.outputs.Image(label="Output Image")
55
+
56
+ def process_and_display(input_image, conf):
57
+ output_img = process_image(input_image, conf)
58
+ return output_img
59
+
60
+ gr.Interface(
61
+ fn=process_and_display,
62
+ inputs=[input_image, conf],
63
+ outputs=output_image,
64
+ title="Real Time Object Detection with RT-DETR",
65
+ description="This demo uses RT-DETR for object detection in images. Adjust the confidence threshold to see different results.",
66
+ capture_session=True,
67
+ ).launch()