minar09 commited on
Commit
6c6cd1e
·
verified ·
1 Parent(s): da73307

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +231 -0
app.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import gradio as gr
4
+ import numpy as np
5
+ import supervision as sv
6
+ from pathlib import Path
7
+ from dds_cloudapi_sdk import Config, Client, TextPrompt
8
+ from dds_cloudapi_sdk.tasks.dinox import DinoxTask
9
+ from dds_cloudapi_sdk.tasks.detection import DetectionTask
10
+ from dds_cloudapi_sdk.tasks.types import DetectionTarget
11
+
12
+ # Constants
13
+ API_TOKEN = "361d32fa5ce22649133660c65cfcaf22"
14
+ TEXT_PROMPT = "wheel . eye . helmet . mouse . mouth . vehicle . steering wheel . ear . nose"
15
+ TEMP_DIR = "./temp"
16
+ OUTPUT_DIR = "./outputs"
17
+
18
+ # Ensure directories exist
19
+ os.makedirs(TEMP_DIR, exist_ok=True)
20
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
21
+
22
+ def initialize_dino_client():
23
+ """Initialize the DINO-X client"""
24
+ config = Config(API_TOKEN)
25
+ return Client(config)
26
+
27
+ def get_class_mappings(text_prompt):
28
+ """Create class name to ID mappings"""
29
+ classes = [x.strip().lower() for x in text_prompt.split('.') if x]
30
+ class_name_to_id = {name: id for id, name in enumerate(classes)}
31
+ return classes, class_name_to_id
32
+
33
+ def process_predictions(predictions, class_name_to_id):
34
+ """Process DINO-X predictions into detection format"""
35
+ boxes = []
36
+ masks = []
37
+ confidences = []
38
+ class_names = []
39
+ class_ids = []
40
+
41
+ for obj in predictions:
42
+ boxes.append(obj.bbox)
43
+ if hasattr(obj, 'mask') and obj.mask:
44
+ masks.append(DetectionTask.rle2mask(
45
+ DetectionTask.string2rle(obj.mask.counts),
46
+ obj.mask.size
47
+ ))
48
+ cls_name = obj.category.lower().strip()
49
+ class_names.append(cls_name)
50
+ class_ids.append(class_name_to_id[cls_name])
51
+ confidences.append(obj.score)
52
+
53
+ return {
54
+ 'boxes': np.array(boxes),
55
+ 'masks': np.array(masks) if masks else None,
56
+ 'class_ids': np.array(class_ids),
57
+ 'class_names': class_names,
58
+ 'confidences': confidences
59
+ }
60
+
61
+ def process_image(image_path, prompt=TEXT_PROMPT):
62
+ """Process a single image with DINO-X"""
63
+ try:
64
+ client = initialize_dino_client()
65
+ _, class_name_to_id = get_class_mappings(prompt)
66
+
67
+ # Upload and process image
68
+ image_url = client.upload_file(image_path)
69
+ task = DinoxTask(
70
+ image_url=image_url,
71
+ prompts=[TextPrompt(text=prompt)],
72
+ bbox_threshold=0.25,
73
+ targets=[DetectionTarget.BBox, DetectionTarget.Mask]
74
+ )
75
+ client.run_task(task)
76
+
77
+ # Process predictions
78
+ results = process_predictions(task.result.objects, class_name_to_id)
79
+
80
+ # Annotate image
81
+ img = cv2.imread(image_path)
82
+ detections = sv.Detections(
83
+ xyxy=results['boxes'],
84
+ mask=results['masks'].astype(bool) if results['masks'] is not None else None,
85
+ class_id=results['class_ids']
86
+ )
87
+
88
+ labels = [
89
+ f"{name} {conf:.2f}"
90
+ for name, conf in zip(results['class_names'], results['confidences'])
91
+ ]
92
+
93
+ # Apply annotations
94
+ annotator = sv.BoxAnnotator()
95
+ annotated_frame = annotator.annotate(scene=img.copy(), detections=detections)
96
+
97
+ label_annotator = sv.LabelAnnotator()
98
+ annotated_frame = label_annotator.annotate(
99
+ scene=annotated_frame,
100
+ detections=detections,
101
+ labels=labels
102
+ )
103
+
104
+ if results['masks'] is not None:
105
+ mask_annotator = sv.MaskAnnotator()
106
+ annotated_frame = mask_annotator.annotate(
107
+ scene=annotated_frame,
108
+ detections=detections
109
+ )
110
+
111
+ output_path = os.path.join(OUTPUT_DIR, "result.jpg")
112
+ cv2.imwrite(output_path, annotated_frame)
113
+
114
+ return output_path
115
+
116
+ except Exception as e:
117
+ return f"Error processing image: {str(e)}"
118
+
119
+ def process_video(video_path, prompt=TEXT_PROMPT):
120
+ """Process a video with DINO-X"""
121
+ try:
122
+ client = initialize_dino_client()
123
+ _, class_name_to_id = get_class_mappings(prompt)
124
+
125
+ cap = cv2.VideoCapture(video_path)
126
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
127
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
128
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
129
+
130
+ output_path = os.path.join(OUTPUT_DIR, "result.mp4")
131
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
132
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
133
+
134
+ frame_count = 0
135
+ temp_frame_path = os.path.join(TEMP_DIR, "temp_frame.jpg")
136
+
137
+ while cap.isOpened():
138
+ ret, frame = cap.read()
139
+ if not ret:
140
+ break
141
+
142
+ frame_count += 1
143
+ if frame_count % 3 != 0: # Process every 3rd frame for speed
144
+ continue
145
+
146
+ cv2.imwrite(temp_frame_path, frame)
147
+ image_url = client.upload_file(temp_frame_path)
148
+
149
+ task = DinoxTask(
150
+ image_url=image_url,
151
+ prompts=[TextPrompt(text=prompt)],
152
+ bbox_threshold=0.25
153
+ )
154
+ client.run_task(task)
155
+
156
+ results = process_predictions(task.result.objects, class_name_to_id)
157
+
158
+ detections = sv.Detections(
159
+ xyxy=results['boxes'],
160
+ class_id=results['class_ids']
161
+ )
162
+
163
+ labels = [
164
+ f"{name} {conf:.2f}"
165
+ for name, conf in zip(results['class_names'], results['confidences'])
166
+ ]
167
+
168
+ annotator = sv.BoxAnnotator()
169
+ annotated_frame = annotator.annotate(scene=frame.copy(), detections=detections)
170
+
171
+ label_annotator = sv.LabelAnnotator()
172
+ annotated_frame = label_annotator.annotate(
173
+ scene=annotated_frame,
174
+ detections=detections,
175
+ labels=labels
176
+ )
177
+
178
+ out.write(annotated_frame)
179
+
180
+ cap.release()
181
+ out.release()
182
+
183
+ if os.path.exists(temp_frame_path):
184
+ os.remove(temp_frame_path)
185
+
186
+ return output_path
187
+
188
+ except Exception as e:
189
+ return f"Error processing video: {str(e)}"
190
+
191
+ def process_input(input_file, prompt=TEXT_PROMPT):
192
+ """Process either image or video input"""
193
+ if input_file is None:
194
+ return "Please provide an input file"
195
+
196
+ file_path = input_file.name
197
+ extension = os.path.splitext(file_path)[1].lower()
198
+
199
+ if extension in ['.jpg', '.jpeg', '.png']:
200
+ return process_image(file_path, prompt)
201
+ elif extension in ['.mp4', '.avi', '.mov']:
202
+ return process_video(file_path, prompt)
203
+ else:
204
+ return "Unsupported file format. Please use jpg/jpeg/png for images or mp4/avi/mov for videos."
205
+
206
+ # Create Gradio interface
207
+ demo = gr.Interface(
208
+ fn=process_input,
209
+ inputs=[
210
+ gr.File(
211
+ label="Upload Image/Video",
212
+ file_types=["image", "video"]
213
+ ),
214
+ gr.Textbox(
215
+ label="Detection Prompt",
216
+ value=TEXT_PROMPT,
217
+ lines=2
218
+ )
219
+ ],
220
+ outputs=gr.Image(label="Detection Result"),
221
+ title="DINO-X Object Detection",
222
+ description="Upload an image or video to detect objects using DINO-X. You can modify the detection prompt to specify what objects to look for.",
223
+ examples=[
224
+ ["assets/demo.png", TEXT_PROMPT],
225
+ ["assets/demo.mp4", TEXT_PROMPT]
226
+ ],
227
+ cache_examples=True
228
+ )
229
+
230
+ if __name__ == "__main__":
231
+ demo.launch()