Spaces:
Sleeping
Sleeping
Commit
·
ec32911
1
Parent(s):
67fa42b
First commit
Browse files- app.py +93 -0
- data.yaml +6 -0
- image_1.jpg +0 -0
- image_10.jpg +0 -0
- image_2.jpg +0 -0
- image_3.jpg +0 -0
- image_4.jpg +0 -0
- image_5.jpg +0 -0
- image_6.jpg +0 -0
- image_7.jpg +0 -0
- image_8.jpg +0 -0
- image_9.jpg +0 -0
- models/common.py +1212 -0
- requirements.txt +47 -0
- runs/train/best_striped.pt +3 -0
- utils/general.py +1135 -0
- utils/plots.py +570 -0
- utils/torch_utils.py +529 -0
app.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import gradio as gr
|
4 |
+
from PIL import Image
|
5 |
+
from models.common import DetectMultiBackend
|
6 |
+
from utils.plots import Annotator, colors
|
7 |
+
from utils.torch_utils import select_device, smart_inference_mode
|
8 |
+
from utils.general import check_img_size, Profile, non_max_suppression, scale_boxes
|
9 |
+
|
10 |
+
weights = "runs/train/best_striped.pt"
|
11 |
+
data = "data.yaml"
|
12 |
+
|
13 |
+
def resize_image_pil(image, new_width, new_height):
|
14 |
+
|
15 |
+
# Convert to PIL image
|
16 |
+
img = Image.fromarray(np.array(image))
|
17 |
+
|
18 |
+
# Get original size
|
19 |
+
width, height = img.size
|
20 |
+
|
21 |
+
# Calculate scale
|
22 |
+
width_scale = new_width / width
|
23 |
+
height_scale = new_height / height
|
24 |
+
scale = min(width_scale, height_scale)
|
25 |
+
|
26 |
+
# Resize
|
27 |
+
resized = img.resize((int(width*scale), int(height*scale)), Image.NEAREST)
|
28 |
+
|
29 |
+
# Crop to exact size
|
30 |
+
resized = resized.crop((0, 0, new_width, new_height))
|
31 |
+
|
32 |
+
return resized
|
33 |
+
|
34 |
+
def inference(input_img, conf_thres, iou_thres):
|
35 |
+
im0 = input_img.copy()
|
36 |
+
# Load model
|
37 |
+
device = select_device(device)
|
38 |
+
model = DetectMultiBackend(weights, device=device, dnn=False, data=data, fp16=False)
|
39 |
+
stride, names, pt = model.stride, model.names, model.pt
|
40 |
+
imgsz = check_img_size(imgsz, s=stride) # check image size
|
41 |
+
|
42 |
+
bs = 1
|
43 |
+
# Run inference
|
44 |
+
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz))
|
45 |
+
seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
|
46 |
+
|
47 |
+
with dt[0]:
|
48 |
+
im = torch.from_numpy(input_img).to(model.device)
|
49 |
+
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
|
50 |
+
im /= 255 # 0 - 255 to 0.0 - 1.0
|
51 |
+
if len(im.shape) == 3:
|
52 |
+
im = im[None] # expand for batch dim
|
53 |
+
|
54 |
+
# Inference
|
55 |
+
with dt[1]:
|
56 |
+
pred = model(im, augment=False, visualize=False)
|
57 |
+
|
58 |
+
# NMS
|
59 |
+
with dt[2]:
|
60 |
+
pred = non_max_suppression(pred, conf_thres, iou_thres, None, False, max_det=10)
|
61 |
+
|
62 |
+
# Process predictions
|
63 |
+
for i, det in enumerate(pred): # per image
|
64 |
+
seen += 1
|
65 |
+
annotator = Annotator(im0, line_width=2, example=str(model.names))
|
66 |
+
if len(det):
|
67 |
+
# Rescale boxes from img_size to im0 size
|
68 |
+
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
|
69 |
+
|
70 |
+
# Write results
|
71 |
+
for *xyxy, conf, cls in reversed(det):
|
72 |
+
c = int(cls) # integer class
|
73 |
+
label = '{names[c]} {conf:.2f}'
|
74 |
+
annotator.box_label(xyxy, label, color=colors(c, True))
|
75 |
+
|
76 |
+
return im0
|
77 |
+
|
78 |
+
title = "YOLOv9 model to detect shirt/tshirt"
|
79 |
+
description = "A simple Gradio interface to infer on YOLOv9 model and detect tshirt in image"
|
80 |
+
examples = [["image_1.jpg", 0.25, 0.45], ["image_2.jpg", 0.25, 0.45],
|
81 |
+
["image_3.jpg", 0.25, 0.45], ["image_4.jpg", 0.25, 0.45],
|
82 |
+
["image_5.jpg", 0.25, 0.45], ["image_6.jpg", 0.25, 0.45],
|
83 |
+
["image_7.jpg", 0.25, 0.45], ["image_8.jpg", 0.25, 0.45],
|
84 |
+
["image_9.jpg", 0.25, 0.45], ["image_10.jpg", 0.25, 0.45]]
|
85 |
+
|
86 |
+
demo = gr.Interface(inference,
|
87 |
+
inputs = [gr.Image(width=320, height=320, label="Input Image"),
|
88 |
+
gr.Slider(0, 1, 0.25, label="Confidance Thresold"),
|
89 |
+
gr.Slider(0, 1, 0.45, label="IoU Thresold")],
|
90 |
+
outputs= [gr.Image(width=640, height=640, label="Output")],
|
91 |
+
title=title,
|
92 |
+
description=description,
|
93 |
+
examples=examples)
|
data.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train: data/customdata/train/images
|
2 |
+
val: data/customdata/valid/images
|
3 |
+
test: data/customdata/test/images
|
4 |
+
|
5 |
+
nc: 1
|
6 |
+
names: ['shirt']
|
image_1.jpg
ADDED
![]() |
image_10.jpg
ADDED
![]() |
image_2.jpg
ADDED
![]() |
image_3.jpg
ADDED
![]() |
image_4.jpg
ADDED
![]() |
image_5.jpg
ADDED
![]() |
image_6.jpg
ADDED
![]() |
image_7.jpg
ADDED
![]() |
image_8.jpg
ADDED
![]() |
image_9.jpg
ADDED
![]() |
models/common.py
ADDED
@@ -0,0 +1,1212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import contextlib
|
3 |
+
import json
|
4 |
+
import math
|
5 |
+
import platform
|
6 |
+
import warnings
|
7 |
+
import zipfile
|
8 |
+
from collections import OrderedDict, namedtuple
|
9 |
+
from copy import copy
|
10 |
+
from pathlib import Path
|
11 |
+
from urllib.parse import urlparse
|
12 |
+
|
13 |
+
from typing import Optional
|
14 |
+
|
15 |
+
import cv2
|
16 |
+
import numpy as np
|
17 |
+
import pandas as pd
|
18 |
+
import requests
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
from IPython.display import display
|
22 |
+
from PIL import Image
|
23 |
+
from torch.cuda import amp
|
24 |
+
|
25 |
+
from utils import TryExcept
|
26 |
+
from utils.dataloaders import exif_transpose, letterbox
|
27 |
+
from utils.general import (LOGGER, ROOT, Profile, check_requirements, check_suffix, check_version, colorstr,
|
28 |
+
increment_path, is_notebook, make_divisible, non_max_suppression, scale_boxes,
|
29 |
+
xywh2xyxy, xyxy2xywh, yaml_load)
|
30 |
+
from utils.plots import Annotator, colors, save_one_box
|
31 |
+
from utils.torch_utils import copy_attr, smart_inference_mode
|
32 |
+
|
33 |
+
|
34 |
+
def autopad(k, p=None, d=1): # kernel, padding, dilation
|
35 |
+
# Pad to 'same' shape outputs
|
36 |
+
if d > 1:
|
37 |
+
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
|
38 |
+
if p is None:
|
39 |
+
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
|
40 |
+
return p
|
41 |
+
|
42 |
+
|
43 |
+
class Conv(nn.Module):
|
44 |
+
# Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)
|
45 |
+
default_act = nn.SiLU() # default activation
|
46 |
+
|
47 |
+
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
|
48 |
+
super().__init__()
|
49 |
+
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
|
50 |
+
self.bn = nn.BatchNorm2d(c2)
|
51 |
+
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
return self.act(self.bn(self.conv(x)))
|
55 |
+
|
56 |
+
def forward_fuse(self, x):
|
57 |
+
return self.act(self.conv(x))
|
58 |
+
|
59 |
+
|
60 |
+
class AConv(nn.Module):
|
61 |
+
def __init__(self, c1, c2): # ch_in, ch_out, shortcut, kernels, groups, expand
|
62 |
+
super().__init__()
|
63 |
+
self.cv1 = Conv(c1, c2, 3, 2, 1)
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)
|
67 |
+
return self.cv1(x)
|
68 |
+
|
69 |
+
|
70 |
+
class ADown(nn.Module):
|
71 |
+
def __init__(self, c1, c2): # ch_in, ch_out, shortcut, kernels, groups, expand
|
72 |
+
super().__init__()
|
73 |
+
self.c = c2 // 2
|
74 |
+
self.cv1 = Conv(c1 // 2, self.c, 3, 2, 1)
|
75 |
+
self.cv2 = Conv(c1 // 2, self.c, 1, 1, 0)
|
76 |
+
|
77 |
+
def forward(self, x):
|
78 |
+
x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)
|
79 |
+
x1,x2 = x.chunk(2, 1)
|
80 |
+
x1 = self.cv1(x1)
|
81 |
+
x2 = torch.nn.functional.max_pool2d(x2, 3, 2, 1)
|
82 |
+
x2 = self.cv2(x2)
|
83 |
+
return torch.cat((x1, x2), 1)
|
84 |
+
|
85 |
+
|
86 |
+
class RepConvN(nn.Module):
|
87 |
+
"""RepConv is a basic rep-style block, including training and deploy status
|
88 |
+
This code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
|
89 |
+
"""
|
90 |
+
default_act = nn.SiLU() # default activation
|
91 |
+
|
92 |
+
def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False):
|
93 |
+
super().__init__()
|
94 |
+
assert k == 3 and p == 1
|
95 |
+
self.g = g
|
96 |
+
self.c1 = c1
|
97 |
+
self.c2 = c2
|
98 |
+
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
|
99 |
+
|
100 |
+
self.bn = None
|
101 |
+
self.conv1 = Conv(c1, c2, k, s, p=p, g=g, act=False)
|
102 |
+
self.conv2 = Conv(c1, c2, 1, s, p=(p - k // 2), g=g, act=False)
|
103 |
+
|
104 |
+
def forward_fuse(self, x):
|
105 |
+
"""Forward process"""
|
106 |
+
return self.act(self.conv(x))
|
107 |
+
|
108 |
+
def forward(self, x):
|
109 |
+
"""Forward process"""
|
110 |
+
id_out = 0 if self.bn is None else self.bn(x)
|
111 |
+
return self.act(self.conv1(x) + self.conv2(x) + id_out)
|
112 |
+
|
113 |
+
def get_equivalent_kernel_bias(self):
|
114 |
+
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
|
115 |
+
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
|
116 |
+
kernelid, biasid = self._fuse_bn_tensor(self.bn)
|
117 |
+
return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
|
118 |
+
|
119 |
+
def _avg_to_3x3_tensor(self, avgp):
|
120 |
+
channels = self.c1
|
121 |
+
groups = self.g
|
122 |
+
kernel_size = avgp.kernel_size
|
123 |
+
input_dim = channels // groups
|
124 |
+
k = torch.zeros((channels, input_dim, kernel_size, kernel_size))
|
125 |
+
k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1.0 / kernel_size ** 2
|
126 |
+
return k
|
127 |
+
|
128 |
+
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
|
129 |
+
if kernel1x1 is None:
|
130 |
+
return 0
|
131 |
+
else:
|
132 |
+
return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
|
133 |
+
|
134 |
+
def _fuse_bn_tensor(self, branch):
|
135 |
+
if branch is None:
|
136 |
+
return 0, 0
|
137 |
+
if isinstance(branch, Conv):
|
138 |
+
kernel = branch.conv.weight
|
139 |
+
running_mean = branch.bn.running_mean
|
140 |
+
running_var = branch.bn.running_var
|
141 |
+
gamma = branch.bn.weight
|
142 |
+
beta = branch.bn.bias
|
143 |
+
eps = branch.bn.eps
|
144 |
+
elif isinstance(branch, nn.BatchNorm2d):
|
145 |
+
if not hasattr(self, 'id_tensor'):
|
146 |
+
input_dim = self.c1 // self.g
|
147 |
+
kernel_value = np.zeros((self.c1, input_dim, 3, 3), dtype=np.float32)
|
148 |
+
for i in range(self.c1):
|
149 |
+
kernel_value[i, i % input_dim, 1, 1] = 1
|
150 |
+
self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
|
151 |
+
kernel = self.id_tensor
|
152 |
+
running_mean = branch.running_mean
|
153 |
+
running_var = branch.running_var
|
154 |
+
gamma = branch.weight
|
155 |
+
beta = branch.bias
|
156 |
+
eps = branch.eps
|
157 |
+
std = (running_var + eps).sqrt()
|
158 |
+
t = (gamma / std).reshape(-1, 1, 1, 1)
|
159 |
+
return kernel * t, beta - running_mean * gamma / std
|
160 |
+
|
161 |
+
def fuse_convs(self):
|
162 |
+
if hasattr(self, 'conv'):
|
163 |
+
return
|
164 |
+
kernel, bias = self.get_equivalent_kernel_bias()
|
165 |
+
self.conv = nn.Conv2d(in_channels=self.conv1.conv.in_channels,
|
166 |
+
out_channels=self.conv1.conv.out_channels,
|
167 |
+
kernel_size=self.conv1.conv.kernel_size,
|
168 |
+
stride=self.conv1.conv.stride,
|
169 |
+
padding=self.conv1.conv.padding,
|
170 |
+
dilation=self.conv1.conv.dilation,
|
171 |
+
groups=self.conv1.conv.groups,
|
172 |
+
bias=True).requires_grad_(False)
|
173 |
+
self.conv.weight.data = kernel
|
174 |
+
self.conv.bias.data = bias
|
175 |
+
for para in self.parameters():
|
176 |
+
para.detach_()
|
177 |
+
self.__delattr__('conv1')
|
178 |
+
self.__delattr__('conv2')
|
179 |
+
if hasattr(self, 'nm'):
|
180 |
+
self.__delattr__('nm')
|
181 |
+
if hasattr(self, 'bn'):
|
182 |
+
self.__delattr__('bn')
|
183 |
+
if hasattr(self, 'id_tensor'):
|
184 |
+
self.__delattr__('id_tensor')
|
185 |
+
|
186 |
+
|
187 |
+
class SP(nn.Module):
|
188 |
+
def __init__(self, k=3, s=1):
|
189 |
+
super(SP, self).__init__()
|
190 |
+
self.m = nn.MaxPool2d(kernel_size=k, stride=s, padding=k // 2)
|
191 |
+
|
192 |
+
def forward(self, x):
|
193 |
+
return self.m(x)
|
194 |
+
|
195 |
+
|
196 |
+
class MP(nn.Module):
|
197 |
+
# Max pooling
|
198 |
+
def __init__(self, k=2):
|
199 |
+
super(MP, self).__init__()
|
200 |
+
self.m = nn.MaxPool2d(kernel_size=k, stride=k)
|
201 |
+
|
202 |
+
def forward(self, x):
|
203 |
+
return self.m(x)
|
204 |
+
|
205 |
+
|
206 |
+
class ConvTranspose(nn.Module):
|
207 |
+
# Convolution transpose 2d layer
|
208 |
+
default_act = nn.SiLU() # default activation
|
209 |
+
|
210 |
+
def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):
|
211 |
+
super().__init__()
|
212 |
+
self.conv_transpose = nn.ConvTranspose2d(c1, c2, k, s, p, bias=not bn)
|
213 |
+
self.bn = nn.BatchNorm2d(c2) if bn else nn.Identity()
|
214 |
+
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
|
215 |
+
|
216 |
+
def forward(self, x):
|
217 |
+
return self.act(self.bn(self.conv_transpose(x)))
|
218 |
+
|
219 |
+
|
220 |
+
class DWConv(Conv):
|
221 |
+
# Depth-wise convolution
|
222 |
+
def __init__(self, c1, c2, k=1, s=1, d=1, act=True): # ch_in, ch_out, kernel, stride, dilation, activation
|
223 |
+
super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
|
224 |
+
|
225 |
+
|
226 |
+
class DWConvTranspose2d(nn.ConvTranspose2d):
|
227 |
+
# Depth-wise transpose convolution
|
228 |
+
def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0): # ch_in, ch_out, kernel, stride, padding, padding_out
|
229 |
+
super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2))
|
230 |
+
|
231 |
+
|
232 |
+
class DFL(nn.Module):
|
233 |
+
# DFL module
|
234 |
+
def __init__(self, c1=17):
|
235 |
+
super().__init__()
|
236 |
+
self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
|
237 |
+
self.conv.weight.data[:] = nn.Parameter(torch.arange(c1, dtype=torch.float).view(1, c1, 1, 1)) # / 120.0
|
238 |
+
self.c1 = c1
|
239 |
+
# self.bn = nn.BatchNorm2d(4)
|
240 |
+
|
241 |
+
def forward(self, x):
|
242 |
+
b, c, a = x.shape # batch, channels, anchors
|
243 |
+
return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
|
244 |
+
# return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)
|
245 |
+
|
246 |
+
|
247 |
+
class BottleneckBase(nn.Module):
|
248 |
+
# Standard bottleneck
|
249 |
+
def __init__(self, c1, c2, shortcut=True, g=1, k=(1, 3), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
|
250 |
+
super().__init__()
|
251 |
+
c_ = int(c2 * e) # hidden channels
|
252 |
+
self.cv1 = Conv(c1, c_, k[0], 1)
|
253 |
+
self.cv2 = Conv(c_, c2, k[1], 1, g=g)
|
254 |
+
self.add = shortcut and c1 == c2
|
255 |
+
|
256 |
+
def forward(self, x):
|
257 |
+
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
|
258 |
+
|
259 |
+
|
260 |
+
class RBottleneckBase(nn.Module):
|
261 |
+
# Standard bottleneck
|
262 |
+
def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 1), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
|
263 |
+
super().__init__()
|
264 |
+
c_ = int(c2 * e) # hidden channels
|
265 |
+
self.cv1 = Conv(c1, c_, k[0], 1)
|
266 |
+
self.cv2 = Conv(c_, c2, k[1], 1, g=g)
|
267 |
+
self.add = shortcut and c1 == c2
|
268 |
+
|
269 |
+
def forward(self, x):
|
270 |
+
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
|
271 |
+
|
272 |
+
|
273 |
+
class RepNRBottleneckBase(nn.Module):
|
274 |
+
# Standard bottleneck
|
275 |
+
def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 1), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
|
276 |
+
super().__init__()
|
277 |
+
c_ = int(c2 * e) # hidden channels
|
278 |
+
self.cv1 = RepConvN(c1, c_, k[0], 1)
|
279 |
+
self.cv2 = Conv(c_, c2, k[1], 1, g=g)
|
280 |
+
self.add = shortcut and c1 == c2
|
281 |
+
|
282 |
+
def forward(self, x):
|
283 |
+
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
|
284 |
+
|
285 |
+
|
286 |
+
class Bottleneck(nn.Module):
|
287 |
+
# Standard bottleneck
|
288 |
+
def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
|
289 |
+
super().__init__()
|
290 |
+
c_ = int(c2 * e) # hidden channels
|
291 |
+
self.cv1 = Conv(c1, c_, k[0], 1)
|
292 |
+
self.cv2 = Conv(c_, c2, k[1], 1, g=g)
|
293 |
+
self.add = shortcut and c1 == c2
|
294 |
+
|
295 |
+
def forward(self, x):
|
296 |
+
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
|
297 |
+
|
298 |
+
|
299 |
+
class RepNBottleneck(nn.Module):
|
300 |
+
# Standard bottleneck
|
301 |
+
def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
|
302 |
+
super().__init__()
|
303 |
+
c_ = int(c2 * e) # hidden channels
|
304 |
+
self.cv1 = RepConvN(c1, c_, k[0], 1)
|
305 |
+
self.cv2 = Conv(c_, c2, k[1], 1, g=g)
|
306 |
+
self.add = shortcut and c1 == c2
|
307 |
+
|
308 |
+
def forward(self, x):
|
309 |
+
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
|
310 |
+
|
311 |
+
|
312 |
+
class Res(nn.Module):
|
313 |
+
# ResNet bottleneck
|
314 |
+
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
|
315 |
+
super(Res, self).__init__()
|
316 |
+
c_ = int(c2 * e) # hidden channels
|
317 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
318 |
+
self.cv2 = Conv(c_, c_, 3, 1, g=g)
|
319 |
+
self.cv3 = Conv(c_, c2, 1, 1)
|
320 |
+
self.add = shortcut and c1 == c2
|
321 |
+
|
322 |
+
def forward(self, x):
|
323 |
+
return x + self.cv3(self.cv2(self.cv1(x))) if self.add else self.cv3(self.cv2(self.cv1(x)))
|
324 |
+
|
325 |
+
|
326 |
+
class RepNRes(nn.Module):
|
327 |
+
# ResNet bottleneck
|
328 |
+
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
|
329 |
+
super(RepNRes, self).__init__()
|
330 |
+
c_ = int(c2 * e) # hidden channels
|
331 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
332 |
+
self.cv2 = RepConvN(c_, c_, 3, 1, g=g)
|
333 |
+
self.cv3 = Conv(c_, c2, 1, 1)
|
334 |
+
self.add = shortcut and c1 == c2
|
335 |
+
|
336 |
+
def forward(self, x):
|
337 |
+
return x + self.cv3(self.cv2(self.cv1(x))) if self.add else self.cv3(self.cv2(self.cv1(x)))
|
338 |
+
|
339 |
+
|
340 |
+
class BottleneckCSP(nn.Module):
|
341 |
+
# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
|
342 |
+
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
343 |
+
super().__init__()
|
344 |
+
c_ = int(c2 * e) # hidden channels
|
345 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
346 |
+
self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
|
347 |
+
self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
|
348 |
+
self.cv4 = Conv(2 * c_, c2, 1, 1)
|
349 |
+
self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
|
350 |
+
self.act = nn.SiLU()
|
351 |
+
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
|
352 |
+
|
353 |
+
def forward(self, x):
|
354 |
+
y1 = self.cv3(self.m(self.cv1(x)))
|
355 |
+
y2 = self.cv2(x)
|
356 |
+
return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1))))
|
357 |
+
|
358 |
+
|
359 |
+
class CSP(nn.Module):
|
360 |
+
# CSP Bottleneck with 3 convolutions
|
361 |
+
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
362 |
+
super().__init__()
|
363 |
+
c_ = int(c2 * e) # hidden channels
|
364 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
365 |
+
self.cv2 = Conv(c1, c_, 1, 1)
|
366 |
+
self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
|
367 |
+
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
|
368 |
+
|
369 |
+
def forward(self, x):
|
370 |
+
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
|
371 |
+
|
372 |
+
|
373 |
+
class RepNCSP(nn.Module):
|
374 |
+
# CSP Bottleneck with 3 convolutions
|
375 |
+
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
376 |
+
super().__init__()
|
377 |
+
c_ = int(c2 * e) # hidden channels
|
378 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
379 |
+
self.cv2 = Conv(c1, c_, 1, 1)
|
380 |
+
self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
|
381 |
+
self.m = nn.Sequential(*(RepNBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
|
382 |
+
|
383 |
+
def forward(self, x):
|
384 |
+
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
|
385 |
+
|
386 |
+
|
387 |
+
class CSPBase(nn.Module):
|
388 |
+
# CSP Bottleneck with 3 convolutions
|
389 |
+
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
390 |
+
super().__init__()
|
391 |
+
c_ = int(c2 * e) # hidden channels
|
392 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
393 |
+
self.cv2 = Conv(c1, c_, 1, 1)
|
394 |
+
self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
|
395 |
+
self.m = nn.Sequential(*(BottleneckBase(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
|
396 |
+
|
397 |
+
def forward(self, x):
|
398 |
+
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
|
399 |
+
|
400 |
+
|
401 |
+
class SPP(nn.Module):
|
402 |
+
# Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
|
403 |
+
def __init__(self, c1, c2, k=(5, 9, 13)):
|
404 |
+
super().__init__()
|
405 |
+
c_ = c1 // 2 # hidden channels
|
406 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
407 |
+
self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
|
408 |
+
self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
|
409 |
+
|
410 |
+
def forward(self, x):
|
411 |
+
x = self.cv1(x)
|
412 |
+
with warnings.catch_warnings():
|
413 |
+
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
|
414 |
+
return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
|
415 |
+
|
416 |
+
|
417 |
+
class ASPP(torch.nn.Module):
|
418 |
+
|
419 |
+
def __init__(self, in_channels, out_channels):
|
420 |
+
super().__init__()
|
421 |
+
kernel_sizes = [1, 3, 3, 1]
|
422 |
+
dilations = [1, 3, 6, 1]
|
423 |
+
paddings = [0, 3, 6, 0]
|
424 |
+
self.aspp = torch.nn.ModuleList()
|
425 |
+
for aspp_idx in range(len(kernel_sizes)):
|
426 |
+
conv = torch.nn.Conv2d(
|
427 |
+
in_channels,
|
428 |
+
out_channels,
|
429 |
+
kernel_size=kernel_sizes[aspp_idx],
|
430 |
+
stride=1,
|
431 |
+
dilation=dilations[aspp_idx],
|
432 |
+
padding=paddings[aspp_idx],
|
433 |
+
bias=True)
|
434 |
+
self.aspp.append(conv)
|
435 |
+
self.gap = torch.nn.AdaptiveAvgPool2d(1)
|
436 |
+
self.aspp_num = len(kernel_sizes)
|
437 |
+
for m in self.modules():
|
438 |
+
if isinstance(m, torch.nn.Conv2d):
|
439 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
440 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
441 |
+
m.bias.data.fill_(0)
|
442 |
+
|
443 |
+
def forward(self, x):
|
444 |
+
avg_x = self.gap(x)
|
445 |
+
out = []
|
446 |
+
for aspp_idx in range(self.aspp_num):
|
447 |
+
inp = avg_x if (aspp_idx == self.aspp_num - 1) else x
|
448 |
+
out.append(F.relu_(self.aspp[aspp_idx](inp)))
|
449 |
+
out[-1] = out[-1].expand_as(out[-2])
|
450 |
+
out = torch.cat(out, dim=1)
|
451 |
+
return out
|
452 |
+
|
453 |
+
|
454 |
+
class SPPCSPC(nn.Module):
|
455 |
+
# CSP SPP https://github.com/WongKinYiu/CrossStagePartialNetworks
|
456 |
+
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, k=(5, 9, 13)):
|
457 |
+
super(SPPCSPC, self).__init__()
|
458 |
+
c_ = int(2 * c2 * e) # hidden channels
|
459 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
460 |
+
self.cv2 = Conv(c1, c_, 1, 1)
|
461 |
+
self.cv3 = Conv(c_, c_, 3, 1)
|
462 |
+
self.cv4 = Conv(c_, c_, 1, 1)
|
463 |
+
self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
|
464 |
+
self.cv5 = Conv(4 * c_, c_, 1, 1)
|
465 |
+
self.cv6 = Conv(c_, c_, 3, 1)
|
466 |
+
self.cv7 = Conv(2 * c_, c2, 1, 1)
|
467 |
+
|
468 |
+
def forward(self, x):
|
469 |
+
x1 = self.cv4(self.cv3(self.cv1(x)))
|
470 |
+
y1 = self.cv6(self.cv5(torch.cat([x1] + [m(x1) for m in self.m], 1)))
|
471 |
+
y2 = self.cv2(x)
|
472 |
+
return self.cv7(torch.cat((y1, y2), dim=1))
|
473 |
+
|
474 |
+
|
475 |
+
class SPPF(nn.Module):
|
476 |
+
# Spatial Pyramid Pooling - Fast (SPPF) layer by Glenn Jocher
|
477 |
+
def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
|
478 |
+
super().__init__()
|
479 |
+
c_ = c1 // 2 # hidden channels
|
480 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
481 |
+
self.cv2 = Conv(c_ * 4, c2, 1, 1)
|
482 |
+
self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
|
483 |
+
# self.m = SoftPool2d(kernel_size=k, stride=1, padding=k // 2)
|
484 |
+
|
485 |
+
def forward(self, x):
|
486 |
+
x = self.cv1(x)
|
487 |
+
with warnings.catch_warnings():
|
488 |
+
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
|
489 |
+
y1 = self.m(x)
|
490 |
+
y2 = self.m(y1)
|
491 |
+
return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
|
492 |
+
|
493 |
+
|
494 |
+
import torch.nn.functional as F
|
495 |
+
from torch.nn.modules.utils import _pair
|
496 |
+
|
497 |
+
|
498 |
+
class ReOrg(nn.Module):
|
499 |
+
# yolo
|
500 |
+
def __init__(self):
|
501 |
+
super(ReOrg, self).__init__()
|
502 |
+
|
503 |
+
def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
|
504 |
+
return torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)
|
505 |
+
|
506 |
+
|
507 |
+
class Contract(nn.Module):
|
508 |
+
# Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
|
509 |
+
def __init__(self, gain=2):
|
510 |
+
super().__init__()
|
511 |
+
self.gain = gain
|
512 |
+
|
513 |
+
def forward(self, x):
|
514 |
+
b, c, h, w = x.size() # assert (h / s == 0) and (W / s == 0), 'Indivisible gain'
|
515 |
+
s = self.gain
|
516 |
+
x = x.view(b, c, h // s, s, w // s, s) # x(1,64,40,2,40,2)
|
517 |
+
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
|
518 |
+
return x.view(b, c * s * s, h // s, w // s) # x(1,256,40,40)
|
519 |
+
|
520 |
+
|
521 |
+
class Expand(nn.Module):
|
522 |
+
# Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
|
523 |
+
def __init__(self, gain=2):
|
524 |
+
super().__init__()
|
525 |
+
self.gain = gain
|
526 |
+
|
527 |
+
def forward(self, x):
|
528 |
+
b, c, h, w = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
|
529 |
+
s = self.gain
|
530 |
+
x = x.view(b, s, s, c // s ** 2, h, w) # x(1,2,2,16,80,80)
|
531 |
+
x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
|
532 |
+
return x.view(b, c // s ** 2, h * s, w * s) # x(1,16,160,160)
|
533 |
+
|
534 |
+
|
535 |
+
class Concat(nn.Module):
|
536 |
+
# Concatenate a list of tensors along dimension
|
537 |
+
def __init__(self, dimension=1):
|
538 |
+
super().__init__()
|
539 |
+
self.d = dimension
|
540 |
+
|
541 |
+
def forward(self, x):
|
542 |
+
return torch.cat(x, self.d)
|
543 |
+
|
544 |
+
|
545 |
+
class Shortcut(nn.Module):
|
546 |
+
def __init__(self, dimension=0):
|
547 |
+
super(Shortcut, self).__init__()
|
548 |
+
self.d = dimension
|
549 |
+
|
550 |
+
def forward(self, x):
|
551 |
+
return x[0]+x[1]
|
552 |
+
|
553 |
+
|
554 |
+
class Silence(nn.Module):
|
555 |
+
def __init__(self):
|
556 |
+
super(Silence, self).__init__()
|
557 |
+
def forward(self, x):
|
558 |
+
return x
|
559 |
+
|
560 |
+
|
561 |
+
##### GELAN #####
|
562 |
+
|
563 |
+
class SPPELAN(nn.Module):
|
564 |
+
# spp-elan
|
565 |
+
def __init__(self, c1, c2, c3): # ch_in, ch_out, number, shortcut, groups, expansion
|
566 |
+
super().__init__()
|
567 |
+
self.c = c3
|
568 |
+
self.cv1 = Conv(c1, c3, 1, 1)
|
569 |
+
self.cv2 = SP(5)
|
570 |
+
self.cv3 = SP(5)
|
571 |
+
self.cv4 = SP(5)
|
572 |
+
self.cv5 = Conv(4*c3, c2, 1, 1)
|
573 |
+
|
574 |
+
def forward(self, x):
|
575 |
+
y = [self.cv1(x)]
|
576 |
+
y.extend(m(y[-1]) for m in [self.cv2, self.cv3, self.cv4])
|
577 |
+
return self.cv5(torch.cat(y, 1))
|
578 |
+
|
579 |
+
|
580 |
+
class RepNCSPELAN4(nn.Module):
|
581 |
+
# csp-elan
|
582 |
+
def __init__(self, c1, c2, c3, c4, c5=1): # ch_in, ch_out, number, shortcut, groups, expansion
|
583 |
+
super().__init__()
|
584 |
+
self.c = c3//2
|
585 |
+
self.cv1 = Conv(c1, c3, 1, 1)
|
586 |
+
self.cv2 = nn.Sequential(RepNCSP(c3//2, c4, c5), Conv(c4, c4, 3, 1))
|
587 |
+
self.cv3 = nn.Sequential(RepNCSP(c4, c4, c5), Conv(c4, c4, 3, 1))
|
588 |
+
self.cv4 = Conv(c3+(2*c4), c2, 1, 1)
|
589 |
+
|
590 |
+
def forward(self, x):
|
591 |
+
y = list(self.cv1(x).chunk(2, 1))
|
592 |
+
y.extend((m(y[-1])) for m in [self.cv2, self.cv3])
|
593 |
+
return self.cv4(torch.cat(y, 1))
|
594 |
+
|
595 |
+
def forward_split(self, x):
|
596 |
+
y = list(self.cv1(x).split((self.c, self.c), 1))
|
597 |
+
y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
|
598 |
+
return self.cv4(torch.cat(y, 1))
|
599 |
+
|
600 |
+
#################
|
601 |
+
|
602 |
+
|
603 |
+
##### YOLOR #####
|
604 |
+
|
605 |
+
class ImplicitA(nn.Module):
|
606 |
+
def __init__(self, channel):
|
607 |
+
super(ImplicitA, self).__init__()
|
608 |
+
self.channel = channel
|
609 |
+
self.implicit = nn.Parameter(torch.zeros(1, channel, 1, 1))
|
610 |
+
nn.init.normal_(self.implicit, std=.02)
|
611 |
+
|
612 |
+
def forward(self, x):
|
613 |
+
return self.implicit + x
|
614 |
+
|
615 |
+
|
616 |
+
class ImplicitM(nn.Module):
|
617 |
+
def __init__(self, channel):
|
618 |
+
super(ImplicitM, self).__init__()
|
619 |
+
self.channel = channel
|
620 |
+
self.implicit = nn.Parameter(torch.ones(1, channel, 1, 1))
|
621 |
+
nn.init.normal_(self.implicit, mean=1., std=.02)
|
622 |
+
|
623 |
+
def forward(self, x):
|
624 |
+
return self.implicit * x
|
625 |
+
|
626 |
+
#################
|
627 |
+
|
628 |
+
|
629 |
+
##### CBNet #####
|
630 |
+
|
631 |
+
class CBLinear(nn.Module):
|
632 |
+
def __init__(self, c1, c2s, k=1, s=1, p=None, g=1): # ch_in, ch_outs, kernel, stride, padding, groups
|
633 |
+
super(CBLinear, self).__init__()
|
634 |
+
self.c2s = c2s
|
635 |
+
self.conv = nn.Conv2d(c1, sum(c2s), k, s, autopad(k, p), groups=g, bias=True)
|
636 |
+
|
637 |
+
def forward(self, x):
|
638 |
+
outs = self.conv(x).split(self.c2s, dim=1)
|
639 |
+
return outs
|
640 |
+
|
641 |
+
class CBFuse(nn.Module):
|
642 |
+
def __init__(self, idx):
|
643 |
+
super(CBFuse, self).__init__()
|
644 |
+
self.idx = idx
|
645 |
+
|
646 |
+
def forward(self, xs):
|
647 |
+
target_size = xs[-1].shape[2:]
|
648 |
+
res = [F.interpolate(x[self.idx[i]], size=target_size, mode='nearest') for i, x in enumerate(xs[:-1])]
|
649 |
+
out = torch.sum(torch.stack(res + xs[-1:]), dim=0)
|
650 |
+
return out
|
651 |
+
|
652 |
+
#################
|
653 |
+
|
654 |
+
|
655 |
+
class DetectMultiBackend(nn.Module):
|
656 |
+
# YOLO MultiBackend class for python inference on various backends
|
657 |
+
def __init__(self, weights='yolo.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True):
|
658 |
+
# Usage:
|
659 |
+
# PyTorch: weights = *.pt
|
660 |
+
# TorchScript: *.torchscript
|
661 |
+
# ONNX Runtime: *.onnx
|
662 |
+
# ONNX OpenCV DNN: *.onnx --dnn
|
663 |
+
# OpenVINO: *_openvino_model
|
664 |
+
# CoreML: *.mlmodel
|
665 |
+
# TensorRT: *.engine
|
666 |
+
# TensorFlow SavedModel: *_saved_model
|
667 |
+
# TensorFlow GraphDef: *.pb
|
668 |
+
# TensorFlow Lite: *.tflite
|
669 |
+
# TensorFlow Edge TPU: *_edgetpu.tflite
|
670 |
+
# PaddlePaddle: *_paddle_model
|
671 |
+
from models.experimental import attempt_download, attempt_load # scoped to avoid circular import
|
672 |
+
|
673 |
+
super().__init__()
|
674 |
+
w = str(weights[0] if isinstance(weights, list) else weights)
|
675 |
+
pt, jit, onnx, onnx_end2end, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, triton = self._model_type(w)
|
676 |
+
fp16 &= pt or jit or onnx or engine # FP16
|
677 |
+
nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
|
678 |
+
stride = 32 # default stride
|
679 |
+
cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
|
680 |
+
if not (pt or triton):
|
681 |
+
w = attempt_download(w) # download if not local
|
682 |
+
|
683 |
+
if pt: # PyTorch
|
684 |
+
model = attempt_load(weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse)
|
685 |
+
stride = max(int(model.stride.max()), 32) # model stride
|
686 |
+
names = model.module.names if hasattr(model, 'module') else model.names # get class names
|
687 |
+
model.half() if fp16 else model.float()
|
688 |
+
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
|
689 |
+
elif jit: # TorchScript
|
690 |
+
LOGGER.info(f'Loading {w} for TorchScript inference...')
|
691 |
+
extra_files = {'config.txt': ''} # model metadata
|
692 |
+
model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
|
693 |
+
model.half() if fp16 else model.float()
|
694 |
+
if extra_files['config.txt']: # load metadata dict
|
695 |
+
d = json.loads(extra_files['config.txt'],
|
696 |
+
object_hook=lambda d: {int(k) if k.isdigit() else k: v
|
697 |
+
for k, v in d.items()})
|
698 |
+
stride, names = int(d['stride']), d['names']
|
699 |
+
elif dnn: # ONNX OpenCV DNN
|
700 |
+
LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
|
701 |
+
check_requirements('opencv-python>=4.5.4')
|
702 |
+
net = cv2.dnn.readNetFromONNX(w)
|
703 |
+
elif onnx: # ONNX Runtime
|
704 |
+
LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
|
705 |
+
check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
|
706 |
+
import onnxruntime
|
707 |
+
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
|
708 |
+
session = onnxruntime.InferenceSession(w, providers=providers)
|
709 |
+
output_names = [x.name for x in session.get_outputs()]
|
710 |
+
meta = session.get_modelmeta().custom_metadata_map # metadata
|
711 |
+
if 'stride' in meta:
|
712 |
+
stride, names = int(meta['stride']), eval(meta['names'])
|
713 |
+
elif xml: # OpenVINO
|
714 |
+
LOGGER.info(f'Loading {w} for OpenVINO inference...')
|
715 |
+
check_requirements('openvino') # requires openvino-dev: https://pypi.org/project/openvino-dev/
|
716 |
+
from openvino.runtime import Core, Layout, get_batch
|
717 |
+
ie = Core()
|
718 |
+
if not Path(w).is_file(): # if not *.xml
|
719 |
+
w = next(Path(w).glob('*.xml')) # get *.xml file from *_openvino_model dir
|
720 |
+
network = ie.read_model(model=w, weights=Path(w).with_suffix('.bin'))
|
721 |
+
if network.get_parameters()[0].get_layout().empty:
|
722 |
+
network.get_parameters()[0].set_layout(Layout("NCHW"))
|
723 |
+
batch_dim = get_batch(network)
|
724 |
+
if batch_dim.is_static:
|
725 |
+
batch_size = batch_dim.get_length()
|
726 |
+
executable_network = ie.compile_model(network, device_name="CPU") # device_name="MYRIAD" for Intel NCS2
|
727 |
+
stride, names = self._load_metadata(Path(w).with_suffix('.yaml')) # load metadata
|
728 |
+
elif engine: # TensorRT
|
729 |
+
LOGGER.info(f'Loading {w} for TensorRT inference...')
|
730 |
+
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
|
731 |
+
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
|
732 |
+
if device.type == 'cpu':
|
733 |
+
device = torch.device('cuda:0')
|
734 |
+
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
|
735 |
+
logger = trt.Logger(trt.Logger.INFO)
|
736 |
+
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
|
737 |
+
model = runtime.deserialize_cuda_engine(f.read())
|
738 |
+
context = model.create_execution_context()
|
739 |
+
bindings = OrderedDict()
|
740 |
+
output_names = []
|
741 |
+
fp16 = False # default updated below
|
742 |
+
dynamic = False
|
743 |
+
for i in range(model.num_bindings):
|
744 |
+
name = model.get_binding_name(i)
|
745 |
+
dtype = trt.nptype(model.get_binding_dtype(i))
|
746 |
+
if model.binding_is_input(i):
|
747 |
+
if -1 in tuple(model.get_binding_shape(i)): # dynamic
|
748 |
+
dynamic = True
|
749 |
+
context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2]))
|
750 |
+
if dtype == np.float16:
|
751 |
+
fp16 = True
|
752 |
+
else: # output
|
753 |
+
output_names.append(name)
|
754 |
+
shape = tuple(context.get_binding_shape(i))
|
755 |
+
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
|
756 |
+
bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
|
757 |
+
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
|
758 |
+
batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size
|
759 |
+
elif coreml: # CoreML
|
760 |
+
LOGGER.info(f'Loading {w} for CoreML inference...')
|
761 |
+
import coremltools as ct
|
762 |
+
model = ct.models.MLModel(w)
|
763 |
+
elif saved_model: # TF SavedModel
|
764 |
+
LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
|
765 |
+
import tensorflow as tf
|
766 |
+
keras = False # assume TF1 saved_model
|
767 |
+
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
|
768 |
+
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
|
769 |
+
LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
|
770 |
+
import tensorflow as tf
|
771 |
+
|
772 |
+
def wrap_frozen_graph(gd, inputs, outputs):
|
773 |
+
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
|
774 |
+
ge = x.graph.as_graph_element
|
775 |
+
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
|
776 |
+
|
777 |
+
def gd_outputs(gd):
|
778 |
+
name_list, input_list = [], []
|
779 |
+
for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
|
780 |
+
name_list.append(node.name)
|
781 |
+
input_list.extend(node.input)
|
782 |
+
return sorted(f'{x}:0' for x in list(set(name_list) - set(input_list)) if not x.startswith('NoOp'))
|
783 |
+
|
784 |
+
gd = tf.Graph().as_graph_def() # TF GraphDef
|
785 |
+
with open(w, 'rb') as f:
|
786 |
+
gd.ParseFromString(f.read())
|
787 |
+
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
|
788 |
+
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
|
789 |
+
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
|
790 |
+
from tflite_runtime.interpreter import Interpreter, load_delegate
|
791 |
+
except ImportError:
|
792 |
+
import tensorflow as tf
|
793 |
+
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
|
794 |
+
if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
|
795 |
+
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
|
796 |
+
delegate = {
|
797 |
+
'Linux': 'libedgetpu.so.1',
|
798 |
+
'Darwin': 'libedgetpu.1.dylib',
|
799 |
+
'Windows': 'edgetpu.dll'}[platform.system()]
|
800 |
+
interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
|
801 |
+
else: # TFLite
|
802 |
+
LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
|
803 |
+
interpreter = Interpreter(model_path=w) # load TFLite model
|
804 |
+
interpreter.allocate_tensors() # allocate
|
805 |
+
input_details = interpreter.get_input_details() # inputs
|
806 |
+
output_details = interpreter.get_output_details() # outputs
|
807 |
+
# load metadata
|
808 |
+
with contextlib.suppress(zipfile.BadZipFile):
|
809 |
+
with zipfile.ZipFile(w, "r") as model:
|
810 |
+
meta_file = model.namelist()[0]
|
811 |
+
meta = ast.literal_eval(model.read(meta_file).decode("utf-8"))
|
812 |
+
stride, names = int(meta['stride']), meta['names']
|
813 |
+
elif tfjs: # TF.js
|
814 |
+
raise NotImplementedError('ERROR: YOLO TF.js inference is not supported')
|
815 |
+
elif paddle: # PaddlePaddle
|
816 |
+
LOGGER.info(f'Loading {w} for PaddlePaddle inference...')
|
817 |
+
check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle')
|
818 |
+
import paddle.inference as pdi
|
819 |
+
if not Path(w).is_file(): # if not *.pdmodel
|
820 |
+
w = next(Path(w).rglob('*.pdmodel')) # get *.pdmodel file from *_paddle_model dir
|
821 |
+
weights = Path(w).with_suffix('.pdiparams')
|
822 |
+
config = pdi.Config(str(w), str(weights))
|
823 |
+
if cuda:
|
824 |
+
config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
|
825 |
+
predictor = pdi.create_predictor(config)
|
826 |
+
input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
|
827 |
+
output_names = predictor.get_output_names()
|
828 |
+
elif triton: # NVIDIA Triton Inference Server
|
829 |
+
LOGGER.info(f'Using {w} as Triton Inference Server...')
|
830 |
+
check_requirements('tritonclient[all]')
|
831 |
+
from utils.triton import TritonRemoteModel
|
832 |
+
model = TritonRemoteModel(url=w)
|
833 |
+
nhwc = model.runtime.startswith("tensorflow")
|
834 |
+
else:
|
835 |
+
raise NotImplementedError(f'ERROR: {w} is not a supported format')
|
836 |
+
|
837 |
+
# class names
|
838 |
+
if 'names' not in locals():
|
839 |
+
names = yaml_load(data)['names'] if data else {i: f'class{i}' for i in range(999)}
|
840 |
+
if names[0] == 'n01440764' and len(names) == 1000: # ImageNet
|
841 |
+
names = yaml_load(ROOT / 'data/ImageNet.yaml')['names'] # human-readable names
|
842 |
+
|
843 |
+
self.__dict__.update(locals()) # assign all variables to self
|
844 |
+
|
845 |
+
def forward(self, im, augment=False, visualize=False):
|
846 |
+
# YOLO MultiBackend inference
|
847 |
+
b, ch, h, w = im.shape # batch, channel, height, width
|
848 |
+
if self.fp16 and im.dtype != torch.float16:
|
849 |
+
im = im.half() # to FP16
|
850 |
+
if self.nhwc:
|
851 |
+
im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
|
852 |
+
|
853 |
+
if self.pt: # PyTorch
|
854 |
+
y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
|
855 |
+
elif self.jit: # TorchScript
|
856 |
+
y = self.model(im)
|
857 |
+
elif self.dnn: # ONNX OpenCV DNN
|
858 |
+
im = im.cpu().numpy() # torch to numpy
|
859 |
+
self.net.setInput(im)
|
860 |
+
y = self.net.forward()
|
861 |
+
elif self.onnx: # ONNX Runtime
|
862 |
+
im = im.cpu().numpy() # torch to numpy
|
863 |
+
y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
|
864 |
+
elif self.xml: # OpenVINO
|
865 |
+
im = im.cpu().numpy() # FP32
|
866 |
+
y = list(self.executable_network([im]).values())
|
867 |
+
elif self.engine: # TensorRT
|
868 |
+
if self.dynamic and im.shape != self.bindings['images'].shape:
|
869 |
+
i = self.model.get_binding_index('images')
|
870 |
+
self.context.set_binding_shape(i, im.shape) # reshape if dynamic
|
871 |
+
self.bindings['images'] = self.bindings['images']._replace(shape=im.shape)
|
872 |
+
for name in self.output_names:
|
873 |
+
i = self.model.get_binding_index(name)
|
874 |
+
self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
|
875 |
+
s = self.bindings['images'].shape
|
876 |
+
assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
|
877 |
+
self.binding_addrs['images'] = int(im.data_ptr())
|
878 |
+
self.context.execute_v2(list(self.binding_addrs.values()))
|
879 |
+
y = [self.bindings[x].data for x in sorted(self.output_names)]
|
880 |
+
elif self.coreml: # CoreML
|
881 |
+
im = im.cpu().numpy()
|
882 |
+
im = Image.fromarray((im[0] * 255).astype('uint8'))
|
883 |
+
# im = im.resize((192, 320), Image.ANTIALIAS)
|
884 |
+
y = self.model.predict({'image': im}) # coordinates are xywh normalized
|
885 |
+
if 'confidence' in y:
|
886 |
+
box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
|
887 |
+
conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
|
888 |
+
y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
|
889 |
+
else:
|
890 |
+
y = list(reversed(y.values())) # reversed for segmentation models (pred, proto)
|
891 |
+
elif self.paddle: # PaddlePaddle
|
892 |
+
im = im.cpu().numpy().astype(np.float32)
|
893 |
+
self.input_handle.copy_from_cpu(im)
|
894 |
+
self.predictor.run()
|
895 |
+
y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
|
896 |
+
elif self.triton: # NVIDIA Triton Inference Server
|
897 |
+
y = self.model(im)
|
898 |
+
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
|
899 |
+
im = im.cpu().numpy()
|
900 |
+
if self.saved_model: # SavedModel
|
901 |
+
y = self.model(im, training=False) if self.keras else self.model(im)
|
902 |
+
elif self.pb: # GraphDef
|
903 |
+
y = self.frozen_func(x=self.tf.constant(im))
|
904 |
+
else: # Lite or Edge TPU
|
905 |
+
input = self.input_details[0]
|
906 |
+
int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model
|
907 |
+
if int8:
|
908 |
+
scale, zero_point = input['quantization']
|
909 |
+
im = (im / scale + zero_point).astype(np.uint8) # de-scale
|
910 |
+
self.interpreter.set_tensor(input['index'], im)
|
911 |
+
self.interpreter.invoke()
|
912 |
+
y = []
|
913 |
+
for output in self.output_details:
|
914 |
+
x = self.interpreter.get_tensor(output['index'])
|
915 |
+
if int8:
|
916 |
+
scale, zero_point = output['quantization']
|
917 |
+
x = (x.astype(np.float32) - zero_point) * scale # re-scale
|
918 |
+
y.append(x)
|
919 |
+
y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
|
920 |
+
y[0][..., :4] *= [w, h, w, h] # xywh normalized to pixels
|
921 |
+
|
922 |
+
if isinstance(y, (list, tuple)):
|
923 |
+
return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
|
924 |
+
else:
|
925 |
+
return self.from_numpy(y)
|
926 |
+
|
927 |
+
def from_numpy(self, x):
|
928 |
+
return torch.from_numpy(x).to(self.device) if isinstance(x, np.ndarray) else x
|
929 |
+
|
930 |
+
def warmup(self, imgsz=(1, 3, 640, 640)):
|
931 |
+
# Warmup model by running inference once
|
932 |
+
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton
|
933 |
+
if any(warmup_types) and (self.device.type != 'cpu' or self.triton):
|
934 |
+
im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
|
935 |
+
for _ in range(2 if self.jit else 1): #
|
936 |
+
self.forward(im) # warmup
|
937 |
+
|
938 |
+
@staticmethod
|
939 |
+
def _model_type(p='path/to/model.pt'):
|
940 |
+
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
|
941 |
+
# types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
|
942 |
+
from export import export_formats
|
943 |
+
from utils.downloads import is_url
|
944 |
+
sf = list(export_formats().Suffix) # export suffixes
|
945 |
+
if not is_url(p, check=False):
|
946 |
+
check_suffix(p, sf) # checks
|
947 |
+
url = urlparse(p) # if url may be Triton inference server
|
948 |
+
types = [s in Path(p).name for s in sf]
|
949 |
+
types[8] &= not types[9] # tflite &= not edgetpu
|
950 |
+
triton = not any(types) and all([any(s in url.scheme for s in ["http", "grpc"]), url.netloc])
|
951 |
+
return types + [triton]
|
952 |
+
|
953 |
+
@staticmethod
|
954 |
+
def _load_metadata(f=Path('path/to/meta.yaml')):
|
955 |
+
# Load metadata from meta.yaml if it exists
|
956 |
+
if f.exists():
|
957 |
+
d = yaml_load(f)
|
958 |
+
return d['stride'], d['names'] # assign stride, names
|
959 |
+
return None, None
|
960 |
+
|
961 |
+
|
962 |
+
class AutoShape(nn.Module):
|
963 |
+
# YOLO input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
|
964 |
+
conf = 0.25 # NMS confidence threshold
|
965 |
+
iou = 0.45 # NMS IoU threshold
|
966 |
+
agnostic = False # NMS class-agnostic
|
967 |
+
multi_label = False # NMS multiple labels per box
|
968 |
+
classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
|
969 |
+
max_det = 1000 # maximum number of detections per image
|
970 |
+
amp = False # Automatic Mixed Precision (AMP) inference
|
971 |
+
|
972 |
+
def __init__(self, model, verbose=True):
|
973 |
+
super().__init__()
|
974 |
+
if verbose:
|
975 |
+
LOGGER.info('Adding AutoShape... ')
|
976 |
+
copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
|
977 |
+
self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance
|
978 |
+
self.pt = not self.dmb or model.pt # PyTorch model
|
979 |
+
self.model = model.eval()
|
980 |
+
if self.pt:
|
981 |
+
m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
|
982 |
+
m.inplace = False # Detect.inplace=False for safe multithread inference
|
983 |
+
m.export = True # do not output loss values
|
984 |
+
|
985 |
+
def _apply(self, fn):
|
986 |
+
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
|
987 |
+
self = super()._apply(fn)
|
988 |
+
from models.yolo import Detect, Segment
|
989 |
+
if self.pt:
|
990 |
+
m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
|
991 |
+
if isinstance(m, (Detect, Segment)):
|
992 |
+
for k in 'stride', 'anchor_grid', 'stride_grid', 'grid':
|
993 |
+
x = getattr(m, k)
|
994 |
+
setattr(m, k, list(map(fn, x))) if isinstance(x, (list, tuple)) else setattr(m, k, fn(x))
|
995 |
+
return self
|
996 |
+
|
997 |
+
@smart_inference_mode()
|
998 |
+
def forward(self, ims, size=640, augment=False, profile=False):
|
999 |
+
# Inference from various sources. For size(height=640, width=1280), RGB images example inputs are:
|
1000 |
+
# file: ims = 'data/images/zidane.jpg' # str or PosixPath
|
1001 |
+
# URI: = 'https://ultralytics.com/images/zidane.jpg'
|
1002 |
+
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
|
1003 |
+
# PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
|
1004 |
+
# numpy: = np.zeros((640,1280,3)) # HWC
|
1005 |
+
# torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
|
1006 |
+
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
|
1007 |
+
|
1008 |
+
dt = (Profile(), Profile(), Profile())
|
1009 |
+
with dt[0]:
|
1010 |
+
if isinstance(size, int): # expand
|
1011 |
+
size = (size, size)
|
1012 |
+
p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param
|
1013 |
+
autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
|
1014 |
+
if isinstance(ims, torch.Tensor): # torch
|
1015 |
+
with amp.autocast(autocast):
|
1016 |
+
return self.model(ims.to(p.device).type_as(p), augment=augment) # inference
|
1017 |
+
|
1018 |
+
# Pre-process
|
1019 |
+
n, ims = (len(ims), list(ims)) if isinstance(ims, (list, tuple)) else (1, [ims]) # number, list of images
|
1020 |
+
shape0, shape1, files = [], [], [] # image and inference shapes, filenames
|
1021 |
+
for i, im in enumerate(ims):
|
1022 |
+
f = f'image{i}' # filename
|
1023 |
+
if isinstance(im, (str, Path)): # filename or uri
|
1024 |
+
im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
|
1025 |
+
im = np.asarray(exif_transpose(im))
|
1026 |
+
elif isinstance(im, Image.Image): # PIL Image
|
1027 |
+
im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
|
1028 |
+
files.append(Path(f).with_suffix('.jpg').name)
|
1029 |
+
if im.shape[0] < 5: # image in CHW
|
1030 |
+
im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
|
1031 |
+
im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input
|
1032 |
+
s = im.shape[:2] # HWC
|
1033 |
+
shape0.append(s) # image shape
|
1034 |
+
g = max(size) / max(s) # gain
|
1035 |
+
shape1.append([int(y * g) for y in s])
|
1036 |
+
ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
|
1037 |
+
shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)] # inf shape
|
1038 |
+
x = [letterbox(im, shape1, auto=False)[0] for im in ims] # pad
|
1039 |
+
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
|
1040 |
+
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
|
1041 |
+
|
1042 |
+
with amp.autocast(autocast):
|
1043 |
+
# Inference
|
1044 |
+
with dt[1]:
|
1045 |
+
y = self.model(x, augment=augment) # forward
|
1046 |
+
|
1047 |
+
# Post-process
|
1048 |
+
with dt[2]:
|
1049 |
+
y = non_max_suppression(y if self.dmb else y[0],
|
1050 |
+
self.conf,
|
1051 |
+
self.iou,
|
1052 |
+
self.classes,
|
1053 |
+
self.agnostic,
|
1054 |
+
self.multi_label,
|
1055 |
+
max_det=self.max_det) # NMS
|
1056 |
+
for i in range(n):
|
1057 |
+
scale_boxes(shape1, y[i][:, :4], shape0[i])
|
1058 |
+
|
1059 |
+
return Detections(ims, y, files, dt, self.names, x.shape)
|
1060 |
+
|
1061 |
+
|
1062 |
+
class Detections:
|
1063 |
+
# YOLO detections class for inference results
|
1064 |
+
def __init__(self, ims, pred, files, times=(0, 0, 0), names=None, shape=None):
|
1065 |
+
super().__init__()
|
1066 |
+
d = pred[0].device # device
|
1067 |
+
gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in ims] # normalizations
|
1068 |
+
self.ims = ims # list of images as numpy arrays
|
1069 |
+
self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
|
1070 |
+
self.names = names # class names
|
1071 |
+
self.files = files # image filenames
|
1072 |
+
self.times = times # profiling times
|
1073 |
+
self.xyxy = pred # xyxy pixels
|
1074 |
+
self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
|
1075 |
+
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
|
1076 |
+
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
|
1077 |
+
self.n = len(self.pred) # number of images (batch size)
|
1078 |
+
self.t = tuple(x.t / self.n * 1E3 for x in times) # timestamps (ms)
|
1079 |
+
self.s = tuple(shape) # inference BCHW shape
|
1080 |
+
|
1081 |
+
def _run(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
|
1082 |
+
s, crops = '', []
|
1083 |
+
for i, (im, pred) in enumerate(zip(self.ims, self.pred)):
|
1084 |
+
s += f'\nimage {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
|
1085 |
+
if pred.shape[0]:
|
1086 |
+
for c in pred[:, -1].unique():
|
1087 |
+
n = (pred[:, -1] == c).sum() # detections per class
|
1088 |
+
s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
|
1089 |
+
s = s.rstrip(', ')
|
1090 |
+
if show or save or render or crop:
|
1091 |
+
annotator = Annotator(im, example=str(self.names))
|
1092 |
+
for *box, conf, cls in reversed(pred): # xyxy, confidence, class
|
1093 |
+
label = f'{self.names[int(cls)]} {conf:.2f}'
|
1094 |
+
if crop:
|
1095 |
+
file = save_dir / 'crops' / self.names[int(cls)] / self.files[i] if save else None
|
1096 |
+
crops.append({
|
1097 |
+
'box': box,
|
1098 |
+
'conf': conf,
|
1099 |
+
'cls': cls,
|
1100 |
+
'label': label,
|
1101 |
+
'im': save_one_box(box, im, file=file, save=save)})
|
1102 |
+
else: # all others
|
1103 |
+
annotator.box_label(box, label if labels else '', color=colors(cls))
|
1104 |
+
im = annotator.im
|
1105 |
+
else:
|
1106 |
+
s += '(no detections)'
|
1107 |
+
|
1108 |
+
im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
|
1109 |
+
if show:
|
1110 |
+
display(im) if is_notebook() else im.show(self.files[i])
|
1111 |
+
if save:
|
1112 |
+
f = self.files[i]
|
1113 |
+
im.save(save_dir / f) # save
|
1114 |
+
if i == self.n - 1:
|
1115 |
+
LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
|
1116 |
+
if render:
|
1117 |
+
self.ims[i] = np.asarray(im)
|
1118 |
+
if pprint:
|
1119 |
+
s = s.lstrip('\n')
|
1120 |
+
return f'{s}\nSpeed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {self.s}' % self.t
|
1121 |
+
if crop:
|
1122 |
+
if save:
|
1123 |
+
LOGGER.info(f'Saved results to {save_dir}\n')
|
1124 |
+
return crops
|
1125 |
+
|
1126 |
+
@TryExcept('Showing images is not supported in this environment')
|
1127 |
+
def show(self, labels=True):
|
1128 |
+
self._run(show=True, labels=labels) # show results
|
1129 |
+
|
1130 |
+
def save(self, labels=True, save_dir='runs/detect/exp', exist_ok=False):
|
1131 |
+
save_dir = increment_path(save_dir, exist_ok, mkdir=True) # increment save_dir
|
1132 |
+
self._run(save=True, labels=labels, save_dir=save_dir) # save results
|
1133 |
+
|
1134 |
+
def crop(self, save=True, save_dir='runs/detect/exp', exist_ok=False):
|
1135 |
+
save_dir = increment_path(save_dir, exist_ok, mkdir=True) if save else None
|
1136 |
+
return self._run(crop=True, save=save, save_dir=save_dir) # crop results
|
1137 |
+
|
1138 |
+
def render(self, labels=True):
|
1139 |
+
self._run(render=True, labels=labels) # render results
|
1140 |
+
return self.ims
|
1141 |
+
|
1142 |
+
def pandas(self):
|
1143 |
+
# return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
|
1144 |
+
new = copy(self) # return copy
|
1145 |
+
ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
|
1146 |
+
cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
|
1147 |
+
for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
|
1148 |
+
a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
|
1149 |
+
setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
|
1150 |
+
return new
|
1151 |
+
|
1152 |
+
def tolist(self):
|
1153 |
+
# return a list of Detections objects, i.e. 'for result in results.tolist():'
|
1154 |
+
r = range(self.n) # iterable
|
1155 |
+
x = [Detections([self.ims[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
|
1156 |
+
# for d in x:
|
1157 |
+
# for k in ['ims', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
|
1158 |
+
# setattr(d, k, getattr(d, k)[0]) # pop out of list
|
1159 |
+
return x
|
1160 |
+
|
1161 |
+
def print(self):
|
1162 |
+
LOGGER.info(self.__str__())
|
1163 |
+
|
1164 |
+
def __len__(self): # override len(results)
|
1165 |
+
return self.n
|
1166 |
+
|
1167 |
+
def __str__(self): # override print(results)
|
1168 |
+
return self._run(pprint=True) # print results
|
1169 |
+
|
1170 |
+
def __repr__(self):
|
1171 |
+
return f'YOLO {self.__class__} instance\n' + self.__str__()
|
1172 |
+
|
1173 |
+
|
1174 |
+
class Proto(nn.Module):
|
1175 |
+
# YOLO mask Proto module for segmentation models
|
1176 |
+
def __init__(self, c1, c_=256, c2=32): # ch_in, number of protos, number of masks
|
1177 |
+
super().__init__()
|
1178 |
+
self.cv1 = Conv(c1, c_, k=3)
|
1179 |
+
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
|
1180 |
+
self.cv2 = Conv(c_, c_, k=3)
|
1181 |
+
self.cv3 = Conv(c_, c2)
|
1182 |
+
|
1183 |
+
def forward(self, x):
|
1184 |
+
return self.cv3(self.cv2(self.upsample(self.cv1(x))))
|
1185 |
+
|
1186 |
+
|
1187 |
+
class UConv(nn.Module):
|
1188 |
+
def __init__(self, c1, c_=256, c2=256): # ch_in, number of protos, number of masks
|
1189 |
+
super().__init__()
|
1190 |
+
|
1191 |
+
self.cv1 = Conv(c1, c_, k=3)
|
1192 |
+
self.cv2 = nn.Conv2d(c_, c2, 1, 1)
|
1193 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
1194 |
+
|
1195 |
+
def forward(self, x):
|
1196 |
+
return self.up(self.cv2(self.cv1(x)))
|
1197 |
+
|
1198 |
+
|
1199 |
+
class Classify(nn.Module):
|
1200 |
+
# YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2)
|
1201 |
+
def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
|
1202 |
+
super().__init__()
|
1203 |
+
c_ = 1280 # efficientnet_b0 size
|
1204 |
+
self.conv = Conv(c1, c_, k, s, autopad(k, p), g)
|
1205 |
+
self.pool = nn.AdaptiveAvgPool2d(1) # to x(b,c_,1,1)
|
1206 |
+
self.drop = nn.Dropout(p=0.0, inplace=True)
|
1207 |
+
self.linear = nn.Linear(c_, c2) # to x(b,c2)
|
1208 |
+
|
1209 |
+
def forward(self, x):
|
1210 |
+
if isinstance(x, list):
|
1211 |
+
x = torch.cat(x, 1)
|
1212 |
+
return self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
|
requirements.txt
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# requirements
|
2 |
+
# Usage: pip install -r requirements.txt
|
3 |
+
|
4 |
+
# Base ------------------------------------------------------------------------
|
5 |
+
gitpython
|
6 |
+
ipython
|
7 |
+
matplotlib>=3.2.2
|
8 |
+
numpy>=1.18.5
|
9 |
+
opencv-python>=4.1.1
|
10 |
+
Pillow>=7.1.2
|
11 |
+
psutil
|
12 |
+
PyYAML>=5.3.1
|
13 |
+
requests>=2.23.0
|
14 |
+
scipy>=1.4.1
|
15 |
+
thop>=0.1.1
|
16 |
+
torch>=1.7.0
|
17 |
+
torchvision>=0.8.1
|
18 |
+
tqdm>=4.64.0
|
19 |
+
# protobuf<=3.20.1
|
20 |
+
|
21 |
+
# Logging ---------------------------------------------------------------------
|
22 |
+
tensorboard>=2.4.1
|
23 |
+
# clearml>=1.2.0
|
24 |
+
# comet
|
25 |
+
|
26 |
+
# Plotting --------------------------------------------------------------------
|
27 |
+
pandas>=1.1.4
|
28 |
+
seaborn>=0.11.0
|
29 |
+
|
30 |
+
# Export ----------------------------------------------------------------------
|
31 |
+
# coremltools>=6.0
|
32 |
+
# onnx>=1.9.0
|
33 |
+
# onnx-simplifier>=0.4.1
|
34 |
+
# nvidia-pyindex
|
35 |
+
# nvidia-tensorrt
|
36 |
+
# scikit-learn<=1.1.2
|
37 |
+
# tensorflow>=2.4.1
|
38 |
+
# tensorflowjs>=3.9.0
|
39 |
+
# openvino-dev
|
40 |
+
|
41 |
+
# Deploy ----------------------------------------------------------------------
|
42 |
+
# tritonclient[all]~=2.24.0
|
43 |
+
|
44 |
+
# Extras ----------------------------------------------------------------------
|
45 |
+
# mss
|
46 |
+
albumentations>=1.0.3
|
47 |
+
pycocotools>=2.0
|
runs/train/best_striped.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0d5aaa20f90d1c2e2a3206ca2392b1e48e2593f305c54610526317c2d7082d99
|
3 |
+
size 51440592
|
utils/general.py
ADDED
@@ -0,0 +1,1135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import glob
|
3 |
+
import inspect
|
4 |
+
import logging
|
5 |
+
import logging.config
|
6 |
+
import math
|
7 |
+
import os
|
8 |
+
import platform
|
9 |
+
import random
|
10 |
+
import re
|
11 |
+
import signal
|
12 |
+
import sys
|
13 |
+
import time
|
14 |
+
import urllib
|
15 |
+
from copy import deepcopy
|
16 |
+
from datetime import datetime
|
17 |
+
from itertools import repeat
|
18 |
+
from multiprocessing.pool import ThreadPool
|
19 |
+
from pathlib import Path
|
20 |
+
from subprocess import check_output
|
21 |
+
from tarfile import is_tarfile
|
22 |
+
from typing import Optional
|
23 |
+
from zipfile import ZipFile, is_zipfile
|
24 |
+
|
25 |
+
import cv2
|
26 |
+
import IPython
|
27 |
+
import numpy as np
|
28 |
+
import pandas as pd
|
29 |
+
import pkg_resources as pkg
|
30 |
+
import torch
|
31 |
+
import torchvision
|
32 |
+
import yaml
|
33 |
+
|
34 |
+
from utils import TryExcept, emojis
|
35 |
+
from utils.downloads import gsutil_getsize
|
36 |
+
from utils.metrics import box_iou, fitness
|
37 |
+
|
38 |
+
FILE = Path(__file__).resolve()
|
39 |
+
ROOT = FILE.parents[1] # YOLO root directory
|
40 |
+
RANK = int(os.getenv('RANK', -1))
|
41 |
+
|
42 |
+
# Settings
|
43 |
+
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
|
44 |
+
DATASETS_DIR = Path(os.getenv('YOLOv5_DATASETS_DIR', ROOT.parent / 'datasets')) # global datasets directory
|
45 |
+
AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
|
46 |
+
VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode
|
47 |
+
TQDM_BAR_FORMAT = '{l_bar}{bar:10}| {n_fmt}/{total_fmt} {elapsed}' # tqdm bar format
|
48 |
+
FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
|
49 |
+
|
50 |
+
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
51 |
+
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
|
52 |
+
pd.options.display.max_columns = 10
|
53 |
+
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
|
54 |
+
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
|
55 |
+
os.environ['OMP_NUM_THREADS'] = '1' if platform.system() == 'darwin' else str(NUM_THREADS) # OpenMP (PyTorch and SciPy)
|
56 |
+
|
57 |
+
|
58 |
+
def is_ascii(s=''):
|
59 |
+
# Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
|
60 |
+
s = str(s) # convert list, tuple, None, etc. to str
|
61 |
+
return len(s.encode().decode('ascii', 'ignore')) == len(s)
|
62 |
+
|
63 |
+
|
64 |
+
def is_chinese(s='人工智能'):
|
65 |
+
# Is string composed of any Chinese characters?
|
66 |
+
return bool(re.search('[\u4e00-\u9fff]', str(s)))
|
67 |
+
|
68 |
+
|
69 |
+
def is_colab():
|
70 |
+
# Is environment a Google Colab instance?
|
71 |
+
return 'google.colab' in sys.modules
|
72 |
+
|
73 |
+
|
74 |
+
def is_notebook():
|
75 |
+
# Is environment a Jupyter notebook? Verified on Colab, Jupyterlab, Kaggle, Paperspace
|
76 |
+
ipython_type = str(type(IPython.get_ipython()))
|
77 |
+
return 'colab' in ipython_type or 'zmqshell' in ipython_type
|
78 |
+
|
79 |
+
|
80 |
+
def is_kaggle():
|
81 |
+
# Is environment a Kaggle Notebook?
|
82 |
+
return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
|
83 |
+
|
84 |
+
|
85 |
+
def is_docker() -> bool:
|
86 |
+
"""Check if the process runs inside a docker container."""
|
87 |
+
if Path("/.dockerenv").exists():
|
88 |
+
return True
|
89 |
+
try: # check if docker is in control groups
|
90 |
+
with open("/proc/self/cgroup") as file:
|
91 |
+
return any("docker" in line for line in file)
|
92 |
+
except OSError:
|
93 |
+
return False
|
94 |
+
|
95 |
+
|
96 |
+
def is_writeable(dir, test=False):
|
97 |
+
# Return True if directory has write permissions, test opening a file with write permissions if test=True
|
98 |
+
if not test:
|
99 |
+
return os.access(dir, os.W_OK) # possible issues on Windows
|
100 |
+
file = Path(dir) / 'tmp.txt'
|
101 |
+
try:
|
102 |
+
with open(file, 'w'): # open file with write permissions
|
103 |
+
pass
|
104 |
+
file.unlink() # remove file
|
105 |
+
return True
|
106 |
+
except OSError:
|
107 |
+
return False
|
108 |
+
|
109 |
+
|
110 |
+
LOGGING_NAME = "yolov5"
|
111 |
+
|
112 |
+
|
113 |
+
def set_logging(name=LOGGING_NAME, verbose=True):
|
114 |
+
# sets up logging for the given name
|
115 |
+
rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
|
116 |
+
level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR
|
117 |
+
logging.config.dictConfig({
|
118 |
+
"version": 1,
|
119 |
+
"disable_existing_loggers": False,
|
120 |
+
"formatters": {
|
121 |
+
name: {
|
122 |
+
"format": "%(message)s"}},
|
123 |
+
"handlers": {
|
124 |
+
name: {
|
125 |
+
"class": "logging.StreamHandler",
|
126 |
+
"formatter": name,
|
127 |
+
"level": level,}},
|
128 |
+
"loggers": {
|
129 |
+
name: {
|
130 |
+
"level": level,
|
131 |
+
"handlers": [name],
|
132 |
+
"propagate": False,}}})
|
133 |
+
|
134 |
+
|
135 |
+
set_logging(LOGGING_NAME) # run before defining LOGGER
|
136 |
+
LOGGER = logging.getLogger(LOGGING_NAME) # define globally (used in train.py, val.py, detect.py, etc.)
|
137 |
+
if platform.system() == 'Windows':
|
138 |
+
for fn in LOGGER.info, LOGGER.warning:
|
139 |
+
setattr(LOGGER, fn.__name__, lambda x: fn(emojis(x))) # emoji safe logging
|
140 |
+
|
141 |
+
|
142 |
+
def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
|
143 |
+
# Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
|
144 |
+
env = os.getenv(env_var)
|
145 |
+
if env:
|
146 |
+
path = Path(env) # use environment variable
|
147 |
+
else:
|
148 |
+
cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
|
149 |
+
path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
|
150 |
+
path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
|
151 |
+
path.mkdir(exist_ok=True) # make if required
|
152 |
+
return path
|
153 |
+
|
154 |
+
|
155 |
+
CONFIG_DIR = user_config_dir() # Ultralytics settings dir
|
156 |
+
|
157 |
+
|
158 |
+
class Profile(contextlib.ContextDecorator):
|
159 |
+
# YOLO Profile class. Usage: @Profile() decorator or 'with Profile():' context manager
|
160 |
+
def __init__(self, t=0.0):
|
161 |
+
self.t = t
|
162 |
+
self.cuda = torch.cuda.is_available()
|
163 |
+
|
164 |
+
def __enter__(self):
|
165 |
+
self.start = self.time()
|
166 |
+
return self
|
167 |
+
|
168 |
+
def __exit__(self, type, value, traceback):
|
169 |
+
self.dt = self.time() - self.start # delta-time
|
170 |
+
self.t += self.dt # accumulate dt
|
171 |
+
|
172 |
+
def time(self):
|
173 |
+
if self.cuda:
|
174 |
+
torch.cuda.synchronize()
|
175 |
+
return time.time()
|
176 |
+
|
177 |
+
|
178 |
+
class Timeout(contextlib.ContextDecorator):
|
179 |
+
# YOLO Timeout class. Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
|
180 |
+
def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
|
181 |
+
self.seconds = int(seconds)
|
182 |
+
self.timeout_message = timeout_msg
|
183 |
+
self.suppress = bool(suppress_timeout_errors)
|
184 |
+
|
185 |
+
def _timeout_handler(self, signum, frame):
|
186 |
+
raise TimeoutError(self.timeout_message)
|
187 |
+
|
188 |
+
def __enter__(self):
|
189 |
+
if platform.system() != 'Windows': # not supported on Windows
|
190 |
+
signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
|
191 |
+
signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
|
192 |
+
|
193 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
194 |
+
if platform.system() != 'Windows':
|
195 |
+
signal.alarm(0) # Cancel SIGALRM if it's scheduled
|
196 |
+
if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
|
197 |
+
return True
|
198 |
+
|
199 |
+
|
200 |
+
class WorkingDirectory(contextlib.ContextDecorator):
|
201 |
+
# Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
|
202 |
+
def __init__(self, new_dir):
|
203 |
+
self.dir = new_dir # new dir
|
204 |
+
self.cwd = Path.cwd().resolve() # current dir
|
205 |
+
|
206 |
+
def __enter__(self):
|
207 |
+
os.chdir(self.dir)
|
208 |
+
|
209 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
210 |
+
os.chdir(self.cwd)
|
211 |
+
|
212 |
+
|
213 |
+
def methods(instance):
|
214 |
+
# Get class/instance methods
|
215 |
+
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
|
216 |
+
|
217 |
+
|
218 |
+
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
|
219 |
+
# Print function arguments (optional args dict)
|
220 |
+
x = inspect.currentframe().f_back # previous frame
|
221 |
+
file, _, func, _, _ = inspect.getframeinfo(x)
|
222 |
+
if args is None: # get args automatically
|
223 |
+
args, _, _, frm = inspect.getargvalues(x)
|
224 |
+
args = {k: v for k, v in frm.items() if k in args}
|
225 |
+
try:
|
226 |
+
file = Path(file).resolve().relative_to(ROOT).with_suffix('')
|
227 |
+
except ValueError:
|
228 |
+
file = Path(file).stem
|
229 |
+
s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '')
|
230 |
+
LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
|
231 |
+
|
232 |
+
|
233 |
+
def init_seeds(seed=0, deterministic=False):
|
234 |
+
# Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
|
235 |
+
random.seed(seed)
|
236 |
+
np.random.seed(seed)
|
237 |
+
torch.manual_seed(seed)
|
238 |
+
torch.cuda.manual_seed(seed)
|
239 |
+
torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
|
240 |
+
# torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
|
241 |
+
if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213
|
242 |
+
torch.use_deterministic_algorithms(True)
|
243 |
+
torch.backends.cudnn.deterministic = True
|
244 |
+
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
245 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
246 |
+
|
247 |
+
|
248 |
+
def intersect_dicts(da, db, exclude=()):
|
249 |
+
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
|
250 |
+
return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
|
251 |
+
|
252 |
+
|
253 |
+
def get_default_args(func):
|
254 |
+
# Get func() default arguments
|
255 |
+
signature = inspect.signature(func)
|
256 |
+
return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
|
257 |
+
|
258 |
+
|
259 |
+
def get_latest_run(search_dir='.'):
|
260 |
+
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
|
261 |
+
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
|
262 |
+
return max(last_list, key=os.path.getctime) if last_list else ''
|
263 |
+
|
264 |
+
|
265 |
+
def file_age(path=__file__):
|
266 |
+
# Return days since last file update
|
267 |
+
dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
|
268 |
+
return dt.days # + dt.seconds / 86400 # fractional days
|
269 |
+
|
270 |
+
|
271 |
+
def file_date(path=__file__):
|
272 |
+
# Return human-readable file modification date, i.e. '2021-3-26'
|
273 |
+
t = datetime.fromtimestamp(Path(path).stat().st_mtime)
|
274 |
+
return f'{t.year}-{t.month}-{t.day}'
|
275 |
+
|
276 |
+
|
277 |
+
def file_size(path):
|
278 |
+
# Return file/dir size (MB)
|
279 |
+
mb = 1 << 20 # bytes to MiB (1024 ** 2)
|
280 |
+
path = Path(path)
|
281 |
+
if path.is_file():
|
282 |
+
return path.stat().st_size / mb
|
283 |
+
elif path.is_dir():
|
284 |
+
return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
|
285 |
+
else:
|
286 |
+
return 0.0
|
287 |
+
|
288 |
+
|
289 |
+
def check_online():
|
290 |
+
# Check internet connectivity
|
291 |
+
import socket
|
292 |
+
|
293 |
+
def run_once():
|
294 |
+
# Check once
|
295 |
+
try:
|
296 |
+
socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
|
297 |
+
return True
|
298 |
+
except OSError:
|
299 |
+
return False
|
300 |
+
|
301 |
+
return run_once() or run_once() # check twice to increase robustness to intermittent connectivity issues
|
302 |
+
|
303 |
+
|
304 |
+
def git_describe(path=ROOT): # path must be a directory
|
305 |
+
# Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
|
306 |
+
try:
|
307 |
+
assert (Path(path) / '.git').is_dir()
|
308 |
+
return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
|
309 |
+
except Exception:
|
310 |
+
return ''
|
311 |
+
|
312 |
+
|
313 |
+
@TryExcept()
|
314 |
+
@WorkingDirectory(ROOT)
|
315 |
+
def check_git_status(repo='WongKinYiu/yolov9', branch='main'):
|
316 |
+
# YOLO status check, recommend 'git pull' if code is out of date
|
317 |
+
url = f'https://github.com/{repo}'
|
318 |
+
msg = f', for updates see {url}'
|
319 |
+
s = colorstr('github: ') # string
|
320 |
+
assert Path('.git').exists(), s + 'skipping check (not a git repository)' + msg
|
321 |
+
assert check_online(), s + 'skipping check (offline)' + msg
|
322 |
+
|
323 |
+
splits = re.split(pattern=r'\s', string=check_output('git remote -v', shell=True).decode())
|
324 |
+
matches = [repo in s for s in splits]
|
325 |
+
if any(matches):
|
326 |
+
remote = splits[matches.index(True) - 1]
|
327 |
+
else:
|
328 |
+
remote = 'ultralytics'
|
329 |
+
check_output(f'git remote add {remote} {url}', shell=True)
|
330 |
+
check_output(f'git fetch {remote}', shell=True, timeout=5) # git fetch
|
331 |
+
local_branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
|
332 |
+
n = int(check_output(f'git rev-list {local_branch}..{remote}/{branch} --count', shell=True)) # commits behind
|
333 |
+
if n > 0:
|
334 |
+
pull = 'git pull' if remote == 'origin' else f'git pull {remote} {branch}'
|
335 |
+
s += f"⚠️ YOLO is out of date by {n} commit{'s' * (n > 1)}. Use `{pull}` or `git clone {url}` to update."
|
336 |
+
else:
|
337 |
+
s += f'up to date with {url} ✅'
|
338 |
+
LOGGER.info(s)
|
339 |
+
|
340 |
+
|
341 |
+
@WorkingDirectory(ROOT)
|
342 |
+
def check_git_info(path='.'):
|
343 |
+
# YOLO git info check, return {remote, branch, commit}
|
344 |
+
check_requirements('gitpython')
|
345 |
+
import git
|
346 |
+
try:
|
347 |
+
repo = git.Repo(path)
|
348 |
+
remote = repo.remotes.origin.url.replace('.git', '') # i.e. 'https://github.com/WongKinYiu/yolov9'
|
349 |
+
commit = repo.head.commit.hexsha # i.e. '3134699c73af83aac2a481435550b968d5792c0d'
|
350 |
+
try:
|
351 |
+
branch = repo.active_branch.name # i.e. 'main'
|
352 |
+
except TypeError: # not on any branch
|
353 |
+
branch = None # i.e. 'detached HEAD' state
|
354 |
+
return {'remote': remote, 'branch': branch, 'commit': commit}
|
355 |
+
except git.exc.InvalidGitRepositoryError: # path is not a git dir
|
356 |
+
return {'remote': None, 'branch': None, 'commit': None}
|
357 |
+
|
358 |
+
|
359 |
+
def check_python(minimum='3.7.0'):
|
360 |
+
# Check current python version vs. required python version
|
361 |
+
check_version(platform.python_version(), minimum, name='Python ', hard=True)
|
362 |
+
|
363 |
+
|
364 |
+
def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
|
365 |
+
# Check version vs. required version
|
366 |
+
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
|
367 |
+
result = (current == minimum) if pinned else (current >= minimum) # bool
|
368 |
+
s = f'WARNING ⚠️ {name}{minimum} is required by YOLO, but {name}{current} is currently installed' # string
|
369 |
+
if hard:
|
370 |
+
assert result, emojis(s) # assert min requirements met
|
371 |
+
if verbose and not result:
|
372 |
+
LOGGER.warning(s)
|
373 |
+
return result
|
374 |
+
|
375 |
+
|
376 |
+
@TryExcept()
|
377 |
+
def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=''):
|
378 |
+
# Check installed dependencies meet YOLO requirements (pass *.txt file or list of packages or single package str)
|
379 |
+
prefix = colorstr('red', 'bold', 'requirements:')
|
380 |
+
check_python() # check python version
|
381 |
+
if isinstance(requirements, Path): # requirements.txt file
|
382 |
+
file = requirements.resolve()
|
383 |
+
assert file.exists(), f"{prefix} {file} not found, check failed."
|
384 |
+
with file.open() as f:
|
385 |
+
requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
|
386 |
+
elif isinstance(requirements, str):
|
387 |
+
requirements = [requirements]
|
388 |
+
|
389 |
+
s = ''
|
390 |
+
n = 0
|
391 |
+
for r in requirements:
|
392 |
+
try:
|
393 |
+
pkg.require(r)
|
394 |
+
except (pkg.VersionConflict, pkg.DistributionNotFound): # exception if requirements not met
|
395 |
+
s += f'"{r}" '
|
396 |
+
n += 1
|
397 |
+
|
398 |
+
if s and install and AUTOINSTALL: # check environment variable
|
399 |
+
LOGGER.info(f"{prefix} YOLO requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...")
|
400 |
+
try:
|
401 |
+
# assert check_online(), "AutoUpdate skipped (offline)"
|
402 |
+
LOGGER.info(check_output(f'pip install {s} {cmds}', shell=True).decode())
|
403 |
+
source = file if 'file' in locals() else requirements
|
404 |
+
s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
|
405 |
+
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
|
406 |
+
LOGGER.info(s)
|
407 |
+
except Exception as e:
|
408 |
+
LOGGER.warning(f'{prefix} ❌ {e}')
|
409 |
+
|
410 |
+
|
411 |
+
def check_img_size(imgsz, s=32, floor=0):
|
412 |
+
# Verify image size is a multiple of stride s in each dimension
|
413 |
+
if isinstance(imgsz, int): # integer i.e. img_size=640
|
414 |
+
new_size = max(make_divisible(imgsz, int(s)), floor)
|
415 |
+
else: # list i.e. img_size=[640, 480]
|
416 |
+
imgsz = list(imgsz) # convert to list if tuple
|
417 |
+
new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
|
418 |
+
if new_size != imgsz:
|
419 |
+
LOGGER.warning(f'WARNING ⚠️ --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
|
420 |
+
return new_size
|
421 |
+
|
422 |
+
|
423 |
+
def check_imshow(warn=False):
|
424 |
+
# Check if environment supports image displays
|
425 |
+
try:
|
426 |
+
assert not is_notebook()
|
427 |
+
assert not is_docker()
|
428 |
+
cv2.imshow('test', np.zeros((1, 1, 3)))
|
429 |
+
cv2.waitKey(1)
|
430 |
+
cv2.destroyAllWindows()
|
431 |
+
cv2.waitKey(1)
|
432 |
+
return True
|
433 |
+
except Exception as e:
|
434 |
+
if warn:
|
435 |
+
LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}')
|
436 |
+
return False
|
437 |
+
|
438 |
+
|
439 |
+
def check_suffix(file='yolo.pt', suffix=('.pt',), msg=''):
|
440 |
+
# Check file(s) for acceptable suffix
|
441 |
+
if file and suffix:
|
442 |
+
if isinstance(suffix, str):
|
443 |
+
suffix = [suffix]
|
444 |
+
for f in file if isinstance(file, (list, tuple)) else [file]:
|
445 |
+
s = Path(f).suffix.lower() # file suffix
|
446 |
+
if len(s):
|
447 |
+
assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
|
448 |
+
|
449 |
+
|
450 |
+
def check_yaml(file, suffix=('.yaml', '.yml')):
|
451 |
+
# Search/download YAML file (if necessary) and return path, checking suffix
|
452 |
+
return check_file(file, suffix)
|
453 |
+
|
454 |
+
|
455 |
+
def check_file(file, suffix=''):
|
456 |
+
# Search/download file (if necessary) and return path
|
457 |
+
check_suffix(file, suffix) # optional
|
458 |
+
file = str(file) # convert to str()
|
459 |
+
if os.path.isfile(file) or not file: # exists
|
460 |
+
return file
|
461 |
+
elif file.startswith(('http:/', 'https:/')): # download
|
462 |
+
url = file # warning: Pathlib turns :// -> :/
|
463 |
+
file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
|
464 |
+
if os.path.isfile(file):
|
465 |
+
LOGGER.info(f'Found {url} locally at {file}') # file already exists
|
466 |
+
else:
|
467 |
+
LOGGER.info(f'Downloading {url} to {file}...')
|
468 |
+
torch.hub.download_url_to_file(url, file)
|
469 |
+
assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
|
470 |
+
return file
|
471 |
+
elif file.startswith('clearml://'): # ClearML Dataset ID
|
472 |
+
assert 'clearml' in sys.modules, "ClearML is not installed, so cannot use ClearML dataset. Try running 'pip install clearml'."
|
473 |
+
return file
|
474 |
+
else: # search
|
475 |
+
files = []
|
476 |
+
for d in 'data', 'models', 'utils': # search directories
|
477 |
+
files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
|
478 |
+
assert len(files), f'File not found: {file}' # assert file was found
|
479 |
+
assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
|
480 |
+
return files[0] # return file
|
481 |
+
|
482 |
+
|
483 |
+
def check_font(font=FONT, progress=False):
|
484 |
+
# Download font to CONFIG_DIR if necessary
|
485 |
+
font = Path(font)
|
486 |
+
file = CONFIG_DIR / font.name
|
487 |
+
if not font.exists() and not file.exists():
|
488 |
+
url = f'https://ultralytics.com/assets/{font.name}'
|
489 |
+
LOGGER.info(f'Downloading {url} to {file}...')
|
490 |
+
torch.hub.download_url_to_file(url, str(file), progress=progress)
|
491 |
+
|
492 |
+
|
493 |
+
def check_dataset(data, autodownload=True):
|
494 |
+
# Download, check and/or unzip dataset if not found locally
|
495 |
+
|
496 |
+
# Download (optional)
|
497 |
+
extract_dir = ''
|
498 |
+
if isinstance(data, (str, Path)) and (is_zipfile(data) or is_tarfile(data)):
|
499 |
+
download(data, dir=f'{DATASETS_DIR}/{Path(data).stem}', unzip=True, delete=False, curl=False, threads=1)
|
500 |
+
data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml'))
|
501 |
+
extract_dir, autodownload = data.parent, False
|
502 |
+
|
503 |
+
# Read yaml (optional)
|
504 |
+
if isinstance(data, (str, Path)):
|
505 |
+
data = yaml_load(data) # dictionary
|
506 |
+
|
507 |
+
# Checks
|
508 |
+
for k in 'train', 'val', 'names':
|
509 |
+
assert k in data, emojis(f"data.yaml '{k}:' field missing ❌")
|
510 |
+
if isinstance(data['names'], (list, tuple)): # old array format
|
511 |
+
data['names'] = dict(enumerate(data['names'])) # convert to dict
|
512 |
+
assert all(isinstance(k, int) for k in data['names'].keys()), 'data.yaml names keys must be integers, i.e. 2: car'
|
513 |
+
data['nc'] = len(data['names'])
|
514 |
+
|
515 |
+
# Resolve paths
|
516 |
+
path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.'
|
517 |
+
if not path.is_absolute():
|
518 |
+
path = (ROOT / path).resolve()
|
519 |
+
data['path'] = path # download scripts
|
520 |
+
for k in 'train', 'val', 'test':
|
521 |
+
if data.get(k): # prepend path
|
522 |
+
if isinstance(data[k], str):
|
523 |
+
x = (path / data[k]).resolve()
|
524 |
+
if not x.exists() and data[k].startswith('../'):
|
525 |
+
x = (path / data[k][3:]).resolve()
|
526 |
+
data[k] = str(x)
|
527 |
+
else:
|
528 |
+
data[k] = [str((path / x).resolve()) for x in data[k]]
|
529 |
+
|
530 |
+
# Parse yaml
|
531 |
+
train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
|
532 |
+
if val:
|
533 |
+
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
|
534 |
+
if not all(x.exists() for x in val):
|
535 |
+
LOGGER.info('\nDataset not found ⚠️, missing paths %s' % [str(x) for x in val if not x.exists()])
|
536 |
+
if not s or not autodownload:
|
537 |
+
raise Exception('Dataset not found ❌')
|
538 |
+
t = time.time()
|
539 |
+
if s.startswith('http') and s.endswith('.zip'): # URL
|
540 |
+
f = Path(s).name # filename
|
541 |
+
LOGGER.info(f'Downloading {s} to {f}...')
|
542 |
+
torch.hub.download_url_to_file(s, f)
|
543 |
+
Path(DATASETS_DIR).mkdir(parents=True, exist_ok=True) # create root
|
544 |
+
unzip_file(f, path=DATASETS_DIR) # unzip
|
545 |
+
Path(f).unlink() # remove zip
|
546 |
+
r = None # success
|
547 |
+
elif s.startswith('bash '): # bash script
|
548 |
+
LOGGER.info(f'Running {s} ...')
|
549 |
+
r = os.system(s)
|
550 |
+
else: # python script
|
551 |
+
r = exec(s, {'yaml': data}) # return None
|
552 |
+
dt = f'({round(time.time() - t, 1)}s)'
|
553 |
+
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌"
|
554 |
+
LOGGER.info(f"Dataset download {s}")
|
555 |
+
check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True) # download fonts
|
556 |
+
return data # dictionary
|
557 |
+
|
558 |
+
|
559 |
+
def check_amp(model):
|
560 |
+
# Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation
|
561 |
+
from models.common import AutoShape, DetectMultiBackend
|
562 |
+
|
563 |
+
def amp_allclose(model, im):
|
564 |
+
# All close FP32 vs AMP results
|
565 |
+
m = AutoShape(model, verbose=False) # model
|
566 |
+
a = m(im).xywhn[0] # FP32 inference
|
567 |
+
m.amp = True
|
568 |
+
b = m(im).xywhn[0] # AMP inference
|
569 |
+
return a.shape == b.shape and torch.allclose(a, b, atol=0.1) # close to 10% absolute tolerance
|
570 |
+
|
571 |
+
prefix = colorstr('AMP: ')
|
572 |
+
device = next(model.parameters()).device # get model device
|
573 |
+
if device.type in ('cpu', 'mps'):
|
574 |
+
return False # AMP only used on CUDA devices
|
575 |
+
f = ROOT / 'data' / 'images' / 'bus.jpg' # image to check
|
576 |
+
im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if check_online() else np.ones((640, 640, 3))
|
577 |
+
try:
|
578 |
+
#assert amp_allclose(deepcopy(model), im) or amp_allclose(DetectMultiBackend('yolo.pt', device), im)
|
579 |
+
LOGGER.info(f'{prefix}checks passed ✅')
|
580 |
+
return True
|
581 |
+
except Exception:
|
582 |
+
help_url = 'https://github.com/ultralytics/yolov5/issues/7908'
|
583 |
+
LOGGER.warning(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}')
|
584 |
+
return False
|
585 |
+
|
586 |
+
|
587 |
+
def yaml_load(file='data.yaml'):
|
588 |
+
# Single-line safe yaml loading
|
589 |
+
with open(file, errors='ignore') as f:
|
590 |
+
return yaml.safe_load(f)
|
591 |
+
|
592 |
+
|
593 |
+
def yaml_save(file='data.yaml', data={}):
|
594 |
+
# Single-line safe yaml saving
|
595 |
+
with open(file, 'w') as f:
|
596 |
+
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
|
597 |
+
|
598 |
+
|
599 |
+
def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX')):
|
600 |
+
# Unzip a *.zip file to path/, excluding files containing strings in exclude list
|
601 |
+
if path is None:
|
602 |
+
path = Path(file).parent # default path
|
603 |
+
with ZipFile(file) as zipObj:
|
604 |
+
for f in zipObj.namelist(): # list all archived filenames in the zip
|
605 |
+
if all(x not in f for x in exclude):
|
606 |
+
zipObj.extract(f, path=path)
|
607 |
+
|
608 |
+
|
609 |
+
def url2file(url):
|
610 |
+
# Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
|
611 |
+
url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
|
612 |
+
return Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
|
613 |
+
|
614 |
+
|
615 |
+
def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
|
616 |
+
# Multithreaded file download and unzip function, used in data.yaml for autodownload
|
617 |
+
def download_one(url, dir):
|
618 |
+
# Download 1 file
|
619 |
+
success = True
|
620 |
+
if os.path.isfile(url):
|
621 |
+
f = Path(url) # filename
|
622 |
+
else: # does not exist
|
623 |
+
f = dir / Path(url).name
|
624 |
+
LOGGER.info(f'Downloading {url} to {f}...')
|
625 |
+
for i in range(retry + 1):
|
626 |
+
if curl:
|
627 |
+
s = 'sS' if threads > 1 else '' # silent
|
628 |
+
r = os.system(
|
629 |
+
f'curl -# -{s}L "{url}" -o "{f}" --retry 9 -C -') # curl download with retry, continue
|
630 |
+
success = r == 0
|
631 |
+
else:
|
632 |
+
torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
|
633 |
+
success = f.is_file()
|
634 |
+
if success:
|
635 |
+
break
|
636 |
+
elif i < retry:
|
637 |
+
LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
|
638 |
+
else:
|
639 |
+
LOGGER.warning(f'❌ Failed to download {url}...')
|
640 |
+
|
641 |
+
if unzip and success and (f.suffix == '.gz' or is_zipfile(f) or is_tarfile(f)):
|
642 |
+
LOGGER.info(f'Unzipping {f}...')
|
643 |
+
if is_zipfile(f):
|
644 |
+
unzip_file(f, dir) # unzip
|
645 |
+
elif is_tarfile(f):
|
646 |
+
os.system(f'tar xf {f} --directory {f.parent}') # unzip
|
647 |
+
elif f.suffix == '.gz':
|
648 |
+
os.system(f'tar xfz {f} --directory {f.parent}') # unzip
|
649 |
+
if delete:
|
650 |
+
f.unlink() # remove zip
|
651 |
+
|
652 |
+
dir = Path(dir)
|
653 |
+
dir.mkdir(parents=True, exist_ok=True) # make directory
|
654 |
+
if threads > 1:
|
655 |
+
pool = ThreadPool(threads)
|
656 |
+
pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
|
657 |
+
pool.close()
|
658 |
+
pool.join()
|
659 |
+
else:
|
660 |
+
for u in [url] if isinstance(url, (str, Path)) else url:
|
661 |
+
download_one(u, dir)
|
662 |
+
|
663 |
+
|
664 |
+
def make_divisible(x, divisor):
|
665 |
+
# Returns nearest x divisible by divisor
|
666 |
+
if isinstance(divisor, torch.Tensor):
|
667 |
+
divisor = int(divisor.max()) # to int
|
668 |
+
return math.ceil(x / divisor) * divisor
|
669 |
+
|
670 |
+
|
671 |
+
def clean_str(s):
|
672 |
+
# Cleans a string by replacing special characters with underscore _
|
673 |
+
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
|
674 |
+
|
675 |
+
|
676 |
+
def one_cycle(y1=0.0, y2=1.0, steps=100):
|
677 |
+
# lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
|
678 |
+
return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
|
679 |
+
|
680 |
+
|
681 |
+
def one_flat_cycle(y1=0.0, y2=1.0, steps=100):
|
682 |
+
# lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
|
683 |
+
#return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
|
684 |
+
return lambda x: ((1 - math.cos((x - (steps // 2)) * math.pi / (steps // 2))) / 2) * (y2 - y1) + y1 if (x > (steps // 2)) else y1
|
685 |
+
|
686 |
+
|
687 |
+
def colorstr(*input):
|
688 |
+
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
|
689 |
+
*args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
|
690 |
+
colors = {
|
691 |
+
'black': '\033[30m', # basic colors
|
692 |
+
'red': '\033[31m',
|
693 |
+
'green': '\033[32m',
|
694 |
+
'yellow': '\033[33m',
|
695 |
+
'blue': '\033[34m',
|
696 |
+
'magenta': '\033[35m',
|
697 |
+
'cyan': '\033[36m',
|
698 |
+
'white': '\033[37m',
|
699 |
+
'bright_black': '\033[90m', # bright colors
|
700 |
+
'bright_red': '\033[91m',
|
701 |
+
'bright_green': '\033[92m',
|
702 |
+
'bright_yellow': '\033[93m',
|
703 |
+
'bright_blue': '\033[94m',
|
704 |
+
'bright_magenta': '\033[95m',
|
705 |
+
'bright_cyan': '\033[96m',
|
706 |
+
'bright_white': '\033[97m',
|
707 |
+
'end': '\033[0m', # misc
|
708 |
+
'bold': '\033[1m',
|
709 |
+
'underline': '\033[4m'}
|
710 |
+
return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
|
711 |
+
|
712 |
+
|
713 |
+
def labels_to_class_weights(labels, nc=80):
|
714 |
+
# Get class weights (inverse frequency) from training labels
|
715 |
+
if labels[0] is None: # no labels loaded
|
716 |
+
return torch.Tensor()
|
717 |
+
|
718 |
+
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
|
719 |
+
classes = labels[:, 0].astype(int) # labels = [class xywh]
|
720 |
+
weights = np.bincount(classes, minlength=nc) # occurrences per class
|
721 |
+
|
722 |
+
# Prepend gridpoint count (for uCE training)
|
723 |
+
# gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
|
724 |
+
# weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
|
725 |
+
|
726 |
+
weights[weights == 0] = 1 # replace empty bins with 1
|
727 |
+
weights = 1 / weights # number of targets per class
|
728 |
+
weights /= weights.sum() # normalize
|
729 |
+
return torch.from_numpy(weights).float()
|
730 |
+
|
731 |
+
|
732 |
+
def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
|
733 |
+
# Produces image weights based on class_weights and image contents
|
734 |
+
# Usage: index = random.choices(range(n), weights=image_weights, k=1) # weighted image sample
|
735 |
+
class_counts = np.array([np.bincount(x[:, 0].astype(int), minlength=nc) for x in labels])
|
736 |
+
return (class_weights.reshape(1, nc) * class_counts).sum(1)
|
737 |
+
|
738 |
+
|
739 |
+
def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
|
740 |
+
# https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
|
741 |
+
# a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
|
742 |
+
# b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
|
743 |
+
# x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
|
744 |
+
# x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
|
745 |
+
return [
|
746 |
+
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
|
747 |
+
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
|
748 |
+
64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
|
749 |
+
|
750 |
+
|
751 |
+
def xyxy2xywh(x):
|
752 |
+
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
|
753 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
754 |
+
y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
|
755 |
+
y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
|
756 |
+
y[..., 2] = x[..., 2] - x[..., 0] # width
|
757 |
+
y[..., 3] = x[..., 3] - x[..., 1] # height
|
758 |
+
return y
|
759 |
+
|
760 |
+
|
761 |
+
def xywh2xyxy(x):
|
762 |
+
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
763 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
764 |
+
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
|
765 |
+
y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
|
766 |
+
y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
|
767 |
+
y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
|
768 |
+
return y
|
769 |
+
|
770 |
+
|
771 |
+
def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
|
772 |
+
# Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
773 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
774 |
+
y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
|
775 |
+
y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
|
776 |
+
y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
|
777 |
+
y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y
|
778 |
+
return y
|
779 |
+
|
780 |
+
|
781 |
+
def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
|
782 |
+
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
|
783 |
+
if clip:
|
784 |
+
clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
|
785 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
786 |
+
y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
|
787 |
+
y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
|
788 |
+
y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
|
789 |
+
y[..., 3] = (x[..., 3] - x[..., 1]) / h # height
|
790 |
+
return y
|
791 |
+
|
792 |
+
|
793 |
+
def xyn2xy(x, w=640, h=640, padw=0, padh=0):
|
794 |
+
# Convert normalized segments into pixel segments, shape (n,2)
|
795 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
796 |
+
y[..., 0] = w * x[..., 0] + padw # top left x
|
797 |
+
y[..., 1] = h * x[..., 1] + padh # top left y
|
798 |
+
return y
|
799 |
+
|
800 |
+
|
801 |
+
def segment2box(segment, width=640, height=640):
|
802 |
+
# Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
|
803 |
+
x, y = segment.T # segment xy
|
804 |
+
inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
|
805 |
+
x, y, = x[inside], y[inside]
|
806 |
+
return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
|
807 |
+
|
808 |
+
|
809 |
+
def segments2boxes(segments):
|
810 |
+
# Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
|
811 |
+
boxes = []
|
812 |
+
for s in segments:
|
813 |
+
x, y = s.T # segment xy
|
814 |
+
boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
|
815 |
+
return xyxy2xywh(np.array(boxes)) # cls, xywh
|
816 |
+
|
817 |
+
|
818 |
+
def resample_segments(segments, n=1000):
|
819 |
+
# Up-sample an (n,2) segment
|
820 |
+
for i, s in enumerate(segments):
|
821 |
+
s = np.concatenate((s, s[0:1, :]), axis=0)
|
822 |
+
x = np.linspace(0, len(s) - 1, n)
|
823 |
+
xp = np.arange(len(s))
|
824 |
+
segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
|
825 |
+
return segments
|
826 |
+
|
827 |
+
|
828 |
+
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
|
829 |
+
# Rescale boxes (xyxy) from img1_shape to img0_shape
|
830 |
+
if ratio_pad is None: # calculate from img0_shape
|
831 |
+
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
832 |
+
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
|
833 |
+
else:
|
834 |
+
gain = ratio_pad[0][0]
|
835 |
+
pad = ratio_pad[1]
|
836 |
+
|
837 |
+
boxes[:, [0, 2]] -= pad[0] # x padding
|
838 |
+
boxes[:, [1, 3]] -= pad[1] # y padding
|
839 |
+
boxes[:, :4] /= gain
|
840 |
+
clip_boxes(boxes, img0_shape)
|
841 |
+
return boxes
|
842 |
+
|
843 |
+
|
844 |
+
def scale_segments(img1_shape, segments, img0_shape, ratio_pad=None, normalize=False):
|
845 |
+
# Rescale coords (xyxy) from img1_shape to img0_shape
|
846 |
+
if ratio_pad is None: # calculate from img0_shape
|
847 |
+
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
848 |
+
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
|
849 |
+
else:
|
850 |
+
gain = ratio_pad[0][0]
|
851 |
+
pad = ratio_pad[1]
|
852 |
+
|
853 |
+
segments[:, 0] -= pad[0] # x padding
|
854 |
+
segments[:, 1] -= pad[1] # y padding
|
855 |
+
segments /= gain
|
856 |
+
clip_segments(segments, img0_shape)
|
857 |
+
if normalize:
|
858 |
+
segments[:, 0] /= img0_shape[1] # width
|
859 |
+
segments[:, 1] /= img0_shape[0] # height
|
860 |
+
return segments
|
861 |
+
|
862 |
+
|
863 |
+
def clip_boxes(boxes, shape):
|
864 |
+
# Clip boxes (xyxy) to image shape (height, width)
|
865 |
+
if isinstance(boxes, torch.Tensor): # faster individually
|
866 |
+
boxes[:, 0].clamp_(0, shape[1]) # x1
|
867 |
+
boxes[:, 1].clamp_(0, shape[0]) # y1
|
868 |
+
boxes[:, 2].clamp_(0, shape[1]) # x2
|
869 |
+
boxes[:, 3].clamp_(0, shape[0]) # y2
|
870 |
+
else: # np.array (faster grouped)
|
871 |
+
boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
|
872 |
+
boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
|
873 |
+
|
874 |
+
|
875 |
+
def clip_segments(segments, shape):
|
876 |
+
# Clip segments (xy1,xy2,...) to image shape (height, width)
|
877 |
+
if isinstance(segments, torch.Tensor): # faster individually
|
878 |
+
segments[:, 0].clamp_(0, shape[1]) # x
|
879 |
+
segments[:, 1].clamp_(0, shape[0]) # y
|
880 |
+
else: # np.array (faster grouped)
|
881 |
+
segments[:, 0] = segments[:, 0].clip(0, shape[1]) # x
|
882 |
+
segments[:, 1] = segments[:, 1].clip(0, shape[0]) # y
|
883 |
+
|
884 |
+
|
885 |
+
def non_max_suppression(
|
886 |
+
prediction,
|
887 |
+
conf_thres=0.25,
|
888 |
+
iou_thres=0.45,
|
889 |
+
classes=None,
|
890 |
+
agnostic=False,
|
891 |
+
multi_label=False,
|
892 |
+
labels=(),
|
893 |
+
max_det=300,
|
894 |
+
nm=0, # number of masks
|
895 |
+
):
|
896 |
+
"""Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
|
897 |
+
|
898 |
+
Returns:
|
899 |
+
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
|
900 |
+
"""
|
901 |
+
|
902 |
+
if isinstance(prediction, (list, tuple)): # YOLO model in validation model, output = (inference_out, loss_out)
|
903 |
+
prediction = prediction[0] # select only inference output
|
904 |
+
|
905 |
+
device = prediction.device
|
906 |
+
mps = 'mps' in device.type # Apple MPS
|
907 |
+
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
|
908 |
+
prediction = prediction.cpu()
|
909 |
+
bs = prediction.shape[0] # batch size
|
910 |
+
nc = prediction.shape[1] - nm - 4 # number of classes
|
911 |
+
mi = 4 + nc # mask start index
|
912 |
+
xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
|
913 |
+
|
914 |
+
# Checks
|
915 |
+
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
|
916 |
+
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
|
917 |
+
|
918 |
+
# Settings
|
919 |
+
# min_wh = 2 # (pixels) minimum box width and height
|
920 |
+
max_wh = 7680 # (pixels) maximum box width and height
|
921 |
+
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
|
922 |
+
time_limit = 2.5 + 0.05 * bs # seconds to quit after
|
923 |
+
redundant = True # require redundant detections
|
924 |
+
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
925 |
+
merge = False # use merge-NMS
|
926 |
+
|
927 |
+
t = time.time()
|
928 |
+
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
|
929 |
+
for xi, x in enumerate(prediction): # image index, image inference
|
930 |
+
# Apply constraints
|
931 |
+
# x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
|
932 |
+
x = x.T[xc[xi]] # confidence
|
933 |
+
|
934 |
+
# Cat apriori labels if autolabelling
|
935 |
+
if labels and len(labels[xi]):
|
936 |
+
lb = labels[xi]
|
937 |
+
v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
|
938 |
+
v[:, :4] = lb[:, 1:5] # box
|
939 |
+
v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
|
940 |
+
x = torch.cat((x, v), 0)
|
941 |
+
|
942 |
+
# If none remain process next image
|
943 |
+
if not x.shape[0]:
|
944 |
+
continue
|
945 |
+
|
946 |
+
# Detections matrix nx6 (xyxy, conf, cls)
|
947 |
+
box, cls, mask = x.split((4, nc, nm), 1)
|
948 |
+
box = xywh2xyxy(box) # center_x, center_y, width, height) to (x1, y1, x2, y2)
|
949 |
+
if multi_label:
|
950 |
+
i, j = (cls > conf_thres).nonzero(as_tuple=False).T
|
951 |
+
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
|
952 |
+
else: # best class only
|
953 |
+
conf, j = cls.max(1, keepdim=True)
|
954 |
+
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
|
955 |
+
|
956 |
+
# Filter by class
|
957 |
+
if classes is not None:
|
958 |
+
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
|
959 |
+
|
960 |
+
# Apply finite constraint
|
961 |
+
# if not torch.isfinite(x).all():
|
962 |
+
# x = x[torch.isfinite(x).all(1)]
|
963 |
+
|
964 |
+
# Check shape
|
965 |
+
n = x.shape[0] # number of boxes
|
966 |
+
if not n: # no boxes
|
967 |
+
continue
|
968 |
+
elif n > max_nms: # excess boxes
|
969 |
+
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
|
970 |
+
else:
|
971 |
+
x = x[x[:, 4].argsort(descending=True)] # sort by confidence
|
972 |
+
|
973 |
+
# Batched NMS
|
974 |
+
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
975 |
+
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
|
976 |
+
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
977 |
+
if i.shape[0] > max_det: # limit detections
|
978 |
+
i = i[:max_det]
|
979 |
+
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
|
980 |
+
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
|
981 |
+
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
|
982 |
+
weights = iou * scores[None] # box weights
|
983 |
+
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
|
984 |
+
if redundant:
|
985 |
+
i = i[iou.sum(1) > 1] # require redundancy
|
986 |
+
|
987 |
+
output[xi] = x[i]
|
988 |
+
if mps:
|
989 |
+
output[xi] = output[xi].to(device)
|
990 |
+
if (time.time() - t) > time_limit:
|
991 |
+
LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
|
992 |
+
break # time limit exceeded
|
993 |
+
|
994 |
+
return output
|
995 |
+
|
996 |
+
|
997 |
+
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
|
998 |
+
# Strip optimizer from 'f' to finalize training, optionally save as 's'
|
999 |
+
x = torch.load(f, map_location=torch.device('cpu'))
|
1000 |
+
if x.get('ema'):
|
1001 |
+
x['model'] = x['ema'] # replace model with ema
|
1002 |
+
for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
|
1003 |
+
x[k] = None
|
1004 |
+
x['epoch'] = -1
|
1005 |
+
x['model'].half() # to FP16
|
1006 |
+
for p in x['model'].parameters():
|
1007 |
+
p.requires_grad = False
|
1008 |
+
torch.save(x, s or f)
|
1009 |
+
mb = os.path.getsize(s or f) / 1E6 # filesize
|
1010 |
+
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
1011 |
+
|
1012 |
+
|
1013 |
+
def print_mutation(keys, results, hyp, save_dir, bucket, prefix=colorstr('evolve: ')):
|
1014 |
+
evolve_csv = save_dir / 'evolve.csv'
|
1015 |
+
evolve_yaml = save_dir / 'hyp_evolve.yaml'
|
1016 |
+
keys = tuple(keys) + tuple(hyp.keys()) # [results + hyps]
|
1017 |
+
keys = tuple(x.strip() for x in keys)
|
1018 |
+
vals = results + tuple(hyp.values())
|
1019 |
+
n = len(keys)
|
1020 |
+
|
1021 |
+
# Download (optional)
|
1022 |
+
if bucket:
|
1023 |
+
url = f'gs://{bucket}/evolve.csv'
|
1024 |
+
if gsutil_getsize(url) > (evolve_csv.stat().st_size if evolve_csv.exists() else 0):
|
1025 |
+
os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
|
1026 |
+
|
1027 |
+
# Log to evolve.csv
|
1028 |
+
s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
|
1029 |
+
with open(evolve_csv, 'a') as f:
|
1030 |
+
f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
|
1031 |
+
|
1032 |
+
# Save yaml
|
1033 |
+
with open(evolve_yaml, 'w') as f:
|
1034 |
+
data = pd.read_csv(evolve_csv)
|
1035 |
+
data = data.rename(columns=lambda x: x.strip()) # strip keys
|
1036 |
+
i = np.argmax(fitness(data.values[:, :4])) #
|
1037 |
+
generations = len(data)
|
1038 |
+
f.write('# YOLO Hyperparameter Evolution Results\n' + f'# Best generation: {i}\n' +
|
1039 |
+
f'# Last generation: {generations - 1}\n' + '# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) +
|
1040 |
+
'\n' + '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
|
1041 |
+
yaml.safe_dump(data.loc[i][7:].to_dict(), f, sort_keys=False)
|
1042 |
+
|
1043 |
+
# Print to screen
|
1044 |
+
LOGGER.info(prefix + f'{generations} generations finished, current result:\n' + prefix +
|
1045 |
+
', '.join(f'{x.strip():>20s}' for x in keys) + '\n' + prefix + ', '.join(f'{x:20.5g}'
|
1046 |
+
for x in vals) + '\n\n')
|
1047 |
+
|
1048 |
+
if bucket:
|
1049 |
+
os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
|
1050 |
+
|
1051 |
+
|
1052 |
+
def apply_classifier(x, model, img, im0):
|
1053 |
+
# Apply a second stage classifier to YOLO outputs
|
1054 |
+
# Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
|
1055 |
+
im0 = [im0] if isinstance(im0, np.ndarray) else im0
|
1056 |
+
for i, d in enumerate(x): # per image
|
1057 |
+
if d is not None and len(d):
|
1058 |
+
d = d.clone()
|
1059 |
+
|
1060 |
+
# Reshape and pad cutouts
|
1061 |
+
b = xyxy2xywh(d[:, :4]) # boxes
|
1062 |
+
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
|
1063 |
+
b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
|
1064 |
+
d[:, :4] = xywh2xyxy(b).long()
|
1065 |
+
|
1066 |
+
# Rescale boxes from img_size to im0 size
|
1067 |
+
scale_boxes(img.shape[2:], d[:, :4], im0[i].shape)
|
1068 |
+
|
1069 |
+
# Classes
|
1070 |
+
pred_cls1 = d[:, 5].long()
|
1071 |
+
ims = []
|
1072 |
+
for a in d:
|
1073 |
+
cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
|
1074 |
+
im = cv2.resize(cutout, (224, 224)) # BGR
|
1075 |
+
|
1076 |
+
im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
|
1077 |
+
im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
|
1078 |
+
im /= 255 # 0 - 255 to 0.0 - 1.0
|
1079 |
+
ims.append(im)
|
1080 |
+
|
1081 |
+
pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
|
1082 |
+
x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
|
1083 |
+
|
1084 |
+
return x
|
1085 |
+
|
1086 |
+
|
1087 |
+
def increment_path(path, exist_ok=False, sep='', mkdir=False):
|
1088 |
+
# Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
|
1089 |
+
path = Path(path) # os-agnostic
|
1090 |
+
if path.exists() and not exist_ok:
|
1091 |
+
path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
|
1092 |
+
|
1093 |
+
# Method 1
|
1094 |
+
for n in range(2, 9999):
|
1095 |
+
p = f'{path}{sep}{n}{suffix}' # increment path
|
1096 |
+
if not os.path.exists(p): #
|
1097 |
+
break
|
1098 |
+
path = Path(p)
|
1099 |
+
|
1100 |
+
# Method 2 (deprecated)
|
1101 |
+
# dirs = glob.glob(f"{path}{sep}*") # similar paths
|
1102 |
+
# matches = [re.search(rf"{path.stem}{sep}(\d+)", d) for d in dirs]
|
1103 |
+
# i = [int(m.groups()[0]) for m in matches if m] # indices
|
1104 |
+
# n = max(i) + 1 if i else 2 # increment number
|
1105 |
+
# path = Path(f"{path}{sep}{n}{suffix}") # increment path
|
1106 |
+
|
1107 |
+
if mkdir:
|
1108 |
+
path.mkdir(parents=True, exist_ok=True) # make directory
|
1109 |
+
|
1110 |
+
return path
|
1111 |
+
|
1112 |
+
|
1113 |
+
# OpenCV Chinese-friendly functions ------------------------------------------------------------------------------------
|
1114 |
+
imshow_ = cv2.imshow # copy to avoid recursion errors
|
1115 |
+
|
1116 |
+
|
1117 |
+
def imread(path, flags=cv2.IMREAD_COLOR):
|
1118 |
+
return cv2.imdecode(np.fromfile(path, np.uint8), flags)
|
1119 |
+
|
1120 |
+
|
1121 |
+
def imwrite(path, im):
|
1122 |
+
try:
|
1123 |
+
cv2.imencode(Path(path).suffix, im)[1].tofile(path)
|
1124 |
+
return True
|
1125 |
+
except Exception:
|
1126 |
+
return False
|
1127 |
+
|
1128 |
+
|
1129 |
+
def imshow(path, im):
|
1130 |
+
imshow_(path.encode('unicode_escape').decode(), im)
|
1131 |
+
|
1132 |
+
|
1133 |
+
cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow # redefine
|
1134 |
+
|
1135 |
+
# Variables ------------------------------------------------------------------------------------------------------------
|
utils/plots.py
ADDED
@@ -0,0 +1,570 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import math
|
3 |
+
import os
|
4 |
+
from copy import copy
|
5 |
+
from pathlib import Path
|
6 |
+
from urllib.error import URLError
|
7 |
+
|
8 |
+
import cv2
|
9 |
+
import matplotlib
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import numpy as np
|
12 |
+
import pandas as pd
|
13 |
+
import seaborn as sn
|
14 |
+
import torch
|
15 |
+
from PIL import Image, ImageDraw, ImageFont
|
16 |
+
|
17 |
+
from utils import TryExcept, threaded
|
18 |
+
from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_boxes, increment_path,
|
19 |
+
is_ascii, xywh2xyxy, xyxy2xywh)
|
20 |
+
from utils.metrics import fitness
|
21 |
+
from utils.segment.general import scale_image
|
22 |
+
|
23 |
+
# Settings
|
24 |
+
RANK = int(os.getenv('RANK', -1))
|
25 |
+
matplotlib.rc('font', **{'size': 11})
|
26 |
+
matplotlib.use('Agg') # for writing to files only
|
27 |
+
|
28 |
+
|
29 |
+
class Colors:
|
30 |
+
# Ultralytics color palette https://ultralytics.com/
|
31 |
+
def __init__(self):
|
32 |
+
# hex = matplotlib.colors.TABLEAU_COLORS.values()
|
33 |
+
hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
|
34 |
+
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
|
35 |
+
self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
|
36 |
+
self.n = len(self.palette)
|
37 |
+
|
38 |
+
def __call__(self, i, bgr=False):
|
39 |
+
c = self.palette[int(i) % self.n]
|
40 |
+
return (c[2], c[1], c[0]) if bgr else c
|
41 |
+
|
42 |
+
@staticmethod
|
43 |
+
def hex2rgb(h): # rgb order (PIL)
|
44 |
+
return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
|
45 |
+
|
46 |
+
|
47 |
+
colors = Colors() # create instance for 'from utils.plots import colors'
|
48 |
+
|
49 |
+
|
50 |
+
def check_pil_font(font=FONT, size=10):
|
51 |
+
# Return a PIL TrueType Font, downloading to CONFIG_DIR if necessary
|
52 |
+
font = Path(font)
|
53 |
+
font = font if font.exists() else (CONFIG_DIR / font.name)
|
54 |
+
try:
|
55 |
+
return ImageFont.truetype(str(font) if font.exists() else font.name, size)
|
56 |
+
except Exception: # download if missing
|
57 |
+
try:
|
58 |
+
check_font(font)
|
59 |
+
return ImageFont.truetype(str(font), size)
|
60 |
+
except TypeError:
|
61 |
+
check_requirements('Pillow>=8.4.0') # known issue https://github.com/ultralytics/yolov5/issues/5374
|
62 |
+
except URLError: # not online
|
63 |
+
return ImageFont.load_default()
|
64 |
+
|
65 |
+
|
66 |
+
class Annotator:
|
67 |
+
# YOLOv5 Annotator for train/val mosaics and jpgs and detect/hub inference annotations
|
68 |
+
def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
|
69 |
+
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
|
70 |
+
non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
|
71 |
+
self.pil = pil or non_ascii
|
72 |
+
if self.pil: # use PIL
|
73 |
+
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
|
74 |
+
self.draw = ImageDraw.Draw(self.im)
|
75 |
+
self.font = check_pil_font(font='Arial.Unicode.ttf' if non_ascii else font,
|
76 |
+
size=font_size or max(round(sum(self.im.size) / 2 * 0.035), 12))
|
77 |
+
else: # use cv2
|
78 |
+
self.im = im
|
79 |
+
self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
|
80 |
+
|
81 |
+
def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
|
82 |
+
# Add one xyxy box to image with label
|
83 |
+
if self.pil or not is_ascii(label):
|
84 |
+
self.draw.rectangle(box, width=self.lw, outline=color) # box
|
85 |
+
if label:
|
86 |
+
w, h = self.font.getsize(label) # text width, height
|
87 |
+
outside = box[1] - h >= 0 # label fits outside box
|
88 |
+
self.draw.rectangle(
|
89 |
+
(box[0], box[1] - h if outside else box[1], box[0] + w + 1,
|
90 |
+
box[1] + 1 if outside else box[1] + h + 1),
|
91 |
+
fill=color,
|
92 |
+
)
|
93 |
+
# self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
|
94 |
+
self.draw.text((box[0], box[1] - h if outside else box[1]), label, fill=txt_color, font=self.font)
|
95 |
+
else: # cv2
|
96 |
+
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
|
97 |
+
cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
|
98 |
+
if label:
|
99 |
+
tf = max(self.lw - 1, 1) # font thickness
|
100 |
+
w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height
|
101 |
+
outside = p1[1] - h >= 3
|
102 |
+
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
|
103 |
+
cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
|
104 |
+
cv2.putText(self.im,
|
105 |
+
label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
|
106 |
+
0,
|
107 |
+
self.lw / 3,
|
108 |
+
txt_color,
|
109 |
+
thickness=tf,
|
110 |
+
lineType=cv2.LINE_AA)
|
111 |
+
|
112 |
+
def masks(self, masks, colors, im_gpu=None, alpha=0.5):
|
113 |
+
"""Plot masks at once.
|
114 |
+
Args:
|
115 |
+
masks (tensor): predicted masks on cuda, shape: [n, h, w]
|
116 |
+
colors (List[List[Int]]): colors for predicted masks, [[r, g, b] * n]
|
117 |
+
im_gpu (tensor): img is in cuda, shape: [3, h, w], range: [0, 1]
|
118 |
+
alpha (float): mask transparency: 0.0 fully transparent, 1.0 opaque
|
119 |
+
"""
|
120 |
+
if self.pil:
|
121 |
+
# convert to numpy first
|
122 |
+
self.im = np.asarray(self.im).copy()
|
123 |
+
if im_gpu is None:
|
124 |
+
# Add multiple masks of shape(h,w,n) with colors list([r,g,b], [r,g,b], ...)
|
125 |
+
if len(masks) == 0:
|
126 |
+
return
|
127 |
+
if isinstance(masks, torch.Tensor):
|
128 |
+
masks = torch.as_tensor(masks, dtype=torch.uint8)
|
129 |
+
masks = masks.permute(1, 2, 0).contiguous()
|
130 |
+
masks = masks.cpu().numpy()
|
131 |
+
# masks = np.ascontiguousarray(masks.transpose(1, 2, 0))
|
132 |
+
masks = scale_image(masks.shape[:2], masks, self.im.shape)
|
133 |
+
masks = np.asarray(masks, dtype=np.float32)
|
134 |
+
colors = np.asarray(colors, dtype=np.float32) # shape(n,3)
|
135 |
+
s = masks.sum(2, keepdims=True).clip(0, 1) # add all masks together
|
136 |
+
masks = (masks @ colors).clip(0, 255) # (h,w,n) @ (n,3) = (h,w,3)
|
137 |
+
self.im[:] = masks * alpha + self.im * (1 - s * alpha)
|
138 |
+
else:
|
139 |
+
if len(masks) == 0:
|
140 |
+
self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
|
141 |
+
colors = torch.tensor(colors, device=im_gpu.device, dtype=torch.float32) / 255.0
|
142 |
+
colors = colors[:, None, None] # shape(n,1,1,3)
|
143 |
+
masks = masks.unsqueeze(3) # shape(n,h,w,1)
|
144 |
+
masks_color = masks * (colors * alpha) # shape(n,h,w,3)
|
145 |
+
|
146 |
+
inv_alph_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
|
147 |
+
mcs = (masks_color * inv_alph_masks).sum(0) * 2 # mask color summand shape(n,h,w,3)
|
148 |
+
|
149 |
+
im_gpu = im_gpu.flip(dims=[0]) # flip channel
|
150 |
+
im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
|
151 |
+
im_gpu = im_gpu * inv_alph_masks[-1] + mcs
|
152 |
+
im_mask = (im_gpu * 255).byte().cpu().numpy()
|
153 |
+
self.im[:] = scale_image(im_gpu.shape, im_mask, self.im.shape)
|
154 |
+
if self.pil:
|
155 |
+
# convert im back to PIL and update draw
|
156 |
+
self.fromarray(self.im)
|
157 |
+
|
158 |
+
def rectangle(self, xy, fill=None, outline=None, width=1):
|
159 |
+
# Add rectangle to image (PIL-only)
|
160 |
+
self.draw.rectangle(xy, fill, outline, width)
|
161 |
+
|
162 |
+
def text(self, xy, text, txt_color=(255, 255, 255), anchor='top'):
|
163 |
+
# Add text to image (PIL-only)
|
164 |
+
if anchor == 'bottom': # start y from font bottom
|
165 |
+
w, h = self.font.getsize(text) # text width, height
|
166 |
+
xy[1] += 1 - h
|
167 |
+
self.draw.text(xy, text, fill=txt_color, font=self.font)
|
168 |
+
|
169 |
+
def fromarray(self, im):
|
170 |
+
# Update self.im from a numpy array
|
171 |
+
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
|
172 |
+
self.draw = ImageDraw.Draw(self.im)
|
173 |
+
|
174 |
+
def result(self):
|
175 |
+
# Return annotated image as array
|
176 |
+
return np.asarray(self.im)
|
177 |
+
|
178 |
+
|
179 |
+
def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):
|
180 |
+
"""
|
181 |
+
x: Features to be visualized
|
182 |
+
module_type: Module type
|
183 |
+
stage: Module stage within model
|
184 |
+
n: Maximum number of feature maps to plot
|
185 |
+
save_dir: Directory to save results
|
186 |
+
"""
|
187 |
+
if 'Detect' not in module_type:
|
188 |
+
batch, channels, height, width = x.shape # batch, channels, height, width
|
189 |
+
if height > 1 and width > 1:
|
190 |
+
f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
|
191 |
+
|
192 |
+
blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
|
193 |
+
n = min(n, channels) # number of plots
|
194 |
+
fig, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols
|
195 |
+
ax = ax.ravel()
|
196 |
+
plt.subplots_adjust(wspace=0.05, hspace=0.05)
|
197 |
+
for i in range(n):
|
198 |
+
ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
|
199 |
+
ax[i].axis('off')
|
200 |
+
|
201 |
+
LOGGER.info(f'Saving {f}... ({n}/{channels})')
|
202 |
+
plt.savefig(f, dpi=300, bbox_inches='tight')
|
203 |
+
plt.close()
|
204 |
+
np.save(str(f.with_suffix('.npy')), x[0].cpu().numpy()) # npy save
|
205 |
+
|
206 |
+
|
207 |
+
def hist2d(x, y, n=100):
|
208 |
+
# 2d histogram used in labels.png and evolve.png
|
209 |
+
xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
|
210 |
+
hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
|
211 |
+
xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
|
212 |
+
yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
|
213 |
+
return np.log(hist[xidx, yidx])
|
214 |
+
|
215 |
+
|
216 |
+
def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
|
217 |
+
from scipy.signal import butter, filtfilt
|
218 |
+
|
219 |
+
# https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
|
220 |
+
def butter_lowpass(cutoff, fs, order):
|
221 |
+
nyq = 0.5 * fs
|
222 |
+
normal_cutoff = cutoff / nyq
|
223 |
+
return butter(order, normal_cutoff, btype='low', analog=False)
|
224 |
+
|
225 |
+
b, a = butter_lowpass(cutoff, fs, order=order)
|
226 |
+
return filtfilt(b, a, data) # forward-backward filter
|
227 |
+
|
228 |
+
|
229 |
+
def output_to_target(output, max_det=300):
|
230 |
+
# Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting
|
231 |
+
targets = []
|
232 |
+
for i, o in enumerate(output):
|
233 |
+
box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)
|
234 |
+
j = torch.full((conf.shape[0], 1), i)
|
235 |
+
targets.append(torch.cat((j, cls, xyxy2xywh(box), conf), 1))
|
236 |
+
return torch.cat(targets, 0).numpy()
|
237 |
+
|
238 |
+
|
239 |
+
@threaded
|
240 |
+
def plot_images(images, targets, paths=None, fname='images.jpg', names=None):
|
241 |
+
# Plot image grid with labels
|
242 |
+
if isinstance(images, torch.Tensor):
|
243 |
+
images = images.cpu().float().numpy()
|
244 |
+
if isinstance(targets, torch.Tensor):
|
245 |
+
targets = targets.cpu().numpy()
|
246 |
+
|
247 |
+
max_size = 1920 # max image size
|
248 |
+
max_subplots = 16 # max image subplots, i.e. 4x4
|
249 |
+
bs, _, h, w = images.shape # batch size, _, height, width
|
250 |
+
bs = min(bs, max_subplots) # limit plot images
|
251 |
+
ns = np.ceil(bs ** 0.5) # number of subplots (square)
|
252 |
+
if np.max(images[0]) <= 1:
|
253 |
+
images *= 255 # de-normalise (optional)
|
254 |
+
|
255 |
+
# Build Image
|
256 |
+
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
|
257 |
+
for i, im in enumerate(images):
|
258 |
+
if i == max_subplots: # if last batch has fewer images than we expect
|
259 |
+
break
|
260 |
+
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
261 |
+
im = im.transpose(1, 2, 0)
|
262 |
+
mosaic[y:y + h, x:x + w, :] = im
|
263 |
+
|
264 |
+
# Resize (optional)
|
265 |
+
scale = max_size / ns / max(h, w)
|
266 |
+
if scale < 1:
|
267 |
+
h = math.ceil(scale * h)
|
268 |
+
w = math.ceil(scale * w)
|
269 |
+
mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
|
270 |
+
|
271 |
+
# Annotate
|
272 |
+
fs = int((h + w) * ns * 0.01) # font size
|
273 |
+
annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
|
274 |
+
for i in range(i + 1):
|
275 |
+
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
276 |
+
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
|
277 |
+
if paths:
|
278 |
+
annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
|
279 |
+
if len(targets) > 0:
|
280 |
+
ti = targets[targets[:, 0] == i] # image targets
|
281 |
+
boxes = xywh2xyxy(ti[:, 2:6]).T
|
282 |
+
classes = ti[:, 1].astype('int')
|
283 |
+
labels = ti.shape[1] == 6 # labels if no conf column
|
284 |
+
conf = None if labels else ti[:, 6] # check for confidence presence (label vs pred)
|
285 |
+
|
286 |
+
if boxes.shape[1]:
|
287 |
+
if boxes.max() <= 1.01: # if normalized with tolerance 0.01
|
288 |
+
boxes[[0, 2]] *= w # scale to pixels
|
289 |
+
boxes[[1, 3]] *= h
|
290 |
+
elif scale < 1: # absolute coords need scale if image scales
|
291 |
+
boxes *= scale
|
292 |
+
boxes[[0, 2]] += x
|
293 |
+
boxes[[1, 3]] += y
|
294 |
+
for j, box in enumerate(boxes.T.tolist()):
|
295 |
+
cls = classes[j]
|
296 |
+
color = colors(cls)
|
297 |
+
cls = names[cls] if names else cls
|
298 |
+
if labels or conf[j] > 0.25: # 0.25 conf thresh
|
299 |
+
label = f'{cls}' if labels else f'{cls} {conf[j]:.1f}'
|
300 |
+
annotator.box_label(box, label, color=color)
|
301 |
+
annotator.im.save(fname) # save
|
302 |
+
|
303 |
+
|
304 |
+
def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
|
305 |
+
# Plot LR simulating training for full epochs
|
306 |
+
optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals
|
307 |
+
y = []
|
308 |
+
for _ in range(epochs):
|
309 |
+
scheduler.step()
|
310 |
+
y.append(optimizer.param_groups[0]['lr'])
|
311 |
+
plt.plot(y, '.-', label='LR')
|
312 |
+
plt.xlabel('epoch')
|
313 |
+
plt.ylabel('LR')
|
314 |
+
plt.grid()
|
315 |
+
plt.xlim(0, epochs)
|
316 |
+
plt.ylim(0)
|
317 |
+
plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
|
318 |
+
plt.close()
|
319 |
+
|
320 |
+
|
321 |
+
def plot_val_txt(): # from utils.plots import *; plot_val()
|
322 |
+
# Plot val.txt histograms
|
323 |
+
x = np.loadtxt('val.txt', dtype=np.float32)
|
324 |
+
box = xyxy2xywh(x[:, :4])
|
325 |
+
cx, cy = box[:, 0], box[:, 1]
|
326 |
+
|
327 |
+
fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
|
328 |
+
ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
|
329 |
+
ax.set_aspect('equal')
|
330 |
+
plt.savefig('hist2d.png', dpi=300)
|
331 |
+
|
332 |
+
fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
|
333 |
+
ax[0].hist(cx, bins=600)
|
334 |
+
ax[1].hist(cy, bins=600)
|
335 |
+
plt.savefig('hist1d.png', dpi=200)
|
336 |
+
|
337 |
+
|
338 |
+
def plot_targets_txt(): # from utils.plots import *; plot_targets_txt()
|
339 |
+
# Plot targets.txt histograms
|
340 |
+
x = np.loadtxt('targets.txt', dtype=np.float32).T
|
341 |
+
s = ['x targets', 'y targets', 'width targets', 'height targets']
|
342 |
+
fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
|
343 |
+
ax = ax.ravel()
|
344 |
+
for i in range(4):
|
345 |
+
ax[i].hist(x[i], bins=100, label=f'{x[i].mean():.3g} +/- {x[i].std():.3g}')
|
346 |
+
ax[i].legend()
|
347 |
+
ax[i].set_title(s[i])
|
348 |
+
plt.savefig('targets.jpg', dpi=200)
|
349 |
+
|
350 |
+
|
351 |
+
def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_val_study()
|
352 |
+
# Plot file=study.txt generated by val.py (or plot all study*.txt in dir)
|
353 |
+
save_dir = Path(file).parent if file else Path(dir)
|
354 |
+
plot2 = False # plot additional results
|
355 |
+
if plot2:
|
356 |
+
ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)[1].ravel()
|
357 |
+
|
358 |
+
fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
|
359 |
+
# for f in [save_dir / f'study_coco_{x}.txt' for x in ['yolov5n6', 'yolov5s6', 'yolov5m6', 'yolov5l6', 'yolov5x6']]:
|
360 |
+
for f in sorted(save_dir.glob('study*.txt')):
|
361 |
+
y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
|
362 |
+
x = np.arange(y.shape[1]) if x is None else np.array(x)
|
363 |
+
if plot2:
|
364 |
+
s = ['P', 'R', '[email protected]', '[email protected]:.95', 't_preprocess (ms/img)', 't_inference (ms/img)', 't_NMS (ms/img)']
|
365 |
+
for i in range(7):
|
366 |
+
ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
|
367 |
+
ax[i].set_title(s[i])
|
368 |
+
|
369 |
+
j = y[3].argmax() + 1
|
370 |
+
ax2.plot(y[5, 1:j],
|
371 |
+
y[3, 1:j] * 1E2,
|
372 |
+
'.-',
|
373 |
+
linewidth=2,
|
374 |
+
markersize=8,
|
375 |
+
label=f.stem.replace('study_coco_', '').replace('yolo', 'YOLO'))
|
376 |
+
|
377 |
+
ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5],
|
378 |
+
'k.-',
|
379 |
+
linewidth=2,
|
380 |
+
markersize=8,
|
381 |
+
alpha=.25,
|
382 |
+
label='EfficientDet')
|
383 |
+
|
384 |
+
ax2.grid(alpha=0.2)
|
385 |
+
ax2.set_yticks(np.arange(20, 60, 5))
|
386 |
+
ax2.set_xlim(0, 57)
|
387 |
+
ax2.set_ylim(25, 55)
|
388 |
+
ax2.set_xlabel('GPU Speed (ms/img)')
|
389 |
+
ax2.set_ylabel('COCO AP val')
|
390 |
+
ax2.legend(loc='lower right')
|
391 |
+
f = save_dir / 'study.png'
|
392 |
+
print(f'Saving {f}...')
|
393 |
+
plt.savefig(f, dpi=300)
|
394 |
+
|
395 |
+
|
396 |
+
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
|
397 |
+
def plot_labels(labels, names=(), save_dir=Path('')):
|
398 |
+
# plot dataset labels
|
399 |
+
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
|
400 |
+
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
|
401 |
+
nc = int(c.max() + 1) # number of classes
|
402 |
+
x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
|
403 |
+
|
404 |
+
# seaborn correlogram
|
405 |
+
sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
|
406 |
+
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
|
407 |
+
plt.close()
|
408 |
+
|
409 |
+
# matplotlib labels
|
410 |
+
matplotlib.use('svg') # faster
|
411 |
+
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
|
412 |
+
y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
|
413 |
+
with contextlib.suppress(Exception): # color histogram bars by class
|
414 |
+
[y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # known issue #3195
|
415 |
+
ax[0].set_ylabel('instances')
|
416 |
+
if 0 < len(names) < 30:
|
417 |
+
ax[0].set_xticks(range(len(names)))
|
418 |
+
ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
|
419 |
+
else:
|
420 |
+
ax[0].set_xlabel('classes')
|
421 |
+
sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
|
422 |
+
sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
|
423 |
+
|
424 |
+
# rectangles
|
425 |
+
labels[:, 1:3] = 0.5 # center
|
426 |
+
labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
|
427 |
+
img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
|
428 |
+
for cls, *box in labels[:1000]:
|
429 |
+
ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
|
430 |
+
ax[1].imshow(img)
|
431 |
+
ax[1].axis('off')
|
432 |
+
|
433 |
+
for a in [0, 1, 2, 3]:
|
434 |
+
for s in ['top', 'right', 'left', 'bottom']:
|
435 |
+
ax[a].spines[s].set_visible(False)
|
436 |
+
|
437 |
+
plt.savefig(save_dir / 'labels.jpg', dpi=200)
|
438 |
+
matplotlib.use('Agg')
|
439 |
+
plt.close()
|
440 |
+
|
441 |
+
|
442 |
+
def imshow_cls(im, labels=None, pred=None, names=None, nmax=25, verbose=False, f=Path('images.jpg')):
|
443 |
+
# Show classification image grid with labels (optional) and predictions (optional)
|
444 |
+
from utils.augmentations import denormalize
|
445 |
+
|
446 |
+
names = names or [f'class{i}' for i in range(1000)]
|
447 |
+
blocks = torch.chunk(denormalize(im.clone()).cpu().float(), len(im),
|
448 |
+
dim=0) # select batch index 0, block by channels
|
449 |
+
n = min(len(blocks), nmax) # number of plots
|
450 |
+
m = min(8, round(n ** 0.5)) # 8 x 8 default
|
451 |
+
fig, ax = plt.subplots(math.ceil(n / m), m) # 8 rows x n/8 cols
|
452 |
+
ax = ax.ravel() if m > 1 else [ax]
|
453 |
+
# plt.subplots_adjust(wspace=0.05, hspace=0.05)
|
454 |
+
for i in range(n):
|
455 |
+
ax[i].imshow(blocks[i].squeeze().permute((1, 2, 0)).numpy().clip(0.0, 1.0))
|
456 |
+
ax[i].axis('off')
|
457 |
+
if labels is not None:
|
458 |
+
s = names[labels[i]] + (f'—{names[pred[i]]}' if pred is not None else '')
|
459 |
+
ax[i].set_title(s, fontsize=8, verticalalignment='top')
|
460 |
+
plt.savefig(f, dpi=300, bbox_inches='tight')
|
461 |
+
plt.close()
|
462 |
+
if verbose:
|
463 |
+
LOGGER.info(f"Saving {f}")
|
464 |
+
if labels is not None:
|
465 |
+
LOGGER.info('True: ' + ' '.join(f'{names[i]:3s}' for i in labels[:nmax]))
|
466 |
+
if pred is not None:
|
467 |
+
LOGGER.info('Predicted:' + ' '.join(f'{names[i]:3s}' for i in pred[:nmax]))
|
468 |
+
return f
|
469 |
+
|
470 |
+
|
471 |
+
def plot_evolve(evolve_csv='path/to/evolve.csv'): # from utils.plots import *; plot_evolve()
|
472 |
+
# Plot evolve.csv hyp evolution results
|
473 |
+
evolve_csv = Path(evolve_csv)
|
474 |
+
data = pd.read_csv(evolve_csv)
|
475 |
+
keys = [x.strip() for x in data.columns]
|
476 |
+
x = data.values
|
477 |
+
f = fitness(x)
|
478 |
+
j = np.argmax(f) # max fitness index
|
479 |
+
plt.figure(figsize=(10, 12), tight_layout=True)
|
480 |
+
matplotlib.rc('font', **{'size': 8})
|
481 |
+
print(f'Best results from row {j} of {evolve_csv}:')
|
482 |
+
for i, k in enumerate(keys[7:]):
|
483 |
+
v = x[:, 7 + i]
|
484 |
+
mu = v[j] # best single result
|
485 |
+
plt.subplot(6, 5, i + 1)
|
486 |
+
plt.scatter(v, f, c=hist2d(v, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
|
487 |
+
plt.plot(mu, f.max(), 'k+', markersize=15)
|
488 |
+
plt.title(f'{k} = {mu:.3g}', fontdict={'size': 9}) # limit to 40 characters
|
489 |
+
if i % 5 != 0:
|
490 |
+
plt.yticks([])
|
491 |
+
print(f'{k:>15}: {mu:.3g}')
|
492 |
+
f = evolve_csv.with_suffix('.png') # filename
|
493 |
+
plt.savefig(f, dpi=200)
|
494 |
+
plt.close()
|
495 |
+
print(f'Saved {f}')
|
496 |
+
|
497 |
+
|
498 |
+
def plot_results(file='path/to/results.csv', dir=''):
|
499 |
+
# Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
|
500 |
+
save_dir = Path(file).parent if file else Path(dir)
|
501 |
+
fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
|
502 |
+
ax = ax.ravel()
|
503 |
+
files = list(save_dir.glob('results*.csv'))
|
504 |
+
assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
|
505 |
+
for f in files:
|
506 |
+
try:
|
507 |
+
data = pd.read_csv(f)
|
508 |
+
s = [x.strip() for x in data.columns]
|
509 |
+
x = data.values[:, 0]
|
510 |
+
for i, j in enumerate([1, 2, 3, 4, 5, 8, 9, 10, 6, 7]):
|
511 |
+
y = data.values[:, j].astype('float')
|
512 |
+
# y[y == 0] = np.nan # don't show zero values
|
513 |
+
ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8)
|
514 |
+
ax[i].set_title(s[j], fontsize=12)
|
515 |
+
# if j in [8, 9, 10]: # share train and val loss y axes
|
516 |
+
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
|
517 |
+
except Exception as e:
|
518 |
+
LOGGER.info(f'Warning: Plotting error for {f}: {e}')
|
519 |
+
ax[1].legend()
|
520 |
+
fig.savefig(save_dir / 'results.png', dpi=200)
|
521 |
+
plt.close()
|
522 |
+
|
523 |
+
|
524 |
+
def profile_idetection(start=0, stop=0, labels=(), save_dir=''):
|
525 |
+
# Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection()
|
526 |
+
ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel()
|
527 |
+
s = ['Images', 'Free Storage (GB)', 'RAM Usage (GB)', 'Battery', 'dt_raw (ms)', 'dt_smooth (ms)', 'real-world FPS']
|
528 |
+
files = list(Path(save_dir).glob('frames*.txt'))
|
529 |
+
for fi, f in enumerate(files):
|
530 |
+
try:
|
531 |
+
results = np.loadtxt(f, ndmin=2).T[:, 90:-30] # clip first and last rows
|
532 |
+
n = results.shape[1] # number of rows
|
533 |
+
x = np.arange(start, min(stop, n) if stop else n)
|
534 |
+
results = results[:, x]
|
535 |
+
t = (results[0] - results[0].min()) # set t0=0s
|
536 |
+
results[0] = x
|
537 |
+
for i, a in enumerate(ax):
|
538 |
+
if i < len(results):
|
539 |
+
label = labels[fi] if len(labels) else f.stem.replace('frames_', '')
|
540 |
+
a.plot(t, results[i], marker='.', label=label, linewidth=1, markersize=5)
|
541 |
+
a.set_title(s[i])
|
542 |
+
a.set_xlabel('time (s)')
|
543 |
+
# if fi == len(files) - 1:
|
544 |
+
# a.set_ylim(bottom=0)
|
545 |
+
for side in ['top', 'right']:
|
546 |
+
a.spines[side].set_visible(False)
|
547 |
+
else:
|
548 |
+
a.remove()
|
549 |
+
except Exception as e:
|
550 |
+
print(f'Warning: Plotting error for {f}; {e}')
|
551 |
+
ax[1].legend()
|
552 |
+
plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200)
|
553 |
+
|
554 |
+
|
555 |
+
def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
|
556 |
+
# Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
|
557 |
+
xyxy = torch.tensor(xyxy).view(-1, 4)
|
558 |
+
b = xyxy2xywh(xyxy) # boxes
|
559 |
+
if square:
|
560 |
+
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
|
561 |
+
b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
|
562 |
+
xyxy = xywh2xyxy(b).long()
|
563 |
+
clip_boxes(xyxy, im.shape)
|
564 |
+
crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
|
565 |
+
if save:
|
566 |
+
file.parent.mkdir(parents=True, exist_ok=True) # make directory
|
567 |
+
f = str(increment_path(file).with_suffix('.jpg'))
|
568 |
+
# cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
|
569 |
+
Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB
|
570 |
+
return crop
|
utils/torch_utils.py
ADDED
@@ -0,0 +1,529 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import platform
|
4 |
+
import subprocess
|
5 |
+
import time
|
6 |
+
import warnings
|
7 |
+
from contextlib import contextmanager
|
8 |
+
from copy import deepcopy
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.distributed as dist
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
16 |
+
|
17 |
+
from utils.general import LOGGER, check_version, colorstr, file_date, git_describe
|
18 |
+
from utils.lion import Lion
|
19 |
+
|
20 |
+
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
21 |
+
RANK = int(os.getenv('RANK', -1))
|
22 |
+
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
|
23 |
+
|
24 |
+
try:
|
25 |
+
import thop # for FLOPs computation
|
26 |
+
except ImportError:
|
27 |
+
thop = None
|
28 |
+
|
29 |
+
# Suppress PyTorch warnings
|
30 |
+
warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling')
|
31 |
+
warnings.filterwarnings('ignore', category=UserWarning)
|
32 |
+
|
33 |
+
|
34 |
+
def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
|
35 |
+
# Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator
|
36 |
+
def decorate(fn):
|
37 |
+
return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn)
|
38 |
+
|
39 |
+
return decorate
|
40 |
+
|
41 |
+
|
42 |
+
def smartCrossEntropyLoss(label_smoothing=0.0):
|
43 |
+
# Returns nn.CrossEntropyLoss with label smoothing enabled for torch>=1.10.0
|
44 |
+
if check_version(torch.__version__, '1.10.0'):
|
45 |
+
return nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
46 |
+
if label_smoothing > 0:
|
47 |
+
LOGGER.warning(f'WARNING ⚠️ label smoothing {label_smoothing} requires torch>=1.10.0')
|
48 |
+
return nn.CrossEntropyLoss()
|
49 |
+
|
50 |
+
|
51 |
+
def smart_DDP(model):
|
52 |
+
# Model DDP creation with checks
|
53 |
+
assert not check_version(torch.__version__, '1.12.0', pinned=True), \
|
54 |
+
'torch==1.12.0 torchvision==0.13.0 DDP training is not supported due to a known issue. ' \
|
55 |
+
'Please upgrade or downgrade torch to use DDP. See https://github.com/ultralytics/yolov5/issues/8395'
|
56 |
+
if check_version(torch.__version__, '1.11.0'):
|
57 |
+
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
|
58 |
+
else:
|
59 |
+
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
|
60 |
+
|
61 |
+
|
62 |
+
def reshape_classifier_output(model, n=1000):
|
63 |
+
# Update a TorchVision classification model to class count 'n' if required
|
64 |
+
from models.common import Classify
|
65 |
+
name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1] # last module
|
66 |
+
if isinstance(m, Classify): # YOLOv5 Classify() head
|
67 |
+
if m.linear.out_features != n:
|
68 |
+
m.linear = nn.Linear(m.linear.in_features, n)
|
69 |
+
elif isinstance(m, nn.Linear): # ResNet, EfficientNet
|
70 |
+
if m.out_features != n:
|
71 |
+
setattr(model, name, nn.Linear(m.in_features, n))
|
72 |
+
elif isinstance(m, nn.Sequential):
|
73 |
+
types = [type(x) for x in m]
|
74 |
+
if nn.Linear in types:
|
75 |
+
i = types.index(nn.Linear) # nn.Linear index
|
76 |
+
if m[i].out_features != n:
|
77 |
+
m[i] = nn.Linear(m[i].in_features, n)
|
78 |
+
elif nn.Conv2d in types:
|
79 |
+
i = types.index(nn.Conv2d) # nn.Conv2d index
|
80 |
+
if m[i].out_channels != n:
|
81 |
+
m[i] = nn.Conv2d(m[i].in_channels, n, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
|
82 |
+
|
83 |
+
|
84 |
+
@contextmanager
|
85 |
+
def torch_distributed_zero_first(local_rank: int):
|
86 |
+
# Decorator to make all processes in distributed training wait for each local_master to do something
|
87 |
+
if local_rank not in [-1, 0]:
|
88 |
+
dist.barrier(device_ids=[local_rank])
|
89 |
+
yield
|
90 |
+
if local_rank == 0:
|
91 |
+
dist.barrier(device_ids=[0])
|
92 |
+
|
93 |
+
|
94 |
+
def device_count():
|
95 |
+
# Returns number of CUDA devices available. Safe version of torch.cuda.device_count(). Supports Linux and Windows
|
96 |
+
assert platform.system() in ('Linux', 'Windows'), 'device_count() only supported on Linux or Windows'
|
97 |
+
try:
|
98 |
+
cmd = 'nvidia-smi -L | wc -l' if platform.system() == 'Linux' else 'nvidia-smi -L | find /c /v ""' # Windows
|
99 |
+
return int(subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1])
|
100 |
+
except Exception:
|
101 |
+
return 0
|
102 |
+
|
103 |
+
|
104 |
+
def select_device(device='', batch_size=0, newline=True):
|
105 |
+
# device = None or 'cpu' or 0 or '0' or '0,1,2,3'
|
106 |
+
s = f'YOLO 🚀 {git_describe() or file_date()} Python-{platform.python_version()} torch-{torch.__version__} '
|
107 |
+
device = str(device).strip().lower().replace('cuda:', '').replace('none', '') # to string, 'cuda:0' to '0'
|
108 |
+
cpu = device == 'cpu'
|
109 |
+
mps = device == 'mps' # Apple Metal Performance Shaders (MPS)
|
110 |
+
if cpu or mps:
|
111 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
|
112 |
+
elif device: # non-cpu device requested
|
113 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
|
114 |
+
assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
|
115 |
+
f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
|
116 |
+
|
117 |
+
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
|
118 |
+
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
119 |
+
n = len(devices) # device count
|
120 |
+
if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
|
121 |
+
assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
|
122 |
+
space = ' ' * (len(s) + 1)
|
123 |
+
for i, d in enumerate(devices):
|
124 |
+
p = torch.cuda.get_device_properties(i)
|
125 |
+
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
|
126 |
+
arg = 'cuda:0'
|
127 |
+
elif mps and getattr(torch, 'has_mps', False) and torch.backends.mps.is_available(): # prefer MPS if available
|
128 |
+
s += 'MPS\n'
|
129 |
+
arg = 'mps'
|
130 |
+
else: # revert to CPU
|
131 |
+
s += 'CPU\n'
|
132 |
+
arg = 'cpu'
|
133 |
+
|
134 |
+
if not newline:
|
135 |
+
s = s.rstrip()
|
136 |
+
LOGGER.info(s)
|
137 |
+
return torch.device(arg)
|
138 |
+
|
139 |
+
|
140 |
+
def time_sync():
|
141 |
+
# PyTorch-accurate time
|
142 |
+
if torch.cuda.is_available():
|
143 |
+
torch.cuda.synchronize()
|
144 |
+
return time.time()
|
145 |
+
|
146 |
+
|
147 |
+
def profile(input, ops, n=10, device=None):
|
148 |
+
""" YOLOv5 speed/memory/FLOPs profiler
|
149 |
+
Usage:
|
150 |
+
input = torch.randn(16, 3, 640, 640)
|
151 |
+
m1 = lambda x: x * torch.sigmoid(x)
|
152 |
+
m2 = nn.SiLU()
|
153 |
+
profile(input, [m1, m2], n=100) # profile over 100 iterations
|
154 |
+
"""
|
155 |
+
results = []
|
156 |
+
if not isinstance(device, torch.device):
|
157 |
+
device = select_device(device)
|
158 |
+
print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
|
159 |
+
f"{'input':>24s}{'output':>24s}")
|
160 |
+
|
161 |
+
for x in input if isinstance(input, list) else [input]:
|
162 |
+
x = x.to(device)
|
163 |
+
x.requires_grad = True
|
164 |
+
for m in ops if isinstance(ops, list) else [ops]:
|
165 |
+
m = m.to(device) if hasattr(m, 'to') else m # device
|
166 |
+
m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
|
167 |
+
tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
|
168 |
+
try:
|
169 |
+
flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs
|
170 |
+
except Exception:
|
171 |
+
flops = 0
|
172 |
+
|
173 |
+
try:
|
174 |
+
for _ in range(n):
|
175 |
+
t[0] = time_sync()
|
176 |
+
y = m(x)
|
177 |
+
t[1] = time_sync()
|
178 |
+
try:
|
179 |
+
_ = (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
|
180 |
+
t[2] = time_sync()
|
181 |
+
except Exception: # no backward method
|
182 |
+
# print(e) # for debug
|
183 |
+
t[2] = float('nan')
|
184 |
+
tf += (t[1] - t[0]) * 1000 / n # ms per op forward
|
185 |
+
tb += (t[2] - t[1]) * 1000 / n # ms per op backward
|
186 |
+
mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
|
187 |
+
s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' for x in (x, y)) # shapes
|
188 |
+
p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
|
189 |
+
print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
|
190 |
+
results.append([p, flops, mem, tf, tb, s_in, s_out])
|
191 |
+
except Exception as e:
|
192 |
+
print(e)
|
193 |
+
results.append(None)
|
194 |
+
torch.cuda.empty_cache()
|
195 |
+
return results
|
196 |
+
|
197 |
+
|
198 |
+
def is_parallel(model):
|
199 |
+
# Returns True if model is of type DP or DDP
|
200 |
+
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
201 |
+
|
202 |
+
|
203 |
+
def de_parallel(model):
|
204 |
+
# De-parallelize a model: returns single-GPU model if model is of type DP or DDP
|
205 |
+
return model.module if is_parallel(model) else model
|
206 |
+
|
207 |
+
|
208 |
+
def initialize_weights(model):
|
209 |
+
for m in model.modules():
|
210 |
+
t = type(m)
|
211 |
+
if t is nn.Conv2d:
|
212 |
+
pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
213 |
+
elif t is nn.BatchNorm2d:
|
214 |
+
m.eps = 1e-3
|
215 |
+
m.momentum = 0.03
|
216 |
+
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
|
217 |
+
m.inplace = True
|
218 |
+
|
219 |
+
|
220 |
+
def find_modules(model, mclass=nn.Conv2d):
|
221 |
+
# Finds layer indices matching module class 'mclass'
|
222 |
+
return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
|
223 |
+
|
224 |
+
|
225 |
+
def sparsity(model):
|
226 |
+
# Return global model sparsity
|
227 |
+
a, b = 0, 0
|
228 |
+
for p in model.parameters():
|
229 |
+
a += p.numel()
|
230 |
+
b += (p == 0).sum()
|
231 |
+
return b / a
|
232 |
+
|
233 |
+
|
234 |
+
def prune(model, amount=0.3):
|
235 |
+
# Prune model to requested global sparsity
|
236 |
+
import torch.nn.utils.prune as prune
|
237 |
+
for name, m in model.named_modules():
|
238 |
+
if isinstance(m, nn.Conv2d):
|
239 |
+
prune.l1_unstructured(m, name='weight', amount=amount) # prune
|
240 |
+
prune.remove(m, 'weight') # make permanent
|
241 |
+
LOGGER.info(f'Model pruned to {sparsity(model):.3g} global sparsity')
|
242 |
+
|
243 |
+
|
244 |
+
def fuse_conv_and_bn(conv, bn):
|
245 |
+
# Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
246 |
+
fusedconv = nn.Conv2d(conv.in_channels,
|
247 |
+
conv.out_channels,
|
248 |
+
kernel_size=conv.kernel_size,
|
249 |
+
stride=conv.stride,
|
250 |
+
padding=conv.padding,
|
251 |
+
dilation=conv.dilation,
|
252 |
+
groups=conv.groups,
|
253 |
+
bias=True).requires_grad_(False).to(conv.weight.device)
|
254 |
+
|
255 |
+
# Prepare filters
|
256 |
+
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
257 |
+
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
258 |
+
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
|
259 |
+
|
260 |
+
# Prepare spatial bias
|
261 |
+
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
|
262 |
+
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
|
263 |
+
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
|
264 |
+
|
265 |
+
return fusedconv
|
266 |
+
|
267 |
+
|
268 |
+
def model_info(model, verbose=False, imgsz=640):
|
269 |
+
# Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
|
270 |
+
n_p = sum(x.numel() for x in model.parameters()) # number parameters
|
271 |
+
n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
|
272 |
+
if verbose:
|
273 |
+
print(f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
|
274 |
+
for i, (name, p) in enumerate(model.named_parameters()):
|
275 |
+
name = name.replace('module_list.', '')
|
276 |
+
print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
|
277 |
+
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
|
278 |
+
|
279 |
+
try: # FLOPs
|
280 |
+
p = next(model.parameters())
|
281 |
+
stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride
|
282 |
+
im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
|
283 |
+
flops = thop.profile(deepcopy(model), inputs=(im,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
|
284 |
+
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
|
285 |
+
fs = f', {flops * imgsz[0] / stride * imgsz[1] / stride:.1f} GFLOPs' # 640x640 GFLOPs
|
286 |
+
except Exception:
|
287 |
+
fs = ''
|
288 |
+
|
289 |
+
name = Path(model.yaml_file).stem.replace('yolov5', 'YOLOv5') if hasattr(model, 'yaml_file') else 'Model'
|
290 |
+
LOGGER.info(f"{name} summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
|
291 |
+
|
292 |
+
|
293 |
+
def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
|
294 |
+
# Scales img(bs,3,y,x) by ratio constrained to gs-multiple
|
295 |
+
if ratio == 1.0:
|
296 |
+
return img
|
297 |
+
h, w = img.shape[2:]
|
298 |
+
s = (int(h * ratio), int(w * ratio)) # new size
|
299 |
+
img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
|
300 |
+
if not same_shape: # pad/crop img
|
301 |
+
h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
|
302 |
+
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
|
303 |
+
|
304 |
+
|
305 |
+
def copy_attr(a, b, include=(), exclude=()):
|
306 |
+
# Copy attributes from b to a, options to only include [...] and to exclude [...]
|
307 |
+
for k, v in b.__dict__.items():
|
308 |
+
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
|
309 |
+
continue
|
310 |
+
else:
|
311 |
+
setattr(a, k, v)
|
312 |
+
|
313 |
+
|
314 |
+
def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
|
315 |
+
# YOLOv5 3-param group optimizer: 0) weights with decay, 1) weights no decay, 2) biases no decay
|
316 |
+
g = [], [], [] # optimizer parameter groups
|
317 |
+
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
318 |
+
#for v in model.modules():
|
319 |
+
# for p_name, p in v.named_parameters(recurse=0):
|
320 |
+
# if p_name == 'bias': # bias (no decay)
|
321 |
+
# g[2].append(p)
|
322 |
+
# elif p_name == 'weight' and isinstance(v, bn): # weight (no decay)
|
323 |
+
# g[1].append(p)
|
324 |
+
# else:
|
325 |
+
# g[0].append(p) # weight (with decay)
|
326 |
+
|
327 |
+
for v in model.modules():
|
328 |
+
if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias (no decay)
|
329 |
+
g[2].append(v.bias)
|
330 |
+
if isinstance(v, bn): # weight (no decay)
|
331 |
+
g[1].append(v.weight)
|
332 |
+
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
|
333 |
+
g[0].append(v.weight)
|
334 |
+
|
335 |
+
if hasattr(v, 'im'):
|
336 |
+
if hasattr(v.im, 'implicit'):
|
337 |
+
g[1].append(v.im.implicit)
|
338 |
+
else:
|
339 |
+
for iv in v.im:
|
340 |
+
g[1].append(iv.implicit)
|
341 |
+
if hasattr(v, 'ia'):
|
342 |
+
if hasattr(v.ia, 'implicit'):
|
343 |
+
g[1].append(v.ia.implicit)
|
344 |
+
else:
|
345 |
+
for iv in v.ia:
|
346 |
+
g[1].append(iv.implicit)
|
347 |
+
|
348 |
+
if hasattr(v, 'im2'):
|
349 |
+
if hasattr(v.im2, 'implicit'):
|
350 |
+
g[1].append(v.im2.implicit)
|
351 |
+
else:
|
352 |
+
for iv in v.im2:
|
353 |
+
g[1].append(iv.implicit)
|
354 |
+
if hasattr(v, 'ia2'):
|
355 |
+
if hasattr(v.ia2, 'implicit'):
|
356 |
+
g[1].append(v.ia2.implicit)
|
357 |
+
else:
|
358 |
+
for iv in v.ia2:
|
359 |
+
g[1].append(iv.implicit)
|
360 |
+
|
361 |
+
if hasattr(v, 'im3'):
|
362 |
+
if hasattr(v.im3, 'implicit'):
|
363 |
+
g[1].append(v.im3.implicit)
|
364 |
+
else:
|
365 |
+
for iv in v.im3:
|
366 |
+
g[1].append(iv.implicit)
|
367 |
+
if hasattr(v, 'ia3'):
|
368 |
+
if hasattr(v.ia3, 'implicit'):
|
369 |
+
g[1].append(v.ia3.implicit)
|
370 |
+
else:
|
371 |
+
for iv in v.ia3:
|
372 |
+
g[1].append(iv.implicit)
|
373 |
+
|
374 |
+
if hasattr(v, 'im4'):
|
375 |
+
if hasattr(v.im4, 'implicit'):
|
376 |
+
g[1].append(v.im4.implicit)
|
377 |
+
else:
|
378 |
+
for iv in v.im4:
|
379 |
+
g[1].append(iv.implicit)
|
380 |
+
if hasattr(v, 'ia4'):
|
381 |
+
if hasattr(v.ia4, 'implicit'):
|
382 |
+
g[1].append(v.ia4.implicit)
|
383 |
+
else:
|
384 |
+
for iv in v.ia4:
|
385 |
+
g[1].append(iv.implicit)
|
386 |
+
|
387 |
+
if hasattr(v, 'im5'):
|
388 |
+
if hasattr(v.im5, 'implicit'):
|
389 |
+
g[1].append(v.im5.implicit)
|
390 |
+
else:
|
391 |
+
for iv in v.im5:
|
392 |
+
g[1].append(iv.implicit)
|
393 |
+
if hasattr(v, 'ia5'):
|
394 |
+
if hasattr(v.ia5, 'implicit'):
|
395 |
+
g[1].append(v.ia5.implicit)
|
396 |
+
else:
|
397 |
+
for iv in v.ia5:
|
398 |
+
g[1].append(iv.implicit)
|
399 |
+
|
400 |
+
if hasattr(v, 'im6'):
|
401 |
+
if hasattr(v.im6, 'implicit'):
|
402 |
+
g[1].append(v.im6.implicit)
|
403 |
+
else:
|
404 |
+
for iv in v.im6:
|
405 |
+
g[1].append(iv.implicit)
|
406 |
+
if hasattr(v, 'ia6'):
|
407 |
+
if hasattr(v.ia6, 'implicit'):
|
408 |
+
g[1].append(v.ia6.implicit)
|
409 |
+
else:
|
410 |
+
for iv in v.ia6:
|
411 |
+
g[1].append(iv.implicit)
|
412 |
+
|
413 |
+
if hasattr(v, 'im7'):
|
414 |
+
if hasattr(v.im7, 'implicit'):
|
415 |
+
g[1].append(v.im7.implicit)
|
416 |
+
else:
|
417 |
+
for iv in v.im7:
|
418 |
+
g[1].append(iv.implicit)
|
419 |
+
if hasattr(v, 'ia7'):
|
420 |
+
if hasattr(v.ia7, 'implicit'):
|
421 |
+
g[1].append(v.ia7.implicit)
|
422 |
+
else:
|
423 |
+
for iv in v.ia7:
|
424 |
+
g[1].append(iv.implicit)
|
425 |
+
|
426 |
+
if name == 'Adam':
|
427 |
+
optimizer = torch.optim.Adam(g[2], lr=lr, betas=(momentum, 0.999)) # adjust beta1 to momentum
|
428 |
+
elif name == 'AdamW':
|
429 |
+
optimizer = torch.optim.AdamW(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0, amsgrad=True)
|
430 |
+
elif name == 'RMSProp':
|
431 |
+
optimizer = torch.optim.RMSprop(g[2], lr=lr, momentum=momentum)
|
432 |
+
elif name == 'SGD':
|
433 |
+
optimizer = torch.optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
|
434 |
+
elif name == 'LION':
|
435 |
+
optimizer = Lion(g[2], lr=lr, betas=(momentum, 0.99), weight_decay=0.0)
|
436 |
+
else:
|
437 |
+
raise NotImplementedError(f'Optimizer {name} not implemented.')
|
438 |
+
|
439 |
+
optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
|
440 |
+
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
|
441 |
+
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
|
442 |
+
f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias")
|
443 |
+
return optimizer
|
444 |
+
|
445 |
+
|
446 |
+
def smart_hub_load(repo='ultralytics/yolov5', model='yolov5s', **kwargs):
|
447 |
+
# YOLOv5 torch.hub.load() wrapper with smart error/issue handling
|
448 |
+
if check_version(torch.__version__, '1.9.1'):
|
449 |
+
kwargs['skip_validation'] = True # validation causes GitHub API rate limit errors
|
450 |
+
if check_version(torch.__version__, '1.12.0'):
|
451 |
+
kwargs['trust_repo'] = True # argument required starting in torch 0.12
|
452 |
+
try:
|
453 |
+
return torch.hub.load(repo, model, **kwargs)
|
454 |
+
except Exception:
|
455 |
+
return torch.hub.load(repo, model, force_reload=True, **kwargs)
|
456 |
+
|
457 |
+
|
458 |
+
def smart_resume(ckpt, optimizer, ema=None, weights='yolov5s.pt', epochs=300, resume=True):
|
459 |
+
# Resume training from a partially trained checkpoint
|
460 |
+
best_fitness = 0.0
|
461 |
+
start_epoch = ckpt['epoch'] + 1
|
462 |
+
if ckpt['optimizer'] is not None:
|
463 |
+
optimizer.load_state_dict(ckpt['optimizer']) # optimizer
|
464 |
+
best_fitness = ckpt['best_fitness']
|
465 |
+
if ema and ckpt.get('ema'):
|
466 |
+
ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
|
467 |
+
ema.updates = ckpt['updates']
|
468 |
+
if resume:
|
469 |
+
assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.\n' \
|
470 |
+
f"Start a new training without --resume, i.e. 'python train.py --weights {weights}'"
|
471 |
+
LOGGER.info(f'Resuming training from {weights} from epoch {start_epoch} to {epochs} total epochs')
|
472 |
+
if epochs < start_epoch:
|
473 |
+
LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
|
474 |
+
epochs += ckpt['epoch'] # finetune additional epochs
|
475 |
+
return best_fitness, start_epoch, epochs
|
476 |
+
|
477 |
+
|
478 |
+
class EarlyStopping:
|
479 |
+
# YOLOv5 simple early stopper
|
480 |
+
def __init__(self, patience=30):
|
481 |
+
self.best_fitness = 0.0 # i.e. mAP
|
482 |
+
self.best_epoch = 0
|
483 |
+
self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
|
484 |
+
self.possible_stop = False # possible stop may occur next epoch
|
485 |
+
|
486 |
+
def __call__(self, epoch, fitness):
|
487 |
+
if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
|
488 |
+
self.best_epoch = epoch
|
489 |
+
self.best_fitness = fitness
|
490 |
+
delta = epoch - self.best_epoch # epochs without improvement
|
491 |
+
self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
|
492 |
+
stop = delta >= self.patience # stop training if patience exceeded
|
493 |
+
if stop:
|
494 |
+
LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. '
|
495 |
+
f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n'
|
496 |
+
f'To update EarlyStopping(patience={self.patience}) pass a new patience value, '
|
497 |
+
f'i.e. `python train.py --patience 300` or use `--patience 0` to disable EarlyStopping.')
|
498 |
+
return stop
|
499 |
+
|
500 |
+
|
501 |
+
class ModelEMA:
|
502 |
+
""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
|
503 |
+
Keeps a moving average of everything in the model state_dict (parameters and buffers)
|
504 |
+
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
505 |
+
"""
|
506 |
+
|
507 |
+
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
508 |
+
# Create EMA
|
509 |
+
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
|
510 |
+
self.updates = updates # number of EMA updates
|
511 |
+
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
|
512 |
+
for p in self.ema.parameters():
|
513 |
+
p.requires_grad_(False)
|
514 |
+
|
515 |
+
def update(self, model):
|
516 |
+
# Update EMA parameters
|
517 |
+
self.updates += 1
|
518 |
+
d = self.decay(self.updates)
|
519 |
+
|
520 |
+
msd = de_parallel(model).state_dict() # model state_dict
|
521 |
+
for k, v in self.ema.state_dict().items():
|
522 |
+
if v.dtype.is_floating_point: # true for FP16 and FP32
|
523 |
+
v *= d
|
524 |
+
v += (1 - d) * msd[k].detach()
|
525 |
+
# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
|
526 |
+
|
527 |
+
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
|
528 |
+
# Update EMA attributes
|
529 |
+
copy_attr(self.ema, model, include, exclude)
|