Spaces:
Runtime error
Runtime error
Zengyf-CVer
commited on
Commit
•
fe9546f
1
Parent(s):
8dd584f
v04 add color
Browse files
app.py
CHANGED
@@ -160,23 +160,25 @@ def export_json(results, img_size):
|
|
160 |
|
161 |
|
162 |
# frame conversion
|
163 |
-
def pil_draw(img, countdown_msg, textFont, xyxy, font_size, opt):
|
164 |
|
165 |
img_pil = ImageDraw.Draw(img)
|
166 |
|
167 |
-
img_pil.rectangle(xyxy, fill=None, outline=
|
168 |
|
169 |
if "label" in opt:
|
170 |
text_w, text_h = textFont.getsize(countdown_msg) # Label size
|
|
|
171 |
img_pil.rectangle(
|
172 |
(xyxy[0], xyxy[1], xyxy[0] + text_w, xyxy[1] + text_h),
|
173 |
-
fill=
|
174 |
-
outline=
|
175 |
) # label background
|
|
|
176 |
img_pil.multiline_text(
|
177 |
(xyxy[0], xyxy[1]),
|
178 |
countdown_msg,
|
179 |
-
fill=(
|
180 |
font=textFont,
|
181 |
align="center",
|
182 |
)
|
@@ -184,6 +186,16 @@ def pil_draw(img, countdown_msg, textFont, xyxy, font_size, opt):
|
|
184 |
return img
|
185 |
|
186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
# YOLOv5 image detection function
|
188 |
def yolo_det_img(img, device, model_name, infer_size, conf, iou, max_num, model_cls, opt):
|
189 |
|
@@ -210,6 +222,8 @@ def yolo_det_img(img, device, model_name, infer_size, conf, iou, max_num, model_
|
|
210 |
model.max_det = int(max_num) # Maximum number of detection frames
|
211 |
model.classes = model_cls # model classes
|
212 |
|
|
|
|
|
213 |
img_size = img.size # frame size
|
214 |
|
215 |
results = model(img, size=infer_size) # detection
|
@@ -260,6 +274,8 @@ def yolo_det_img(img, device, model_name, infer_size, conf, iou, max_num, model_
|
|
260 |
[x0, y0, x1, y1],
|
261 |
FONTSIZE,
|
262 |
opt,
|
|
|
|
|
263 |
)
|
264 |
|
265 |
# ----------add object size----------
|
@@ -332,6 +348,8 @@ def yolo_det_video(video, device, model_name, infer_size, conf, iou, max_num, mo
|
|
332 |
model.max_det = int(max_num) # Maximum number of detection frames
|
333 |
model.classes = model_cls # model classes
|
334 |
|
|
|
|
|
335 |
# ----------------Load fonts----------------
|
336 |
yaml_index = cls_name.index(".yaml")
|
337 |
cls_name_lang = cls_name[yaml_index - 2:yaml_index]
|
@@ -393,6 +411,8 @@ def yolo_det_video(video, device, model_name, infer_size, conf, iou, max_num, mo
|
|
393 |
[x0, y0, x1, y1],
|
394 |
FONTSIZE,
|
395 |
opt,
|
|
|
|
|
396 |
)
|
397 |
|
398 |
frame = cv2.cvtColor(np.asarray(frame), cv2.COLOR_RGB2BGR)
|
|
|
160 |
|
161 |
|
162 |
# frame conversion
|
163 |
+
def pil_draw(img, countdown_msg, textFont, xyxy, font_size, opt, obj_cls_index, color_list):
|
164 |
|
165 |
img_pil = ImageDraw.Draw(img)
|
166 |
|
167 |
+
img_pil.rectangle(xyxy, fill=None, outline=color_list[obj_cls_index]) # bounding box
|
168 |
|
169 |
if "label" in opt:
|
170 |
text_w, text_h = textFont.getsize(countdown_msg) # Label size
|
171 |
+
|
172 |
img_pil.rectangle(
|
173 |
(xyxy[0], xyxy[1], xyxy[0] + text_w, xyxy[1] + text_h),
|
174 |
+
fill=color_list[obj_cls_index],
|
175 |
+
outline=color_list[obj_cls_index],
|
176 |
) # label background
|
177 |
+
|
178 |
img_pil.multiline_text(
|
179 |
(xyxy[0], xyxy[1]),
|
180 |
countdown_msg,
|
181 |
+
fill=(255, 255, 255),
|
182 |
font=textFont,
|
183 |
align="center",
|
184 |
)
|
|
|
186 |
return img
|
187 |
|
188 |
|
189 |
+
def color_set(cls_num):
|
190 |
+
color_list = []
|
191 |
+
for i in range(cls_num):
|
192 |
+
color = tuple(np.random.choice(range(256), size=3))
|
193 |
+
# color = ["#"+''.join([random.choice('0123456789ABCDEF') for j in range(6)])]
|
194 |
+
color_list.append(color)
|
195 |
+
|
196 |
+
return color_list
|
197 |
+
|
198 |
+
|
199 |
# YOLOv5 image detection function
|
200 |
def yolo_det_img(img, device, model_name, infer_size, conf, iou, max_num, model_cls, opt):
|
201 |
|
|
|
222 |
model.max_det = int(max_num) # Maximum number of detection frames
|
223 |
model.classes = model_cls # model classes
|
224 |
|
225 |
+
color_list = color_set(len(model_cls_name_cp)) # 设置颜色
|
226 |
+
|
227 |
img_size = img.size # frame size
|
228 |
|
229 |
results = model(img, size=infer_size) # detection
|
|
|
274 |
[x0, y0, x1, y1],
|
275 |
FONTSIZE,
|
276 |
opt,
|
277 |
+
obj_cls_index,
|
278 |
+
color_list,
|
279 |
)
|
280 |
|
281 |
# ----------add object size----------
|
|
|
348 |
model.max_det = int(max_num) # Maximum number of detection frames
|
349 |
model.classes = model_cls # model classes
|
350 |
|
351 |
+
color_list = color_set(len(model_cls_name_cp)) # 设置颜色
|
352 |
+
|
353 |
# ----------------Load fonts----------------
|
354 |
yaml_index = cls_name.index(".yaml")
|
355 |
cls_name_lang = cls_name[yaml_index - 2:yaml_index]
|
|
|
411 |
[x0, y0, x1, y1],
|
412 |
FONTSIZE,
|
413 |
opt,
|
414 |
+
obj_cls_index,
|
415 |
+
color_list,
|
416 |
)
|
417 |
|
418 |
frame = cv2.cvtColor(np.asarray(frame), cv2.COLOR_RGB2BGR)
|