AkashDataScience commited on
Commit
ec32911
·
1 Parent(s): 67fa42b

First commit

Browse files
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import gradio as gr
4
+ from PIL import Image
5
+ from models.common import DetectMultiBackend
6
+ from utils.plots import Annotator, colors
7
+ from utils.torch_utils import select_device, smart_inference_mode
8
+ from utils.general import check_img_size, Profile, non_max_suppression, scale_boxes
9
+
10
+ weights = "runs/train/best_striped.pt"
11
+ data = "data.yaml"
12
+
13
+ def resize_image_pil(image, new_width, new_height):
14
+
15
+ # Convert to PIL image
16
+ img = Image.fromarray(np.array(image))
17
+
18
+ # Get original size
19
+ width, height = img.size
20
+
21
+ # Calculate scale
22
+ width_scale = new_width / width
23
+ height_scale = new_height / height
24
+ scale = min(width_scale, height_scale)
25
+
26
+ # Resize
27
+ resized = img.resize((int(width*scale), int(height*scale)), Image.NEAREST)
28
+
29
+ # Crop to exact size
30
+ resized = resized.crop((0, 0, new_width, new_height))
31
+
32
+ return resized
33
+
34
+ def inference(input_img, conf_thres, iou_thres):
35
+ im0 = input_img.copy()
36
+ # Load model
37
+ device = select_device(device)
38
+ model = DetectMultiBackend(weights, device=device, dnn=False, data=data, fp16=False)
39
+ stride, names, pt = model.stride, model.names, model.pt
40
+ imgsz = check_img_size(imgsz, s=stride) # check image size
41
+
42
+ bs = 1
43
+ # Run inference
44
+ model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz))
45
+ seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
46
+
47
+ with dt[0]:
48
+ im = torch.from_numpy(input_img).to(model.device)
49
+ im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
50
+ im /= 255 # 0 - 255 to 0.0 - 1.0
51
+ if len(im.shape) == 3:
52
+ im = im[None] # expand for batch dim
53
+
54
+ # Inference
55
+ with dt[1]:
56
+ pred = model(im, augment=False, visualize=False)
57
+
58
+ # NMS
59
+ with dt[2]:
60
+ pred = non_max_suppression(pred, conf_thres, iou_thres, None, False, max_det=10)
61
+
62
+ # Process predictions
63
+ for i, det in enumerate(pred): # per image
64
+ seen += 1
65
+ annotator = Annotator(im0, line_width=2, example=str(model.names))
66
+ if len(det):
67
+ # Rescale boxes from img_size to im0 size
68
+ det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
69
+
70
+ # Write results
71
+ for *xyxy, conf, cls in reversed(det):
72
+ c = int(cls) # integer class
73
+ label = '{names[c]} {conf:.2f}'
74
+ annotator.box_label(xyxy, label, color=colors(c, True))
75
+
76
+ return im0
77
+
78
+ title = "YOLOv9 model to detect shirt/tshirt"
79
+ description = "A simple Gradio interface to infer on YOLOv9 model and detect tshirt in image"
80
+ examples = [["image_1.jpg", 0.25, 0.45], ["image_2.jpg", 0.25, 0.45],
81
+ ["image_3.jpg", 0.25, 0.45], ["image_4.jpg", 0.25, 0.45],
82
+ ["image_5.jpg", 0.25, 0.45], ["image_6.jpg", 0.25, 0.45],
83
+ ["image_7.jpg", 0.25, 0.45], ["image_8.jpg", 0.25, 0.45],
84
+ ["image_9.jpg", 0.25, 0.45], ["image_10.jpg", 0.25, 0.45]]
85
+
86
+ demo = gr.Interface(inference,
87
+ inputs = [gr.Image(width=320, height=320, label="Input Image"),
88
+ gr.Slider(0, 1, 0.25, label="Confidance Thresold"),
89
+ gr.Slider(0, 1, 0.45, label="IoU Thresold")],
90
+ outputs= [gr.Image(width=640, height=640, label="Output")],
91
+ title=title,
92
+ description=description,
93
+ examples=examples)
data.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ train: data/customdata/train/images
2
+ val: data/customdata/valid/images
3
+ test: data/customdata/test/images
4
+
5
+ nc: 1
6
+ names: ['shirt']
image_1.jpg ADDED
image_10.jpg ADDED
image_2.jpg ADDED
image_3.jpg ADDED
image_4.jpg ADDED
image_5.jpg ADDED
image_6.jpg ADDED
image_7.jpg ADDED
image_8.jpg ADDED
image_9.jpg ADDED
models/common.py ADDED
@@ -0,0 +1,1212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import contextlib
3
+ import json
4
+ import math
5
+ import platform
6
+ import warnings
7
+ import zipfile
8
+ from collections import OrderedDict, namedtuple
9
+ from copy import copy
10
+ from pathlib import Path
11
+ from urllib.parse import urlparse
12
+
13
+ from typing import Optional
14
+
15
+ import cv2
16
+ import numpy as np
17
+ import pandas as pd
18
+ import requests
19
+ import torch
20
+ import torch.nn as nn
21
+ from IPython.display import display
22
+ from PIL import Image
23
+ from torch.cuda import amp
24
+
25
+ from utils import TryExcept
26
+ from utils.dataloaders import exif_transpose, letterbox
27
+ from utils.general import (LOGGER, ROOT, Profile, check_requirements, check_suffix, check_version, colorstr,
28
+ increment_path, is_notebook, make_divisible, non_max_suppression, scale_boxes,
29
+ xywh2xyxy, xyxy2xywh, yaml_load)
30
+ from utils.plots import Annotator, colors, save_one_box
31
+ from utils.torch_utils import copy_attr, smart_inference_mode
32
+
33
+
34
+ def autopad(k, p=None, d=1): # kernel, padding, dilation
35
+ # Pad to 'same' shape outputs
36
+ if d > 1:
37
+ k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
38
+ if p is None:
39
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
40
+ return p
41
+
42
+
43
+ class Conv(nn.Module):
44
+ # Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)
45
+ default_act = nn.SiLU() # default activation
46
+
47
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
48
+ super().__init__()
49
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
50
+ self.bn = nn.BatchNorm2d(c2)
51
+ self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
52
+
53
+ def forward(self, x):
54
+ return self.act(self.bn(self.conv(x)))
55
+
56
+ def forward_fuse(self, x):
57
+ return self.act(self.conv(x))
58
+
59
+
60
+ class AConv(nn.Module):
61
+ def __init__(self, c1, c2): # ch_in, ch_out, shortcut, kernels, groups, expand
62
+ super().__init__()
63
+ self.cv1 = Conv(c1, c2, 3, 2, 1)
64
+
65
+ def forward(self, x):
66
+ x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)
67
+ return self.cv1(x)
68
+
69
+
70
+ class ADown(nn.Module):
71
+ def __init__(self, c1, c2): # ch_in, ch_out, shortcut, kernels, groups, expand
72
+ super().__init__()
73
+ self.c = c2 // 2
74
+ self.cv1 = Conv(c1 // 2, self.c, 3, 2, 1)
75
+ self.cv2 = Conv(c1 // 2, self.c, 1, 1, 0)
76
+
77
+ def forward(self, x):
78
+ x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)
79
+ x1,x2 = x.chunk(2, 1)
80
+ x1 = self.cv1(x1)
81
+ x2 = torch.nn.functional.max_pool2d(x2, 3, 2, 1)
82
+ x2 = self.cv2(x2)
83
+ return torch.cat((x1, x2), 1)
84
+
85
+
86
+ class RepConvN(nn.Module):
87
+ """RepConv is a basic rep-style block, including training and deploy status
88
+ This code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
89
+ """
90
+ default_act = nn.SiLU() # default activation
91
+
92
+ def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False):
93
+ super().__init__()
94
+ assert k == 3 and p == 1
95
+ self.g = g
96
+ self.c1 = c1
97
+ self.c2 = c2
98
+ self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
99
+
100
+ self.bn = None
101
+ self.conv1 = Conv(c1, c2, k, s, p=p, g=g, act=False)
102
+ self.conv2 = Conv(c1, c2, 1, s, p=(p - k // 2), g=g, act=False)
103
+
104
+ def forward_fuse(self, x):
105
+ """Forward process"""
106
+ return self.act(self.conv(x))
107
+
108
+ def forward(self, x):
109
+ """Forward process"""
110
+ id_out = 0 if self.bn is None else self.bn(x)
111
+ return self.act(self.conv1(x) + self.conv2(x) + id_out)
112
+
113
+ def get_equivalent_kernel_bias(self):
114
+ kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
115
+ kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
116
+ kernelid, biasid = self._fuse_bn_tensor(self.bn)
117
+ return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
118
+
119
+ def _avg_to_3x3_tensor(self, avgp):
120
+ channels = self.c1
121
+ groups = self.g
122
+ kernel_size = avgp.kernel_size
123
+ input_dim = channels // groups
124
+ k = torch.zeros((channels, input_dim, kernel_size, kernel_size))
125
+ k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1.0 / kernel_size ** 2
126
+ return k
127
+
128
+ def _pad_1x1_to_3x3_tensor(self, kernel1x1):
129
+ if kernel1x1 is None:
130
+ return 0
131
+ else:
132
+ return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
133
+
134
+ def _fuse_bn_tensor(self, branch):
135
+ if branch is None:
136
+ return 0, 0
137
+ if isinstance(branch, Conv):
138
+ kernel = branch.conv.weight
139
+ running_mean = branch.bn.running_mean
140
+ running_var = branch.bn.running_var
141
+ gamma = branch.bn.weight
142
+ beta = branch.bn.bias
143
+ eps = branch.bn.eps
144
+ elif isinstance(branch, nn.BatchNorm2d):
145
+ if not hasattr(self, 'id_tensor'):
146
+ input_dim = self.c1 // self.g
147
+ kernel_value = np.zeros((self.c1, input_dim, 3, 3), dtype=np.float32)
148
+ for i in range(self.c1):
149
+ kernel_value[i, i % input_dim, 1, 1] = 1
150
+ self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
151
+ kernel = self.id_tensor
152
+ running_mean = branch.running_mean
153
+ running_var = branch.running_var
154
+ gamma = branch.weight
155
+ beta = branch.bias
156
+ eps = branch.eps
157
+ std = (running_var + eps).sqrt()
158
+ t = (gamma / std).reshape(-1, 1, 1, 1)
159
+ return kernel * t, beta - running_mean * gamma / std
160
+
161
+ def fuse_convs(self):
162
+ if hasattr(self, 'conv'):
163
+ return
164
+ kernel, bias = self.get_equivalent_kernel_bias()
165
+ self.conv = nn.Conv2d(in_channels=self.conv1.conv.in_channels,
166
+ out_channels=self.conv1.conv.out_channels,
167
+ kernel_size=self.conv1.conv.kernel_size,
168
+ stride=self.conv1.conv.stride,
169
+ padding=self.conv1.conv.padding,
170
+ dilation=self.conv1.conv.dilation,
171
+ groups=self.conv1.conv.groups,
172
+ bias=True).requires_grad_(False)
173
+ self.conv.weight.data = kernel
174
+ self.conv.bias.data = bias
175
+ for para in self.parameters():
176
+ para.detach_()
177
+ self.__delattr__('conv1')
178
+ self.__delattr__('conv2')
179
+ if hasattr(self, 'nm'):
180
+ self.__delattr__('nm')
181
+ if hasattr(self, 'bn'):
182
+ self.__delattr__('bn')
183
+ if hasattr(self, 'id_tensor'):
184
+ self.__delattr__('id_tensor')
185
+
186
+
187
+ class SP(nn.Module):
188
+ def __init__(self, k=3, s=1):
189
+ super(SP, self).__init__()
190
+ self.m = nn.MaxPool2d(kernel_size=k, stride=s, padding=k // 2)
191
+
192
+ def forward(self, x):
193
+ return self.m(x)
194
+
195
+
196
+ class MP(nn.Module):
197
+ # Max pooling
198
+ def __init__(self, k=2):
199
+ super(MP, self).__init__()
200
+ self.m = nn.MaxPool2d(kernel_size=k, stride=k)
201
+
202
+ def forward(self, x):
203
+ return self.m(x)
204
+
205
+
206
+ class ConvTranspose(nn.Module):
207
+ # Convolution transpose 2d layer
208
+ default_act = nn.SiLU() # default activation
209
+
210
+ def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):
211
+ super().__init__()
212
+ self.conv_transpose = nn.ConvTranspose2d(c1, c2, k, s, p, bias=not bn)
213
+ self.bn = nn.BatchNorm2d(c2) if bn else nn.Identity()
214
+ self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
215
+
216
+ def forward(self, x):
217
+ return self.act(self.bn(self.conv_transpose(x)))
218
+
219
+
220
+ class DWConv(Conv):
221
+ # Depth-wise convolution
222
+ def __init__(self, c1, c2, k=1, s=1, d=1, act=True): # ch_in, ch_out, kernel, stride, dilation, activation
223
+ super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
224
+
225
+
226
+ class DWConvTranspose2d(nn.ConvTranspose2d):
227
+ # Depth-wise transpose convolution
228
+ def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0): # ch_in, ch_out, kernel, stride, padding, padding_out
229
+ super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2))
230
+
231
+
232
+ class DFL(nn.Module):
233
+ # DFL module
234
+ def __init__(self, c1=17):
235
+ super().__init__()
236
+ self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
237
+ self.conv.weight.data[:] = nn.Parameter(torch.arange(c1, dtype=torch.float).view(1, c1, 1, 1)) # / 120.0
238
+ self.c1 = c1
239
+ # self.bn = nn.BatchNorm2d(4)
240
+
241
+ def forward(self, x):
242
+ b, c, a = x.shape # batch, channels, anchors
243
+ return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
244
+ # return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)
245
+
246
+
247
+ class BottleneckBase(nn.Module):
248
+ # Standard bottleneck
249
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(1, 3), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
250
+ super().__init__()
251
+ c_ = int(c2 * e) # hidden channels
252
+ self.cv1 = Conv(c1, c_, k[0], 1)
253
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
254
+ self.add = shortcut and c1 == c2
255
+
256
+ def forward(self, x):
257
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
258
+
259
+
260
+ class RBottleneckBase(nn.Module):
261
+ # Standard bottleneck
262
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 1), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
263
+ super().__init__()
264
+ c_ = int(c2 * e) # hidden channels
265
+ self.cv1 = Conv(c1, c_, k[0], 1)
266
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
267
+ self.add = shortcut and c1 == c2
268
+
269
+ def forward(self, x):
270
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
271
+
272
+
273
+ class RepNRBottleneckBase(nn.Module):
274
+ # Standard bottleneck
275
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 1), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
276
+ super().__init__()
277
+ c_ = int(c2 * e) # hidden channels
278
+ self.cv1 = RepConvN(c1, c_, k[0], 1)
279
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
280
+ self.add = shortcut and c1 == c2
281
+
282
+ def forward(self, x):
283
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
284
+
285
+
286
+ class Bottleneck(nn.Module):
287
+ # Standard bottleneck
288
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
289
+ super().__init__()
290
+ c_ = int(c2 * e) # hidden channels
291
+ self.cv1 = Conv(c1, c_, k[0], 1)
292
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
293
+ self.add = shortcut and c1 == c2
294
+
295
+ def forward(self, x):
296
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
297
+
298
+
299
+ class RepNBottleneck(nn.Module):
300
+ # Standard bottleneck
301
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
302
+ super().__init__()
303
+ c_ = int(c2 * e) # hidden channels
304
+ self.cv1 = RepConvN(c1, c_, k[0], 1)
305
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
306
+ self.add = shortcut and c1 == c2
307
+
308
+ def forward(self, x):
309
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
310
+
311
+
312
+ class Res(nn.Module):
313
+ # ResNet bottleneck
314
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
315
+ super(Res, self).__init__()
316
+ c_ = int(c2 * e) # hidden channels
317
+ self.cv1 = Conv(c1, c_, 1, 1)
318
+ self.cv2 = Conv(c_, c_, 3, 1, g=g)
319
+ self.cv3 = Conv(c_, c2, 1, 1)
320
+ self.add = shortcut and c1 == c2
321
+
322
+ def forward(self, x):
323
+ return x + self.cv3(self.cv2(self.cv1(x))) if self.add else self.cv3(self.cv2(self.cv1(x)))
324
+
325
+
326
+ class RepNRes(nn.Module):
327
+ # ResNet bottleneck
328
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
329
+ super(RepNRes, self).__init__()
330
+ c_ = int(c2 * e) # hidden channels
331
+ self.cv1 = Conv(c1, c_, 1, 1)
332
+ self.cv2 = RepConvN(c_, c_, 3, 1, g=g)
333
+ self.cv3 = Conv(c_, c2, 1, 1)
334
+ self.add = shortcut and c1 == c2
335
+
336
+ def forward(self, x):
337
+ return x + self.cv3(self.cv2(self.cv1(x))) if self.add else self.cv3(self.cv2(self.cv1(x)))
338
+
339
+
340
+ class BottleneckCSP(nn.Module):
341
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
342
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
343
+ super().__init__()
344
+ c_ = int(c2 * e) # hidden channels
345
+ self.cv1 = Conv(c1, c_, 1, 1)
346
+ self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
347
+ self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
348
+ self.cv4 = Conv(2 * c_, c2, 1, 1)
349
+ self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
350
+ self.act = nn.SiLU()
351
+ self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
352
+
353
+ def forward(self, x):
354
+ y1 = self.cv3(self.m(self.cv1(x)))
355
+ y2 = self.cv2(x)
356
+ return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1))))
357
+
358
+
359
+ class CSP(nn.Module):
360
+ # CSP Bottleneck with 3 convolutions
361
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
362
+ super().__init__()
363
+ c_ = int(c2 * e) # hidden channels
364
+ self.cv1 = Conv(c1, c_, 1, 1)
365
+ self.cv2 = Conv(c1, c_, 1, 1)
366
+ self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
367
+ self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
368
+
369
+ def forward(self, x):
370
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
371
+
372
+
373
+ class RepNCSP(nn.Module):
374
+ # CSP Bottleneck with 3 convolutions
375
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
376
+ super().__init__()
377
+ c_ = int(c2 * e) # hidden channels
378
+ self.cv1 = Conv(c1, c_, 1, 1)
379
+ self.cv2 = Conv(c1, c_, 1, 1)
380
+ self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
381
+ self.m = nn.Sequential(*(RepNBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
382
+
383
+ def forward(self, x):
384
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
385
+
386
+
387
+ class CSPBase(nn.Module):
388
+ # CSP Bottleneck with 3 convolutions
389
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
390
+ super().__init__()
391
+ c_ = int(c2 * e) # hidden channels
392
+ self.cv1 = Conv(c1, c_, 1, 1)
393
+ self.cv2 = Conv(c1, c_, 1, 1)
394
+ self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
395
+ self.m = nn.Sequential(*(BottleneckBase(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
396
+
397
+ def forward(self, x):
398
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
399
+
400
+
401
+ class SPP(nn.Module):
402
+ # Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
403
+ def __init__(self, c1, c2, k=(5, 9, 13)):
404
+ super().__init__()
405
+ c_ = c1 // 2 # hidden channels
406
+ self.cv1 = Conv(c1, c_, 1, 1)
407
+ self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
408
+ self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
409
+
410
+ def forward(self, x):
411
+ x = self.cv1(x)
412
+ with warnings.catch_warnings():
413
+ warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
414
+ return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
415
+
416
+
417
+ class ASPP(torch.nn.Module):
418
+
419
+ def __init__(self, in_channels, out_channels):
420
+ super().__init__()
421
+ kernel_sizes = [1, 3, 3, 1]
422
+ dilations = [1, 3, 6, 1]
423
+ paddings = [0, 3, 6, 0]
424
+ self.aspp = torch.nn.ModuleList()
425
+ for aspp_idx in range(len(kernel_sizes)):
426
+ conv = torch.nn.Conv2d(
427
+ in_channels,
428
+ out_channels,
429
+ kernel_size=kernel_sizes[aspp_idx],
430
+ stride=1,
431
+ dilation=dilations[aspp_idx],
432
+ padding=paddings[aspp_idx],
433
+ bias=True)
434
+ self.aspp.append(conv)
435
+ self.gap = torch.nn.AdaptiveAvgPool2d(1)
436
+ self.aspp_num = len(kernel_sizes)
437
+ for m in self.modules():
438
+ if isinstance(m, torch.nn.Conv2d):
439
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
440
+ m.weight.data.normal_(0, math.sqrt(2. / n))
441
+ m.bias.data.fill_(0)
442
+
443
+ def forward(self, x):
444
+ avg_x = self.gap(x)
445
+ out = []
446
+ for aspp_idx in range(self.aspp_num):
447
+ inp = avg_x if (aspp_idx == self.aspp_num - 1) else x
448
+ out.append(F.relu_(self.aspp[aspp_idx](inp)))
449
+ out[-1] = out[-1].expand_as(out[-2])
450
+ out = torch.cat(out, dim=1)
451
+ return out
452
+
453
+
454
+ class SPPCSPC(nn.Module):
455
+ # CSP SPP https://github.com/WongKinYiu/CrossStagePartialNetworks
456
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, k=(5, 9, 13)):
457
+ super(SPPCSPC, self).__init__()
458
+ c_ = int(2 * c2 * e) # hidden channels
459
+ self.cv1 = Conv(c1, c_, 1, 1)
460
+ self.cv2 = Conv(c1, c_, 1, 1)
461
+ self.cv3 = Conv(c_, c_, 3, 1)
462
+ self.cv4 = Conv(c_, c_, 1, 1)
463
+ self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
464
+ self.cv5 = Conv(4 * c_, c_, 1, 1)
465
+ self.cv6 = Conv(c_, c_, 3, 1)
466
+ self.cv7 = Conv(2 * c_, c2, 1, 1)
467
+
468
+ def forward(self, x):
469
+ x1 = self.cv4(self.cv3(self.cv1(x)))
470
+ y1 = self.cv6(self.cv5(torch.cat([x1] + [m(x1) for m in self.m], 1)))
471
+ y2 = self.cv2(x)
472
+ return self.cv7(torch.cat((y1, y2), dim=1))
473
+
474
+
475
+ class SPPF(nn.Module):
476
+ # Spatial Pyramid Pooling - Fast (SPPF) layer by Glenn Jocher
477
+ def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
478
+ super().__init__()
479
+ c_ = c1 // 2 # hidden channels
480
+ self.cv1 = Conv(c1, c_, 1, 1)
481
+ self.cv2 = Conv(c_ * 4, c2, 1, 1)
482
+ self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
483
+ # self.m = SoftPool2d(kernel_size=k, stride=1, padding=k // 2)
484
+
485
+ def forward(self, x):
486
+ x = self.cv1(x)
487
+ with warnings.catch_warnings():
488
+ warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
489
+ y1 = self.m(x)
490
+ y2 = self.m(y1)
491
+ return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
492
+
493
+
494
+ import torch.nn.functional as F
495
+ from torch.nn.modules.utils import _pair
496
+
497
+
498
+ class ReOrg(nn.Module):
499
+ # yolo
500
+ def __init__(self):
501
+ super(ReOrg, self).__init__()
502
+
503
+ def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
504
+ return torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)
505
+
506
+
507
+ class Contract(nn.Module):
508
+ # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
509
+ def __init__(self, gain=2):
510
+ super().__init__()
511
+ self.gain = gain
512
+
513
+ def forward(self, x):
514
+ b, c, h, w = x.size() # assert (h / s == 0) and (W / s == 0), 'Indivisible gain'
515
+ s = self.gain
516
+ x = x.view(b, c, h // s, s, w // s, s) # x(1,64,40,2,40,2)
517
+ x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
518
+ return x.view(b, c * s * s, h // s, w // s) # x(1,256,40,40)
519
+
520
+
521
+ class Expand(nn.Module):
522
+ # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
523
+ def __init__(self, gain=2):
524
+ super().__init__()
525
+ self.gain = gain
526
+
527
+ def forward(self, x):
528
+ b, c, h, w = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
529
+ s = self.gain
530
+ x = x.view(b, s, s, c // s ** 2, h, w) # x(1,2,2,16,80,80)
531
+ x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
532
+ return x.view(b, c // s ** 2, h * s, w * s) # x(1,16,160,160)
533
+
534
+
535
+ class Concat(nn.Module):
536
+ # Concatenate a list of tensors along dimension
537
+ def __init__(self, dimension=1):
538
+ super().__init__()
539
+ self.d = dimension
540
+
541
+ def forward(self, x):
542
+ return torch.cat(x, self.d)
543
+
544
+
545
+ class Shortcut(nn.Module):
546
+ def __init__(self, dimension=0):
547
+ super(Shortcut, self).__init__()
548
+ self.d = dimension
549
+
550
+ def forward(self, x):
551
+ return x[0]+x[1]
552
+
553
+
554
+ class Silence(nn.Module):
555
+ def __init__(self):
556
+ super(Silence, self).__init__()
557
+ def forward(self, x):
558
+ return x
559
+
560
+
561
+ ##### GELAN #####
562
+
563
+ class SPPELAN(nn.Module):
564
+ # spp-elan
565
+ def __init__(self, c1, c2, c3): # ch_in, ch_out, number, shortcut, groups, expansion
566
+ super().__init__()
567
+ self.c = c3
568
+ self.cv1 = Conv(c1, c3, 1, 1)
569
+ self.cv2 = SP(5)
570
+ self.cv3 = SP(5)
571
+ self.cv4 = SP(5)
572
+ self.cv5 = Conv(4*c3, c2, 1, 1)
573
+
574
+ def forward(self, x):
575
+ y = [self.cv1(x)]
576
+ y.extend(m(y[-1]) for m in [self.cv2, self.cv3, self.cv4])
577
+ return self.cv5(torch.cat(y, 1))
578
+
579
+
580
+ class RepNCSPELAN4(nn.Module):
581
+ # csp-elan
582
+ def __init__(self, c1, c2, c3, c4, c5=1): # ch_in, ch_out, number, shortcut, groups, expansion
583
+ super().__init__()
584
+ self.c = c3//2
585
+ self.cv1 = Conv(c1, c3, 1, 1)
586
+ self.cv2 = nn.Sequential(RepNCSP(c3//2, c4, c5), Conv(c4, c4, 3, 1))
587
+ self.cv3 = nn.Sequential(RepNCSP(c4, c4, c5), Conv(c4, c4, 3, 1))
588
+ self.cv4 = Conv(c3+(2*c4), c2, 1, 1)
589
+
590
+ def forward(self, x):
591
+ y = list(self.cv1(x).chunk(2, 1))
592
+ y.extend((m(y[-1])) for m in [self.cv2, self.cv3])
593
+ return self.cv4(torch.cat(y, 1))
594
+
595
+ def forward_split(self, x):
596
+ y = list(self.cv1(x).split((self.c, self.c), 1))
597
+ y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
598
+ return self.cv4(torch.cat(y, 1))
599
+
600
+ #################
601
+
602
+
603
+ ##### YOLOR #####
604
+
605
+ class ImplicitA(nn.Module):
606
+ def __init__(self, channel):
607
+ super(ImplicitA, self).__init__()
608
+ self.channel = channel
609
+ self.implicit = nn.Parameter(torch.zeros(1, channel, 1, 1))
610
+ nn.init.normal_(self.implicit, std=.02)
611
+
612
+ def forward(self, x):
613
+ return self.implicit + x
614
+
615
+
616
+ class ImplicitM(nn.Module):
617
+ def __init__(self, channel):
618
+ super(ImplicitM, self).__init__()
619
+ self.channel = channel
620
+ self.implicit = nn.Parameter(torch.ones(1, channel, 1, 1))
621
+ nn.init.normal_(self.implicit, mean=1., std=.02)
622
+
623
+ def forward(self, x):
624
+ return self.implicit * x
625
+
626
+ #################
627
+
628
+
629
+ ##### CBNet #####
630
+
631
+ class CBLinear(nn.Module):
632
+ def __init__(self, c1, c2s, k=1, s=1, p=None, g=1): # ch_in, ch_outs, kernel, stride, padding, groups
633
+ super(CBLinear, self).__init__()
634
+ self.c2s = c2s
635
+ self.conv = nn.Conv2d(c1, sum(c2s), k, s, autopad(k, p), groups=g, bias=True)
636
+
637
+ def forward(self, x):
638
+ outs = self.conv(x).split(self.c2s, dim=1)
639
+ return outs
640
+
641
+ class CBFuse(nn.Module):
642
+ def __init__(self, idx):
643
+ super(CBFuse, self).__init__()
644
+ self.idx = idx
645
+
646
+ def forward(self, xs):
647
+ target_size = xs[-1].shape[2:]
648
+ res = [F.interpolate(x[self.idx[i]], size=target_size, mode='nearest') for i, x in enumerate(xs[:-1])]
649
+ out = torch.sum(torch.stack(res + xs[-1:]), dim=0)
650
+ return out
651
+
652
+ #################
653
+
654
+
655
+ class DetectMultiBackend(nn.Module):
656
+ # YOLO MultiBackend class for python inference on various backends
657
+ def __init__(self, weights='yolo.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True):
658
+ # Usage:
659
+ # PyTorch: weights = *.pt
660
+ # TorchScript: *.torchscript
661
+ # ONNX Runtime: *.onnx
662
+ # ONNX OpenCV DNN: *.onnx --dnn
663
+ # OpenVINO: *_openvino_model
664
+ # CoreML: *.mlmodel
665
+ # TensorRT: *.engine
666
+ # TensorFlow SavedModel: *_saved_model
667
+ # TensorFlow GraphDef: *.pb
668
+ # TensorFlow Lite: *.tflite
669
+ # TensorFlow Edge TPU: *_edgetpu.tflite
670
+ # PaddlePaddle: *_paddle_model
671
+ from models.experimental import attempt_download, attempt_load # scoped to avoid circular import
672
+
673
+ super().__init__()
674
+ w = str(weights[0] if isinstance(weights, list) else weights)
675
+ pt, jit, onnx, onnx_end2end, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, triton = self._model_type(w)
676
+ fp16 &= pt or jit or onnx or engine # FP16
677
+ nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
678
+ stride = 32 # default stride
679
+ cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
680
+ if not (pt or triton):
681
+ w = attempt_download(w) # download if not local
682
+
683
+ if pt: # PyTorch
684
+ model = attempt_load(weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse)
685
+ stride = max(int(model.stride.max()), 32) # model stride
686
+ names = model.module.names if hasattr(model, 'module') else model.names # get class names
687
+ model.half() if fp16 else model.float()
688
+ self.model = model # explicitly assign for to(), cpu(), cuda(), half()
689
+ elif jit: # TorchScript
690
+ LOGGER.info(f'Loading {w} for TorchScript inference...')
691
+ extra_files = {'config.txt': ''} # model metadata
692
+ model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
693
+ model.half() if fp16 else model.float()
694
+ if extra_files['config.txt']: # load metadata dict
695
+ d = json.loads(extra_files['config.txt'],
696
+ object_hook=lambda d: {int(k) if k.isdigit() else k: v
697
+ for k, v in d.items()})
698
+ stride, names = int(d['stride']), d['names']
699
+ elif dnn: # ONNX OpenCV DNN
700
+ LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
701
+ check_requirements('opencv-python>=4.5.4')
702
+ net = cv2.dnn.readNetFromONNX(w)
703
+ elif onnx: # ONNX Runtime
704
+ LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
705
+ check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
706
+ import onnxruntime
707
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
708
+ session = onnxruntime.InferenceSession(w, providers=providers)
709
+ output_names = [x.name for x in session.get_outputs()]
710
+ meta = session.get_modelmeta().custom_metadata_map # metadata
711
+ if 'stride' in meta:
712
+ stride, names = int(meta['stride']), eval(meta['names'])
713
+ elif xml: # OpenVINO
714
+ LOGGER.info(f'Loading {w} for OpenVINO inference...')
715
+ check_requirements('openvino') # requires openvino-dev: https://pypi.org/project/openvino-dev/
716
+ from openvino.runtime import Core, Layout, get_batch
717
+ ie = Core()
718
+ if not Path(w).is_file(): # if not *.xml
719
+ w = next(Path(w).glob('*.xml')) # get *.xml file from *_openvino_model dir
720
+ network = ie.read_model(model=w, weights=Path(w).with_suffix('.bin'))
721
+ if network.get_parameters()[0].get_layout().empty:
722
+ network.get_parameters()[0].set_layout(Layout("NCHW"))
723
+ batch_dim = get_batch(network)
724
+ if batch_dim.is_static:
725
+ batch_size = batch_dim.get_length()
726
+ executable_network = ie.compile_model(network, device_name="CPU") # device_name="MYRIAD" for Intel NCS2
727
+ stride, names = self._load_metadata(Path(w).with_suffix('.yaml')) # load metadata
728
+ elif engine: # TensorRT
729
+ LOGGER.info(f'Loading {w} for TensorRT inference...')
730
+ import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
731
+ check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
732
+ if device.type == 'cpu':
733
+ device = torch.device('cuda:0')
734
+ Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
735
+ logger = trt.Logger(trt.Logger.INFO)
736
+ with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
737
+ model = runtime.deserialize_cuda_engine(f.read())
738
+ context = model.create_execution_context()
739
+ bindings = OrderedDict()
740
+ output_names = []
741
+ fp16 = False # default updated below
742
+ dynamic = False
743
+ for i in range(model.num_bindings):
744
+ name = model.get_binding_name(i)
745
+ dtype = trt.nptype(model.get_binding_dtype(i))
746
+ if model.binding_is_input(i):
747
+ if -1 in tuple(model.get_binding_shape(i)): # dynamic
748
+ dynamic = True
749
+ context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2]))
750
+ if dtype == np.float16:
751
+ fp16 = True
752
+ else: # output
753
+ output_names.append(name)
754
+ shape = tuple(context.get_binding_shape(i))
755
+ im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
756
+ bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
757
+ binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
758
+ batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size
759
+ elif coreml: # CoreML
760
+ LOGGER.info(f'Loading {w} for CoreML inference...')
761
+ import coremltools as ct
762
+ model = ct.models.MLModel(w)
763
+ elif saved_model: # TF SavedModel
764
+ LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
765
+ import tensorflow as tf
766
+ keras = False # assume TF1 saved_model
767
+ model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
768
+ elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
769
+ LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
770
+ import tensorflow as tf
771
+
772
+ def wrap_frozen_graph(gd, inputs, outputs):
773
+ x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
774
+ ge = x.graph.as_graph_element
775
+ return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
776
+
777
+ def gd_outputs(gd):
778
+ name_list, input_list = [], []
779
+ for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
780
+ name_list.append(node.name)
781
+ input_list.extend(node.input)
782
+ return sorted(f'{x}:0' for x in list(set(name_list) - set(input_list)) if not x.startswith('NoOp'))
783
+
784
+ gd = tf.Graph().as_graph_def() # TF GraphDef
785
+ with open(w, 'rb') as f:
786
+ gd.ParseFromString(f.read())
787
+ frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
788
+ elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
789
+ try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
790
+ from tflite_runtime.interpreter import Interpreter, load_delegate
791
+ except ImportError:
792
+ import tensorflow as tf
793
+ Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
794
+ if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
795
+ LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
796
+ delegate = {
797
+ 'Linux': 'libedgetpu.so.1',
798
+ 'Darwin': 'libedgetpu.1.dylib',
799
+ 'Windows': 'edgetpu.dll'}[platform.system()]
800
+ interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
801
+ else: # TFLite
802
+ LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
803
+ interpreter = Interpreter(model_path=w) # load TFLite model
804
+ interpreter.allocate_tensors() # allocate
805
+ input_details = interpreter.get_input_details() # inputs
806
+ output_details = interpreter.get_output_details() # outputs
807
+ # load metadata
808
+ with contextlib.suppress(zipfile.BadZipFile):
809
+ with zipfile.ZipFile(w, "r") as model:
810
+ meta_file = model.namelist()[0]
811
+ meta = ast.literal_eval(model.read(meta_file).decode("utf-8"))
812
+ stride, names = int(meta['stride']), meta['names']
813
+ elif tfjs: # TF.js
814
+ raise NotImplementedError('ERROR: YOLO TF.js inference is not supported')
815
+ elif paddle: # PaddlePaddle
816
+ LOGGER.info(f'Loading {w} for PaddlePaddle inference...')
817
+ check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle')
818
+ import paddle.inference as pdi
819
+ if not Path(w).is_file(): # if not *.pdmodel
820
+ w = next(Path(w).rglob('*.pdmodel')) # get *.pdmodel file from *_paddle_model dir
821
+ weights = Path(w).with_suffix('.pdiparams')
822
+ config = pdi.Config(str(w), str(weights))
823
+ if cuda:
824
+ config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
825
+ predictor = pdi.create_predictor(config)
826
+ input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
827
+ output_names = predictor.get_output_names()
828
+ elif triton: # NVIDIA Triton Inference Server
829
+ LOGGER.info(f'Using {w} as Triton Inference Server...')
830
+ check_requirements('tritonclient[all]')
831
+ from utils.triton import TritonRemoteModel
832
+ model = TritonRemoteModel(url=w)
833
+ nhwc = model.runtime.startswith("tensorflow")
834
+ else:
835
+ raise NotImplementedError(f'ERROR: {w} is not a supported format')
836
+
837
+ # class names
838
+ if 'names' not in locals():
839
+ names = yaml_load(data)['names'] if data else {i: f'class{i}' for i in range(999)}
840
+ if names[0] == 'n01440764' and len(names) == 1000: # ImageNet
841
+ names = yaml_load(ROOT / 'data/ImageNet.yaml')['names'] # human-readable names
842
+
843
+ self.__dict__.update(locals()) # assign all variables to self
844
+
845
+ def forward(self, im, augment=False, visualize=False):
846
+ # YOLO MultiBackend inference
847
+ b, ch, h, w = im.shape # batch, channel, height, width
848
+ if self.fp16 and im.dtype != torch.float16:
849
+ im = im.half() # to FP16
850
+ if self.nhwc:
851
+ im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
852
+
853
+ if self.pt: # PyTorch
854
+ y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
855
+ elif self.jit: # TorchScript
856
+ y = self.model(im)
857
+ elif self.dnn: # ONNX OpenCV DNN
858
+ im = im.cpu().numpy() # torch to numpy
859
+ self.net.setInput(im)
860
+ y = self.net.forward()
861
+ elif self.onnx: # ONNX Runtime
862
+ im = im.cpu().numpy() # torch to numpy
863
+ y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
864
+ elif self.xml: # OpenVINO
865
+ im = im.cpu().numpy() # FP32
866
+ y = list(self.executable_network([im]).values())
867
+ elif self.engine: # TensorRT
868
+ if self.dynamic and im.shape != self.bindings['images'].shape:
869
+ i = self.model.get_binding_index('images')
870
+ self.context.set_binding_shape(i, im.shape) # reshape if dynamic
871
+ self.bindings['images'] = self.bindings['images']._replace(shape=im.shape)
872
+ for name in self.output_names:
873
+ i = self.model.get_binding_index(name)
874
+ self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
875
+ s = self.bindings['images'].shape
876
+ assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
877
+ self.binding_addrs['images'] = int(im.data_ptr())
878
+ self.context.execute_v2(list(self.binding_addrs.values()))
879
+ y = [self.bindings[x].data for x in sorted(self.output_names)]
880
+ elif self.coreml: # CoreML
881
+ im = im.cpu().numpy()
882
+ im = Image.fromarray((im[0] * 255).astype('uint8'))
883
+ # im = im.resize((192, 320), Image.ANTIALIAS)
884
+ y = self.model.predict({'image': im}) # coordinates are xywh normalized
885
+ if 'confidence' in y:
886
+ box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
887
+ conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
888
+ y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
889
+ else:
890
+ y = list(reversed(y.values())) # reversed for segmentation models (pred, proto)
891
+ elif self.paddle: # PaddlePaddle
892
+ im = im.cpu().numpy().astype(np.float32)
893
+ self.input_handle.copy_from_cpu(im)
894
+ self.predictor.run()
895
+ y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
896
+ elif self.triton: # NVIDIA Triton Inference Server
897
+ y = self.model(im)
898
+ else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
899
+ im = im.cpu().numpy()
900
+ if self.saved_model: # SavedModel
901
+ y = self.model(im, training=False) if self.keras else self.model(im)
902
+ elif self.pb: # GraphDef
903
+ y = self.frozen_func(x=self.tf.constant(im))
904
+ else: # Lite or Edge TPU
905
+ input = self.input_details[0]
906
+ int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model
907
+ if int8:
908
+ scale, zero_point = input['quantization']
909
+ im = (im / scale + zero_point).astype(np.uint8) # de-scale
910
+ self.interpreter.set_tensor(input['index'], im)
911
+ self.interpreter.invoke()
912
+ y = []
913
+ for output in self.output_details:
914
+ x = self.interpreter.get_tensor(output['index'])
915
+ if int8:
916
+ scale, zero_point = output['quantization']
917
+ x = (x.astype(np.float32) - zero_point) * scale # re-scale
918
+ y.append(x)
919
+ y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
920
+ y[0][..., :4] *= [w, h, w, h] # xywh normalized to pixels
921
+
922
+ if isinstance(y, (list, tuple)):
923
+ return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
924
+ else:
925
+ return self.from_numpy(y)
926
+
927
+ def from_numpy(self, x):
928
+ return torch.from_numpy(x).to(self.device) if isinstance(x, np.ndarray) else x
929
+
930
+ def warmup(self, imgsz=(1, 3, 640, 640)):
931
+ # Warmup model by running inference once
932
+ warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton
933
+ if any(warmup_types) and (self.device.type != 'cpu' or self.triton):
934
+ im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
935
+ for _ in range(2 if self.jit else 1): #
936
+ self.forward(im) # warmup
937
+
938
+ @staticmethod
939
+ def _model_type(p='path/to/model.pt'):
940
+ # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
941
+ # types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
942
+ from export import export_formats
943
+ from utils.downloads import is_url
944
+ sf = list(export_formats().Suffix) # export suffixes
945
+ if not is_url(p, check=False):
946
+ check_suffix(p, sf) # checks
947
+ url = urlparse(p) # if url may be Triton inference server
948
+ types = [s in Path(p).name for s in sf]
949
+ types[8] &= not types[9] # tflite &= not edgetpu
950
+ triton = not any(types) and all([any(s in url.scheme for s in ["http", "grpc"]), url.netloc])
951
+ return types + [triton]
952
+
953
+ @staticmethod
954
+ def _load_metadata(f=Path('path/to/meta.yaml')):
955
+ # Load metadata from meta.yaml if it exists
956
+ if f.exists():
957
+ d = yaml_load(f)
958
+ return d['stride'], d['names'] # assign stride, names
959
+ return None, None
960
+
961
+
962
+ class AutoShape(nn.Module):
963
+ # YOLO input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
964
+ conf = 0.25 # NMS confidence threshold
965
+ iou = 0.45 # NMS IoU threshold
966
+ agnostic = False # NMS class-agnostic
967
+ multi_label = False # NMS multiple labels per box
968
+ classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
969
+ max_det = 1000 # maximum number of detections per image
970
+ amp = False # Automatic Mixed Precision (AMP) inference
971
+
972
+ def __init__(self, model, verbose=True):
973
+ super().__init__()
974
+ if verbose:
975
+ LOGGER.info('Adding AutoShape... ')
976
+ copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
977
+ self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance
978
+ self.pt = not self.dmb or model.pt # PyTorch model
979
+ self.model = model.eval()
980
+ if self.pt:
981
+ m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
982
+ m.inplace = False # Detect.inplace=False for safe multithread inference
983
+ m.export = True # do not output loss values
984
+
985
+ def _apply(self, fn):
986
+ # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
987
+ self = super()._apply(fn)
988
+ from models.yolo import Detect, Segment
989
+ if self.pt:
990
+ m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
991
+ if isinstance(m, (Detect, Segment)):
992
+ for k in 'stride', 'anchor_grid', 'stride_grid', 'grid':
993
+ x = getattr(m, k)
994
+ setattr(m, k, list(map(fn, x))) if isinstance(x, (list, tuple)) else setattr(m, k, fn(x))
995
+ return self
996
+
997
+ @smart_inference_mode()
998
+ def forward(self, ims, size=640, augment=False, profile=False):
999
+ # Inference from various sources. For size(height=640, width=1280), RGB images example inputs are:
1000
+ # file: ims = 'data/images/zidane.jpg' # str or PosixPath
1001
+ # URI: = 'https://ultralytics.com/images/zidane.jpg'
1002
+ # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
1003
+ # PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
1004
+ # numpy: = np.zeros((640,1280,3)) # HWC
1005
+ # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
1006
+ # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
1007
+
1008
+ dt = (Profile(), Profile(), Profile())
1009
+ with dt[0]:
1010
+ if isinstance(size, int): # expand
1011
+ size = (size, size)
1012
+ p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param
1013
+ autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
1014
+ if isinstance(ims, torch.Tensor): # torch
1015
+ with amp.autocast(autocast):
1016
+ return self.model(ims.to(p.device).type_as(p), augment=augment) # inference
1017
+
1018
+ # Pre-process
1019
+ n, ims = (len(ims), list(ims)) if isinstance(ims, (list, tuple)) else (1, [ims]) # number, list of images
1020
+ shape0, shape1, files = [], [], [] # image and inference shapes, filenames
1021
+ for i, im in enumerate(ims):
1022
+ f = f'image{i}' # filename
1023
+ if isinstance(im, (str, Path)): # filename or uri
1024
+ im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
1025
+ im = np.asarray(exif_transpose(im))
1026
+ elif isinstance(im, Image.Image): # PIL Image
1027
+ im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
1028
+ files.append(Path(f).with_suffix('.jpg').name)
1029
+ if im.shape[0] < 5: # image in CHW
1030
+ im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
1031
+ im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input
1032
+ s = im.shape[:2] # HWC
1033
+ shape0.append(s) # image shape
1034
+ g = max(size) / max(s) # gain
1035
+ shape1.append([int(y * g) for y in s])
1036
+ ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
1037
+ shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)] # inf shape
1038
+ x = [letterbox(im, shape1, auto=False)[0] for im in ims] # pad
1039
+ x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
1040
+ x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
1041
+
1042
+ with amp.autocast(autocast):
1043
+ # Inference
1044
+ with dt[1]:
1045
+ y = self.model(x, augment=augment) # forward
1046
+
1047
+ # Post-process
1048
+ with dt[2]:
1049
+ y = non_max_suppression(y if self.dmb else y[0],
1050
+ self.conf,
1051
+ self.iou,
1052
+ self.classes,
1053
+ self.agnostic,
1054
+ self.multi_label,
1055
+ max_det=self.max_det) # NMS
1056
+ for i in range(n):
1057
+ scale_boxes(shape1, y[i][:, :4], shape0[i])
1058
+
1059
+ return Detections(ims, y, files, dt, self.names, x.shape)
1060
+
1061
+
1062
+ class Detections:
1063
+ # YOLO detections class for inference results
1064
+ def __init__(self, ims, pred, files, times=(0, 0, 0), names=None, shape=None):
1065
+ super().__init__()
1066
+ d = pred[0].device # device
1067
+ gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in ims] # normalizations
1068
+ self.ims = ims # list of images as numpy arrays
1069
+ self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
1070
+ self.names = names # class names
1071
+ self.files = files # image filenames
1072
+ self.times = times # profiling times
1073
+ self.xyxy = pred # xyxy pixels
1074
+ self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
1075
+ self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
1076
+ self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
1077
+ self.n = len(self.pred) # number of images (batch size)
1078
+ self.t = tuple(x.t / self.n * 1E3 for x in times) # timestamps (ms)
1079
+ self.s = tuple(shape) # inference BCHW shape
1080
+
1081
+ def _run(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
1082
+ s, crops = '', []
1083
+ for i, (im, pred) in enumerate(zip(self.ims, self.pred)):
1084
+ s += f'\nimage {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
1085
+ if pred.shape[0]:
1086
+ for c in pred[:, -1].unique():
1087
+ n = (pred[:, -1] == c).sum() # detections per class
1088
+ s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
1089
+ s = s.rstrip(', ')
1090
+ if show or save or render or crop:
1091
+ annotator = Annotator(im, example=str(self.names))
1092
+ for *box, conf, cls in reversed(pred): # xyxy, confidence, class
1093
+ label = f'{self.names[int(cls)]} {conf:.2f}'
1094
+ if crop:
1095
+ file = save_dir / 'crops' / self.names[int(cls)] / self.files[i] if save else None
1096
+ crops.append({
1097
+ 'box': box,
1098
+ 'conf': conf,
1099
+ 'cls': cls,
1100
+ 'label': label,
1101
+ 'im': save_one_box(box, im, file=file, save=save)})
1102
+ else: # all others
1103
+ annotator.box_label(box, label if labels else '', color=colors(cls))
1104
+ im = annotator.im
1105
+ else:
1106
+ s += '(no detections)'
1107
+
1108
+ im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
1109
+ if show:
1110
+ display(im) if is_notebook() else im.show(self.files[i])
1111
+ if save:
1112
+ f = self.files[i]
1113
+ im.save(save_dir / f) # save
1114
+ if i == self.n - 1:
1115
+ LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
1116
+ if render:
1117
+ self.ims[i] = np.asarray(im)
1118
+ if pprint:
1119
+ s = s.lstrip('\n')
1120
+ return f'{s}\nSpeed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {self.s}' % self.t
1121
+ if crop:
1122
+ if save:
1123
+ LOGGER.info(f'Saved results to {save_dir}\n')
1124
+ return crops
1125
+
1126
+ @TryExcept('Showing images is not supported in this environment')
1127
+ def show(self, labels=True):
1128
+ self._run(show=True, labels=labels) # show results
1129
+
1130
+ def save(self, labels=True, save_dir='runs/detect/exp', exist_ok=False):
1131
+ save_dir = increment_path(save_dir, exist_ok, mkdir=True) # increment save_dir
1132
+ self._run(save=True, labels=labels, save_dir=save_dir) # save results
1133
+
1134
+ def crop(self, save=True, save_dir='runs/detect/exp', exist_ok=False):
1135
+ save_dir = increment_path(save_dir, exist_ok, mkdir=True) if save else None
1136
+ return self._run(crop=True, save=save, save_dir=save_dir) # crop results
1137
+
1138
+ def render(self, labels=True):
1139
+ self._run(render=True, labels=labels) # render results
1140
+ return self.ims
1141
+
1142
+ def pandas(self):
1143
+ # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
1144
+ new = copy(self) # return copy
1145
+ ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
1146
+ cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
1147
+ for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
1148
+ a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
1149
+ setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
1150
+ return new
1151
+
1152
+ def tolist(self):
1153
+ # return a list of Detections objects, i.e. 'for result in results.tolist():'
1154
+ r = range(self.n) # iterable
1155
+ x = [Detections([self.ims[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
1156
+ # for d in x:
1157
+ # for k in ['ims', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
1158
+ # setattr(d, k, getattr(d, k)[0]) # pop out of list
1159
+ return x
1160
+
1161
+ def print(self):
1162
+ LOGGER.info(self.__str__())
1163
+
1164
+ def __len__(self): # override len(results)
1165
+ return self.n
1166
+
1167
+ def __str__(self): # override print(results)
1168
+ return self._run(pprint=True) # print results
1169
+
1170
+ def __repr__(self):
1171
+ return f'YOLO {self.__class__} instance\n' + self.__str__()
1172
+
1173
+
1174
+ class Proto(nn.Module):
1175
+ # YOLO mask Proto module for segmentation models
1176
+ def __init__(self, c1, c_=256, c2=32): # ch_in, number of protos, number of masks
1177
+ super().__init__()
1178
+ self.cv1 = Conv(c1, c_, k=3)
1179
+ self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
1180
+ self.cv2 = Conv(c_, c_, k=3)
1181
+ self.cv3 = Conv(c_, c2)
1182
+
1183
+ def forward(self, x):
1184
+ return self.cv3(self.cv2(self.upsample(self.cv1(x))))
1185
+
1186
+
1187
+ class UConv(nn.Module):
1188
+ def __init__(self, c1, c_=256, c2=256): # ch_in, number of protos, number of masks
1189
+ super().__init__()
1190
+
1191
+ self.cv1 = Conv(c1, c_, k=3)
1192
+ self.cv2 = nn.Conv2d(c_, c2, 1, 1)
1193
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
1194
+
1195
+ def forward(self, x):
1196
+ return self.up(self.cv2(self.cv1(x)))
1197
+
1198
+
1199
+ class Classify(nn.Module):
1200
+ # YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2)
1201
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
1202
+ super().__init__()
1203
+ c_ = 1280 # efficientnet_b0 size
1204
+ self.conv = Conv(c1, c_, k, s, autopad(k, p), g)
1205
+ self.pool = nn.AdaptiveAvgPool2d(1) # to x(b,c_,1,1)
1206
+ self.drop = nn.Dropout(p=0.0, inplace=True)
1207
+ self.linear = nn.Linear(c_, c2) # to x(b,c2)
1208
+
1209
+ def forward(self, x):
1210
+ if isinstance(x, list):
1211
+ x = torch.cat(x, 1)
1212
+ return self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
requirements.txt ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # requirements
2
+ # Usage: pip install -r requirements.txt
3
+
4
+ # Base ------------------------------------------------------------------------
5
+ gitpython
6
+ ipython
7
+ matplotlib>=3.2.2
8
+ numpy>=1.18.5
9
+ opencv-python>=4.1.1
10
+ Pillow>=7.1.2
11
+ psutil
12
+ PyYAML>=5.3.1
13
+ requests>=2.23.0
14
+ scipy>=1.4.1
15
+ thop>=0.1.1
16
+ torch>=1.7.0
17
+ torchvision>=0.8.1
18
+ tqdm>=4.64.0
19
+ # protobuf<=3.20.1
20
+
21
+ # Logging ---------------------------------------------------------------------
22
+ tensorboard>=2.4.1
23
+ # clearml>=1.2.0
24
+ # comet
25
+
26
+ # Plotting --------------------------------------------------------------------
27
+ pandas>=1.1.4
28
+ seaborn>=0.11.0
29
+
30
+ # Export ----------------------------------------------------------------------
31
+ # coremltools>=6.0
32
+ # onnx>=1.9.0
33
+ # onnx-simplifier>=0.4.1
34
+ # nvidia-pyindex
35
+ # nvidia-tensorrt
36
+ # scikit-learn<=1.1.2
37
+ # tensorflow>=2.4.1
38
+ # tensorflowjs>=3.9.0
39
+ # openvino-dev
40
+
41
+ # Deploy ----------------------------------------------------------------------
42
+ # tritonclient[all]~=2.24.0
43
+
44
+ # Extras ----------------------------------------------------------------------
45
+ # mss
46
+ albumentations>=1.0.3
47
+ pycocotools>=2.0
runs/train/best_striped.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d5aaa20f90d1c2e2a3206ca2392b1e48e2593f305c54610526317c2d7082d99
3
+ size 51440592
utils/general.py ADDED
@@ -0,0 +1,1135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import glob
3
+ import inspect
4
+ import logging
5
+ import logging.config
6
+ import math
7
+ import os
8
+ import platform
9
+ import random
10
+ import re
11
+ import signal
12
+ import sys
13
+ import time
14
+ import urllib
15
+ from copy import deepcopy
16
+ from datetime import datetime
17
+ from itertools import repeat
18
+ from multiprocessing.pool import ThreadPool
19
+ from pathlib import Path
20
+ from subprocess import check_output
21
+ from tarfile import is_tarfile
22
+ from typing import Optional
23
+ from zipfile import ZipFile, is_zipfile
24
+
25
+ import cv2
26
+ import IPython
27
+ import numpy as np
28
+ import pandas as pd
29
+ import pkg_resources as pkg
30
+ import torch
31
+ import torchvision
32
+ import yaml
33
+
34
+ from utils import TryExcept, emojis
35
+ from utils.downloads import gsutil_getsize
36
+ from utils.metrics import box_iou, fitness
37
+
38
+ FILE = Path(__file__).resolve()
39
+ ROOT = FILE.parents[1] # YOLO root directory
40
+ RANK = int(os.getenv('RANK', -1))
41
+
42
+ # Settings
43
+ NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
44
+ DATASETS_DIR = Path(os.getenv('YOLOv5_DATASETS_DIR', ROOT.parent / 'datasets')) # global datasets directory
45
+ AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
46
+ VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode
47
+ TQDM_BAR_FORMAT = '{l_bar}{bar:10}| {n_fmt}/{total_fmt} {elapsed}' # tqdm bar format
48
+ FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
49
+
50
+ torch.set_printoptions(linewidth=320, precision=5, profile='long')
51
+ np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
52
+ pd.options.display.max_columns = 10
53
+ cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
54
+ os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
55
+ os.environ['OMP_NUM_THREADS'] = '1' if platform.system() == 'darwin' else str(NUM_THREADS) # OpenMP (PyTorch and SciPy)
56
+
57
+
58
+ def is_ascii(s=''):
59
+ # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
60
+ s = str(s) # convert list, tuple, None, etc. to str
61
+ return len(s.encode().decode('ascii', 'ignore')) == len(s)
62
+
63
+
64
+ def is_chinese(s='人工智能'):
65
+ # Is string composed of any Chinese characters?
66
+ return bool(re.search('[\u4e00-\u9fff]', str(s)))
67
+
68
+
69
+ def is_colab():
70
+ # Is environment a Google Colab instance?
71
+ return 'google.colab' in sys.modules
72
+
73
+
74
+ def is_notebook():
75
+ # Is environment a Jupyter notebook? Verified on Colab, Jupyterlab, Kaggle, Paperspace
76
+ ipython_type = str(type(IPython.get_ipython()))
77
+ return 'colab' in ipython_type or 'zmqshell' in ipython_type
78
+
79
+
80
+ def is_kaggle():
81
+ # Is environment a Kaggle Notebook?
82
+ return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
83
+
84
+
85
+ def is_docker() -> bool:
86
+ """Check if the process runs inside a docker container."""
87
+ if Path("/.dockerenv").exists():
88
+ return True
89
+ try: # check if docker is in control groups
90
+ with open("/proc/self/cgroup") as file:
91
+ return any("docker" in line for line in file)
92
+ except OSError:
93
+ return False
94
+
95
+
96
+ def is_writeable(dir, test=False):
97
+ # Return True if directory has write permissions, test opening a file with write permissions if test=True
98
+ if not test:
99
+ return os.access(dir, os.W_OK) # possible issues on Windows
100
+ file = Path(dir) / 'tmp.txt'
101
+ try:
102
+ with open(file, 'w'): # open file with write permissions
103
+ pass
104
+ file.unlink() # remove file
105
+ return True
106
+ except OSError:
107
+ return False
108
+
109
+
110
+ LOGGING_NAME = "yolov5"
111
+
112
+
113
+ def set_logging(name=LOGGING_NAME, verbose=True):
114
+ # sets up logging for the given name
115
+ rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
116
+ level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR
117
+ logging.config.dictConfig({
118
+ "version": 1,
119
+ "disable_existing_loggers": False,
120
+ "formatters": {
121
+ name: {
122
+ "format": "%(message)s"}},
123
+ "handlers": {
124
+ name: {
125
+ "class": "logging.StreamHandler",
126
+ "formatter": name,
127
+ "level": level,}},
128
+ "loggers": {
129
+ name: {
130
+ "level": level,
131
+ "handlers": [name],
132
+ "propagate": False,}}})
133
+
134
+
135
+ set_logging(LOGGING_NAME) # run before defining LOGGER
136
+ LOGGER = logging.getLogger(LOGGING_NAME) # define globally (used in train.py, val.py, detect.py, etc.)
137
+ if platform.system() == 'Windows':
138
+ for fn in LOGGER.info, LOGGER.warning:
139
+ setattr(LOGGER, fn.__name__, lambda x: fn(emojis(x))) # emoji safe logging
140
+
141
+
142
+ def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
143
+ # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
144
+ env = os.getenv(env_var)
145
+ if env:
146
+ path = Path(env) # use environment variable
147
+ else:
148
+ cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
149
+ path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
150
+ path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
151
+ path.mkdir(exist_ok=True) # make if required
152
+ return path
153
+
154
+
155
+ CONFIG_DIR = user_config_dir() # Ultralytics settings dir
156
+
157
+
158
+ class Profile(contextlib.ContextDecorator):
159
+ # YOLO Profile class. Usage: @Profile() decorator or 'with Profile():' context manager
160
+ def __init__(self, t=0.0):
161
+ self.t = t
162
+ self.cuda = torch.cuda.is_available()
163
+
164
+ def __enter__(self):
165
+ self.start = self.time()
166
+ return self
167
+
168
+ def __exit__(self, type, value, traceback):
169
+ self.dt = self.time() - self.start # delta-time
170
+ self.t += self.dt # accumulate dt
171
+
172
+ def time(self):
173
+ if self.cuda:
174
+ torch.cuda.synchronize()
175
+ return time.time()
176
+
177
+
178
+ class Timeout(contextlib.ContextDecorator):
179
+ # YOLO Timeout class. Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
180
+ def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
181
+ self.seconds = int(seconds)
182
+ self.timeout_message = timeout_msg
183
+ self.suppress = bool(suppress_timeout_errors)
184
+
185
+ def _timeout_handler(self, signum, frame):
186
+ raise TimeoutError(self.timeout_message)
187
+
188
+ def __enter__(self):
189
+ if platform.system() != 'Windows': # not supported on Windows
190
+ signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
191
+ signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
192
+
193
+ def __exit__(self, exc_type, exc_val, exc_tb):
194
+ if platform.system() != 'Windows':
195
+ signal.alarm(0) # Cancel SIGALRM if it's scheduled
196
+ if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
197
+ return True
198
+
199
+
200
+ class WorkingDirectory(contextlib.ContextDecorator):
201
+ # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
202
+ def __init__(self, new_dir):
203
+ self.dir = new_dir # new dir
204
+ self.cwd = Path.cwd().resolve() # current dir
205
+
206
+ def __enter__(self):
207
+ os.chdir(self.dir)
208
+
209
+ def __exit__(self, exc_type, exc_val, exc_tb):
210
+ os.chdir(self.cwd)
211
+
212
+
213
+ def methods(instance):
214
+ # Get class/instance methods
215
+ return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
216
+
217
+
218
+ def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
219
+ # Print function arguments (optional args dict)
220
+ x = inspect.currentframe().f_back # previous frame
221
+ file, _, func, _, _ = inspect.getframeinfo(x)
222
+ if args is None: # get args automatically
223
+ args, _, _, frm = inspect.getargvalues(x)
224
+ args = {k: v for k, v in frm.items() if k in args}
225
+ try:
226
+ file = Path(file).resolve().relative_to(ROOT).with_suffix('')
227
+ except ValueError:
228
+ file = Path(file).stem
229
+ s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '')
230
+ LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
231
+
232
+
233
+ def init_seeds(seed=0, deterministic=False):
234
+ # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
235
+ random.seed(seed)
236
+ np.random.seed(seed)
237
+ torch.manual_seed(seed)
238
+ torch.cuda.manual_seed(seed)
239
+ torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
240
+ # torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
241
+ if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213
242
+ torch.use_deterministic_algorithms(True)
243
+ torch.backends.cudnn.deterministic = True
244
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
245
+ os.environ['PYTHONHASHSEED'] = str(seed)
246
+
247
+
248
+ def intersect_dicts(da, db, exclude=()):
249
+ # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
250
+ return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
251
+
252
+
253
+ def get_default_args(func):
254
+ # Get func() default arguments
255
+ signature = inspect.signature(func)
256
+ return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
257
+
258
+
259
+ def get_latest_run(search_dir='.'):
260
+ # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
261
+ last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
262
+ return max(last_list, key=os.path.getctime) if last_list else ''
263
+
264
+
265
+ def file_age(path=__file__):
266
+ # Return days since last file update
267
+ dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
268
+ return dt.days # + dt.seconds / 86400 # fractional days
269
+
270
+
271
+ def file_date(path=__file__):
272
+ # Return human-readable file modification date, i.e. '2021-3-26'
273
+ t = datetime.fromtimestamp(Path(path).stat().st_mtime)
274
+ return f'{t.year}-{t.month}-{t.day}'
275
+
276
+
277
+ def file_size(path):
278
+ # Return file/dir size (MB)
279
+ mb = 1 << 20 # bytes to MiB (1024 ** 2)
280
+ path = Path(path)
281
+ if path.is_file():
282
+ return path.stat().st_size / mb
283
+ elif path.is_dir():
284
+ return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
285
+ else:
286
+ return 0.0
287
+
288
+
289
+ def check_online():
290
+ # Check internet connectivity
291
+ import socket
292
+
293
+ def run_once():
294
+ # Check once
295
+ try:
296
+ socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
297
+ return True
298
+ except OSError:
299
+ return False
300
+
301
+ return run_once() or run_once() # check twice to increase robustness to intermittent connectivity issues
302
+
303
+
304
+ def git_describe(path=ROOT): # path must be a directory
305
+ # Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
306
+ try:
307
+ assert (Path(path) / '.git').is_dir()
308
+ return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
309
+ except Exception:
310
+ return ''
311
+
312
+
313
+ @TryExcept()
314
+ @WorkingDirectory(ROOT)
315
+ def check_git_status(repo='WongKinYiu/yolov9', branch='main'):
316
+ # YOLO status check, recommend 'git pull' if code is out of date
317
+ url = f'https://github.com/{repo}'
318
+ msg = f', for updates see {url}'
319
+ s = colorstr('github: ') # string
320
+ assert Path('.git').exists(), s + 'skipping check (not a git repository)' + msg
321
+ assert check_online(), s + 'skipping check (offline)' + msg
322
+
323
+ splits = re.split(pattern=r'\s', string=check_output('git remote -v', shell=True).decode())
324
+ matches = [repo in s for s in splits]
325
+ if any(matches):
326
+ remote = splits[matches.index(True) - 1]
327
+ else:
328
+ remote = 'ultralytics'
329
+ check_output(f'git remote add {remote} {url}', shell=True)
330
+ check_output(f'git fetch {remote}', shell=True, timeout=5) # git fetch
331
+ local_branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
332
+ n = int(check_output(f'git rev-list {local_branch}..{remote}/{branch} --count', shell=True)) # commits behind
333
+ if n > 0:
334
+ pull = 'git pull' if remote == 'origin' else f'git pull {remote} {branch}'
335
+ s += f"⚠️ YOLO is out of date by {n} commit{'s' * (n > 1)}. Use `{pull}` or `git clone {url}` to update."
336
+ else:
337
+ s += f'up to date with {url} ✅'
338
+ LOGGER.info(s)
339
+
340
+
341
+ @WorkingDirectory(ROOT)
342
+ def check_git_info(path='.'):
343
+ # YOLO git info check, return {remote, branch, commit}
344
+ check_requirements('gitpython')
345
+ import git
346
+ try:
347
+ repo = git.Repo(path)
348
+ remote = repo.remotes.origin.url.replace('.git', '') # i.e. 'https://github.com/WongKinYiu/yolov9'
349
+ commit = repo.head.commit.hexsha # i.e. '3134699c73af83aac2a481435550b968d5792c0d'
350
+ try:
351
+ branch = repo.active_branch.name # i.e. 'main'
352
+ except TypeError: # not on any branch
353
+ branch = None # i.e. 'detached HEAD' state
354
+ return {'remote': remote, 'branch': branch, 'commit': commit}
355
+ except git.exc.InvalidGitRepositoryError: # path is not a git dir
356
+ return {'remote': None, 'branch': None, 'commit': None}
357
+
358
+
359
+ def check_python(minimum='3.7.0'):
360
+ # Check current python version vs. required python version
361
+ check_version(platform.python_version(), minimum, name='Python ', hard=True)
362
+
363
+
364
+ def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
365
+ # Check version vs. required version
366
+ current, minimum = (pkg.parse_version(x) for x in (current, minimum))
367
+ result = (current == minimum) if pinned else (current >= minimum) # bool
368
+ s = f'WARNING ⚠️ {name}{minimum} is required by YOLO, but {name}{current} is currently installed' # string
369
+ if hard:
370
+ assert result, emojis(s) # assert min requirements met
371
+ if verbose and not result:
372
+ LOGGER.warning(s)
373
+ return result
374
+
375
+
376
+ @TryExcept()
377
+ def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=''):
378
+ # Check installed dependencies meet YOLO requirements (pass *.txt file or list of packages or single package str)
379
+ prefix = colorstr('red', 'bold', 'requirements:')
380
+ check_python() # check python version
381
+ if isinstance(requirements, Path): # requirements.txt file
382
+ file = requirements.resolve()
383
+ assert file.exists(), f"{prefix} {file} not found, check failed."
384
+ with file.open() as f:
385
+ requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
386
+ elif isinstance(requirements, str):
387
+ requirements = [requirements]
388
+
389
+ s = ''
390
+ n = 0
391
+ for r in requirements:
392
+ try:
393
+ pkg.require(r)
394
+ except (pkg.VersionConflict, pkg.DistributionNotFound): # exception if requirements not met
395
+ s += f'"{r}" '
396
+ n += 1
397
+
398
+ if s and install and AUTOINSTALL: # check environment variable
399
+ LOGGER.info(f"{prefix} YOLO requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...")
400
+ try:
401
+ # assert check_online(), "AutoUpdate skipped (offline)"
402
+ LOGGER.info(check_output(f'pip install {s} {cmds}', shell=True).decode())
403
+ source = file if 'file' in locals() else requirements
404
+ s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
405
+ f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
406
+ LOGGER.info(s)
407
+ except Exception as e:
408
+ LOGGER.warning(f'{prefix} ❌ {e}')
409
+
410
+
411
+ def check_img_size(imgsz, s=32, floor=0):
412
+ # Verify image size is a multiple of stride s in each dimension
413
+ if isinstance(imgsz, int): # integer i.e. img_size=640
414
+ new_size = max(make_divisible(imgsz, int(s)), floor)
415
+ else: # list i.e. img_size=[640, 480]
416
+ imgsz = list(imgsz) # convert to list if tuple
417
+ new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
418
+ if new_size != imgsz:
419
+ LOGGER.warning(f'WARNING ⚠️ --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
420
+ return new_size
421
+
422
+
423
+ def check_imshow(warn=False):
424
+ # Check if environment supports image displays
425
+ try:
426
+ assert not is_notebook()
427
+ assert not is_docker()
428
+ cv2.imshow('test', np.zeros((1, 1, 3)))
429
+ cv2.waitKey(1)
430
+ cv2.destroyAllWindows()
431
+ cv2.waitKey(1)
432
+ return True
433
+ except Exception as e:
434
+ if warn:
435
+ LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}')
436
+ return False
437
+
438
+
439
+ def check_suffix(file='yolo.pt', suffix=('.pt',), msg=''):
440
+ # Check file(s) for acceptable suffix
441
+ if file and suffix:
442
+ if isinstance(suffix, str):
443
+ suffix = [suffix]
444
+ for f in file if isinstance(file, (list, tuple)) else [file]:
445
+ s = Path(f).suffix.lower() # file suffix
446
+ if len(s):
447
+ assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
448
+
449
+
450
+ def check_yaml(file, suffix=('.yaml', '.yml')):
451
+ # Search/download YAML file (if necessary) and return path, checking suffix
452
+ return check_file(file, suffix)
453
+
454
+
455
+ def check_file(file, suffix=''):
456
+ # Search/download file (if necessary) and return path
457
+ check_suffix(file, suffix) # optional
458
+ file = str(file) # convert to str()
459
+ if os.path.isfile(file) or not file: # exists
460
+ return file
461
+ elif file.startswith(('http:/', 'https:/')): # download
462
+ url = file # warning: Pathlib turns :// -> :/
463
+ file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
464
+ if os.path.isfile(file):
465
+ LOGGER.info(f'Found {url} locally at {file}') # file already exists
466
+ else:
467
+ LOGGER.info(f'Downloading {url} to {file}...')
468
+ torch.hub.download_url_to_file(url, file)
469
+ assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
470
+ return file
471
+ elif file.startswith('clearml://'): # ClearML Dataset ID
472
+ assert 'clearml' in sys.modules, "ClearML is not installed, so cannot use ClearML dataset. Try running 'pip install clearml'."
473
+ return file
474
+ else: # search
475
+ files = []
476
+ for d in 'data', 'models', 'utils': # search directories
477
+ files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
478
+ assert len(files), f'File not found: {file}' # assert file was found
479
+ assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
480
+ return files[0] # return file
481
+
482
+
483
+ def check_font(font=FONT, progress=False):
484
+ # Download font to CONFIG_DIR if necessary
485
+ font = Path(font)
486
+ file = CONFIG_DIR / font.name
487
+ if not font.exists() and not file.exists():
488
+ url = f'https://ultralytics.com/assets/{font.name}'
489
+ LOGGER.info(f'Downloading {url} to {file}...')
490
+ torch.hub.download_url_to_file(url, str(file), progress=progress)
491
+
492
+
493
+ def check_dataset(data, autodownload=True):
494
+ # Download, check and/or unzip dataset if not found locally
495
+
496
+ # Download (optional)
497
+ extract_dir = ''
498
+ if isinstance(data, (str, Path)) and (is_zipfile(data) or is_tarfile(data)):
499
+ download(data, dir=f'{DATASETS_DIR}/{Path(data).stem}', unzip=True, delete=False, curl=False, threads=1)
500
+ data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml'))
501
+ extract_dir, autodownload = data.parent, False
502
+
503
+ # Read yaml (optional)
504
+ if isinstance(data, (str, Path)):
505
+ data = yaml_load(data) # dictionary
506
+
507
+ # Checks
508
+ for k in 'train', 'val', 'names':
509
+ assert k in data, emojis(f"data.yaml '{k}:' field missing ❌")
510
+ if isinstance(data['names'], (list, tuple)): # old array format
511
+ data['names'] = dict(enumerate(data['names'])) # convert to dict
512
+ assert all(isinstance(k, int) for k in data['names'].keys()), 'data.yaml names keys must be integers, i.e. 2: car'
513
+ data['nc'] = len(data['names'])
514
+
515
+ # Resolve paths
516
+ path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.'
517
+ if not path.is_absolute():
518
+ path = (ROOT / path).resolve()
519
+ data['path'] = path # download scripts
520
+ for k in 'train', 'val', 'test':
521
+ if data.get(k): # prepend path
522
+ if isinstance(data[k], str):
523
+ x = (path / data[k]).resolve()
524
+ if not x.exists() and data[k].startswith('../'):
525
+ x = (path / data[k][3:]).resolve()
526
+ data[k] = str(x)
527
+ else:
528
+ data[k] = [str((path / x).resolve()) for x in data[k]]
529
+
530
+ # Parse yaml
531
+ train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
532
+ if val:
533
+ val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
534
+ if not all(x.exists() for x in val):
535
+ LOGGER.info('\nDataset not found ⚠️, missing paths %s' % [str(x) for x in val if not x.exists()])
536
+ if not s or not autodownload:
537
+ raise Exception('Dataset not found ❌')
538
+ t = time.time()
539
+ if s.startswith('http') and s.endswith('.zip'): # URL
540
+ f = Path(s).name # filename
541
+ LOGGER.info(f'Downloading {s} to {f}...')
542
+ torch.hub.download_url_to_file(s, f)
543
+ Path(DATASETS_DIR).mkdir(parents=True, exist_ok=True) # create root
544
+ unzip_file(f, path=DATASETS_DIR) # unzip
545
+ Path(f).unlink() # remove zip
546
+ r = None # success
547
+ elif s.startswith('bash '): # bash script
548
+ LOGGER.info(f'Running {s} ...')
549
+ r = os.system(s)
550
+ else: # python script
551
+ r = exec(s, {'yaml': data}) # return None
552
+ dt = f'({round(time.time() - t, 1)}s)'
553
+ s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌"
554
+ LOGGER.info(f"Dataset download {s}")
555
+ check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True) # download fonts
556
+ return data # dictionary
557
+
558
+
559
+ def check_amp(model):
560
+ # Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation
561
+ from models.common import AutoShape, DetectMultiBackend
562
+
563
+ def amp_allclose(model, im):
564
+ # All close FP32 vs AMP results
565
+ m = AutoShape(model, verbose=False) # model
566
+ a = m(im).xywhn[0] # FP32 inference
567
+ m.amp = True
568
+ b = m(im).xywhn[0] # AMP inference
569
+ return a.shape == b.shape and torch.allclose(a, b, atol=0.1) # close to 10% absolute tolerance
570
+
571
+ prefix = colorstr('AMP: ')
572
+ device = next(model.parameters()).device # get model device
573
+ if device.type in ('cpu', 'mps'):
574
+ return False # AMP only used on CUDA devices
575
+ f = ROOT / 'data' / 'images' / 'bus.jpg' # image to check
576
+ im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if check_online() else np.ones((640, 640, 3))
577
+ try:
578
+ #assert amp_allclose(deepcopy(model), im) or amp_allclose(DetectMultiBackend('yolo.pt', device), im)
579
+ LOGGER.info(f'{prefix}checks passed ✅')
580
+ return True
581
+ except Exception:
582
+ help_url = 'https://github.com/ultralytics/yolov5/issues/7908'
583
+ LOGGER.warning(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}')
584
+ return False
585
+
586
+
587
+ def yaml_load(file='data.yaml'):
588
+ # Single-line safe yaml loading
589
+ with open(file, errors='ignore') as f:
590
+ return yaml.safe_load(f)
591
+
592
+
593
+ def yaml_save(file='data.yaml', data={}):
594
+ # Single-line safe yaml saving
595
+ with open(file, 'w') as f:
596
+ yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
597
+
598
+
599
+ def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX')):
600
+ # Unzip a *.zip file to path/, excluding files containing strings in exclude list
601
+ if path is None:
602
+ path = Path(file).parent # default path
603
+ with ZipFile(file) as zipObj:
604
+ for f in zipObj.namelist(): # list all archived filenames in the zip
605
+ if all(x not in f for x in exclude):
606
+ zipObj.extract(f, path=path)
607
+
608
+
609
+ def url2file(url):
610
+ # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
611
+ url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
612
+ return Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
613
+
614
+
615
+ def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
616
+ # Multithreaded file download and unzip function, used in data.yaml for autodownload
617
+ def download_one(url, dir):
618
+ # Download 1 file
619
+ success = True
620
+ if os.path.isfile(url):
621
+ f = Path(url) # filename
622
+ else: # does not exist
623
+ f = dir / Path(url).name
624
+ LOGGER.info(f'Downloading {url} to {f}...')
625
+ for i in range(retry + 1):
626
+ if curl:
627
+ s = 'sS' if threads > 1 else '' # silent
628
+ r = os.system(
629
+ f'curl -# -{s}L "{url}" -o "{f}" --retry 9 -C -') # curl download with retry, continue
630
+ success = r == 0
631
+ else:
632
+ torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
633
+ success = f.is_file()
634
+ if success:
635
+ break
636
+ elif i < retry:
637
+ LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
638
+ else:
639
+ LOGGER.warning(f'❌ Failed to download {url}...')
640
+
641
+ if unzip and success and (f.suffix == '.gz' or is_zipfile(f) or is_tarfile(f)):
642
+ LOGGER.info(f'Unzipping {f}...')
643
+ if is_zipfile(f):
644
+ unzip_file(f, dir) # unzip
645
+ elif is_tarfile(f):
646
+ os.system(f'tar xf {f} --directory {f.parent}') # unzip
647
+ elif f.suffix == '.gz':
648
+ os.system(f'tar xfz {f} --directory {f.parent}') # unzip
649
+ if delete:
650
+ f.unlink() # remove zip
651
+
652
+ dir = Path(dir)
653
+ dir.mkdir(parents=True, exist_ok=True) # make directory
654
+ if threads > 1:
655
+ pool = ThreadPool(threads)
656
+ pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
657
+ pool.close()
658
+ pool.join()
659
+ else:
660
+ for u in [url] if isinstance(url, (str, Path)) else url:
661
+ download_one(u, dir)
662
+
663
+
664
+ def make_divisible(x, divisor):
665
+ # Returns nearest x divisible by divisor
666
+ if isinstance(divisor, torch.Tensor):
667
+ divisor = int(divisor.max()) # to int
668
+ return math.ceil(x / divisor) * divisor
669
+
670
+
671
+ def clean_str(s):
672
+ # Cleans a string by replacing special characters with underscore _
673
+ return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
674
+
675
+
676
+ def one_cycle(y1=0.0, y2=1.0, steps=100):
677
+ # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
678
+ return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
679
+
680
+
681
+ def one_flat_cycle(y1=0.0, y2=1.0, steps=100):
682
+ # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
683
+ #return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
684
+ return lambda x: ((1 - math.cos((x - (steps // 2)) * math.pi / (steps // 2))) / 2) * (y2 - y1) + y1 if (x > (steps // 2)) else y1
685
+
686
+
687
+ def colorstr(*input):
688
+ # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
689
+ *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
690
+ colors = {
691
+ 'black': '\033[30m', # basic colors
692
+ 'red': '\033[31m',
693
+ 'green': '\033[32m',
694
+ 'yellow': '\033[33m',
695
+ 'blue': '\033[34m',
696
+ 'magenta': '\033[35m',
697
+ 'cyan': '\033[36m',
698
+ 'white': '\033[37m',
699
+ 'bright_black': '\033[90m', # bright colors
700
+ 'bright_red': '\033[91m',
701
+ 'bright_green': '\033[92m',
702
+ 'bright_yellow': '\033[93m',
703
+ 'bright_blue': '\033[94m',
704
+ 'bright_magenta': '\033[95m',
705
+ 'bright_cyan': '\033[96m',
706
+ 'bright_white': '\033[97m',
707
+ 'end': '\033[0m', # misc
708
+ 'bold': '\033[1m',
709
+ 'underline': '\033[4m'}
710
+ return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
711
+
712
+
713
+ def labels_to_class_weights(labels, nc=80):
714
+ # Get class weights (inverse frequency) from training labels
715
+ if labels[0] is None: # no labels loaded
716
+ return torch.Tensor()
717
+
718
+ labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
719
+ classes = labels[:, 0].astype(int) # labels = [class xywh]
720
+ weights = np.bincount(classes, minlength=nc) # occurrences per class
721
+
722
+ # Prepend gridpoint count (for uCE training)
723
+ # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
724
+ # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
725
+
726
+ weights[weights == 0] = 1 # replace empty bins with 1
727
+ weights = 1 / weights # number of targets per class
728
+ weights /= weights.sum() # normalize
729
+ return torch.from_numpy(weights).float()
730
+
731
+
732
+ def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
733
+ # Produces image weights based on class_weights and image contents
734
+ # Usage: index = random.choices(range(n), weights=image_weights, k=1) # weighted image sample
735
+ class_counts = np.array([np.bincount(x[:, 0].astype(int), minlength=nc) for x in labels])
736
+ return (class_weights.reshape(1, nc) * class_counts).sum(1)
737
+
738
+
739
+ def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
740
+ # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
741
+ # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
742
+ # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
743
+ # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
744
+ # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
745
+ return [
746
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
747
+ 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
748
+ 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
749
+
750
+
751
+ def xyxy2xywh(x):
752
+ # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
753
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
754
+ y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
755
+ y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
756
+ y[..., 2] = x[..., 2] - x[..., 0] # width
757
+ y[..., 3] = x[..., 3] - x[..., 1] # height
758
+ return y
759
+
760
+
761
+ def xywh2xyxy(x):
762
+ # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
763
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
764
+ y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
765
+ y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
766
+ y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
767
+ y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
768
+ return y
769
+
770
+
771
+ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
772
+ # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
773
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
774
+ y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
775
+ y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
776
+ y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
777
+ y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y
778
+ return y
779
+
780
+
781
+ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
782
+ # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
783
+ if clip:
784
+ clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
785
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
786
+ y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
787
+ y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
788
+ y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
789
+ y[..., 3] = (x[..., 3] - x[..., 1]) / h # height
790
+ return y
791
+
792
+
793
+ def xyn2xy(x, w=640, h=640, padw=0, padh=0):
794
+ # Convert normalized segments into pixel segments, shape (n,2)
795
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
796
+ y[..., 0] = w * x[..., 0] + padw # top left x
797
+ y[..., 1] = h * x[..., 1] + padh # top left y
798
+ return y
799
+
800
+
801
+ def segment2box(segment, width=640, height=640):
802
+ # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
803
+ x, y = segment.T # segment xy
804
+ inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
805
+ x, y, = x[inside], y[inside]
806
+ return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
807
+
808
+
809
+ def segments2boxes(segments):
810
+ # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
811
+ boxes = []
812
+ for s in segments:
813
+ x, y = s.T # segment xy
814
+ boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
815
+ return xyxy2xywh(np.array(boxes)) # cls, xywh
816
+
817
+
818
+ def resample_segments(segments, n=1000):
819
+ # Up-sample an (n,2) segment
820
+ for i, s in enumerate(segments):
821
+ s = np.concatenate((s, s[0:1, :]), axis=0)
822
+ x = np.linspace(0, len(s) - 1, n)
823
+ xp = np.arange(len(s))
824
+ segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
825
+ return segments
826
+
827
+
828
+ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
829
+ # Rescale boxes (xyxy) from img1_shape to img0_shape
830
+ if ratio_pad is None: # calculate from img0_shape
831
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
832
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
833
+ else:
834
+ gain = ratio_pad[0][0]
835
+ pad = ratio_pad[1]
836
+
837
+ boxes[:, [0, 2]] -= pad[0] # x padding
838
+ boxes[:, [1, 3]] -= pad[1] # y padding
839
+ boxes[:, :4] /= gain
840
+ clip_boxes(boxes, img0_shape)
841
+ return boxes
842
+
843
+
844
+ def scale_segments(img1_shape, segments, img0_shape, ratio_pad=None, normalize=False):
845
+ # Rescale coords (xyxy) from img1_shape to img0_shape
846
+ if ratio_pad is None: # calculate from img0_shape
847
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
848
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
849
+ else:
850
+ gain = ratio_pad[0][0]
851
+ pad = ratio_pad[1]
852
+
853
+ segments[:, 0] -= pad[0] # x padding
854
+ segments[:, 1] -= pad[1] # y padding
855
+ segments /= gain
856
+ clip_segments(segments, img0_shape)
857
+ if normalize:
858
+ segments[:, 0] /= img0_shape[1] # width
859
+ segments[:, 1] /= img0_shape[0] # height
860
+ return segments
861
+
862
+
863
+ def clip_boxes(boxes, shape):
864
+ # Clip boxes (xyxy) to image shape (height, width)
865
+ if isinstance(boxes, torch.Tensor): # faster individually
866
+ boxes[:, 0].clamp_(0, shape[1]) # x1
867
+ boxes[:, 1].clamp_(0, shape[0]) # y1
868
+ boxes[:, 2].clamp_(0, shape[1]) # x2
869
+ boxes[:, 3].clamp_(0, shape[0]) # y2
870
+ else: # np.array (faster grouped)
871
+ boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
872
+ boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
873
+
874
+
875
+ def clip_segments(segments, shape):
876
+ # Clip segments (xy1,xy2,...) to image shape (height, width)
877
+ if isinstance(segments, torch.Tensor): # faster individually
878
+ segments[:, 0].clamp_(0, shape[1]) # x
879
+ segments[:, 1].clamp_(0, shape[0]) # y
880
+ else: # np.array (faster grouped)
881
+ segments[:, 0] = segments[:, 0].clip(0, shape[1]) # x
882
+ segments[:, 1] = segments[:, 1].clip(0, shape[0]) # y
883
+
884
+
885
+ def non_max_suppression(
886
+ prediction,
887
+ conf_thres=0.25,
888
+ iou_thres=0.45,
889
+ classes=None,
890
+ agnostic=False,
891
+ multi_label=False,
892
+ labels=(),
893
+ max_det=300,
894
+ nm=0, # number of masks
895
+ ):
896
+ """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
897
+
898
+ Returns:
899
+ list of detections, on (n,6) tensor per image [xyxy, conf, cls]
900
+ """
901
+
902
+ if isinstance(prediction, (list, tuple)): # YOLO model in validation model, output = (inference_out, loss_out)
903
+ prediction = prediction[0] # select only inference output
904
+
905
+ device = prediction.device
906
+ mps = 'mps' in device.type # Apple MPS
907
+ if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
908
+ prediction = prediction.cpu()
909
+ bs = prediction.shape[0] # batch size
910
+ nc = prediction.shape[1] - nm - 4 # number of classes
911
+ mi = 4 + nc # mask start index
912
+ xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
913
+
914
+ # Checks
915
+ assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
916
+ assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
917
+
918
+ # Settings
919
+ # min_wh = 2 # (pixels) minimum box width and height
920
+ max_wh = 7680 # (pixels) maximum box width and height
921
+ max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
922
+ time_limit = 2.5 + 0.05 * bs # seconds to quit after
923
+ redundant = True # require redundant detections
924
+ multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
925
+ merge = False # use merge-NMS
926
+
927
+ t = time.time()
928
+ output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
929
+ for xi, x in enumerate(prediction): # image index, image inference
930
+ # Apply constraints
931
+ # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
932
+ x = x.T[xc[xi]] # confidence
933
+
934
+ # Cat apriori labels if autolabelling
935
+ if labels and len(labels[xi]):
936
+ lb = labels[xi]
937
+ v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
938
+ v[:, :4] = lb[:, 1:5] # box
939
+ v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
940
+ x = torch.cat((x, v), 0)
941
+
942
+ # If none remain process next image
943
+ if not x.shape[0]:
944
+ continue
945
+
946
+ # Detections matrix nx6 (xyxy, conf, cls)
947
+ box, cls, mask = x.split((4, nc, nm), 1)
948
+ box = xywh2xyxy(box) # center_x, center_y, width, height) to (x1, y1, x2, y2)
949
+ if multi_label:
950
+ i, j = (cls > conf_thres).nonzero(as_tuple=False).T
951
+ x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
952
+ else: # best class only
953
+ conf, j = cls.max(1, keepdim=True)
954
+ x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
955
+
956
+ # Filter by class
957
+ if classes is not None:
958
+ x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
959
+
960
+ # Apply finite constraint
961
+ # if not torch.isfinite(x).all():
962
+ # x = x[torch.isfinite(x).all(1)]
963
+
964
+ # Check shape
965
+ n = x.shape[0] # number of boxes
966
+ if not n: # no boxes
967
+ continue
968
+ elif n > max_nms: # excess boxes
969
+ x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
970
+ else:
971
+ x = x[x[:, 4].argsort(descending=True)] # sort by confidence
972
+
973
+ # Batched NMS
974
+ c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
975
+ boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
976
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
977
+ if i.shape[0] > max_det: # limit detections
978
+ i = i[:max_det]
979
+ if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
980
+ # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
981
+ iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
982
+ weights = iou * scores[None] # box weights
983
+ x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
984
+ if redundant:
985
+ i = i[iou.sum(1) > 1] # require redundancy
986
+
987
+ output[xi] = x[i]
988
+ if mps:
989
+ output[xi] = output[xi].to(device)
990
+ if (time.time() - t) > time_limit:
991
+ LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
992
+ break # time limit exceeded
993
+
994
+ return output
995
+
996
+
997
+ def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
998
+ # Strip optimizer from 'f' to finalize training, optionally save as 's'
999
+ x = torch.load(f, map_location=torch.device('cpu'))
1000
+ if x.get('ema'):
1001
+ x['model'] = x['ema'] # replace model with ema
1002
+ for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
1003
+ x[k] = None
1004
+ x['epoch'] = -1
1005
+ x['model'].half() # to FP16
1006
+ for p in x['model'].parameters():
1007
+ p.requires_grad = False
1008
+ torch.save(x, s or f)
1009
+ mb = os.path.getsize(s or f) / 1E6 # filesize
1010
+ LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
1011
+
1012
+
1013
+ def print_mutation(keys, results, hyp, save_dir, bucket, prefix=colorstr('evolve: ')):
1014
+ evolve_csv = save_dir / 'evolve.csv'
1015
+ evolve_yaml = save_dir / 'hyp_evolve.yaml'
1016
+ keys = tuple(keys) + tuple(hyp.keys()) # [results + hyps]
1017
+ keys = tuple(x.strip() for x in keys)
1018
+ vals = results + tuple(hyp.values())
1019
+ n = len(keys)
1020
+
1021
+ # Download (optional)
1022
+ if bucket:
1023
+ url = f'gs://{bucket}/evolve.csv'
1024
+ if gsutil_getsize(url) > (evolve_csv.stat().st_size if evolve_csv.exists() else 0):
1025
+ os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
1026
+
1027
+ # Log to evolve.csv
1028
+ s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
1029
+ with open(evolve_csv, 'a') as f:
1030
+ f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
1031
+
1032
+ # Save yaml
1033
+ with open(evolve_yaml, 'w') as f:
1034
+ data = pd.read_csv(evolve_csv)
1035
+ data = data.rename(columns=lambda x: x.strip()) # strip keys
1036
+ i = np.argmax(fitness(data.values[:, :4])) #
1037
+ generations = len(data)
1038
+ f.write('# YOLO Hyperparameter Evolution Results\n' + f'# Best generation: {i}\n' +
1039
+ f'# Last generation: {generations - 1}\n' + '# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) +
1040
+ '\n' + '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
1041
+ yaml.safe_dump(data.loc[i][7:].to_dict(), f, sort_keys=False)
1042
+
1043
+ # Print to screen
1044
+ LOGGER.info(prefix + f'{generations} generations finished, current result:\n' + prefix +
1045
+ ', '.join(f'{x.strip():>20s}' for x in keys) + '\n' + prefix + ', '.join(f'{x:20.5g}'
1046
+ for x in vals) + '\n\n')
1047
+
1048
+ if bucket:
1049
+ os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
1050
+
1051
+
1052
+ def apply_classifier(x, model, img, im0):
1053
+ # Apply a second stage classifier to YOLO outputs
1054
+ # Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
1055
+ im0 = [im0] if isinstance(im0, np.ndarray) else im0
1056
+ for i, d in enumerate(x): # per image
1057
+ if d is not None and len(d):
1058
+ d = d.clone()
1059
+
1060
+ # Reshape and pad cutouts
1061
+ b = xyxy2xywh(d[:, :4]) # boxes
1062
+ b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
1063
+ b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
1064
+ d[:, :4] = xywh2xyxy(b).long()
1065
+
1066
+ # Rescale boxes from img_size to im0 size
1067
+ scale_boxes(img.shape[2:], d[:, :4], im0[i].shape)
1068
+
1069
+ # Classes
1070
+ pred_cls1 = d[:, 5].long()
1071
+ ims = []
1072
+ for a in d:
1073
+ cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
1074
+ im = cv2.resize(cutout, (224, 224)) # BGR
1075
+
1076
+ im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
1077
+ im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
1078
+ im /= 255 # 0 - 255 to 0.0 - 1.0
1079
+ ims.append(im)
1080
+
1081
+ pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
1082
+ x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
1083
+
1084
+ return x
1085
+
1086
+
1087
+ def increment_path(path, exist_ok=False, sep='', mkdir=False):
1088
+ # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
1089
+ path = Path(path) # os-agnostic
1090
+ if path.exists() and not exist_ok:
1091
+ path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
1092
+
1093
+ # Method 1
1094
+ for n in range(2, 9999):
1095
+ p = f'{path}{sep}{n}{suffix}' # increment path
1096
+ if not os.path.exists(p): #
1097
+ break
1098
+ path = Path(p)
1099
+
1100
+ # Method 2 (deprecated)
1101
+ # dirs = glob.glob(f"{path}{sep}*") # similar paths
1102
+ # matches = [re.search(rf"{path.stem}{sep}(\d+)", d) for d in dirs]
1103
+ # i = [int(m.groups()[0]) for m in matches if m] # indices
1104
+ # n = max(i) + 1 if i else 2 # increment number
1105
+ # path = Path(f"{path}{sep}{n}{suffix}") # increment path
1106
+
1107
+ if mkdir:
1108
+ path.mkdir(parents=True, exist_ok=True) # make directory
1109
+
1110
+ return path
1111
+
1112
+
1113
+ # OpenCV Chinese-friendly functions ------------------------------------------------------------------------------------
1114
+ imshow_ = cv2.imshow # copy to avoid recursion errors
1115
+
1116
+
1117
+ def imread(path, flags=cv2.IMREAD_COLOR):
1118
+ return cv2.imdecode(np.fromfile(path, np.uint8), flags)
1119
+
1120
+
1121
+ def imwrite(path, im):
1122
+ try:
1123
+ cv2.imencode(Path(path).suffix, im)[1].tofile(path)
1124
+ return True
1125
+ except Exception:
1126
+ return False
1127
+
1128
+
1129
+ def imshow(path, im):
1130
+ imshow_(path.encode('unicode_escape').decode(), im)
1131
+
1132
+
1133
+ cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow # redefine
1134
+
1135
+ # Variables ------------------------------------------------------------------------------------------------------------
utils/plots.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import math
3
+ import os
4
+ from copy import copy
5
+ from pathlib import Path
6
+ from urllib.error import URLError
7
+
8
+ import cv2
9
+ import matplotlib
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import pandas as pd
13
+ import seaborn as sn
14
+ import torch
15
+ from PIL import Image, ImageDraw, ImageFont
16
+
17
+ from utils import TryExcept, threaded
18
+ from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_boxes, increment_path,
19
+ is_ascii, xywh2xyxy, xyxy2xywh)
20
+ from utils.metrics import fitness
21
+ from utils.segment.general import scale_image
22
+
23
+ # Settings
24
+ RANK = int(os.getenv('RANK', -1))
25
+ matplotlib.rc('font', **{'size': 11})
26
+ matplotlib.use('Agg') # for writing to files only
27
+
28
+
29
+ class Colors:
30
+ # Ultralytics color palette https://ultralytics.com/
31
+ def __init__(self):
32
+ # hex = matplotlib.colors.TABLEAU_COLORS.values()
33
+ hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
34
+ '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
35
+ self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
36
+ self.n = len(self.palette)
37
+
38
+ def __call__(self, i, bgr=False):
39
+ c = self.palette[int(i) % self.n]
40
+ return (c[2], c[1], c[0]) if bgr else c
41
+
42
+ @staticmethod
43
+ def hex2rgb(h): # rgb order (PIL)
44
+ return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
45
+
46
+
47
+ colors = Colors() # create instance for 'from utils.plots import colors'
48
+
49
+
50
+ def check_pil_font(font=FONT, size=10):
51
+ # Return a PIL TrueType Font, downloading to CONFIG_DIR if necessary
52
+ font = Path(font)
53
+ font = font if font.exists() else (CONFIG_DIR / font.name)
54
+ try:
55
+ return ImageFont.truetype(str(font) if font.exists() else font.name, size)
56
+ except Exception: # download if missing
57
+ try:
58
+ check_font(font)
59
+ return ImageFont.truetype(str(font), size)
60
+ except TypeError:
61
+ check_requirements('Pillow>=8.4.0') # known issue https://github.com/ultralytics/yolov5/issues/5374
62
+ except URLError: # not online
63
+ return ImageFont.load_default()
64
+
65
+
66
+ class Annotator:
67
+ # YOLOv5 Annotator for train/val mosaics and jpgs and detect/hub inference annotations
68
+ def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
69
+ assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
70
+ non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
71
+ self.pil = pil or non_ascii
72
+ if self.pil: # use PIL
73
+ self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
74
+ self.draw = ImageDraw.Draw(self.im)
75
+ self.font = check_pil_font(font='Arial.Unicode.ttf' if non_ascii else font,
76
+ size=font_size or max(round(sum(self.im.size) / 2 * 0.035), 12))
77
+ else: # use cv2
78
+ self.im = im
79
+ self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
80
+
81
+ def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
82
+ # Add one xyxy box to image with label
83
+ if self.pil or not is_ascii(label):
84
+ self.draw.rectangle(box, width=self.lw, outline=color) # box
85
+ if label:
86
+ w, h = self.font.getsize(label) # text width, height
87
+ outside = box[1] - h >= 0 # label fits outside box
88
+ self.draw.rectangle(
89
+ (box[0], box[1] - h if outside else box[1], box[0] + w + 1,
90
+ box[1] + 1 if outside else box[1] + h + 1),
91
+ fill=color,
92
+ )
93
+ # self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
94
+ self.draw.text((box[0], box[1] - h if outside else box[1]), label, fill=txt_color, font=self.font)
95
+ else: # cv2
96
+ p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
97
+ cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
98
+ if label:
99
+ tf = max(self.lw - 1, 1) # font thickness
100
+ w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height
101
+ outside = p1[1] - h >= 3
102
+ p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
103
+ cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
104
+ cv2.putText(self.im,
105
+ label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
106
+ 0,
107
+ self.lw / 3,
108
+ txt_color,
109
+ thickness=tf,
110
+ lineType=cv2.LINE_AA)
111
+
112
+ def masks(self, masks, colors, im_gpu=None, alpha=0.5):
113
+ """Plot masks at once.
114
+ Args:
115
+ masks (tensor): predicted masks on cuda, shape: [n, h, w]
116
+ colors (List[List[Int]]): colors for predicted masks, [[r, g, b] * n]
117
+ im_gpu (tensor): img is in cuda, shape: [3, h, w], range: [0, 1]
118
+ alpha (float): mask transparency: 0.0 fully transparent, 1.0 opaque
119
+ """
120
+ if self.pil:
121
+ # convert to numpy first
122
+ self.im = np.asarray(self.im).copy()
123
+ if im_gpu is None:
124
+ # Add multiple masks of shape(h,w,n) with colors list([r,g,b], [r,g,b], ...)
125
+ if len(masks) == 0:
126
+ return
127
+ if isinstance(masks, torch.Tensor):
128
+ masks = torch.as_tensor(masks, dtype=torch.uint8)
129
+ masks = masks.permute(1, 2, 0).contiguous()
130
+ masks = masks.cpu().numpy()
131
+ # masks = np.ascontiguousarray(masks.transpose(1, 2, 0))
132
+ masks = scale_image(masks.shape[:2], masks, self.im.shape)
133
+ masks = np.asarray(masks, dtype=np.float32)
134
+ colors = np.asarray(colors, dtype=np.float32) # shape(n,3)
135
+ s = masks.sum(2, keepdims=True).clip(0, 1) # add all masks together
136
+ masks = (masks @ colors).clip(0, 255) # (h,w,n) @ (n,3) = (h,w,3)
137
+ self.im[:] = masks * alpha + self.im * (1 - s * alpha)
138
+ else:
139
+ if len(masks) == 0:
140
+ self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
141
+ colors = torch.tensor(colors, device=im_gpu.device, dtype=torch.float32) / 255.0
142
+ colors = colors[:, None, None] # shape(n,1,1,3)
143
+ masks = masks.unsqueeze(3) # shape(n,h,w,1)
144
+ masks_color = masks * (colors * alpha) # shape(n,h,w,3)
145
+
146
+ inv_alph_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
147
+ mcs = (masks_color * inv_alph_masks).sum(0) * 2 # mask color summand shape(n,h,w,3)
148
+
149
+ im_gpu = im_gpu.flip(dims=[0]) # flip channel
150
+ im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
151
+ im_gpu = im_gpu * inv_alph_masks[-1] + mcs
152
+ im_mask = (im_gpu * 255).byte().cpu().numpy()
153
+ self.im[:] = scale_image(im_gpu.shape, im_mask, self.im.shape)
154
+ if self.pil:
155
+ # convert im back to PIL and update draw
156
+ self.fromarray(self.im)
157
+
158
+ def rectangle(self, xy, fill=None, outline=None, width=1):
159
+ # Add rectangle to image (PIL-only)
160
+ self.draw.rectangle(xy, fill, outline, width)
161
+
162
+ def text(self, xy, text, txt_color=(255, 255, 255), anchor='top'):
163
+ # Add text to image (PIL-only)
164
+ if anchor == 'bottom': # start y from font bottom
165
+ w, h = self.font.getsize(text) # text width, height
166
+ xy[1] += 1 - h
167
+ self.draw.text(xy, text, fill=txt_color, font=self.font)
168
+
169
+ def fromarray(self, im):
170
+ # Update self.im from a numpy array
171
+ self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
172
+ self.draw = ImageDraw.Draw(self.im)
173
+
174
+ def result(self):
175
+ # Return annotated image as array
176
+ return np.asarray(self.im)
177
+
178
+
179
+ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):
180
+ """
181
+ x: Features to be visualized
182
+ module_type: Module type
183
+ stage: Module stage within model
184
+ n: Maximum number of feature maps to plot
185
+ save_dir: Directory to save results
186
+ """
187
+ if 'Detect' not in module_type:
188
+ batch, channels, height, width = x.shape # batch, channels, height, width
189
+ if height > 1 and width > 1:
190
+ f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
191
+
192
+ blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
193
+ n = min(n, channels) # number of plots
194
+ fig, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols
195
+ ax = ax.ravel()
196
+ plt.subplots_adjust(wspace=0.05, hspace=0.05)
197
+ for i in range(n):
198
+ ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
199
+ ax[i].axis('off')
200
+
201
+ LOGGER.info(f'Saving {f}... ({n}/{channels})')
202
+ plt.savefig(f, dpi=300, bbox_inches='tight')
203
+ plt.close()
204
+ np.save(str(f.with_suffix('.npy')), x[0].cpu().numpy()) # npy save
205
+
206
+
207
+ def hist2d(x, y, n=100):
208
+ # 2d histogram used in labels.png and evolve.png
209
+ xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
210
+ hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
211
+ xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
212
+ yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
213
+ return np.log(hist[xidx, yidx])
214
+
215
+
216
+ def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
217
+ from scipy.signal import butter, filtfilt
218
+
219
+ # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
220
+ def butter_lowpass(cutoff, fs, order):
221
+ nyq = 0.5 * fs
222
+ normal_cutoff = cutoff / nyq
223
+ return butter(order, normal_cutoff, btype='low', analog=False)
224
+
225
+ b, a = butter_lowpass(cutoff, fs, order=order)
226
+ return filtfilt(b, a, data) # forward-backward filter
227
+
228
+
229
+ def output_to_target(output, max_det=300):
230
+ # Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting
231
+ targets = []
232
+ for i, o in enumerate(output):
233
+ box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)
234
+ j = torch.full((conf.shape[0], 1), i)
235
+ targets.append(torch.cat((j, cls, xyxy2xywh(box), conf), 1))
236
+ return torch.cat(targets, 0).numpy()
237
+
238
+
239
+ @threaded
240
+ def plot_images(images, targets, paths=None, fname='images.jpg', names=None):
241
+ # Plot image grid with labels
242
+ if isinstance(images, torch.Tensor):
243
+ images = images.cpu().float().numpy()
244
+ if isinstance(targets, torch.Tensor):
245
+ targets = targets.cpu().numpy()
246
+
247
+ max_size = 1920 # max image size
248
+ max_subplots = 16 # max image subplots, i.e. 4x4
249
+ bs, _, h, w = images.shape # batch size, _, height, width
250
+ bs = min(bs, max_subplots) # limit plot images
251
+ ns = np.ceil(bs ** 0.5) # number of subplots (square)
252
+ if np.max(images[0]) <= 1:
253
+ images *= 255 # de-normalise (optional)
254
+
255
+ # Build Image
256
+ mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
257
+ for i, im in enumerate(images):
258
+ if i == max_subplots: # if last batch has fewer images than we expect
259
+ break
260
+ x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
261
+ im = im.transpose(1, 2, 0)
262
+ mosaic[y:y + h, x:x + w, :] = im
263
+
264
+ # Resize (optional)
265
+ scale = max_size / ns / max(h, w)
266
+ if scale < 1:
267
+ h = math.ceil(scale * h)
268
+ w = math.ceil(scale * w)
269
+ mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
270
+
271
+ # Annotate
272
+ fs = int((h + w) * ns * 0.01) # font size
273
+ annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
274
+ for i in range(i + 1):
275
+ x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
276
+ annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
277
+ if paths:
278
+ annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
279
+ if len(targets) > 0:
280
+ ti = targets[targets[:, 0] == i] # image targets
281
+ boxes = xywh2xyxy(ti[:, 2:6]).T
282
+ classes = ti[:, 1].astype('int')
283
+ labels = ti.shape[1] == 6 # labels if no conf column
284
+ conf = None if labels else ti[:, 6] # check for confidence presence (label vs pred)
285
+
286
+ if boxes.shape[1]:
287
+ if boxes.max() <= 1.01: # if normalized with tolerance 0.01
288
+ boxes[[0, 2]] *= w # scale to pixels
289
+ boxes[[1, 3]] *= h
290
+ elif scale < 1: # absolute coords need scale if image scales
291
+ boxes *= scale
292
+ boxes[[0, 2]] += x
293
+ boxes[[1, 3]] += y
294
+ for j, box in enumerate(boxes.T.tolist()):
295
+ cls = classes[j]
296
+ color = colors(cls)
297
+ cls = names[cls] if names else cls
298
+ if labels or conf[j] > 0.25: # 0.25 conf thresh
299
+ label = f'{cls}' if labels else f'{cls} {conf[j]:.1f}'
300
+ annotator.box_label(box, label, color=color)
301
+ annotator.im.save(fname) # save
302
+
303
+
304
+ def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
305
+ # Plot LR simulating training for full epochs
306
+ optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals
307
+ y = []
308
+ for _ in range(epochs):
309
+ scheduler.step()
310
+ y.append(optimizer.param_groups[0]['lr'])
311
+ plt.plot(y, '.-', label='LR')
312
+ plt.xlabel('epoch')
313
+ plt.ylabel('LR')
314
+ plt.grid()
315
+ plt.xlim(0, epochs)
316
+ plt.ylim(0)
317
+ plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
318
+ plt.close()
319
+
320
+
321
+ def plot_val_txt(): # from utils.plots import *; plot_val()
322
+ # Plot val.txt histograms
323
+ x = np.loadtxt('val.txt', dtype=np.float32)
324
+ box = xyxy2xywh(x[:, :4])
325
+ cx, cy = box[:, 0], box[:, 1]
326
+
327
+ fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
328
+ ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
329
+ ax.set_aspect('equal')
330
+ plt.savefig('hist2d.png', dpi=300)
331
+
332
+ fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
333
+ ax[0].hist(cx, bins=600)
334
+ ax[1].hist(cy, bins=600)
335
+ plt.savefig('hist1d.png', dpi=200)
336
+
337
+
338
+ def plot_targets_txt(): # from utils.plots import *; plot_targets_txt()
339
+ # Plot targets.txt histograms
340
+ x = np.loadtxt('targets.txt', dtype=np.float32).T
341
+ s = ['x targets', 'y targets', 'width targets', 'height targets']
342
+ fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
343
+ ax = ax.ravel()
344
+ for i in range(4):
345
+ ax[i].hist(x[i], bins=100, label=f'{x[i].mean():.3g} +/- {x[i].std():.3g}')
346
+ ax[i].legend()
347
+ ax[i].set_title(s[i])
348
+ plt.savefig('targets.jpg', dpi=200)
349
+
350
+
351
+ def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_val_study()
352
+ # Plot file=study.txt generated by val.py (or plot all study*.txt in dir)
353
+ save_dir = Path(file).parent if file else Path(dir)
354
+ plot2 = False # plot additional results
355
+ if plot2:
356
+ ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)[1].ravel()
357
+
358
+ fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
359
+ # for f in [save_dir / f'study_coco_{x}.txt' for x in ['yolov5n6', 'yolov5s6', 'yolov5m6', 'yolov5l6', 'yolov5x6']]:
360
+ for f in sorted(save_dir.glob('study*.txt')):
361
+ y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
362
+ x = np.arange(y.shape[1]) if x is None else np.array(x)
363
+ if plot2:
364
+ s = ['P', 'R', '[email protected]', '[email protected]:.95', 't_preprocess (ms/img)', 't_inference (ms/img)', 't_NMS (ms/img)']
365
+ for i in range(7):
366
+ ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
367
+ ax[i].set_title(s[i])
368
+
369
+ j = y[3].argmax() + 1
370
+ ax2.plot(y[5, 1:j],
371
+ y[3, 1:j] * 1E2,
372
+ '.-',
373
+ linewidth=2,
374
+ markersize=8,
375
+ label=f.stem.replace('study_coco_', '').replace('yolo', 'YOLO'))
376
+
377
+ ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5],
378
+ 'k.-',
379
+ linewidth=2,
380
+ markersize=8,
381
+ alpha=.25,
382
+ label='EfficientDet')
383
+
384
+ ax2.grid(alpha=0.2)
385
+ ax2.set_yticks(np.arange(20, 60, 5))
386
+ ax2.set_xlim(0, 57)
387
+ ax2.set_ylim(25, 55)
388
+ ax2.set_xlabel('GPU Speed (ms/img)')
389
+ ax2.set_ylabel('COCO AP val')
390
+ ax2.legend(loc='lower right')
391
+ f = save_dir / 'study.png'
392
+ print(f'Saving {f}...')
393
+ plt.savefig(f, dpi=300)
394
+
395
+
396
+ @TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
397
+ def plot_labels(labels, names=(), save_dir=Path('')):
398
+ # plot dataset labels
399
+ LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
400
+ c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
401
+ nc = int(c.max() + 1) # number of classes
402
+ x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
403
+
404
+ # seaborn correlogram
405
+ sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
406
+ plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
407
+ plt.close()
408
+
409
+ # matplotlib labels
410
+ matplotlib.use('svg') # faster
411
+ ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
412
+ y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
413
+ with contextlib.suppress(Exception): # color histogram bars by class
414
+ [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # known issue #3195
415
+ ax[0].set_ylabel('instances')
416
+ if 0 < len(names) < 30:
417
+ ax[0].set_xticks(range(len(names)))
418
+ ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
419
+ else:
420
+ ax[0].set_xlabel('classes')
421
+ sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
422
+ sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
423
+
424
+ # rectangles
425
+ labels[:, 1:3] = 0.5 # center
426
+ labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
427
+ img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
428
+ for cls, *box in labels[:1000]:
429
+ ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
430
+ ax[1].imshow(img)
431
+ ax[1].axis('off')
432
+
433
+ for a in [0, 1, 2, 3]:
434
+ for s in ['top', 'right', 'left', 'bottom']:
435
+ ax[a].spines[s].set_visible(False)
436
+
437
+ plt.savefig(save_dir / 'labels.jpg', dpi=200)
438
+ matplotlib.use('Agg')
439
+ plt.close()
440
+
441
+
442
+ def imshow_cls(im, labels=None, pred=None, names=None, nmax=25, verbose=False, f=Path('images.jpg')):
443
+ # Show classification image grid with labels (optional) and predictions (optional)
444
+ from utils.augmentations import denormalize
445
+
446
+ names = names or [f'class{i}' for i in range(1000)]
447
+ blocks = torch.chunk(denormalize(im.clone()).cpu().float(), len(im),
448
+ dim=0) # select batch index 0, block by channels
449
+ n = min(len(blocks), nmax) # number of plots
450
+ m = min(8, round(n ** 0.5)) # 8 x 8 default
451
+ fig, ax = plt.subplots(math.ceil(n / m), m) # 8 rows x n/8 cols
452
+ ax = ax.ravel() if m > 1 else [ax]
453
+ # plt.subplots_adjust(wspace=0.05, hspace=0.05)
454
+ for i in range(n):
455
+ ax[i].imshow(blocks[i].squeeze().permute((1, 2, 0)).numpy().clip(0.0, 1.0))
456
+ ax[i].axis('off')
457
+ if labels is not None:
458
+ s = names[labels[i]] + (f'—{names[pred[i]]}' if pred is not None else '')
459
+ ax[i].set_title(s, fontsize=8, verticalalignment='top')
460
+ plt.savefig(f, dpi=300, bbox_inches='tight')
461
+ plt.close()
462
+ if verbose:
463
+ LOGGER.info(f"Saving {f}")
464
+ if labels is not None:
465
+ LOGGER.info('True: ' + ' '.join(f'{names[i]:3s}' for i in labels[:nmax]))
466
+ if pred is not None:
467
+ LOGGER.info('Predicted:' + ' '.join(f'{names[i]:3s}' for i in pred[:nmax]))
468
+ return f
469
+
470
+
471
+ def plot_evolve(evolve_csv='path/to/evolve.csv'): # from utils.plots import *; plot_evolve()
472
+ # Plot evolve.csv hyp evolution results
473
+ evolve_csv = Path(evolve_csv)
474
+ data = pd.read_csv(evolve_csv)
475
+ keys = [x.strip() for x in data.columns]
476
+ x = data.values
477
+ f = fitness(x)
478
+ j = np.argmax(f) # max fitness index
479
+ plt.figure(figsize=(10, 12), tight_layout=True)
480
+ matplotlib.rc('font', **{'size': 8})
481
+ print(f'Best results from row {j} of {evolve_csv}:')
482
+ for i, k in enumerate(keys[7:]):
483
+ v = x[:, 7 + i]
484
+ mu = v[j] # best single result
485
+ plt.subplot(6, 5, i + 1)
486
+ plt.scatter(v, f, c=hist2d(v, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
487
+ plt.plot(mu, f.max(), 'k+', markersize=15)
488
+ plt.title(f'{k} = {mu:.3g}', fontdict={'size': 9}) # limit to 40 characters
489
+ if i % 5 != 0:
490
+ plt.yticks([])
491
+ print(f'{k:>15}: {mu:.3g}')
492
+ f = evolve_csv.with_suffix('.png') # filename
493
+ plt.savefig(f, dpi=200)
494
+ plt.close()
495
+ print(f'Saved {f}')
496
+
497
+
498
+ def plot_results(file='path/to/results.csv', dir=''):
499
+ # Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
500
+ save_dir = Path(file).parent if file else Path(dir)
501
+ fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
502
+ ax = ax.ravel()
503
+ files = list(save_dir.glob('results*.csv'))
504
+ assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
505
+ for f in files:
506
+ try:
507
+ data = pd.read_csv(f)
508
+ s = [x.strip() for x in data.columns]
509
+ x = data.values[:, 0]
510
+ for i, j in enumerate([1, 2, 3, 4, 5, 8, 9, 10, 6, 7]):
511
+ y = data.values[:, j].astype('float')
512
+ # y[y == 0] = np.nan # don't show zero values
513
+ ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8)
514
+ ax[i].set_title(s[j], fontsize=12)
515
+ # if j in [8, 9, 10]: # share train and val loss y axes
516
+ # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
517
+ except Exception as e:
518
+ LOGGER.info(f'Warning: Plotting error for {f}: {e}')
519
+ ax[1].legend()
520
+ fig.savefig(save_dir / 'results.png', dpi=200)
521
+ plt.close()
522
+
523
+
524
+ def profile_idetection(start=0, stop=0, labels=(), save_dir=''):
525
+ # Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection()
526
+ ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel()
527
+ s = ['Images', 'Free Storage (GB)', 'RAM Usage (GB)', 'Battery', 'dt_raw (ms)', 'dt_smooth (ms)', 'real-world FPS']
528
+ files = list(Path(save_dir).glob('frames*.txt'))
529
+ for fi, f in enumerate(files):
530
+ try:
531
+ results = np.loadtxt(f, ndmin=2).T[:, 90:-30] # clip first and last rows
532
+ n = results.shape[1] # number of rows
533
+ x = np.arange(start, min(stop, n) if stop else n)
534
+ results = results[:, x]
535
+ t = (results[0] - results[0].min()) # set t0=0s
536
+ results[0] = x
537
+ for i, a in enumerate(ax):
538
+ if i < len(results):
539
+ label = labels[fi] if len(labels) else f.stem.replace('frames_', '')
540
+ a.plot(t, results[i], marker='.', label=label, linewidth=1, markersize=5)
541
+ a.set_title(s[i])
542
+ a.set_xlabel('time (s)')
543
+ # if fi == len(files) - 1:
544
+ # a.set_ylim(bottom=0)
545
+ for side in ['top', 'right']:
546
+ a.spines[side].set_visible(False)
547
+ else:
548
+ a.remove()
549
+ except Exception as e:
550
+ print(f'Warning: Plotting error for {f}; {e}')
551
+ ax[1].legend()
552
+ plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200)
553
+
554
+
555
+ def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
556
+ # Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
557
+ xyxy = torch.tensor(xyxy).view(-1, 4)
558
+ b = xyxy2xywh(xyxy) # boxes
559
+ if square:
560
+ b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
561
+ b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
562
+ xyxy = xywh2xyxy(b).long()
563
+ clip_boxes(xyxy, im.shape)
564
+ crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
565
+ if save:
566
+ file.parent.mkdir(parents=True, exist_ok=True) # make directory
567
+ f = str(increment_path(file).with_suffix('.jpg'))
568
+ # cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
569
+ Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB
570
+ return crop
utils/torch_utils.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import platform
4
+ import subprocess
5
+ import time
6
+ import warnings
7
+ from contextlib import contextmanager
8
+ from copy import deepcopy
9
+ from pathlib import Path
10
+
11
+ import torch
12
+ import torch.distributed as dist
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from torch.nn.parallel import DistributedDataParallel as DDP
16
+
17
+ from utils.general import LOGGER, check_version, colorstr, file_date, git_describe
18
+ from utils.lion import Lion
19
+
20
+ LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
21
+ RANK = int(os.getenv('RANK', -1))
22
+ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
23
+
24
+ try:
25
+ import thop # for FLOPs computation
26
+ except ImportError:
27
+ thop = None
28
+
29
+ # Suppress PyTorch warnings
30
+ warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling')
31
+ warnings.filterwarnings('ignore', category=UserWarning)
32
+
33
+
34
+ def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
35
+ # Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator
36
+ def decorate(fn):
37
+ return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn)
38
+
39
+ return decorate
40
+
41
+
42
+ def smartCrossEntropyLoss(label_smoothing=0.0):
43
+ # Returns nn.CrossEntropyLoss with label smoothing enabled for torch>=1.10.0
44
+ if check_version(torch.__version__, '1.10.0'):
45
+ return nn.CrossEntropyLoss(label_smoothing=label_smoothing)
46
+ if label_smoothing > 0:
47
+ LOGGER.warning(f'WARNING ⚠️ label smoothing {label_smoothing} requires torch>=1.10.0')
48
+ return nn.CrossEntropyLoss()
49
+
50
+
51
+ def smart_DDP(model):
52
+ # Model DDP creation with checks
53
+ assert not check_version(torch.__version__, '1.12.0', pinned=True), \
54
+ 'torch==1.12.0 torchvision==0.13.0 DDP training is not supported due to a known issue. ' \
55
+ 'Please upgrade or downgrade torch to use DDP. See https://github.com/ultralytics/yolov5/issues/8395'
56
+ if check_version(torch.__version__, '1.11.0'):
57
+ return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
58
+ else:
59
+ return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
60
+
61
+
62
+ def reshape_classifier_output(model, n=1000):
63
+ # Update a TorchVision classification model to class count 'n' if required
64
+ from models.common import Classify
65
+ name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1] # last module
66
+ if isinstance(m, Classify): # YOLOv5 Classify() head
67
+ if m.linear.out_features != n:
68
+ m.linear = nn.Linear(m.linear.in_features, n)
69
+ elif isinstance(m, nn.Linear): # ResNet, EfficientNet
70
+ if m.out_features != n:
71
+ setattr(model, name, nn.Linear(m.in_features, n))
72
+ elif isinstance(m, nn.Sequential):
73
+ types = [type(x) for x in m]
74
+ if nn.Linear in types:
75
+ i = types.index(nn.Linear) # nn.Linear index
76
+ if m[i].out_features != n:
77
+ m[i] = nn.Linear(m[i].in_features, n)
78
+ elif nn.Conv2d in types:
79
+ i = types.index(nn.Conv2d) # nn.Conv2d index
80
+ if m[i].out_channels != n:
81
+ m[i] = nn.Conv2d(m[i].in_channels, n, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
82
+
83
+
84
+ @contextmanager
85
+ def torch_distributed_zero_first(local_rank: int):
86
+ # Decorator to make all processes in distributed training wait for each local_master to do something
87
+ if local_rank not in [-1, 0]:
88
+ dist.barrier(device_ids=[local_rank])
89
+ yield
90
+ if local_rank == 0:
91
+ dist.barrier(device_ids=[0])
92
+
93
+
94
+ def device_count():
95
+ # Returns number of CUDA devices available. Safe version of torch.cuda.device_count(). Supports Linux and Windows
96
+ assert platform.system() in ('Linux', 'Windows'), 'device_count() only supported on Linux or Windows'
97
+ try:
98
+ cmd = 'nvidia-smi -L | wc -l' if platform.system() == 'Linux' else 'nvidia-smi -L | find /c /v ""' # Windows
99
+ return int(subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1])
100
+ except Exception:
101
+ return 0
102
+
103
+
104
+ def select_device(device='', batch_size=0, newline=True):
105
+ # device = None or 'cpu' or 0 or '0' or '0,1,2,3'
106
+ s = f'YOLO 🚀 {git_describe() or file_date()} Python-{platform.python_version()} torch-{torch.__version__} '
107
+ device = str(device).strip().lower().replace('cuda:', '').replace('none', '') # to string, 'cuda:0' to '0'
108
+ cpu = device == 'cpu'
109
+ mps = device == 'mps' # Apple Metal Performance Shaders (MPS)
110
+ if cpu or mps:
111
+ os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
112
+ elif device: # non-cpu device requested
113
+ os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
114
+ assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
115
+ f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
116
+
117
+ if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
118
+ devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
119
+ n = len(devices) # device count
120
+ if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
121
+ assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
122
+ space = ' ' * (len(s) + 1)
123
+ for i, d in enumerate(devices):
124
+ p = torch.cuda.get_device_properties(i)
125
+ s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
126
+ arg = 'cuda:0'
127
+ elif mps and getattr(torch, 'has_mps', False) and torch.backends.mps.is_available(): # prefer MPS if available
128
+ s += 'MPS\n'
129
+ arg = 'mps'
130
+ else: # revert to CPU
131
+ s += 'CPU\n'
132
+ arg = 'cpu'
133
+
134
+ if not newline:
135
+ s = s.rstrip()
136
+ LOGGER.info(s)
137
+ return torch.device(arg)
138
+
139
+
140
+ def time_sync():
141
+ # PyTorch-accurate time
142
+ if torch.cuda.is_available():
143
+ torch.cuda.synchronize()
144
+ return time.time()
145
+
146
+
147
+ def profile(input, ops, n=10, device=None):
148
+ """ YOLOv5 speed/memory/FLOPs profiler
149
+ Usage:
150
+ input = torch.randn(16, 3, 640, 640)
151
+ m1 = lambda x: x * torch.sigmoid(x)
152
+ m2 = nn.SiLU()
153
+ profile(input, [m1, m2], n=100) # profile over 100 iterations
154
+ """
155
+ results = []
156
+ if not isinstance(device, torch.device):
157
+ device = select_device(device)
158
+ print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
159
+ f"{'input':>24s}{'output':>24s}")
160
+
161
+ for x in input if isinstance(input, list) else [input]:
162
+ x = x.to(device)
163
+ x.requires_grad = True
164
+ for m in ops if isinstance(ops, list) else [ops]:
165
+ m = m.to(device) if hasattr(m, 'to') else m # device
166
+ m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
167
+ tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
168
+ try:
169
+ flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs
170
+ except Exception:
171
+ flops = 0
172
+
173
+ try:
174
+ for _ in range(n):
175
+ t[0] = time_sync()
176
+ y = m(x)
177
+ t[1] = time_sync()
178
+ try:
179
+ _ = (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
180
+ t[2] = time_sync()
181
+ except Exception: # no backward method
182
+ # print(e) # for debug
183
+ t[2] = float('nan')
184
+ tf += (t[1] - t[0]) * 1000 / n # ms per op forward
185
+ tb += (t[2] - t[1]) * 1000 / n # ms per op backward
186
+ mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
187
+ s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' for x in (x, y)) # shapes
188
+ p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
189
+ print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
190
+ results.append([p, flops, mem, tf, tb, s_in, s_out])
191
+ except Exception as e:
192
+ print(e)
193
+ results.append(None)
194
+ torch.cuda.empty_cache()
195
+ return results
196
+
197
+
198
+ def is_parallel(model):
199
+ # Returns True if model is of type DP or DDP
200
+ return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
201
+
202
+
203
+ def de_parallel(model):
204
+ # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
205
+ return model.module if is_parallel(model) else model
206
+
207
+
208
+ def initialize_weights(model):
209
+ for m in model.modules():
210
+ t = type(m)
211
+ if t is nn.Conv2d:
212
+ pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
213
+ elif t is nn.BatchNorm2d:
214
+ m.eps = 1e-3
215
+ m.momentum = 0.03
216
+ elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
217
+ m.inplace = True
218
+
219
+
220
+ def find_modules(model, mclass=nn.Conv2d):
221
+ # Finds layer indices matching module class 'mclass'
222
+ return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
223
+
224
+
225
+ def sparsity(model):
226
+ # Return global model sparsity
227
+ a, b = 0, 0
228
+ for p in model.parameters():
229
+ a += p.numel()
230
+ b += (p == 0).sum()
231
+ return b / a
232
+
233
+
234
+ def prune(model, amount=0.3):
235
+ # Prune model to requested global sparsity
236
+ import torch.nn.utils.prune as prune
237
+ for name, m in model.named_modules():
238
+ if isinstance(m, nn.Conv2d):
239
+ prune.l1_unstructured(m, name='weight', amount=amount) # prune
240
+ prune.remove(m, 'weight') # make permanent
241
+ LOGGER.info(f'Model pruned to {sparsity(model):.3g} global sparsity')
242
+
243
+
244
+ def fuse_conv_and_bn(conv, bn):
245
+ # Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
246
+ fusedconv = nn.Conv2d(conv.in_channels,
247
+ conv.out_channels,
248
+ kernel_size=conv.kernel_size,
249
+ stride=conv.stride,
250
+ padding=conv.padding,
251
+ dilation=conv.dilation,
252
+ groups=conv.groups,
253
+ bias=True).requires_grad_(False).to(conv.weight.device)
254
+
255
+ # Prepare filters
256
+ w_conv = conv.weight.clone().view(conv.out_channels, -1)
257
+ w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
258
+ fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
259
+
260
+ # Prepare spatial bias
261
+ b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
262
+ b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
263
+ fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
264
+
265
+ return fusedconv
266
+
267
+
268
+ def model_info(model, verbose=False, imgsz=640):
269
+ # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
270
+ n_p = sum(x.numel() for x in model.parameters()) # number parameters
271
+ n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
272
+ if verbose:
273
+ print(f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
274
+ for i, (name, p) in enumerate(model.named_parameters()):
275
+ name = name.replace('module_list.', '')
276
+ print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
277
+ (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
278
+
279
+ try: # FLOPs
280
+ p = next(model.parameters())
281
+ stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride
282
+ im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
283
+ flops = thop.profile(deepcopy(model), inputs=(im,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
284
+ imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
285
+ fs = f', {flops * imgsz[0] / stride * imgsz[1] / stride:.1f} GFLOPs' # 640x640 GFLOPs
286
+ except Exception:
287
+ fs = ''
288
+
289
+ name = Path(model.yaml_file).stem.replace('yolov5', 'YOLOv5') if hasattr(model, 'yaml_file') else 'Model'
290
+ LOGGER.info(f"{name} summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
291
+
292
+
293
+ def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
294
+ # Scales img(bs,3,y,x) by ratio constrained to gs-multiple
295
+ if ratio == 1.0:
296
+ return img
297
+ h, w = img.shape[2:]
298
+ s = (int(h * ratio), int(w * ratio)) # new size
299
+ img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
300
+ if not same_shape: # pad/crop img
301
+ h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
302
+ return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
303
+
304
+
305
+ def copy_attr(a, b, include=(), exclude=()):
306
+ # Copy attributes from b to a, options to only include [...] and to exclude [...]
307
+ for k, v in b.__dict__.items():
308
+ if (len(include) and k not in include) or k.startswith('_') or k in exclude:
309
+ continue
310
+ else:
311
+ setattr(a, k, v)
312
+
313
+
314
+ def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
315
+ # YOLOv5 3-param group optimizer: 0) weights with decay, 1) weights no decay, 2) biases no decay
316
+ g = [], [], [] # optimizer parameter groups
317
+ bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
318
+ #for v in model.modules():
319
+ # for p_name, p in v.named_parameters(recurse=0):
320
+ # if p_name == 'bias': # bias (no decay)
321
+ # g[2].append(p)
322
+ # elif p_name == 'weight' and isinstance(v, bn): # weight (no decay)
323
+ # g[1].append(p)
324
+ # else:
325
+ # g[0].append(p) # weight (with decay)
326
+
327
+ for v in model.modules():
328
+ if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias (no decay)
329
+ g[2].append(v.bias)
330
+ if isinstance(v, bn): # weight (no decay)
331
+ g[1].append(v.weight)
332
+ elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
333
+ g[0].append(v.weight)
334
+
335
+ if hasattr(v, 'im'):
336
+ if hasattr(v.im, 'implicit'):
337
+ g[1].append(v.im.implicit)
338
+ else:
339
+ for iv in v.im:
340
+ g[1].append(iv.implicit)
341
+ if hasattr(v, 'ia'):
342
+ if hasattr(v.ia, 'implicit'):
343
+ g[1].append(v.ia.implicit)
344
+ else:
345
+ for iv in v.ia:
346
+ g[1].append(iv.implicit)
347
+
348
+ if hasattr(v, 'im2'):
349
+ if hasattr(v.im2, 'implicit'):
350
+ g[1].append(v.im2.implicit)
351
+ else:
352
+ for iv in v.im2:
353
+ g[1].append(iv.implicit)
354
+ if hasattr(v, 'ia2'):
355
+ if hasattr(v.ia2, 'implicit'):
356
+ g[1].append(v.ia2.implicit)
357
+ else:
358
+ for iv in v.ia2:
359
+ g[1].append(iv.implicit)
360
+
361
+ if hasattr(v, 'im3'):
362
+ if hasattr(v.im3, 'implicit'):
363
+ g[1].append(v.im3.implicit)
364
+ else:
365
+ for iv in v.im3:
366
+ g[1].append(iv.implicit)
367
+ if hasattr(v, 'ia3'):
368
+ if hasattr(v.ia3, 'implicit'):
369
+ g[1].append(v.ia3.implicit)
370
+ else:
371
+ for iv in v.ia3:
372
+ g[1].append(iv.implicit)
373
+
374
+ if hasattr(v, 'im4'):
375
+ if hasattr(v.im4, 'implicit'):
376
+ g[1].append(v.im4.implicit)
377
+ else:
378
+ for iv in v.im4:
379
+ g[1].append(iv.implicit)
380
+ if hasattr(v, 'ia4'):
381
+ if hasattr(v.ia4, 'implicit'):
382
+ g[1].append(v.ia4.implicit)
383
+ else:
384
+ for iv in v.ia4:
385
+ g[1].append(iv.implicit)
386
+
387
+ if hasattr(v, 'im5'):
388
+ if hasattr(v.im5, 'implicit'):
389
+ g[1].append(v.im5.implicit)
390
+ else:
391
+ for iv in v.im5:
392
+ g[1].append(iv.implicit)
393
+ if hasattr(v, 'ia5'):
394
+ if hasattr(v.ia5, 'implicit'):
395
+ g[1].append(v.ia5.implicit)
396
+ else:
397
+ for iv in v.ia5:
398
+ g[1].append(iv.implicit)
399
+
400
+ if hasattr(v, 'im6'):
401
+ if hasattr(v.im6, 'implicit'):
402
+ g[1].append(v.im6.implicit)
403
+ else:
404
+ for iv in v.im6:
405
+ g[1].append(iv.implicit)
406
+ if hasattr(v, 'ia6'):
407
+ if hasattr(v.ia6, 'implicit'):
408
+ g[1].append(v.ia6.implicit)
409
+ else:
410
+ for iv in v.ia6:
411
+ g[1].append(iv.implicit)
412
+
413
+ if hasattr(v, 'im7'):
414
+ if hasattr(v.im7, 'implicit'):
415
+ g[1].append(v.im7.implicit)
416
+ else:
417
+ for iv in v.im7:
418
+ g[1].append(iv.implicit)
419
+ if hasattr(v, 'ia7'):
420
+ if hasattr(v.ia7, 'implicit'):
421
+ g[1].append(v.ia7.implicit)
422
+ else:
423
+ for iv in v.ia7:
424
+ g[1].append(iv.implicit)
425
+
426
+ if name == 'Adam':
427
+ optimizer = torch.optim.Adam(g[2], lr=lr, betas=(momentum, 0.999)) # adjust beta1 to momentum
428
+ elif name == 'AdamW':
429
+ optimizer = torch.optim.AdamW(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0, amsgrad=True)
430
+ elif name == 'RMSProp':
431
+ optimizer = torch.optim.RMSprop(g[2], lr=lr, momentum=momentum)
432
+ elif name == 'SGD':
433
+ optimizer = torch.optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
434
+ elif name == 'LION':
435
+ optimizer = Lion(g[2], lr=lr, betas=(momentum, 0.99), weight_decay=0.0)
436
+ else:
437
+ raise NotImplementedError(f'Optimizer {name} not implemented.')
438
+
439
+ optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
440
+ optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
441
+ LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
442
+ f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias")
443
+ return optimizer
444
+
445
+
446
+ def smart_hub_load(repo='ultralytics/yolov5', model='yolov5s', **kwargs):
447
+ # YOLOv5 torch.hub.load() wrapper with smart error/issue handling
448
+ if check_version(torch.__version__, '1.9.1'):
449
+ kwargs['skip_validation'] = True # validation causes GitHub API rate limit errors
450
+ if check_version(torch.__version__, '1.12.0'):
451
+ kwargs['trust_repo'] = True # argument required starting in torch 0.12
452
+ try:
453
+ return torch.hub.load(repo, model, **kwargs)
454
+ except Exception:
455
+ return torch.hub.load(repo, model, force_reload=True, **kwargs)
456
+
457
+
458
+ def smart_resume(ckpt, optimizer, ema=None, weights='yolov5s.pt', epochs=300, resume=True):
459
+ # Resume training from a partially trained checkpoint
460
+ best_fitness = 0.0
461
+ start_epoch = ckpt['epoch'] + 1
462
+ if ckpt['optimizer'] is not None:
463
+ optimizer.load_state_dict(ckpt['optimizer']) # optimizer
464
+ best_fitness = ckpt['best_fitness']
465
+ if ema and ckpt.get('ema'):
466
+ ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
467
+ ema.updates = ckpt['updates']
468
+ if resume:
469
+ assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.\n' \
470
+ f"Start a new training without --resume, i.e. 'python train.py --weights {weights}'"
471
+ LOGGER.info(f'Resuming training from {weights} from epoch {start_epoch} to {epochs} total epochs')
472
+ if epochs < start_epoch:
473
+ LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
474
+ epochs += ckpt['epoch'] # finetune additional epochs
475
+ return best_fitness, start_epoch, epochs
476
+
477
+
478
+ class EarlyStopping:
479
+ # YOLOv5 simple early stopper
480
+ def __init__(self, patience=30):
481
+ self.best_fitness = 0.0 # i.e. mAP
482
+ self.best_epoch = 0
483
+ self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
484
+ self.possible_stop = False # possible stop may occur next epoch
485
+
486
+ def __call__(self, epoch, fitness):
487
+ if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
488
+ self.best_epoch = epoch
489
+ self.best_fitness = fitness
490
+ delta = epoch - self.best_epoch # epochs without improvement
491
+ self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
492
+ stop = delta >= self.patience # stop training if patience exceeded
493
+ if stop:
494
+ LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. '
495
+ f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n'
496
+ f'To update EarlyStopping(patience={self.patience}) pass a new patience value, '
497
+ f'i.e. `python train.py --patience 300` or use `--patience 0` to disable EarlyStopping.')
498
+ return stop
499
+
500
+
501
+ class ModelEMA:
502
+ """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
503
+ Keeps a moving average of everything in the model state_dict (parameters and buffers)
504
+ For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
505
+ """
506
+
507
+ def __init__(self, model, decay=0.9999, tau=2000, updates=0):
508
+ # Create EMA
509
+ self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
510
+ self.updates = updates # number of EMA updates
511
+ self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
512
+ for p in self.ema.parameters():
513
+ p.requires_grad_(False)
514
+
515
+ def update(self, model):
516
+ # Update EMA parameters
517
+ self.updates += 1
518
+ d = self.decay(self.updates)
519
+
520
+ msd = de_parallel(model).state_dict() # model state_dict
521
+ for k, v in self.ema.state_dict().items():
522
+ if v.dtype.is_floating_point: # true for FP16 and FP32
523
+ v *= d
524
+ v += (1 - d) * msd[k].detach()
525
+ # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
526
+
527
+ def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
528
+ # Update EMA attributes
529
+ copy_attr(self.ema, model, include, exclude)