itsindrabudhik commited on
Commit
eeac268
·
verified ·
1 Parent(s): 940d094

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +302 -0
app.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, replace
2
+ from functools import reduce
3
+ from io import BytesIO
4
+ import math
5
+ import os
6
+ from pprint import pprint
7
+ import tempfile
8
+
9
+ from PIL import Image, ImageDraw, ImageFont
10
+ import numpy as np
11
+ import cv2
12
+
13
+ import seaborn as sns
14
+ import matplotlib.pyplot as plt
15
+ %matplotlib inline
16
+
17
+ import torch
18
+ from torch.utils.data import Dataset
19
+ import torchvision
20
+ from torchvision import transforms
21
+
22
+ import roboflow
23
+ from roboflow import Roboflow
24
+ import supervision as sv
25
+ import albumentations as A
26
+
27
+ import gradio as gr
28
+ import requests
29
+
30
+ # from torchmetrics.detection.mean_ap import MeanAveragePrecision
31
+ # from torchmetrics.detection.iou import IntersectionOverUnion
32
+ # import evaluate
33
+ #from datasets import load_metric
34
+
35
+ from transformers import pipeline
36
+ from transformers import (
37
+ AutoProcessor,
38
+ AutoImageProcessor,
39
+ AutoModel,
40
+ AutoModelForObjectDetection,
41
+ RTDetrForObjectDetection,
42
+ RTDetrImageProcessor,
43
+ TrainingArguments,
44
+ Trainer
45
+ )
46
+ from huggingface_hub import hf_hub_download
47
+
48
+ from safetensors.torch import load_file
49
+
50
+ #@title Utilities
51
+ PALETTE = {0: {"color": (255, 0, 0),
52
+ "name": "Ambulance"},
53
+ 1: {"color": (0, 191, 0),
54
+ "name": "Firetruck"},
55
+ 2: {"color": (0, 0, 255),
56
+ "name": "Police"},
57
+ 3: {"color": (255, 0, 255),
58
+ "name": "Non-EV"}}
59
+ label2id = {val["name"]: id for (id, val) in PALETTE.items()}
60
+ id2label = {id: name for (name, id) in label2id.items()}
61
+
62
+ print(label2id)
63
+ print(id2label)
64
+
65
+ def unnormalize_bbox(img_h, img_w, bbox):
66
+ x_min = bbox[0] - bbox[2]/2
67
+ y_min = bbox[1] - bbox[3]/2
68
+ x_max = bbox[0] + bbox[2]/2 # - x_min
69
+ y_max = bbox[1] + bbox[3]/2 # - y_min
70
+
71
+ x_min *= img_w
72
+ y_min *= img_h
73
+ x_max *= img_w
74
+ y_max *= img_h
75
+ x_min, y_min, x_max, y_max = list(map(int, [x_min, y_min, x_max, y_max]))
76
+
77
+ return (x_min, y_min, x_max, y_max)
78
+
79
+ def paint_bbox(
80
+ image,
81
+ annotations,
82
+ normalize_labels=True,
83
+ normalize_bbox=True,
84
+ ):
85
+ bboxes = annotations["boxes"].tolist()
86
+ class_id = annotations["labels"].tolist()
87
+ confidences = annotations["scores"].tolist()
88
+
89
+ painted_img = image.copy() # Wutdehell
90
+ for (bbox, label, confidence) in zip(bboxes, class_id, confidences):
91
+ label = (label - 1) if normalize_labels else label
92
+ if normalize_bbox:
93
+ img_h, img_w = image.shape[0], image.shape[1] # H, W, C
94
+ x_min, y_min, x_max, y_max = unnormalize_bbox(img_h, img_w, bbox)
95
+ print([x_min, y_min, x_max, y_max])
96
+
97
+ """
98
+ x_min = #int(bbox[0] - bbox[2]/2) # Left
99
+ y_min = #int(bbox[1] - bbox[3]/2) # Top
100
+ x_max = #int(bbox[0] + bbox[2]/2)
101
+ y_max = #int(bbox[1] + bbox[3]/2)
102
+ """
103
+ else:
104
+ x_min, y_min, x_max, y_max = list(map(int, bbox))
105
+
106
+ box_color = PALETTE[label]["color"]
107
+ label_name = PALETTE[label]["name"]
108
+
109
+ if confidence != -1:
110
+ label_name = f"{label_name} ({confidence:.2f})"
111
+
112
+ cv2.rectangle(painted_img,
113
+ (x_min, y_min),
114
+ (x_max, y_max),
115
+ color=box_color,
116
+ thickness=2)
117
+ cv2.rectangle(painted_img,
118
+ (x_min, y_min),
119
+ (x_min + 5 + len(label_name)*10, y_min + 17),
120
+ color=box_color,
121
+ thickness=-1)
122
+ cv2.putText(painted_img,
123
+ label_name,
124
+ (x_min + 2, y_min + 12),
125
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX,
126
+ fontScale=0.5,
127
+ color=(255, 255, 255),
128
+ thickness=1)
129
+ return painted_img
130
+
131
+ # Function to calculate Intersection over Union (IoU)
132
+ def calculate_iou(truth_bbx, pred_bbx):
133
+ # Coordinates of the boxes: [xmin, ymin, xmax, ymax]
134
+ x1, y1, x2, y2 = truth_bbx
135
+ x1_p, y1_p, x2_p, y2_p = pred_bbx
136
+
137
+ # Calculate intersection
138
+ ixmin = max(x1, x1_p)
139
+ iymin = max(y1, y1_p)
140
+ ixmax = min(x2, x2_p)
141
+ iymax = min(y2, y2_p)
142
+
143
+ iw = max(0, ixmax - ixmin)
144
+ ih = max(0, iymax - iymin)
145
+
146
+ intersection = iw * ih
147
+ area1 = (x2 - x1) * (y2 - y1)
148
+ area2 = (x2_p - x1_p) * (y2_p - y1_p)
149
+ union = area1 + area2 - intersection
150
+ iou = intersection / union if union != 0 else 0
151
+ return iou
152
+
153
+ # Example: emotion_classifier = pipeline("image-classification", model="itsindrabudhik/emotion_classification")
154
+ # (Load only once)
155
+ DETECTOR = pipeline("object-detection", model="itsindrabudhik/finalProjectCV2425") #later on, change this with out trained modell yesssss (the trained model should be uploaded to hugging face)
156
+ tensor_file = hf_hub_download(repo_id="itsindrabudhik/finalProjectCV2425",
157
+ filename="model.safetensors")
158
+
159
+ # Assign classification head weights since that pipeline seems to not handling it
160
+ # weights = load_file(tensor_file)
161
+ # DETECTOR.model.class_labels_classifier.weight.data = weights["class_labels_classifier.weight"]
162
+ # DETECTOR.model.class_labels_classifier.bias.data = weights["class_labels_classifier.bias"]
163
+ # del weights
164
+
165
+ def detect_ev_nev(image, confidence_threshold=0.5, iou_threshold=0.5):
166
+ # Run the detector pipeline on the image
167
+ results = DETECTOR(image)
168
+
169
+ # Open the image
170
+ if isinstance(image, str): # If the image is a URL or file path
171
+ if image.startswith("http"):
172
+ response = requests.get(image)
173
+ img = Image.open(BytesIO(response.content))
174
+ else:
175
+ img = Image.open(image)
176
+ else:
177
+ img = image
178
+
179
+ # Draw bounding boxes and labels on the image
180
+ font_path = os.path.join(cv2.__path__[0],'qt','fonts','DejaVuSans.ttf')
181
+ font = ImageFont.truetype(font_path, size=32)
182
+ draw = ImageDraw.Draw(img)
183
+
184
+ details = [] # Collect details for text output
185
+ for result in results:
186
+ score = result['score']
187
+ label = result['label']
188
+ box = result['box']
189
+
190
+ # Apply confidence threshold
191
+ if score < confidence_threshold:
192
+ continue
193
+
194
+ # Filter out low IoU detections
195
+ keep = True
196
+ for previous_result in results:
197
+ if previous_result != result:
198
+ prev_box = previous_result['box']
199
+ iou = calculate_iou([box['xmin'], box['ymin'], box['xmax'], box['ymax']],
200
+ [prev_box['xmin'], prev_box['ymin'], prev_box['xmax'], prev_box['ymax']])
201
+ if iou > iou_threshold:
202
+ keep = False
203
+ break
204
+
205
+ label_color = PALETTE[label2id[label]]["color"]
206
+ if keep:
207
+ # Draw the bounding box and label
208
+ xmin, ymin, xmax, ymax = box['xmin'], box['ymin'], box['xmax'], box['ymax']
209
+ draw.rectangle([xmin, ymin, xmax, ymax], outline="red", width=3)
210
+
211
+ # Use a larger font size for text
212
+ text = f"{label} ({score:.2f})"
213
+
214
+ # Calculate text bounding box
215
+ text_bbox = draw.textbbox((xmin, ymin - 10), text, font=font) # This gives (xmin, ymin, xmax, ymax)
216
+ text_width = text_bbox[2] - text_bbox[0] # width of the text box
217
+ text_height = text_bbox[3] - text_bbox[1] # height of the text box
218
+
219
+ # Draw the text on the image (position adjusted)
220
+ draw.text((xmin, ymin - text_height - 5), text, fill="red", font=font)
221
+
222
+ # Add details to the list
223
+ details.append({
224
+ "Label": label,
225
+ "Confidence": f"{score:.2f}",
226
+ "Bounding Box": f"({xmin}, {ymin}, {xmax}, {ymax})"
227
+ })
228
+ details_text = "\n".join([f"Label: {d['Label']}, Confidence: {d['Confidence']}, Box: {d['Bounding Box']}" for d in details])
229
+ return img, details_text
230
+
231
+ def detect_video(video, confidence_threshold=0.5, iou_threshold=0.5):
232
+ video_capture = cv2.VideoCapture(video)
233
+ fps = video_capture.get(cv2.CAP_PROP_FPS)
234
+ frame_width = int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH))
235
+ frame_height = int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
236
+
237
+ temp_output = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
238
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
239
+ out = cv2.VideoWriter(temp_output.name, fourcc, fps, (frame_width, frame_height))
240
+
241
+ details = []
242
+ total_frames = 0
243
+ detected_frames = 0
244
+
245
+ while True:
246
+ ret, frame = video_capture.read()
247
+ if not ret:
248
+ break
249
+
250
+ total_frames += 1
251
+ image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
252
+ annotated_image, frame_details = detect_ev_nev(image, confidence_threshold, iou_threshold)
253
+
254
+ # Count frames with detections
255
+ if frame_details.strip(): # Non-empty details indicate detections
256
+ detected_frames += 1
257
+
258
+ details.append(frame_details)
259
+ annotated_frame = cv2.cvtColor(np.array(annotated_image), cv2.COLOR_RGB2BGR)
260
+ out.write(annotated_frame)
261
+
262
+ video_capture.release()
263
+ out.release()
264
+
265
+ details_text = "\n".join(details)
266
+ summary = f"Total Frames: {total_frames}, Frames with Detections: {detected_frames}\n" + details_text
267
+ return temp_output.name, summary
268
+
269
+ def detect(file, confidence_threshold=0.5, iou_threshold=0.5):
270
+ # Determine if input is an image or video
271
+ file_ext = file.name.split(".")[-1].lower()
272
+ if file_ext in ["png", "jpg", "jpeg"]:
273
+ # Image processing
274
+ annotated_image, details = detect_ev_nev(file, confidence_threshold, iou_threshold)
275
+ return annotated_image, None, details
276
+ elif file_ext in ["mp4", "avi", "mov"]:
277
+ # Video processing
278
+ processed_video, details = detect_video(file, confidence_threshold, iou_threshold)
279
+ return None, processed_video, details
280
+ else:
281
+ raise ValueError("Unsupported file format. Please upload an image or video.")
282
+
283
+
284
+ interface = gr.Interface(
285
+ fn=detect,
286
+ inputs=[
287
+ gr.File(label="Upload Image or Video", file_types=[".png", ".jpg", ".jpeg", ".mp4", ".avi", ".mov"]),
288
+ gr.Slider(0, 1, value=0.5, label="Confidence Threshold"),
289
+ gr.Slider(0, 1, value=0.5, label="IoU Threshold"),
290
+ ],
291
+ outputs=[
292
+
293
+ gr.Image(label="Processed Image"),
294
+
295
+ gr.Video(label="Generated Video"),
296
+ gr.Text(label="Detection Details")
297
+
298
+ ],
299
+ title="RT-DETR Object Detection for Images and Videos",
300
+ description="Upload an image or video to detect objects using the fine-tuned RT-DETR model. Results include the annotated image/video and detection details."
301
+ )
302
+ interface.launch(debug=True)