Spaces:
Runtime error
Runtime error
Zengyf-CVer
commited on
Commit
·
aa24cb4
1
Parent(s):
112bf3b
app update
Browse files- data/coco128.yaml +81 -10
- export.py +322 -345
- models/common.py +111 -94
- models/experimental.py +14 -5
- models/tf.py +1 -1
- models/yolo.py +91 -71
- requirements.txt +2 -2
- utils/__init__.py +30 -1
- utils/augmentations.py +114 -2
- utils/autoanchor.py +5 -6
- utils/autobatch.py +5 -2
- utils/benchmarks.py +6 -2
- utils/callbacks.py +9 -4
- utils/dataloaders.py +138 -101
- utils/downloads.py +23 -11
- utils/general.py +101 -100
- utils/metrics.py +39 -36
- utils/plots.py +42 -13
- utils/torch_utils.py +54 -14
- val.py +35 -34
data/coco128.yaml
CHANGED
@@ -14,16 +14,87 @@ val: images/train2017 # val images (relative to 'path') 128 images
|
|
14 |
test: # test images (optional)
|
15 |
|
16 |
# Classes
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
|
29 |
# Download script/URL (optional)
|
|
|
14 |
test: # test images (optional)
|
15 |
|
16 |
# Classes
|
17 |
+
names:
|
18 |
+
0: person
|
19 |
+
1: bicycle
|
20 |
+
2: car
|
21 |
+
3: motorcycle
|
22 |
+
4: airplane
|
23 |
+
5: bus
|
24 |
+
6: train
|
25 |
+
7: truck
|
26 |
+
8: boat
|
27 |
+
9: traffic light
|
28 |
+
10: fire hydrant
|
29 |
+
11: stop sign
|
30 |
+
12: parking meter
|
31 |
+
13: bench
|
32 |
+
14: bird
|
33 |
+
15: cat
|
34 |
+
16: dog
|
35 |
+
17: horse
|
36 |
+
18: sheep
|
37 |
+
19: cow
|
38 |
+
20: elephant
|
39 |
+
21: bear
|
40 |
+
22: zebra
|
41 |
+
23: giraffe
|
42 |
+
24: backpack
|
43 |
+
25: umbrella
|
44 |
+
26: handbag
|
45 |
+
27: tie
|
46 |
+
28: suitcase
|
47 |
+
29: frisbee
|
48 |
+
30: skis
|
49 |
+
31: snowboard
|
50 |
+
32: sports ball
|
51 |
+
33: kite
|
52 |
+
34: baseball bat
|
53 |
+
35: baseball glove
|
54 |
+
36: skateboard
|
55 |
+
37: surfboard
|
56 |
+
38: tennis racket
|
57 |
+
39: bottle
|
58 |
+
40: wine glass
|
59 |
+
41: cup
|
60 |
+
42: fork
|
61 |
+
43: knife
|
62 |
+
44: spoon
|
63 |
+
45: bowl
|
64 |
+
46: banana
|
65 |
+
47: apple
|
66 |
+
48: sandwich
|
67 |
+
49: orange
|
68 |
+
50: broccoli
|
69 |
+
51: carrot
|
70 |
+
52: hot dog
|
71 |
+
53: pizza
|
72 |
+
54: donut
|
73 |
+
55: cake
|
74 |
+
56: chair
|
75 |
+
57: couch
|
76 |
+
58: potted plant
|
77 |
+
59: bed
|
78 |
+
60: dining table
|
79 |
+
61: toilet
|
80 |
+
62: tv
|
81 |
+
63: laptop
|
82 |
+
64: mouse
|
83 |
+
65: remote
|
84 |
+
66: keyboard
|
85 |
+
67: cell phone
|
86 |
+
68: microwave
|
87 |
+
69: oven
|
88 |
+
70: toaster
|
89 |
+
71: sink
|
90 |
+
72: refrigerator
|
91 |
+
73: book
|
92 |
+
74: clock
|
93 |
+
75: vase
|
94 |
+
76: scissors
|
95 |
+
77: teddy bear
|
96 |
+
78: hair drier
|
97 |
+
79: toothbrush
|
98 |
|
99 |
|
100 |
# Download script/URL (optional)
|
export.py
CHANGED
@@ -21,19 +21,19 @@ Requirements:
|
|
21 |
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow # GPU
|
22 |
|
23 |
Usage:
|
24 |
-
$ python
|
25 |
|
26 |
Inference:
|
27 |
-
$ python
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
|
38 |
TensorFlow.js:
|
39 |
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
|
@@ -65,10 +65,10 @@ if platform.system() != 'Windows':
|
|
65 |
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
66 |
|
67 |
from models.experimental import attempt_load
|
68 |
-
from models.yolo import Detect
|
69 |
from utils.dataloaders import LoadImages
|
70 |
-
from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_version,
|
71 |
-
colorstr, file_size, print_args, url2file)
|
72 |
from utils.torch_utils import select_device, smart_inference_mode
|
73 |
|
74 |
|
@@ -89,200 +89,199 @@ def export_formats():
|
|
89 |
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
|
90 |
|
91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
|
93 |
# YOLOv5 TorchScript model export
|
94 |
-
|
95 |
-
|
96 |
-
f = file.with_suffix('.torchscript')
|
97 |
-
|
98 |
-
ts = torch.jit.trace(model, im, strict=False)
|
99 |
-
d = {"shape": im.shape, "stride": int(max(model.stride)), "names": model.names}
|
100 |
-
extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
|
101 |
-
if optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
|
102 |
-
optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
|
103 |
-
else:
|
104 |
-
ts.save(str(f), _extra_files=extra_files)
|
105 |
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
110 |
|
111 |
|
|
|
112 |
def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):
|
113 |
# YOLOv5 ONNX export
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
'
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
'
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
|
166 |
-
return f
|
167 |
-
except Exception as e:
|
168 |
-
LOGGER.info(f'{prefix} export failure: {e}')
|
169 |
|
170 |
|
|
|
171 |
def export_openvino(model, file, half, prefix=colorstr('OpenVINO:')):
|
172 |
# YOLOv5 OpenVINO export
|
173 |
-
|
174 |
-
|
175 |
-
import openvino.inference_engine as ie
|
176 |
-
|
177 |
-
LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...')
|
178 |
-
f = str(file).replace('.pt', f'_openvino_model{os.sep}')
|
179 |
|
180 |
-
|
181 |
-
|
182 |
-
with open(Path(f) / file.with_suffix('.yaml').name, 'w') as g:
|
183 |
-
yaml.dump({'stride': int(max(model.stride)), 'names': model.names}, g) # add metadata.yaml
|
184 |
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
|
|
189 |
|
190 |
|
|
|
191 |
def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')):
|
192 |
# YOLOv5 CoreML export
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
if
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
LOGGER.info(f'\n{prefix} export failure: {e}')
|
216 |
-
return None, None
|
217 |
-
|
218 |
-
|
219 |
-
def export_engine(model, im, file, train, half, dynamic, simplify, workspace=4, verbose=False):
|
220 |
# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
|
221 |
-
|
222 |
try:
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
logger = trt.Logger
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
LOGGER.info(f'{prefix}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
for inp in inputs:
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
|
273 |
-
config.add_optimization_profile(profile)
|
274 |
-
|
275 |
-
LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine in {f}')
|
276 |
-
if builder.platform_has_fast_fp16 and half:
|
277 |
-
config.set_flag(trt.BuilderFlag.FP16)
|
278 |
-
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
|
279 |
-
t.write(engine.serialize())
|
280 |
-
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
|
281 |
-
return f
|
282 |
-
except Exception as e:
|
283 |
-
LOGGER.info(f'\n{prefix} export failure: {e}')
|
284 |
|
285 |
|
|
|
286 |
def export_saved_model(model,
|
287 |
im,
|
288 |
file,
|
@@ -296,163 +295,142 @@ def export_saved_model(model,
|
|
296 |
keras=False,
|
297 |
prefix=colorstr('TensorFlow SavedModel:')):
|
298 |
# YOLOv5 TensorFlow SavedModel export
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
|
332 |
-
return keras_model, f
|
333 |
-
except Exception as e:
|
334 |
-
LOGGER.info(f'\n{prefix} export failure: {e}')
|
335 |
-
return None, None
|
336 |
|
337 |
|
|
|
338 |
def export_pb(keras_model, file, prefix=colorstr('TensorFlow GraphDef:')):
|
339 |
# YOLOv5 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow
|
340 |
-
|
341 |
-
|
342 |
-
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
|
343 |
|
344 |
-
|
345 |
-
|
346 |
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
|
354 |
-
return f
|
355 |
-
except Exception as e:
|
356 |
-
LOGGER.info(f'\n{prefix} export failure: {e}')
|
357 |
|
358 |
|
|
|
359 |
def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
|
360 |
# YOLOv5 TensorFlow Lite export
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
LOGGER.info(f'\n{prefix} export failure: {e}')
|
391 |
-
|
392 |
-
|
393 |
def export_edgetpu(file, prefix=colorstr('Edge TPU:')):
|
394 |
# YOLOv5 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
except Exception as e:
|
419 |
-
LOGGER.info(f'\n{prefix} export failure: {e}')
|
420 |
-
|
421 |
-
|
422 |
def export_tfjs(file, prefix=colorstr('TensorFlow.js:')):
|
423 |
# YOLOv5 TensorFlow.js export
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
j.write(subst)
|
451 |
-
|
452 |
-
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
|
453 |
-
return f
|
454 |
-
except Exception as e:
|
455 |
-
LOGGER.info(f'\n{prefix} export failure: {e}')
|
456 |
|
457 |
|
458 |
@smart_inference_mode()
|
@@ -495,11 +473,9 @@ def run(
|
|
495 |
assert device.type != 'cpu' or coreml, '--half only compatible with GPU export, i.e. use --device 0'
|
496 |
assert not dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both'
|
497 |
model = attempt_load(weights, device=device, inplace=True, fuse=True) # load FP32 model
|
498 |
-
nc, names = model.nc, model.names # number of classes, class names
|
499 |
|
500 |
# Checks
|
501 |
imgsz *= 2 if len(imgsz) == 1 else 1 # expand
|
502 |
-
assert nc == len(names), f'Model class count {nc} != len(names) {len(names)}'
|
503 |
if optimize:
|
504 |
assert device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'
|
505 |
|
@@ -513,36 +489,37 @@ def run(
|
|
513 |
for k, m in model.named_modules():
|
514 |
if isinstance(m, Detect):
|
515 |
m.inplace = inplace
|
516 |
-
m.
|
517 |
m.export = True
|
518 |
|
519 |
for _ in range(2):
|
520 |
y = model(im) # dry runs
|
521 |
if half and not coreml:
|
522 |
im, model = im.half(), model.half() # to FP16
|
523 |
-
shape = tuple(y[0].shape) # model output shape
|
524 |
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")
|
525 |
|
526 |
# Exports
|
527 |
f = [''] * 10 # exported filenames
|
528 |
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
|
529 |
if jit:
|
530 |
-
f[0] = export_torchscript(model, im, file, optimize)
|
531 |
if engine: # TensorRT required before ONNX
|
532 |
-
f[1] = export_engine(model, im, file,
|
533 |
if onnx or xml: # OpenVINO requires ONNX
|
534 |
-
f[2] = export_onnx(model, im, file, opset, train, dynamic, simplify)
|
535 |
if xml: # OpenVINO
|
536 |
-
f[3] = export_openvino(model, file, half)
|
537 |
if coreml:
|
538 |
-
|
539 |
|
540 |
# TensorFlow Exports
|
541 |
if any((saved_model, pb, tflite, edgetpu, tfjs)):
|
542 |
if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707
|
543 |
check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow`
|
544 |
assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.'
|
545 |
-
model,
|
|
|
546 |
im,
|
547 |
file,
|
548 |
dynamic,
|
@@ -554,19 +531,19 @@ def run(
|
|
554 |
conf_thres=conf_thres,
|
555 |
keras=keras)
|
556 |
if pb or tfjs: # pb prerequisite to tfjs
|
557 |
-
f[6] = export_pb(model, file)
|
558 |
if tflite or edgetpu:
|
559 |
-
f[7] = export_tflite(model, im, file, int8
|
560 |
if edgetpu:
|
561 |
-
f[8] = export_edgetpu(file)
|
562 |
if tfjs:
|
563 |
-
f[9] = export_tfjs(file)
|
564 |
|
565 |
# Finish
|
566 |
f = [str(x) for x in f if x] # filter out '' and None
|
567 |
if any(f):
|
568 |
h = '--half' if half else '' # --half FP16 inference arg
|
569 |
-
LOGGER.info(f'\nExport complete ({time.time() - t:.
|
570 |
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
|
571 |
f"\nDetect: python detect.py --weights {f[-1]} {h}"
|
572 |
f"\nValidate: python val.py --weights {f[-1]} {h}"
|
@@ -601,7 +578,7 @@ def parse_opt():
|
|
601 |
parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
|
602 |
parser.add_argument('--include',
|
603 |
nargs='+',
|
604 |
-
default=['torchscript'
|
605 |
help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs')
|
606 |
opt = parser.parse_args()
|
607 |
print_args(vars(opt))
|
|
|
21 |
$ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow # GPU
|
22 |
|
23 |
Usage:
|
24 |
+
$ python export.py --weights yolov5s.pt --include torchscript onnx openvino engine coreml tflite ...
|
25 |
|
26 |
Inference:
|
27 |
+
$ python detect.py --weights yolov5s.pt # PyTorch
|
28 |
+
yolov5s.torchscript # TorchScript
|
29 |
+
yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn
|
30 |
+
yolov5s.xml # OpenVINO
|
31 |
+
yolov5s.engine # TensorRT
|
32 |
+
yolov5s.mlmodel # CoreML (macOS-only)
|
33 |
+
yolov5s_saved_model # TensorFlow SavedModel
|
34 |
+
yolov5s.pb # TensorFlow GraphDef
|
35 |
+
yolov5s.tflite # TensorFlow Lite
|
36 |
+
yolov5s_edgetpu.tflite # TensorFlow Edge TPU
|
37 |
|
38 |
TensorFlow.js:
|
39 |
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
|
|
|
65 |
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
66 |
|
67 |
from models.experimental import attempt_load
|
68 |
+
from models.yolo import ClassificationModel, Detect
|
69 |
from utils.dataloaders import LoadImages
|
70 |
+
from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_version,
|
71 |
+
check_yaml, colorstr, file_size, get_default_args, print_args, url2file)
|
72 |
from utils.torch_utils import select_device, smart_inference_mode
|
73 |
|
74 |
|
|
|
89 |
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
|
90 |
|
91 |
|
92 |
+
def try_export(inner_func):
|
93 |
+
# YOLOv5 export decorator, i..e @try_export
|
94 |
+
inner_args = get_default_args(inner_func)
|
95 |
+
|
96 |
+
def outer_func(*args, **kwargs):
|
97 |
+
prefix = inner_args['prefix']
|
98 |
+
try:
|
99 |
+
with Profile() as dt:
|
100 |
+
f, model = inner_func(*args, **kwargs)
|
101 |
+
LOGGER.info(f'{prefix} export success ✅ {dt.t:.1f}s, saved as {f} ({file_size(f):.1f} MB)')
|
102 |
+
return f, model
|
103 |
+
except Exception as e:
|
104 |
+
LOGGER.info(f'{prefix} export failure ❌ {dt.t:.1f}s: {e}')
|
105 |
+
return None, None
|
106 |
+
|
107 |
+
return outer_func
|
108 |
+
|
109 |
+
|
110 |
+
@try_export
|
111 |
def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
|
112 |
# YOLOv5 TorchScript model export
|
113 |
+
LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
|
114 |
+
f = file.with_suffix('.torchscript')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
+
ts = torch.jit.trace(model, im, strict=False)
|
117 |
+
d = {"shape": im.shape, "stride": int(max(model.stride)), "names": model.names}
|
118 |
+
extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
|
119 |
+
if optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
|
120 |
+
optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
|
121 |
+
else:
|
122 |
+
ts.save(str(f), _extra_files=extra_files)
|
123 |
+
return f, None
|
124 |
|
125 |
|
126 |
+
@try_export
|
127 |
def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):
|
128 |
# YOLOv5 ONNX export
|
129 |
+
check_requirements(('onnx',))
|
130 |
+
import onnx
|
131 |
+
|
132 |
+
LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
|
133 |
+
f = file.with_suffix('.onnx')
|
134 |
+
|
135 |
+
torch.onnx.export(
|
136 |
+
model.cpu() if dynamic else model, # --dynamic only compatible with cpu
|
137 |
+
im.cpu() if dynamic else im,
|
138 |
+
f,
|
139 |
+
verbose=False,
|
140 |
+
opset_version=opset,
|
141 |
+
training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
|
142 |
+
do_constant_folding=not train,
|
143 |
+
input_names=['images'],
|
144 |
+
output_names=['output'],
|
145 |
+
dynamic_axes={
|
146 |
+
'images': {
|
147 |
+
0: 'batch',
|
148 |
+
2: 'height',
|
149 |
+
3: 'width'}, # shape(1,3,640,640)
|
150 |
+
'output': {
|
151 |
+
0: 'batch',
|
152 |
+
1: 'anchors'} # shape(1,25200,85)
|
153 |
+
} if dynamic else None)
|
154 |
+
|
155 |
+
# Checks
|
156 |
+
model_onnx = onnx.load(f) # load onnx model
|
157 |
+
onnx.checker.check_model(model_onnx) # check onnx model
|
158 |
+
|
159 |
+
# Metadata
|
160 |
+
d = {'stride': int(max(model.stride)), 'names': model.names}
|
161 |
+
for k, v in d.items():
|
162 |
+
meta = model_onnx.metadata_props.add()
|
163 |
+
meta.key, meta.value = k, str(v)
|
164 |
+
onnx.save(model_onnx, f)
|
165 |
+
|
166 |
+
# Simplify
|
167 |
+
if simplify:
|
168 |
+
try:
|
169 |
+
cuda = torch.cuda.is_available()
|
170 |
+
check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1'))
|
171 |
+
import onnxsim
|
172 |
+
|
173 |
+
LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
|
174 |
+
model_onnx, check = onnxsim.simplify(model_onnx)
|
175 |
+
assert check, 'assert check failed'
|
176 |
+
onnx.save(model_onnx, f)
|
177 |
+
except Exception as e:
|
178 |
+
LOGGER.info(f'{prefix} simplifier failure: {e}')
|
179 |
+
return f, model_onnx
|
|
|
|
|
|
|
|
|
180 |
|
181 |
|
182 |
+
@try_export
|
183 |
def export_openvino(model, file, half, prefix=colorstr('OpenVINO:')):
|
184 |
# YOLOv5 OpenVINO export
|
185 |
+
check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/
|
186 |
+
import openvino.inference_engine as ie
|
|
|
|
|
|
|
|
|
187 |
|
188 |
+
LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...')
|
189 |
+
f = str(file).replace('.pt', f'_openvino_model{os.sep}')
|
|
|
|
|
190 |
|
191 |
+
cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}"
|
192 |
+
subprocess.check_output(cmd.split()) # export
|
193 |
+
with open(Path(f) / file.with_suffix('.yaml').name, 'w') as g:
|
194 |
+
yaml.dump({'stride': int(max(model.stride)), 'names': model.names}, g) # add metadata.yaml
|
195 |
+
return f, None
|
196 |
|
197 |
|
198 |
+
@try_export
|
199 |
def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')):
|
200 |
# YOLOv5 CoreML export
|
201 |
+
check_requirements(('coremltools',))
|
202 |
+
import coremltools as ct
|
203 |
+
|
204 |
+
LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
|
205 |
+
f = file.with_suffix('.mlmodel')
|
206 |
+
|
207 |
+
ts = torch.jit.trace(model, im, strict=False) # TorchScript model
|
208 |
+
ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255, bias=[0, 0, 0])])
|
209 |
+
bits, mode = (8, 'kmeans_lut') if int8 else (16, 'linear') if half else (32, None)
|
210 |
+
if bits < 32:
|
211 |
+
if platform.system() == 'Darwin': # quantization only supported on macOS
|
212 |
+
with warnings.catch_warnings():
|
213 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress numpy==1.20 float warning
|
214 |
+
ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
|
215 |
+
else:
|
216 |
+
print(f'{prefix} quantization only supported on macOS, skipping...')
|
217 |
+
ct_model.save(f)
|
218 |
+
return f, ct_model
|
219 |
+
|
220 |
+
|
221 |
+
@try_export
|
222 |
+
def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
|
|
|
|
|
|
|
|
|
|
|
223 |
# YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
|
224 |
+
assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
|
225 |
try:
|
226 |
+
import tensorrt as trt
|
227 |
+
except Exception:
|
228 |
+
if platform.system() == 'Linux':
|
229 |
+
check_requirements(('nvidia-tensorrt',), cmds=('-U --index-url https://pypi.ngc.nvidia.com',))
|
230 |
+
import tensorrt as trt
|
231 |
+
|
232 |
+
if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
|
233 |
+
grid = model.model[-1].anchor_grid
|
234 |
+
model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
|
235 |
+
export_onnx(model, im, file, 12, False, dynamic, simplify) # opset 12
|
236 |
+
model.model[-1].anchor_grid = grid
|
237 |
+
else: # TensorRT >= 8
|
238 |
+
check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0
|
239 |
+
export_onnx(model, im, file, 13, False, dynamic, simplify) # opset 13
|
240 |
+
onnx = file.with_suffix('.onnx')
|
241 |
+
|
242 |
+
LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
|
243 |
+
assert onnx.exists(), f'failed to export ONNX file: {onnx}'
|
244 |
+
f = file.with_suffix('.engine') # TensorRT engine file
|
245 |
+
logger = trt.Logger(trt.Logger.INFO)
|
246 |
+
if verbose:
|
247 |
+
logger.min_severity = trt.Logger.Severity.VERBOSE
|
248 |
+
|
249 |
+
builder = trt.Builder(logger)
|
250 |
+
config = builder.create_builder_config()
|
251 |
+
config.max_workspace_size = workspace * 1 << 30
|
252 |
+
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
|
253 |
+
|
254 |
+
flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
255 |
+
network = builder.create_network(flag)
|
256 |
+
parser = trt.OnnxParser(network, logger)
|
257 |
+
if not parser.parse_from_file(str(onnx)):
|
258 |
+
raise RuntimeError(f'failed to load ONNX file: {onnx}')
|
259 |
+
|
260 |
+
inputs = [network.get_input(i) for i in range(network.num_inputs)]
|
261 |
+
outputs = [network.get_output(i) for i in range(network.num_outputs)]
|
262 |
+
LOGGER.info(f'{prefix} Network Description:')
|
263 |
+
for inp in inputs:
|
264 |
+
LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}')
|
265 |
+
for out in outputs:
|
266 |
+
LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')
|
267 |
+
|
268 |
+
if dynamic:
|
269 |
+
if im.shape[0] <= 1:
|
270 |
+
LOGGER.warning(f"{prefix}WARNING: --dynamic model requires maximum --batch-size argument")
|
271 |
+
profile = builder.create_optimization_profile()
|
272 |
for inp in inputs:
|
273 |
+
profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
|
274 |
+
config.add_optimization_profile(profile)
|
275 |
+
|
276 |
+
LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine in {f}')
|
277 |
+
if builder.platform_has_fast_fp16 and half:
|
278 |
+
config.set_flag(trt.BuilderFlag.FP16)
|
279 |
+
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
|
280 |
+
t.write(engine.serialize())
|
281 |
+
return f, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
|
283 |
|
284 |
+
@try_export
|
285 |
def export_saved_model(model,
|
286 |
im,
|
287 |
file,
|
|
|
295 |
keras=False,
|
296 |
prefix=colorstr('TensorFlow SavedModel:')):
|
297 |
# YOLOv5 TensorFlow SavedModel export
|
298 |
+
import tensorflow as tf
|
299 |
+
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
|
300 |
+
|
301 |
+
from models.tf import TFModel
|
302 |
+
|
303 |
+
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
304 |
+
f = str(file).replace('.pt', '_saved_model')
|
305 |
+
batch_size, ch, *imgsz = list(im.shape) # BCHW
|
306 |
+
|
307 |
+
tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
|
308 |
+
im = tf.zeros((batch_size, *imgsz, ch)) # BHWC order for TensorFlow
|
309 |
+
_ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
|
310 |
+
inputs = tf.keras.Input(shape=(*imgsz, ch), batch_size=None if dynamic else batch_size)
|
311 |
+
outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
|
312 |
+
keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
313 |
+
keras_model.trainable = False
|
314 |
+
keras_model.summary()
|
315 |
+
if keras:
|
316 |
+
keras_model.save(f, save_format='tf')
|
317 |
+
else:
|
318 |
+
spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)
|
319 |
+
m = tf.function(lambda x: keras_model(x)) # full model
|
320 |
+
m = m.get_concrete_function(spec)
|
321 |
+
frozen_func = convert_variables_to_constants_v2(m)
|
322 |
+
tfm = tf.Module()
|
323 |
+
tfm.__call__ = tf.function(lambda x: frozen_func(x)[:4] if tf_nms else frozen_func(x)[0], [spec])
|
324 |
+
tfm.__call__(im)
|
325 |
+
tf.saved_model.save(tfm,
|
326 |
+
f,
|
327 |
+
options=tf.saved_model.SaveOptions(experimental_custom_gradients=False) if check_version(
|
328 |
+
tf.__version__, '2.6') else tf.saved_model.SaveOptions())
|
329 |
+
return f, keras_model
|
|
|
|
|
|
|
|
|
|
|
330 |
|
331 |
|
332 |
+
@try_export
|
333 |
def export_pb(keras_model, file, prefix=colorstr('TensorFlow GraphDef:')):
|
334 |
# YOLOv5 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow
|
335 |
+
import tensorflow as tf
|
336 |
+
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
|
|
|
337 |
|
338 |
+
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
339 |
+
f = file.with_suffix('.pb')
|
340 |
|
341 |
+
m = tf.function(lambda x: keras_model(x)) # full model
|
342 |
+
m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
|
343 |
+
frozen_func = convert_variables_to_constants_v2(m)
|
344 |
+
frozen_func.graph.as_graph_def()
|
345 |
+
tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
|
346 |
+
return f, None
|
|
|
|
|
|
|
|
|
347 |
|
348 |
|
349 |
+
@try_export
|
350 |
def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
|
351 |
# YOLOv5 TensorFlow Lite export
|
352 |
+
import tensorflow as tf
|
353 |
+
|
354 |
+
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
355 |
+
batch_size, ch, *imgsz = list(im.shape) # BCHW
|
356 |
+
f = str(file).replace('.pt', '-fp16.tflite')
|
357 |
+
|
358 |
+
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
|
359 |
+
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
|
360 |
+
converter.target_spec.supported_types = [tf.float16]
|
361 |
+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
362 |
+
if int8:
|
363 |
+
from models.tf import representative_dataset_gen
|
364 |
+
dataset = LoadImages(check_dataset(check_yaml(data))['train'], img_size=imgsz, auto=False)
|
365 |
+
converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib=100)
|
366 |
+
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
|
367 |
+
converter.target_spec.supported_types = []
|
368 |
+
converter.inference_input_type = tf.uint8 # or tf.int8
|
369 |
+
converter.inference_output_type = tf.uint8 # or tf.int8
|
370 |
+
converter.experimental_new_quantizer = True
|
371 |
+
f = str(file).replace('.pt', '-int8.tflite')
|
372 |
+
if nms or agnostic_nms:
|
373 |
+
converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
|
374 |
+
|
375 |
+
tflite_model = converter.convert()
|
376 |
+
open(f, "wb").write(tflite_model)
|
377 |
+
return f, None
|
378 |
+
|
379 |
+
|
380 |
+
@try_export
|
|
|
|
|
|
|
381 |
def export_edgetpu(file, prefix=colorstr('Edge TPU:')):
|
382 |
# YOLOv5 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/
|
383 |
+
cmd = 'edgetpu_compiler --version'
|
384 |
+
help_url = 'https://coral.ai/docs/edgetpu/compiler/'
|
385 |
+
assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}'
|
386 |
+
if subprocess.run(f'{cmd} >/dev/null', shell=True).returncode != 0:
|
387 |
+
LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
|
388 |
+
sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
|
389 |
+
for c in (
|
390 |
+
'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -',
|
391 |
+
'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
|
392 |
+
'sudo apt-get update', 'sudo apt-get install edgetpu-compiler'):
|
393 |
+
subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True)
|
394 |
+
ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
|
395 |
+
|
396 |
+
LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
|
397 |
+
f = str(file).replace('.pt', '-int8_edgetpu.tflite') # Edge TPU model
|
398 |
+
f_tfl = str(file).replace('.pt', '-int8.tflite') # TFLite model
|
399 |
+
|
400 |
+
cmd = f"edgetpu_compiler -s -d -k 10 --out_dir {file.parent} {f_tfl}"
|
401 |
+
subprocess.run(cmd.split(), check=True)
|
402 |
+
return f, None
|
403 |
+
|
404 |
+
|
405 |
+
@try_export
|
|
|
|
|
|
|
|
|
406 |
def export_tfjs(file, prefix=colorstr('TensorFlow.js:')):
|
407 |
# YOLOv5 TensorFlow.js export
|
408 |
+
check_requirements(('tensorflowjs',))
|
409 |
+
import re
|
410 |
+
|
411 |
+
import tensorflowjs as tfjs
|
412 |
+
|
413 |
+
LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
|
414 |
+
f = str(file).replace('.pt', '_web_model') # js dir
|
415 |
+
f_pb = file.with_suffix('.pb') # *.pb path
|
416 |
+
f_json = f'{f}/model.json' # *.json path
|
417 |
+
|
418 |
+
cmd = f'tensorflowjs_converter --input_format=tf_frozen_model ' \
|
419 |
+
f'--output_node_names=Identity,Identity_1,Identity_2,Identity_3 {f_pb} {f}'
|
420 |
+
subprocess.run(cmd.split())
|
421 |
+
|
422 |
+
json = Path(f_json).read_text()
|
423 |
+
with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
|
424 |
+
subst = re.sub(
|
425 |
+
r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
|
426 |
+
r'"Identity.?.?": {"name": "Identity.?.?"}, '
|
427 |
+
r'"Identity.?.?": {"name": "Identity.?.?"}, '
|
428 |
+
r'"Identity.?.?": {"name": "Identity.?.?"}}}', r'{"outputs": {"Identity": {"name": "Identity"}, '
|
429 |
+
r'"Identity_1": {"name": "Identity_1"}, '
|
430 |
+
r'"Identity_2": {"name": "Identity_2"}, '
|
431 |
+
r'"Identity_3": {"name": "Identity_3"}}}', json)
|
432 |
+
j.write(subst)
|
433 |
+
return f, None
|
|
|
|
|
|
|
|
|
|
|
|
|
434 |
|
435 |
|
436 |
@smart_inference_mode()
|
|
|
473 |
assert device.type != 'cpu' or coreml, '--half only compatible with GPU export, i.e. use --device 0'
|
474 |
assert not dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both'
|
475 |
model = attempt_load(weights, device=device, inplace=True, fuse=True) # load FP32 model
|
|
|
476 |
|
477 |
# Checks
|
478 |
imgsz *= 2 if len(imgsz) == 1 else 1 # expand
|
|
|
479 |
if optimize:
|
480 |
assert device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'
|
481 |
|
|
|
489 |
for k, m in model.named_modules():
|
490 |
if isinstance(m, Detect):
|
491 |
m.inplace = inplace
|
492 |
+
m.dynamic = dynamic
|
493 |
m.export = True
|
494 |
|
495 |
for _ in range(2):
|
496 |
y = model(im) # dry runs
|
497 |
if half and not coreml:
|
498 |
im, model = im.half(), model.half() # to FP16
|
499 |
+
shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape
|
500 |
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")
|
501 |
|
502 |
# Exports
|
503 |
f = [''] * 10 # exported filenames
|
504 |
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
|
505 |
if jit:
|
506 |
+
f[0], _ = export_torchscript(model, im, file, optimize)
|
507 |
if engine: # TensorRT required before ONNX
|
508 |
+
f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose)
|
509 |
if onnx or xml: # OpenVINO requires ONNX
|
510 |
+
f[2], _ = export_onnx(model, im, file, opset, train, dynamic, simplify)
|
511 |
if xml: # OpenVINO
|
512 |
+
f[3], _ = export_openvino(model, file, half)
|
513 |
if coreml:
|
514 |
+
f[4], _ = export_coreml(model, im, file, int8, half)
|
515 |
|
516 |
# TensorFlow Exports
|
517 |
if any((saved_model, pb, tflite, edgetpu, tfjs)):
|
518 |
if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707
|
519 |
check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow`
|
520 |
assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.'
|
521 |
+
assert not isinstance(model, ClassificationModel), 'ClassificationModel export to TF formats not yet supported.'
|
522 |
+
f[5], model = export_saved_model(model.cpu(),
|
523 |
im,
|
524 |
file,
|
525 |
dynamic,
|
|
|
531 |
conf_thres=conf_thres,
|
532 |
keras=keras)
|
533 |
if pb or tfjs: # pb prerequisite to tfjs
|
534 |
+
f[6], _ = export_pb(model, file)
|
535 |
if tflite or edgetpu:
|
536 |
+
f[7], _ = export_tflite(model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms)
|
537 |
if edgetpu:
|
538 |
+
f[8], _ = export_edgetpu(file)
|
539 |
if tfjs:
|
540 |
+
f[9], _ = export_tfjs(file)
|
541 |
|
542 |
# Finish
|
543 |
f = [str(x) for x in f if x] # filter out '' and None
|
544 |
if any(f):
|
545 |
h = '--half' if half else '' # --half FP16 inference arg
|
546 |
+
LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
|
547 |
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
|
548 |
f"\nDetect: python detect.py --weights {f[-1]} {h}"
|
549 |
f"\nValidate: python val.py --weights {f[-1]} {h}"
|
|
|
578 |
parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
|
579 |
parser.add_argument('--include',
|
580 |
nargs='+',
|
581 |
+
default=['torchscript'],
|
582 |
help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs')
|
583 |
opt = parser.parse_args()
|
584 |
print_args(vars(opt))
|
models/common.py
CHANGED
@@ -17,15 +17,15 @@ import pandas as pd
|
|
17 |
import requests
|
18 |
import torch
|
19 |
import torch.nn as nn
|
20 |
-
import yaml
|
21 |
from PIL import Image
|
22 |
from torch.cuda import amp
|
23 |
|
24 |
from utils.dataloaders import exif_transpose, letterbox
|
25 |
-
from utils.general import (LOGGER, check_requirements, check_suffix, check_version, colorstr,
|
26 |
-
make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh
|
|
|
27 |
from utils.plots import Annotator, colors, save_one_box
|
28 |
-
from utils.torch_utils import copy_attr, smart_inference_mode
|
29 |
|
30 |
|
31 |
def autopad(k, p=None): # kernel, padding
|
@@ -322,13 +322,10 @@ class DetectMultiBackend(nn.Module):
|
|
322 |
|
323 |
super().__init__()
|
324 |
w = str(weights[0] if isinstance(weights, list) else weights)
|
325 |
-
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = self.
|
326 |
w = attempt_download(w) # download if not local
|
327 |
-
fp16 &=
|
328 |
-
stride
|
329 |
-
if data: # assign class names (optional)
|
330 |
-
with open(data, errors='ignore') as f:
|
331 |
-
names = yaml.safe_load(f)['names']
|
332 |
|
333 |
if pt: # PyTorch
|
334 |
model = attempt_load(weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse)
|
@@ -341,8 +338,10 @@ class DetectMultiBackend(nn.Module):
|
|
341 |
extra_files = {'config.txt': ''} # model metadata
|
342 |
model = torch.jit.load(w, _extra_files=extra_files)
|
343 |
model.half() if fp16 else model.float()
|
344 |
-
if extra_files['config.txt']:
|
345 |
-
d = json.loads(extra_files['config.txt']
|
|
|
|
|
346 |
stride, names = int(d['stride']), d['names']
|
347 |
elif dnn: # ONNX OpenCV DNN
|
348 |
LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
|
@@ -350,11 +349,12 @@ class DetectMultiBackend(nn.Module):
|
|
350 |
net = cv2.dnn.readNetFromONNX(w)
|
351 |
elif onnx: # ONNX Runtime
|
352 |
LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
|
353 |
-
cuda = torch.cuda.is_available()
|
354 |
check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
|
355 |
import onnxruntime
|
356 |
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
|
357 |
session = onnxruntime.InferenceSession(w, providers=providers)
|
|
|
358 |
meta = session.get_modelmeta().custom_metadata_map # metadata
|
359 |
if 'stride' in meta:
|
360 |
stride, names = int(meta['stride']), eval(meta['names'])
|
@@ -373,13 +373,13 @@ class DetectMultiBackend(nn.Module):
|
|
373 |
batch_size = batch_dim.get_length()
|
374 |
executable_network = ie.compile_model(network, device_name="CPU") # device_name="MYRIAD" for Intel NCS2
|
375 |
output_layer = next(iter(executable_network.outputs))
|
376 |
-
|
377 |
-
if meta.exists():
|
378 |
-
stride, names = self._load_metadata(meta) # load metadata
|
379 |
elif engine: # TensorRT
|
380 |
LOGGER.info(f'Loading {w} for TensorRT inference...')
|
381 |
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
|
382 |
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
|
|
|
|
|
383 |
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
|
384 |
logger = trt.Logger(trt.Logger.INFO)
|
385 |
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
|
@@ -398,8 +398,8 @@ class DetectMultiBackend(nn.Module):
|
|
398 |
if dtype == np.float16:
|
399 |
fp16 = True
|
400 |
shape = tuple(context.get_binding_shape(index))
|
401 |
-
|
402 |
-
bindings[name] = Binding(name, dtype, shape,
|
403 |
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
|
404 |
batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size
|
405 |
elif coreml: # CoreML
|
@@ -445,28 +445,35 @@ class DetectMultiBackend(nn.Module):
|
|
445 |
input_details = interpreter.get_input_details() # inputs
|
446 |
output_details = interpreter.get_output_details() # outputs
|
447 |
elif tfjs:
|
448 |
-
raise
|
449 |
else:
|
450 |
-
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
451 |
self.__dict__.update(locals()) # assign all variables to self
|
452 |
|
453 |
-
def forward(self, im, augment=False, visualize=False
|
454 |
# YOLOv5 MultiBackend inference
|
455 |
b, ch, h, w = im.shape # batch, channel, height, width
|
456 |
if self.fp16 and im.dtype != torch.float16:
|
457 |
im = im.half() # to FP16
|
458 |
|
459 |
if self.pt: # PyTorch
|
460 |
-
y = self.model(im, augment=augment, visualize=visualize)
|
461 |
elif self.jit: # TorchScript
|
462 |
-
y = self.model(im)
|
463 |
elif self.dnn: # ONNX OpenCV DNN
|
464 |
im = im.cpu().numpy() # torch to numpy
|
465 |
self.net.setInput(im)
|
466 |
y = self.net.forward()
|
467 |
elif self.onnx: # ONNX Runtime
|
468 |
im = im.cpu().numpy() # torch to numpy
|
469 |
-
y = self.session.run(
|
470 |
elif self.xml: # OpenVINO
|
471 |
im = im.cpu().numpy() # FP32
|
472 |
y = self.executable_network([im])[self.output_layer]
|
@@ -513,20 +520,24 @@ class DetectMultiBackend(nn.Module):
|
|
513 |
y = (y.astype(np.float32) - zero_point) * scale # re-scale
|
514 |
y[..., :4] *= [w, h, w, h] # xywh normalized to pixels
|
515 |
|
516 |
-
if isinstance(y,
|
517 |
-
y
|
518 |
-
|
|
|
|
|
|
|
|
|
519 |
|
520 |
def warmup(self, imgsz=(1, 3, 640, 640)):
|
521 |
# Warmup model by running inference once
|
522 |
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb
|
523 |
if any(warmup_types) and self.device.type != 'cpu':
|
524 |
-
im = torch.
|
525 |
for _ in range(2 if self.jit else 1): #
|
526 |
self.forward(im) # warmup
|
527 |
|
528 |
@staticmethod
|
529 |
-
def
|
530 |
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
|
531 |
from export import export_formats
|
532 |
suffixes = list(export_formats().Suffix) + ['.xml'] # export suffixes
|
@@ -538,11 +549,12 @@ class DetectMultiBackend(nn.Module):
|
|
538 |
return pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs
|
539 |
|
540 |
@staticmethod
|
541 |
-
def _load_metadata(f='path/to/meta.yaml'):
|
542 |
# Load metadata from meta.yaml if it exists
|
543 |
-
|
544 |
-
d =
|
545 |
-
|
|
|
546 |
|
547 |
|
548 |
class AutoShape(nn.Module):
|
@@ -579,9 +591,9 @@ class AutoShape(nn.Module):
|
|
579 |
return self
|
580 |
|
581 |
@smart_inference_mode()
|
582 |
-
def forward(self,
|
583 |
-
# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
|
584 |
-
# file:
|
585 |
# URI: = 'https://ultralytics.com/images/zidane.jpg'
|
586 |
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
|
587 |
# PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
|
@@ -589,65 +601,67 @@ class AutoShape(nn.Module):
|
|
589 |
# torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
|
590 |
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
|
591 |
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
|
|
|
|
623 |
|
624 |
with amp.autocast(autocast):
|
625 |
# Inference
|
626 |
-
|
627 |
-
|
628 |
|
629 |
# Post-process
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
|
|
|
639 |
|
640 |
-
|
641 |
-
return Detections(imgs, y, files, t, self.names, x.shape)
|
642 |
|
643 |
|
644 |
class Detections:
|
645 |
# YOLOv5 detections class for inference results
|
646 |
-
def __init__(self,
|
647 |
super().__init__()
|
648 |
d = pred[0].device # device
|
649 |
-
gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in
|
650 |
-
self.
|
651 |
self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
|
652 |
self.names = names # class names
|
653 |
self.files = files # image filenames
|
@@ -657,12 +671,12 @@ class Detections:
|
|
657 |
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
|
658 |
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
|
659 |
self.n = len(self.pred) # number of images (batch size)
|
660 |
-
self.t = tuple(
|
661 |
self.s = shape # inference BCHW shape
|
662 |
|
663 |
def display(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
|
664 |
crops = []
|
665 |
-
for i, (im, pred) in enumerate(zip(self.
|
666 |
s = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
|
667 |
if pred.shape[0]:
|
668 |
for c in pred[:, -1].unique():
|
@@ -697,7 +711,7 @@ class Detections:
|
|
697 |
if i == self.n - 1:
|
698 |
LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
|
699 |
if render:
|
700 |
-
self.
|
701 |
if crop:
|
702 |
if save:
|
703 |
LOGGER.info(f'Saved results to {save_dir}\n')
|
@@ -720,7 +734,7 @@ class Detections:
|
|
720 |
|
721 |
def render(self, labels=True):
|
722 |
self.display(render=True, labels=labels) # render results
|
723 |
-
return self.
|
724 |
|
725 |
def pandas(self):
|
726 |
# return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
|
@@ -735,9 +749,9 @@ class Detections:
|
|
735 |
def tolist(self):
|
736 |
# return a list of Detections objects, i.e. 'for result in results.tolist():'
|
737 |
r = range(self.n) # iterable
|
738 |
-
x = [Detections([self.
|
739 |
# for d in x:
|
740 |
-
# for k in ['
|
741 |
# setattr(d, k, getattr(d, k)[0]) # pop out of list
|
742 |
return x
|
743 |
|
@@ -753,10 +767,13 @@ class Classify(nn.Module):
|
|
753 |
# Classification head, i.e. x(b,c1,20,20) to x(b,c2)
|
754 |
def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
|
755 |
super().__init__()
|
756 |
-
|
757 |
-
self.conv =
|
758 |
-
self.
|
|
|
|
|
759 |
|
760 |
def forward(self, x):
|
761 |
-
|
762 |
-
|
|
|
|
17 |
import requests
|
18 |
import torch
|
19 |
import torch.nn as nn
|
|
|
20 |
from PIL import Image
|
21 |
from torch.cuda import amp
|
22 |
|
23 |
from utils.dataloaders import exif_transpose, letterbox
|
24 |
+
from utils.general import (LOGGER, ROOT, Profile, check_requirements, check_suffix, check_version, colorstr,
|
25 |
+
increment_path, make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh,
|
26 |
+
yaml_load)
|
27 |
from utils.plots import Annotator, colors, save_one_box
|
28 |
+
from utils.torch_utils import copy_attr, smart_inference_mode
|
29 |
|
30 |
|
31 |
def autopad(k, p=None): # kernel, padding
|
|
|
322 |
|
323 |
super().__init__()
|
324 |
w = str(weights[0] if isinstance(weights, list) else weights)
|
325 |
+
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = self._model_type(w) # get backend
|
326 |
w = attempt_download(w) # download if not local
|
327 |
+
fp16 &= pt or jit or onnx or engine # FP16
|
328 |
+
stride = 32 # default stride
|
|
|
|
|
|
|
329 |
|
330 |
if pt: # PyTorch
|
331 |
model = attempt_load(weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse)
|
|
|
338 |
extra_files = {'config.txt': ''} # model metadata
|
339 |
model = torch.jit.load(w, _extra_files=extra_files)
|
340 |
model.half() if fp16 else model.float()
|
341 |
+
if extra_files['config.txt']: # load metadata dict
|
342 |
+
d = json.loads(extra_files['config.txt'],
|
343 |
+
object_hook=lambda d: {int(k) if k.isdigit() else k: v
|
344 |
+
for k, v in d.items()})
|
345 |
stride, names = int(d['stride']), d['names']
|
346 |
elif dnn: # ONNX OpenCV DNN
|
347 |
LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
|
|
|
349 |
net = cv2.dnn.readNetFromONNX(w)
|
350 |
elif onnx: # ONNX Runtime
|
351 |
LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
|
352 |
+
cuda = torch.cuda.is_available() and device.type != 'cpu'
|
353 |
check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
|
354 |
import onnxruntime
|
355 |
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
|
356 |
session = onnxruntime.InferenceSession(w, providers=providers)
|
357 |
+
output_names = [x.name for x in session.get_outputs()]
|
358 |
meta = session.get_modelmeta().custom_metadata_map # metadata
|
359 |
if 'stride' in meta:
|
360 |
stride, names = int(meta['stride']), eval(meta['names'])
|
|
|
373 |
batch_size = batch_dim.get_length()
|
374 |
executable_network = ie.compile_model(network, device_name="CPU") # device_name="MYRIAD" for Intel NCS2
|
375 |
output_layer = next(iter(executable_network.outputs))
|
376 |
+
stride, names = self._load_metadata(Path(w).with_suffix('.yaml')) # load metadata
|
|
|
|
|
377 |
elif engine: # TensorRT
|
378 |
LOGGER.info(f'Loading {w} for TensorRT inference...')
|
379 |
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
|
380 |
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
|
381 |
+
if device.type == 'cpu':
|
382 |
+
device = torch.device('cuda:0')
|
383 |
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
|
384 |
logger = trt.Logger(trt.Logger.INFO)
|
385 |
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
|
|
|
398 |
if dtype == np.float16:
|
399 |
fp16 = True
|
400 |
shape = tuple(context.get_binding_shape(index))
|
401 |
+
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
|
402 |
+
bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
|
403 |
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
|
404 |
batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size
|
405 |
elif coreml: # CoreML
|
|
|
445 |
input_details = interpreter.get_input_details() # inputs
|
446 |
output_details = interpreter.get_output_details() # outputs
|
447 |
elif tfjs:
|
448 |
+
raise NotImplementedError('ERROR: YOLOv5 TF.js inference is not supported')
|
449 |
else:
|
450 |
+
raise NotImplementedError(f'ERROR: {w} is not a supported format')
|
451 |
+
|
452 |
+
# class names
|
453 |
+
if 'names' not in locals():
|
454 |
+
names = yaml_load(data)['names'] if data else {i: f'class{i}' for i in range(999)}
|
455 |
+
if names[0] == 'n01440764' and len(names) == 1000: # ImageNet
|
456 |
+
names = yaml_load(ROOT / 'data/ImageNet.yaml')['names'] # human-readable names
|
457 |
+
|
458 |
self.__dict__.update(locals()) # assign all variables to self
|
459 |
|
460 |
+
def forward(self, im, augment=False, visualize=False):
|
461 |
# YOLOv5 MultiBackend inference
|
462 |
b, ch, h, w = im.shape # batch, channel, height, width
|
463 |
if self.fp16 and im.dtype != torch.float16:
|
464 |
im = im.half() # to FP16
|
465 |
|
466 |
if self.pt: # PyTorch
|
467 |
+
y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
|
468 |
elif self.jit: # TorchScript
|
469 |
+
y = self.model(im)
|
470 |
elif self.dnn: # ONNX OpenCV DNN
|
471 |
im = im.cpu().numpy() # torch to numpy
|
472 |
self.net.setInput(im)
|
473 |
y = self.net.forward()
|
474 |
elif self.onnx: # ONNX Runtime
|
475 |
im = im.cpu().numpy() # torch to numpy
|
476 |
+
y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
|
477 |
elif self.xml: # OpenVINO
|
478 |
im = im.cpu().numpy() # FP32
|
479 |
y = self.executable_network([im])[self.output_layer]
|
|
|
520 |
y = (y.astype(np.float32) - zero_point) * scale # re-scale
|
521 |
y[..., :4] *= [w, h, w, h] # xywh normalized to pixels
|
522 |
|
523 |
+
if isinstance(y, (list, tuple)):
|
524 |
+
return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
|
525 |
+
else:
|
526 |
+
return self.from_numpy(y)
|
527 |
+
|
528 |
+
def from_numpy(self, x):
|
529 |
+
return torch.from_numpy(x).to(self.device) if isinstance(x, np.ndarray) else x
|
530 |
|
531 |
def warmup(self, imgsz=(1, 3, 640, 640)):
|
532 |
# Warmup model by running inference once
|
533 |
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb
|
534 |
if any(warmup_types) and self.device.type != 'cpu':
|
535 |
+
im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
|
536 |
for _ in range(2 if self.jit else 1): #
|
537 |
self.forward(im) # warmup
|
538 |
|
539 |
@staticmethod
|
540 |
+
def _model_type(p='path/to/model.pt'):
|
541 |
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
|
542 |
from export import export_formats
|
543 |
suffixes = list(export_formats().Suffix) + ['.xml'] # export suffixes
|
|
|
549 |
return pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs
|
550 |
|
551 |
@staticmethod
|
552 |
+
def _load_metadata(f=Path('path/to/meta.yaml')):
|
553 |
# Load metadata from meta.yaml if it exists
|
554 |
+
if f.exists():
|
555 |
+
d = yaml_load(f)
|
556 |
+
return d['stride'], d['names'] # assign stride, names
|
557 |
+
return None, None
|
558 |
|
559 |
|
560 |
class AutoShape(nn.Module):
|
|
|
591 |
return self
|
592 |
|
593 |
@smart_inference_mode()
|
594 |
+
def forward(self, ims, size=640, augment=False, profile=False):
|
595 |
+
# Inference from various sources. For size(height=640, width=1280), RGB images example inputs are:
|
596 |
+
# file: ims = 'data/images/zidane.jpg' # str or PosixPath
|
597 |
# URI: = 'https://ultralytics.com/images/zidane.jpg'
|
598 |
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
|
599 |
# PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
|
|
|
601 |
# torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
|
602 |
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
|
603 |
|
604 |
+
dt = (Profile(), Profile(), Profile())
|
605 |
+
with dt[0]:
|
606 |
+
if isinstance(size, int): # expand
|
607 |
+
size = (size, size)
|
608 |
+
p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param
|
609 |
+
autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
|
610 |
+
if isinstance(ims, torch.Tensor): # torch
|
611 |
+
with amp.autocast(autocast):
|
612 |
+
return self.model(ims.to(p.device).type_as(p), augment, profile) # inference
|
613 |
+
|
614 |
+
# Pre-process
|
615 |
+
n, ims = (len(ims), list(ims)) if isinstance(ims, (list, tuple)) else (1, [ims]) # number, list of images
|
616 |
+
shape0, shape1, files = [], [], [] # image and inference shapes, filenames
|
617 |
+
for i, im in enumerate(ims):
|
618 |
+
f = f'image{i}' # filename
|
619 |
+
if isinstance(im, (str, Path)): # filename or uri
|
620 |
+
im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
|
621 |
+
im = np.asarray(exif_transpose(im))
|
622 |
+
elif isinstance(im, Image.Image): # PIL Image
|
623 |
+
im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
|
624 |
+
files.append(Path(f).with_suffix('.jpg').name)
|
625 |
+
if im.shape[0] < 5: # image in CHW
|
626 |
+
im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
|
627 |
+
im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input
|
628 |
+
s = im.shape[:2] # HWC
|
629 |
+
shape0.append(s) # image shape
|
630 |
+
g = max(size) / max(s) # gain
|
631 |
+
shape1.append([y * g for y in s])
|
632 |
+
ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
|
633 |
+
shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)] if self.pt else size # inf shape
|
634 |
+
x = [letterbox(im, shape1, auto=False)[0] for im in ims] # pad
|
635 |
+
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
|
636 |
+
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
|
637 |
|
638 |
with amp.autocast(autocast):
|
639 |
# Inference
|
640 |
+
with dt[1]:
|
641 |
+
y = self.model(x, augment, profile) # forward
|
642 |
|
643 |
# Post-process
|
644 |
+
with dt[2]:
|
645 |
+
y = non_max_suppression(y if self.dmb else y[0],
|
646 |
+
self.conf,
|
647 |
+
self.iou,
|
648 |
+
self.classes,
|
649 |
+
self.agnostic,
|
650 |
+
self.multi_label,
|
651 |
+
max_det=self.max_det) # NMS
|
652 |
+
for i in range(n):
|
653 |
+
scale_coords(shape1, y[i][:, :4], shape0[i])
|
654 |
|
655 |
+
return Detections(ims, y, files, dt, self.names, x.shape)
|
|
|
656 |
|
657 |
|
658 |
class Detections:
|
659 |
# YOLOv5 detections class for inference results
|
660 |
+
def __init__(self, ims, pred, files, times=(0, 0, 0), names=None, shape=None):
|
661 |
super().__init__()
|
662 |
d = pred[0].device # device
|
663 |
+
gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in ims] # normalizations
|
664 |
+
self.ims = ims # list of images as numpy arrays
|
665 |
self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
|
666 |
self.names = names # class names
|
667 |
self.files = files # image filenames
|
|
|
671 |
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
|
672 |
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
|
673 |
self.n = len(self.pred) # number of images (batch size)
|
674 |
+
self.t = tuple(x.t / self.n * 1E3 for x in times) # timestamps (ms)
|
675 |
self.s = shape # inference BCHW shape
|
676 |
|
677 |
def display(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
|
678 |
crops = []
|
679 |
+
for i, (im, pred) in enumerate(zip(self.ims, self.pred)):
|
680 |
s = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
|
681 |
if pred.shape[0]:
|
682 |
for c in pred[:, -1].unique():
|
|
|
711 |
if i == self.n - 1:
|
712 |
LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
|
713 |
if render:
|
714 |
+
self.ims[i] = np.asarray(im)
|
715 |
if crop:
|
716 |
if save:
|
717 |
LOGGER.info(f'Saved results to {save_dir}\n')
|
|
|
734 |
|
735 |
def render(self, labels=True):
|
736 |
self.display(render=True, labels=labels) # render results
|
737 |
+
return self.ims
|
738 |
|
739 |
def pandas(self):
|
740 |
# return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
|
|
|
749 |
def tolist(self):
|
750 |
# return a list of Detections objects, i.e. 'for result in results.tolist():'
|
751 |
r = range(self.n) # iterable
|
752 |
+
x = [Detections([self.ims[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
|
753 |
# for d in x:
|
754 |
+
# for k in ['ims', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
|
755 |
# setattr(d, k, getattr(d, k)[0]) # pop out of list
|
756 |
return x
|
757 |
|
|
|
767 |
# Classification head, i.e. x(b,c1,20,20) to x(b,c2)
|
768 |
def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
|
769 |
super().__init__()
|
770 |
+
c_ = 1280 # efficientnet_b0 size
|
771 |
+
self.conv = Conv(c1, c_, k, s, autopad(k, p), g)
|
772 |
+
self.pool = nn.AdaptiveAvgPool2d(1) # to x(b,c_,1,1)
|
773 |
+
self.drop = nn.Dropout(p=0.0, inplace=True)
|
774 |
+
self.linear = nn.Linear(c_, c2) # to x(b,c2)
|
775 |
|
776 |
def forward(self, x):
|
777 |
+
if isinstance(x, list):
|
778 |
+
x = torch.cat(x, 1)
|
779 |
+
return self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
|
models/experimental.py
CHANGED
@@ -8,7 +8,6 @@ import numpy as np
|
|
8 |
import torch
|
9 |
import torch.nn as nn
|
10 |
|
11 |
-
from models.common import Conv
|
12 |
from utils.downloads import attempt_download
|
13 |
|
14 |
|
@@ -79,9 +78,16 @@ def attempt_load(weights, device=None, inplace=True, fuse=True):
|
|
79 |
for w in weights if isinstance(weights, list) else [weights]:
|
80 |
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
|
81 |
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
82 |
-
model.append(ckpt.fuse().eval() if fuse else ckpt.eval()) # fused or un-fused model in eval mode
|
83 |
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
for m in model.modules():
|
86 |
t = type(m)
|
87 |
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model):
|
@@ -92,11 +98,14 @@ def attempt_load(weights, device=None, inplace=True, fuse=True):
|
|
92 |
elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
|
93 |
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
94 |
|
|
|
95 |
if len(model) == 1:
|
96 |
-
return model[-1]
|
|
|
|
|
97 |
print(f'Ensemble created with {weights}\n')
|
98 |
for k in 'names', 'nc', 'yaml':
|
99 |
setattr(model, k, getattr(model[0], k))
|
100 |
model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
|
101 |
assert all(model[0].nc == m.nc for m in model), f'Models have different class counts: {[m.nc for m in model]}'
|
102 |
-
return model
|
|
|
8 |
import torch
|
9 |
import torch.nn as nn
|
10 |
|
|
|
11 |
from utils.downloads import attempt_download
|
12 |
|
13 |
|
|
|
78 |
for w in weights if isinstance(weights, list) else [weights]:
|
79 |
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
|
80 |
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
|
|
81 |
|
82 |
+
# Model compatibility updates
|
83 |
+
if not hasattr(ckpt, 'stride'):
|
84 |
+
ckpt.stride = torch.tensor([32.])
|
85 |
+
if hasattr(ckpt, 'names') and isinstance(ckpt.names, (list, tuple)):
|
86 |
+
ckpt.names = dict(enumerate(ckpt.names)) # convert to dict
|
87 |
+
|
88 |
+
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
|
89 |
+
|
90 |
+
# Module compatibility updates
|
91 |
for m in model.modules():
|
92 |
t = type(m)
|
93 |
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model):
|
|
|
98 |
elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
|
99 |
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
100 |
|
101 |
+
# Return model
|
102 |
if len(model) == 1:
|
103 |
+
return model[-1]
|
104 |
+
|
105 |
+
# Return detection ensemble
|
106 |
print(f'Ensemble created with {weights}\n')
|
107 |
for k in 'names', 'nc', 'yaml':
|
108 |
setattr(model, k, getattr(model[0], k))
|
109 |
model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
|
110 |
assert all(model[0].nc == m.nc for m in model), f'Models have different class counts: {[m.nc for m in model]}'
|
111 |
+
return model
|
models/tf.py
CHANGED
@@ -7,7 +7,7 @@ Usage:
|
|
7 |
$ python models/tf.py --weights yolov5s.pt
|
8 |
|
9 |
Export:
|
10 |
-
$ python
|
11 |
"""
|
12 |
|
13 |
import argparse
|
|
|
7 |
$ python models/tf.py --weights yolov5s.pt
|
8 |
|
9 |
Export:
|
10 |
+
$ python export.py --weights yolov5s.pt --include saved_model pb tflite tfjs
|
11 |
"""
|
12 |
|
13 |
import argparse
|
models/yolo.py
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
YOLO-specific modules
|
4 |
|
5 |
Usage:
|
6 |
-
$ python
|
7 |
"""
|
8 |
|
9 |
import argparse
|
@@ -37,7 +37,7 @@ except ImportError:
|
|
37 |
|
38 |
class Detect(nn.Module):
|
39 |
stride = None # strides computed during build
|
40 |
-
|
41 |
export = False # export mode
|
42 |
|
43 |
def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
|
@@ -46,8 +46,8 @@ class Detect(nn.Module):
|
|
46 |
self.no = nc + 5 # number of outputs per anchor
|
47 |
self.nl = len(anchors) # number of detection layers
|
48 |
self.na = len(anchors[0]) // 2 # number of anchors
|
49 |
-
self.grid = [torch.
|
50 |
-
self.anchor_grid = [torch.
|
51 |
self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
|
52 |
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
|
53 |
self.inplace = inplace # use inplace ops (e.g. slice assignment)
|
@@ -60,7 +60,7 @@ class Detect(nn.Module):
|
|
60 |
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
|
61 |
|
62 |
if not self.training: # inference
|
63 |
-
if self.
|
64 |
self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
|
65 |
|
66 |
y = x[i].sigmoid()
|
@@ -81,17 +81,70 @@ class Detect(nn.Module):
|
|
81 |
t = self.anchors[i].dtype
|
82 |
shape = 1, self.na, ny, nx, 2 # grid shape
|
83 |
y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)
|
84 |
-
if torch_1_10
|
85 |
-
yv, xv = torch.meshgrid(y, x, indexing='ij')
|
86 |
-
else:
|
87 |
-
yv, xv = torch.meshgrid(y, x)
|
88 |
grid = torch.stack((xv, yv), 2).expand(shape) - 0.5 # add grid offset, i.e. y = 2.0 * x - 0.5
|
89 |
anchor_grid = (self.anchors[i] * self.stride[i]).view((1, self.na, 1, 1, 2)).expand(shape)
|
90 |
return grid, anchor_grid
|
91 |
|
92 |
|
93 |
-
class
|
94 |
-
# YOLOv5 model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
|
96 |
super().__init__()
|
97 |
if isinstance(cfg, dict):
|
@@ -119,7 +172,7 @@ class Model(nn.Module):
|
|
119 |
if isinstance(m, Detect):
|
120 |
s = 256 # 2x min stride
|
121 |
m.inplace = self.inplace
|
122 |
-
m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.
|
123 |
check_anchor_order(m) # must be in pixel-space (not grid-space)
|
124 |
m.anchors /= m.stride.view(-1, 1, 1)
|
125 |
self.stride = m.stride
|
@@ -149,19 +202,6 @@ class Model(nn.Module):
|
|
149 |
y = self._clip_augmented(y) # clip augmented tails
|
150 |
return torch.cat(y, 1), None # augmented inference, train
|
151 |
|
152 |
-
def _forward_once(self, x, profile=False, visualize=False):
|
153 |
-
y, dt = [], [] # outputs
|
154 |
-
for m in self.model:
|
155 |
-
if m.f != -1: # if not from previous layer
|
156 |
-
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
|
157 |
-
if profile:
|
158 |
-
self._profile_one_layer(m, x, dt)
|
159 |
-
x = m(x) # run
|
160 |
-
y.append(x if m.i in self.save else None) # save output
|
161 |
-
if visualize:
|
162 |
-
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
163 |
-
return x
|
164 |
-
|
165 |
def _descale_pred(self, p, flips, scale, img_size):
|
166 |
# de-scale predictions following augmented inference (inverse operation)
|
167 |
if self.inplace:
|
@@ -190,19 +230,6 @@ class Model(nn.Module):
|
|
190 |
y[-1] = y[-1][:, i:] # small
|
191 |
return y
|
192 |
|
193 |
-
def _profile_one_layer(self, m, x, dt):
|
194 |
-
c = isinstance(m, Detect) # is final layer, copy input as inplace fix
|
195 |
-
o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
|
196 |
-
t = time_sync()
|
197 |
-
for _ in range(10):
|
198 |
-
m(x.copy() if c else x)
|
199 |
-
dt.append((time_sync() - t) * 100)
|
200 |
-
if m == self.model[0]:
|
201 |
-
LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
|
202 |
-
LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
|
203 |
-
if c:
|
204 |
-
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
|
205 |
-
|
206 |
def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
|
207 |
# https://arxiv.org/abs/1708.02002 section 3.3
|
208 |
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
|
@@ -213,41 +240,34 @@ class Model(nn.Module):
|
|
213 |
b[:, 5:] += math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # cls
|
214 |
mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
|
215 |
|
216 |
-
def _print_biases(self):
|
217 |
-
m = self.model[-1] # Detect() module
|
218 |
-
for mi in m.m: # from
|
219 |
-
b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
|
220 |
-
LOGGER.info(
|
221 |
-
('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
|
222 |
|
223 |
-
|
224 |
-
# for m in self.model.modules():
|
225 |
-
# if type(m) is Bottleneck:
|
226 |
-
# LOGGER.info('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
|
227 |
|
228 |
-
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
|
229 |
-
LOGGER.info('Fusing layers... ')
|
230 |
-
for m in self.model.modules():
|
231 |
-
if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
|
232 |
-
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
233 |
-
delattr(m, 'bn') # remove batchnorm
|
234 |
-
m.forward = m.forward_fuse # update forward
|
235 |
-
self.info()
|
236 |
-
return self
|
237 |
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
|
252 |
|
253 |
def parse_model(d, ch): # model_dict, input_channels(3)
|
@@ -321,7 +341,7 @@ if __name__ == '__main__':
|
|
321 |
|
322 |
# Options
|
323 |
if opt.line_profile: # profile layer by layer
|
324 |
-
|
325 |
|
326 |
elif opt.profile: # profile forward-backward
|
327 |
results = profile(input=im, ops=[model], n=3)
|
|
|
3 |
YOLO-specific modules
|
4 |
|
5 |
Usage:
|
6 |
+
$ python models/yolo.py --cfg yolov5s.yaml
|
7 |
"""
|
8 |
|
9 |
import argparse
|
|
|
37 |
|
38 |
class Detect(nn.Module):
|
39 |
stride = None # strides computed during build
|
40 |
+
dynamic = False # force grid reconstruction
|
41 |
export = False # export mode
|
42 |
|
43 |
def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
|
|
|
46 |
self.no = nc + 5 # number of outputs per anchor
|
47 |
self.nl = len(anchors) # number of detection layers
|
48 |
self.na = len(anchors[0]) // 2 # number of anchors
|
49 |
+
self.grid = [torch.empty(1)] * self.nl # init grid
|
50 |
+
self.anchor_grid = [torch.empty(1)] * self.nl # init anchor grid
|
51 |
self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
|
52 |
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
|
53 |
self.inplace = inplace # use inplace ops (e.g. slice assignment)
|
|
|
60 |
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
|
61 |
|
62 |
if not self.training: # inference
|
63 |
+
if self.dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
|
64 |
self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
|
65 |
|
66 |
y = x[i].sigmoid()
|
|
|
81 |
t = self.anchors[i].dtype
|
82 |
shape = 1, self.na, ny, nx, 2 # grid shape
|
83 |
y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)
|
84 |
+
yv, xv = torch.meshgrid(y, x, indexing='ij') if torch_1_10 else torch.meshgrid(y, x) # torch>=0.7 compatibility
|
|
|
|
|
|
|
85 |
grid = torch.stack((xv, yv), 2).expand(shape) - 0.5 # add grid offset, i.e. y = 2.0 * x - 0.5
|
86 |
anchor_grid = (self.anchors[i] * self.stride[i]).view((1, self.na, 1, 1, 2)).expand(shape)
|
87 |
return grid, anchor_grid
|
88 |
|
89 |
|
90 |
+
class BaseModel(nn.Module):
|
91 |
+
# YOLOv5 base model
|
92 |
+
def forward(self, x, profile=False, visualize=False):
|
93 |
+
return self._forward_once(x, profile, visualize) # single-scale inference, train
|
94 |
+
|
95 |
+
def _forward_once(self, x, profile=False, visualize=False):
|
96 |
+
y, dt = [], [] # outputs
|
97 |
+
for m in self.model:
|
98 |
+
if m.f != -1: # if not from previous layer
|
99 |
+
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
|
100 |
+
if profile:
|
101 |
+
self._profile_one_layer(m, x, dt)
|
102 |
+
x = m(x) # run
|
103 |
+
y.append(x if m.i in self.save else None) # save output
|
104 |
+
if visualize:
|
105 |
+
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
106 |
+
return x
|
107 |
+
|
108 |
+
def _profile_one_layer(self, m, x, dt):
|
109 |
+
c = m == self.model[-1] # is final layer, copy input as inplace fix
|
110 |
+
o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
|
111 |
+
t = time_sync()
|
112 |
+
for _ in range(10):
|
113 |
+
m(x.copy() if c else x)
|
114 |
+
dt.append((time_sync() - t) * 100)
|
115 |
+
if m == self.model[0]:
|
116 |
+
LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
|
117 |
+
LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
|
118 |
+
if c:
|
119 |
+
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
|
120 |
+
|
121 |
+
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
|
122 |
+
LOGGER.info('Fusing layers... ')
|
123 |
+
for m in self.model.modules():
|
124 |
+
if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
|
125 |
+
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
126 |
+
delattr(m, 'bn') # remove batchnorm
|
127 |
+
m.forward = m.forward_fuse # update forward
|
128 |
+
self.info()
|
129 |
+
return self
|
130 |
+
|
131 |
+
def info(self, verbose=False, img_size=640): # print model information
|
132 |
+
model_info(self, verbose, img_size)
|
133 |
+
|
134 |
+
def _apply(self, fn):
|
135 |
+
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
|
136 |
+
self = super()._apply(fn)
|
137 |
+
m = self.model[-1] # Detect()
|
138 |
+
if isinstance(m, Detect):
|
139 |
+
m.stride = fn(m.stride)
|
140 |
+
m.grid = list(map(fn, m.grid))
|
141 |
+
if isinstance(m.anchor_grid, list):
|
142 |
+
m.anchor_grid = list(map(fn, m.anchor_grid))
|
143 |
+
return self
|
144 |
+
|
145 |
+
|
146 |
+
class DetectionModel(BaseModel):
|
147 |
+
# YOLOv5 detection model
|
148 |
def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
|
149 |
super().__init__()
|
150 |
if isinstance(cfg, dict):
|
|
|
172 |
if isinstance(m, Detect):
|
173 |
s = 256 # 2x min stride
|
174 |
m.inplace = self.inplace
|
175 |
+
m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.empty(1, ch, s, s))]) # forward
|
176 |
check_anchor_order(m) # must be in pixel-space (not grid-space)
|
177 |
m.anchors /= m.stride.view(-1, 1, 1)
|
178 |
self.stride = m.stride
|
|
|
202 |
y = self._clip_augmented(y) # clip augmented tails
|
203 |
return torch.cat(y, 1), None # augmented inference, train
|
204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
def _descale_pred(self, p, flips, scale, img_size):
|
206 |
# de-scale predictions following augmented inference (inverse operation)
|
207 |
if self.inplace:
|
|
|
230 |
y[-1] = y[-1][:, i:] # small
|
231 |
return y
|
232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
|
234 |
# https://arxiv.org/abs/1708.02002 section 3.3
|
235 |
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
|
|
|
240 |
b[:, 5:] += math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # cls
|
241 |
mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
|
242 |
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
|
244 |
+
Model = DetectionModel # retain YOLOv5 'Model' class for backwards compatibility
|
|
|
|
|
|
|
245 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
|
247 |
+
class ClassificationModel(BaseModel):
|
248 |
+
# YOLOv5 classification model
|
249 |
+
def __init__(self, cfg=None, model=None, nc=1000, cutoff=10): # yaml, model, number of classes, cutoff index
|
250 |
+
super().__init__()
|
251 |
+
self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg)
|
252 |
+
|
253 |
+
def _from_detection_model(self, model, nc=1000, cutoff=10):
|
254 |
+
# Create a YOLOv5 classification model from a YOLOv5 detection model
|
255 |
+
if isinstance(model, DetectMultiBackend):
|
256 |
+
model = model.model # unwrap DetectMultiBackend
|
257 |
+
model.model = model.model[:cutoff] # backbone
|
258 |
+
m = model.model[-1] # last layer
|
259 |
+
ch = m.conv.in_channels if hasattr(m, 'conv') else m.cv1.conv.in_channels # ch into module
|
260 |
+
c = Classify(ch, nc) # Classify()
|
261 |
+
c.i, c.f, c.type = m.i, m.f, 'models.common.Classify' # index, from, type
|
262 |
+
model.model[-1] = c # replace
|
263 |
+
self.model = model.model
|
264 |
+
self.stride = model.stride
|
265 |
+
self.save = []
|
266 |
+
self.nc = nc
|
267 |
+
|
268 |
+
def _from_yaml(self, cfg):
|
269 |
+
# Create a YOLOv5 classification model from a *.yaml file
|
270 |
+
self.model = None
|
271 |
|
272 |
|
273 |
def parse_model(d, ch): # model_dict, input_channels(3)
|
|
|
341 |
|
342 |
# Options
|
343 |
if opt.line_profile: # profile layer by layer
|
344 |
+
model(im, profile=True)
|
345 |
|
346 |
elif opt.profile: # profile forward-backward
|
347 |
results = profile(input=im, ops=[model], n=3)
|
requirements.txt
CHANGED
@@ -23,9 +23,9 @@ pandas>=1.1.4
|
|
23 |
seaborn>=0.11.0
|
24 |
|
25 |
# Export --------------------------------------
|
26 |
-
coremltools>=
|
27 |
onnx>=1.9.0 # ONNX export
|
28 |
-
onnx-simplifier>=0.
|
29 |
onnxruntime
|
30 |
# nvidia-pyindex # TensorRT export
|
31 |
# nvidia-tensorrt # TensorRT export
|
|
|
23 |
seaborn>=0.11.0
|
24 |
|
25 |
# Export --------------------------------------
|
26 |
+
coremltools>=5.2 # CoreML export
|
27 |
onnx>=1.9.0 # ONNX export
|
28 |
+
onnx-simplifier>=0.4.1 # ONNX simplifier
|
29 |
onnxruntime
|
30 |
# nvidia-pyindex # TensorRT export
|
31 |
# nvidia-tensorrt # TensorRT export
|
utils/__init__.py
CHANGED
@@ -3,6 +3,33 @@
|
|
3 |
utils/initialization
|
4 |
"""
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
def notebook_init(verbose=True):
|
8 |
# Check system software and hardware
|
@@ -11,10 +38,12 @@ def notebook_init(verbose=True):
|
|
11 |
import os
|
12 |
import shutil
|
13 |
|
14 |
-
from utils.general import check_requirements, emojis, is_colab
|
15 |
from utils.torch_utils import select_device # imports
|
16 |
|
17 |
check_requirements(('psutil', 'IPython'))
|
|
|
|
|
18 |
import psutil
|
19 |
from IPython import display # to display images and clear console output
|
20 |
|
|
|
3 |
utils/initialization
|
4 |
"""
|
5 |
|
6 |
+
import contextlib
|
7 |
+
import threading
|
8 |
+
|
9 |
+
|
10 |
+
class TryExcept(contextlib.ContextDecorator):
|
11 |
+
# YOLOv5 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager
|
12 |
+
def __init__(self, msg=''):
|
13 |
+
self.msg = msg
|
14 |
+
|
15 |
+
def __enter__(self):
|
16 |
+
pass
|
17 |
+
|
18 |
+
def __exit__(self, exc_type, value, traceback):
|
19 |
+
if value:
|
20 |
+
print(f'{self.msg}{value}')
|
21 |
+
return True
|
22 |
+
|
23 |
+
|
24 |
+
def threaded(func):
|
25 |
+
# Multi-threads a target function and returns thread. Usage: @threaded decorator
|
26 |
+
def wrapper(*args, **kwargs):
|
27 |
+
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
|
28 |
+
thread.start()
|
29 |
+
return thread
|
30 |
+
|
31 |
+
return wrapper
|
32 |
+
|
33 |
|
34 |
def notebook_init(verbose=True):
|
35 |
# Check system software and hardware
|
|
|
38 |
import os
|
39 |
import shutil
|
40 |
|
41 |
+
from utils.general import check_font, check_requirements, emojis, is_colab
|
42 |
from utils.torch_utils import select_device # imports
|
43 |
|
44 |
check_requirements(('psutil', 'IPython'))
|
45 |
+
check_font()
|
46 |
+
|
47 |
import psutil
|
48 |
from IPython import display # to display images and clear console output
|
49 |
|
utils/augmentations.py
CHANGED
@@ -8,15 +8,22 @@ import random
|
|
8 |
|
9 |
import cv2
|
10 |
import numpy as np
|
|
|
|
|
|
|
11 |
|
12 |
from utils.general import LOGGER, check_version, colorstr, resample_segments, segment2box
|
13 |
from utils.metrics import bbox_ioa
|
14 |
|
|
|
|
|
|
|
15 |
|
16 |
class Albumentations:
|
17 |
# YOLOv5 Albumentations class (optional, only used if package is installed)
|
18 |
def __init__(self):
|
19 |
self.transform = None
|
|
|
20 |
try:
|
21 |
import albumentations as A
|
22 |
check_version(A.__version__, '1.0.3', hard=True) # version requirement
|
@@ -31,11 +38,11 @@ class Albumentations:
|
|
31 |
A.ImageCompression(quality_lower=75, p=0.0)] # transforms
|
32 |
self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
|
33 |
|
34 |
-
LOGGER.info(
|
35 |
except ImportError: # package not installed, skip
|
36 |
pass
|
37 |
except Exception as e:
|
38 |
-
LOGGER.info(
|
39 |
|
40 |
def __call__(self, im, labels, p=1.0):
|
41 |
if self.transform and random.random() < p:
|
@@ -44,6 +51,18 @@ class Albumentations:
|
|
44 |
return im, labels
|
45 |
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
|
48 |
# HSV color-space augmentation
|
49 |
if hgain or sgain or vgain:
|
@@ -282,3 +301,96 @@ def box_candidates(box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16):
|
|
282 |
w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
|
283 |
ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
|
284 |
return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
import cv2
|
10 |
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torchvision.transforms as T
|
13 |
+
import torchvision.transforms.functional as TF
|
14 |
|
15 |
from utils.general import LOGGER, check_version, colorstr, resample_segments, segment2box
|
16 |
from utils.metrics import bbox_ioa
|
17 |
|
18 |
+
IMAGENET_MEAN = 0.485, 0.456, 0.406 # RGB mean
|
19 |
+
IMAGENET_STD = 0.229, 0.224, 0.225 # RGB standard deviation
|
20 |
+
|
21 |
|
22 |
class Albumentations:
|
23 |
# YOLOv5 Albumentations class (optional, only used if package is installed)
|
24 |
def __init__(self):
|
25 |
self.transform = None
|
26 |
+
prefix = colorstr('albumentations: ')
|
27 |
try:
|
28 |
import albumentations as A
|
29 |
check_version(A.__version__, '1.0.3', hard=True) # version requirement
|
|
|
38 |
A.ImageCompression(quality_lower=75, p=0.0)] # transforms
|
39 |
self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
|
40 |
|
41 |
+
LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
|
42 |
except ImportError: # package not installed, skip
|
43 |
pass
|
44 |
except Exception as e:
|
45 |
+
LOGGER.info(f'{prefix}{e}')
|
46 |
|
47 |
def __call__(self, im, labels, p=1.0):
|
48 |
if self.transform and random.random() < p:
|
|
|
51 |
return im, labels
|
52 |
|
53 |
|
54 |
+
def normalize(x, mean=IMAGENET_MEAN, std=IMAGENET_STD, inplace=False):
|
55 |
+
# Denormalize RGB images x per ImageNet stats in BCHW format, i.e. = (x - mean) / std
|
56 |
+
return TF.normalize(x, mean, std, inplace=inplace)
|
57 |
+
|
58 |
+
|
59 |
+
def denormalize(x, mean=IMAGENET_MEAN, std=IMAGENET_STD):
|
60 |
+
# Denormalize RGB images x per ImageNet stats in BCHW format, i.e. = x * std + mean
|
61 |
+
for i in range(3):
|
62 |
+
x[:, i] = x[:, i] * std[i] + mean[i]
|
63 |
+
return x
|
64 |
+
|
65 |
+
|
66 |
def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
|
67 |
# HSV color-space augmentation
|
68 |
if hgain or sgain or vgain:
|
|
|
301 |
w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
|
302 |
ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
|
303 |
return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
|
304 |
+
|
305 |
+
|
306 |
+
def classify_albumentations(augment=True,
|
307 |
+
size=224,
|
308 |
+
scale=(0.08, 1.0),
|
309 |
+
hflip=0.5,
|
310 |
+
vflip=0.0,
|
311 |
+
jitter=0.4,
|
312 |
+
mean=IMAGENET_MEAN,
|
313 |
+
std=IMAGENET_STD,
|
314 |
+
auto_aug=False):
|
315 |
+
# YOLOv5 classification Albumentations (optional, only used if package is installed)
|
316 |
+
prefix = colorstr('albumentations: ')
|
317 |
+
try:
|
318 |
+
import albumentations as A
|
319 |
+
from albumentations.pytorch import ToTensorV2
|
320 |
+
check_version(A.__version__, '1.0.3', hard=True) # version requirement
|
321 |
+
if augment: # Resize and crop
|
322 |
+
T = [A.RandomResizedCrop(height=size, width=size, scale=scale)]
|
323 |
+
if auto_aug:
|
324 |
+
# TODO: implement AugMix, AutoAug & RandAug in albumentation
|
325 |
+
LOGGER.info(f'{prefix}auto augmentations are currently not supported')
|
326 |
+
else:
|
327 |
+
if hflip > 0:
|
328 |
+
T += [A.HorizontalFlip(p=hflip)]
|
329 |
+
if vflip > 0:
|
330 |
+
T += [A.VerticalFlip(p=vflip)]
|
331 |
+
if jitter > 0:
|
332 |
+
color_jitter = (float(jitter),) * 3 # repeat value for brightness, contrast, satuaration, 0 hue
|
333 |
+
T += [A.ColorJitter(*color_jitter, 0)]
|
334 |
+
else: # Use fixed crop for eval set (reproducibility)
|
335 |
+
T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
|
336 |
+
T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor
|
337 |
+
LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
|
338 |
+
return A.Compose(T)
|
339 |
+
|
340 |
+
except ImportError: # package not installed, skip
|
341 |
+
pass
|
342 |
+
except Exception as e:
|
343 |
+
LOGGER.info(f'{prefix}{e}')
|
344 |
+
|
345 |
+
|
346 |
+
def classify_transforms(size=224):
|
347 |
+
# Transforms to apply if albumentations not installed
|
348 |
+
assert isinstance(size, int), f'ERROR: classify_transforms size {size} must be integer, not (list, tuple)'
|
349 |
+
# T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
|
350 |
+
return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
|
351 |
+
|
352 |
+
|
353 |
+
class LetterBox:
|
354 |
+
# YOLOv5 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
|
355 |
+
def __init__(self, size=(640, 640), auto=False, stride=32):
|
356 |
+
super().__init__()
|
357 |
+
self.h, self.w = (size, size) if isinstance(size, int) else size
|
358 |
+
self.auto = auto # pass max size integer, automatically solve for short side using stride
|
359 |
+
self.stride = stride # used with auto
|
360 |
+
|
361 |
+
def __call__(self, im): # im = np.array HWC
|
362 |
+
imh, imw = im.shape[:2]
|
363 |
+
r = min(self.h / imh, self.w / imw) # ratio of new/old
|
364 |
+
h, w = round(imh * r), round(imw * r) # resized image
|
365 |
+
hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else self.h, self.w
|
366 |
+
top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1)
|
367 |
+
im_out = np.full((self.h, self.w, 3), 114, dtype=im.dtype)
|
368 |
+
im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
|
369 |
+
return im_out
|
370 |
+
|
371 |
+
|
372 |
+
class CenterCrop:
|
373 |
+
# YOLOv5 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])
|
374 |
+
def __init__(self, size=640):
|
375 |
+
super().__init__()
|
376 |
+
self.h, self.w = (size, size) if isinstance(size, int) else size
|
377 |
+
|
378 |
+
def __call__(self, im): # im = np.array HWC
|
379 |
+
imh, imw = im.shape[:2]
|
380 |
+
m = min(imh, imw) # min dimension
|
381 |
+
top, left = (imh - m) // 2, (imw - m) // 2
|
382 |
+
return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
|
383 |
+
|
384 |
+
|
385 |
+
class ToTensor:
|
386 |
+
# YOLOv5 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
|
387 |
+
def __init__(self, half=False):
|
388 |
+
super().__init__()
|
389 |
+
self.half = half
|
390 |
+
|
391 |
+
def __call__(self, im): # im = np.array HWC in BGR order
|
392 |
+
im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1]) # HWC to CHW -> BGR to RGB -> contiguous
|
393 |
+
im = torch.from_numpy(im) # to torch
|
394 |
+
im = im.half() if self.half else im.float() # uint8 to fp16/32
|
395 |
+
im /= 255.0 # 0-255 to 0.0-1.0
|
396 |
+
return im
|
utils/autoanchor.py
CHANGED
@@ -10,6 +10,7 @@ import torch
|
|
10 |
import yaml
|
11 |
from tqdm import tqdm
|
12 |
|
|
|
13 |
from utils.general import LOGGER, colorstr
|
14 |
|
15 |
PREFIX = colorstr('AutoAnchor: ')
|
@@ -25,6 +26,7 @@ def check_anchor_order(m):
|
|
25 |
m.anchors[:] = m.anchors.flip(0)
|
26 |
|
27 |
|
|
|
28 |
def check_anchors(dataset, model, thr=4.0, imgsz=640):
|
29 |
# Check anchor fit to data, recompute if necessary
|
30 |
m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
|
@@ -49,10 +51,7 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
|
|
49 |
else:
|
50 |
LOGGER.info(f'{s}Anchors are a poor fit to dataset ⚠️, attempting to improve...')
|
51 |
na = m.anchors.numel() // 2 # number of anchors
|
52 |
-
|
53 |
-
anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
|
54 |
-
except Exception as e:
|
55 |
-
LOGGER.info(f'{PREFIX}ERROR: {e}')
|
56 |
new_bpr = metric(anchors)[0]
|
57 |
if new_bpr > bpr: # replace anchors
|
58 |
anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
|
@@ -124,7 +123,7 @@ def kmean_anchors(dataset='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen
|
|
124 |
i = (wh0 < 3.0).any(1).sum()
|
125 |
if i:
|
126 |
LOGGER.info(f'{PREFIX}WARNING: Extremely small objects found: {i} of {len(wh0)} labels are < 3 pixels in size')
|
127 |
-
wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
|
128 |
# wh = wh * (npr.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1
|
129 |
|
130 |
# Kmeans init
|
@@ -167,4 +166,4 @@ def kmean_anchors(dataset='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen
|
|
167 |
if verbose:
|
168 |
print_results(k, verbose)
|
169 |
|
170 |
-
return print_results(k)
|
|
|
10 |
import yaml
|
11 |
from tqdm import tqdm
|
12 |
|
13 |
+
from utils import TryExcept
|
14 |
from utils.general import LOGGER, colorstr
|
15 |
|
16 |
PREFIX = colorstr('AutoAnchor: ')
|
|
|
26 |
m.anchors[:] = m.anchors.flip(0)
|
27 |
|
28 |
|
29 |
+
@TryExcept(f'{PREFIX}ERROR: ')
|
30 |
def check_anchors(dataset, model, thr=4.0, imgsz=640):
|
31 |
# Check anchor fit to data, recompute if necessary
|
32 |
m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
|
|
|
51 |
else:
|
52 |
LOGGER.info(f'{s}Anchors are a poor fit to dataset ⚠️, attempting to improve...')
|
53 |
na = m.anchors.numel() // 2 # number of anchors
|
54 |
+
anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
|
|
|
|
|
|
|
55 |
new_bpr = metric(anchors)[0]
|
56 |
if new_bpr > bpr: # replace anchors
|
57 |
anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
|
|
|
123 |
i = (wh0 < 3.0).any(1).sum()
|
124 |
if i:
|
125 |
LOGGER.info(f'{PREFIX}WARNING: Extremely small objects found: {i} of {len(wh0)} labels are < 3 pixels in size')
|
126 |
+
wh = wh0[(wh0 >= 2.0).any(1)].astype(np.float32) # filter > 2 pixels
|
127 |
# wh = wh * (npr.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1
|
128 |
|
129 |
# Kmeans init
|
|
|
166 |
if verbose:
|
167 |
print_results(k, verbose)
|
168 |
|
169 |
+
return print_results(k).astype(np.float32)
|
utils/autobatch.py
CHANGED
@@ -18,7 +18,7 @@ def check_train_batch_size(model, imgsz=640, amp=True):
|
|
18 |
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
|
19 |
|
20 |
|
21 |
-
def autobatch(model, imgsz=640, fraction=0.
|
22 |
# Automatically estimate best batch size to use `fraction` of available CUDA memory
|
23 |
# Usage:
|
24 |
# import torch
|
@@ -47,7 +47,7 @@ def autobatch(model, imgsz=640, fraction=0.9, batch_size=16):
|
|
47 |
# Profile batch sizes
|
48 |
batch_sizes = [1, 2, 4, 8, 16]
|
49 |
try:
|
50 |
-
img = [torch.
|
51 |
results = profile(img, model, n=3, device=device)
|
52 |
except Exception as e:
|
53 |
LOGGER.warning(f'{prefix}{e}')
|
@@ -60,6 +60,9 @@ def autobatch(model, imgsz=640, fraction=0.9, batch_size=16):
|
|
60 |
i = results.index(None) # first fail index
|
61 |
if b >= batch_sizes[i]: # y intercept above failure point
|
62 |
b = batch_sizes[max(i - 1, 0)] # select prior safe point
|
|
|
|
|
|
|
63 |
|
64 |
fraction = np.polyval(p, b) / t # actual fraction predicted
|
65 |
LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅')
|
|
|
18 |
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
|
19 |
|
20 |
|
21 |
+
def autobatch(model, imgsz=640, fraction=0.8, batch_size=16):
|
22 |
# Automatically estimate best batch size to use `fraction` of available CUDA memory
|
23 |
# Usage:
|
24 |
# import torch
|
|
|
47 |
# Profile batch sizes
|
48 |
batch_sizes = [1, 2, 4, 8, 16]
|
49 |
try:
|
50 |
+
img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
|
51 |
results = profile(img, model, n=3, device=device)
|
52 |
except Exception as e:
|
53 |
LOGGER.warning(f'{prefix}{e}')
|
|
|
60 |
i = results.index(None) # first fail index
|
61 |
if b >= batch_sizes[i]: # y intercept above failure point
|
62 |
b = batch_sizes[max(i - 1, 0)] # select prior safe point
|
63 |
+
if b < 1 or b > 1024: # b outside of safe range
|
64 |
+
b = batch_size
|
65 |
+
LOGGER.warning(f'{prefix}WARNING: ⚠️ CUDA anomaly detected, recommend restart environment and retry command.')
|
66 |
|
67 |
fraction = np.polyval(p, b) / t # actual fraction predicted
|
68 |
LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅')
|
utils/benchmarks.py
CHANGED
@@ -92,10 +92,14 @@ def run(
|
|
92 |
LOGGER.info('\n')
|
93 |
parse_opt()
|
94 |
notebook_init() # print system info
|
95 |
-
c = ['Format', 'Size (MB)', '
|
96 |
py = pd.DataFrame(y, columns=c)
|
97 |
LOGGER.info(f'\nBenchmarks complete ({time.time() - t:.2f}s)')
|
98 |
LOGGER.info(str(py if map else py.iloc[:, :2]))
|
|
|
|
|
|
|
|
|
99 |
return py
|
100 |
|
101 |
|
@@ -141,7 +145,7 @@ def parse_opt():
|
|
141 |
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
|
142 |
parser.add_argument('--test', action='store_true', help='test exports only')
|
143 |
parser.add_argument('--pt-only', action='store_true', help='test PyTorch only')
|
144 |
-
parser.add_argument('--hard-fail',
|
145 |
opt = parser.parse_args()
|
146 |
opt.data = check_yaml(opt.data) # check YAML
|
147 |
print_args(vars(opt))
|
|
|
92 |
LOGGER.info('\n')
|
93 |
parse_opt()
|
94 |
notebook_init() # print system info
|
95 |
+
c = ['Format', 'Size (MB)', 'mAP50-95', 'Inference time (ms)'] if map else ['Format', 'Export', '', '']
|
96 |
py = pd.DataFrame(y, columns=c)
|
97 |
LOGGER.info(f'\nBenchmarks complete ({time.time() - t:.2f}s)')
|
98 |
LOGGER.info(str(py if map else py.iloc[:, :2]))
|
99 |
+
if hard_fail and isinstance(hard_fail, str):
|
100 |
+
metrics = py['mAP50-95'].array # values to compare to floor
|
101 |
+
floor = eval(hard_fail) # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n
|
102 |
+
assert all(x > floor for x in metrics if pd.notna(x)), f'HARD FAIL: mAP50-95 < floor {floor}'
|
103 |
return py
|
104 |
|
105 |
|
|
|
145 |
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
|
146 |
parser.add_argument('--test', action='store_true', help='test exports only')
|
147 |
parser.add_argument('--pt-only', action='store_true', help='test PyTorch only')
|
148 |
+
parser.add_argument('--hard-fail', nargs='?', const=True, default=False, help='Exception on error or < min metric')
|
149 |
opt = parser.parse_args()
|
150 |
opt.data = check_yaml(opt.data) # check YAML
|
151 |
print_args(vars(opt))
|
utils/callbacks.py
CHANGED
@@ -3,6 +3,8 @@
|
|
3 |
Callback utils
|
4 |
"""
|
5 |
|
|
|
|
|
6 |
|
7 |
class Callbacks:
|
8 |
""""
|
@@ -55,17 +57,20 @@ class Callbacks:
|
|
55 |
"""
|
56 |
return self._callbacks[hook] if hook else self._callbacks
|
57 |
|
58 |
-
def run(self, hook, *args, **kwargs):
|
59 |
"""
|
60 |
-
Loop through the registered actions and fire all callbacks
|
61 |
|
62 |
Args:
|
63 |
hook: The name of the hook to check, defaults to all
|
64 |
args: Arguments to receive from YOLOv5
|
|
|
65 |
kwargs: Keyword Arguments to receive from YOLOv5
|
66 |
"""
|
67 |
|
68 |
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
|
69 |
-
|
70 |
for logger in self._callbacks[hook]:
|
71 |
-
|
|
|
|
|
|
|
|
3 |
Callback utils
|
4 |
"""
|
5 |
|
6 |
+
import threading
|
7 |
+
|
8 |
|
9 |
class Callbacks:
|
10 |
""""
|
|
|
57 |
"""
|
58 |
return self._callbacks[hook] if hook else self._callbacks
|
59 |
|
60 |
+
def run(self, hook, *args, thread=False, **kwargs):
|
61 |
"""
|
62 |
+
Loop through the registered actions and fire all callbacks on main thread
|
63 |
|
64 |
Args:
|
65 |
hook: The name of the hook to check, defaults to all
|
66 |
args: Arguments to receive from YOLOv5
|
67 |
+
thread: (boolean) Run callbacks in daemon thread
|
68 |
kwargs: Keyword Arguments to receive from YOLOv5
|
69 |
"""
|
70 |
|
71 |
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
|
|
|
72 |
for logger in self._callbacks[hook]:
|
73 |
+
if thread:
|
74 |
+
threading.Thread(target=logger['callback'], args=args, kwargs=kwargs, daemon=True).start()
|
75 |
+
else:
|
76 |
+
logger['callback'](*args, **kwargs)
|
utils/dataloaders.py
CHANGED
@@ -22,22 +22,25 @@ from zipfile import ZipFile
|
|
22 |
import numpy as np
|
23 |
import torch
|
24 |
import torch.nn.functional as F
|
|
|
25 |
import yaml
|
26 |
from PIL import ExifTags, Image, ImageOps
|
27 |
from torch.utils.data import DataLoader, Dataset, dataloader, distributed
|
28 |
from tqdm import tqdm
|
29 |
|
30 |
-
from utils.augmentations import Albumentations, augment_hsv,
|
|
|
31 |
from utils.general import (DATASETS_DIR, LOGGER, NUM_THREADS, check_dataset, check_requirements, check_yaml, clean_str,
|
32 |
cv2, is_colab, is_kaggle, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn)
|
33 |
from utils.torch_utils import torch_distributed_zero_first
|
34 |
|
35 |
# Parameters
|
36 |
-
HELP_URL = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
|
37 |
-
IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp' # include image suffixes
|
38 |
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
|
39 |
BAR_FORMAT = '{l_bar}{bar:10}{r_bar}{bar:-10b}' # tqdm bar format
|
40 |
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
|
|
41 |
|
42 |
# Get orientation exif tag
|
43 |
for orientation in ExifTags.TAGS.keys():
|
@@ -81,7 +84,7 @@ def exif_transpose(image):
|
|
81 |
5: Image.TRANSPOSE,
|
82 |
6: Image.ROTATE_270,
|
83 |
7: Image.TRANSVERSE,
|
84 |
-
8: Image.ROTATE_90
|
85 |
if method is not None:
|
86 |
image = image.transpose(method)
|
87 |
del exif[0x0112]
|
@@ -142,7 +145,7 @@ def create_dataloader(path,
|
|
142 |
shuffle=shuffle and sampler is None,
|
143 |
num_workers=nw,
|
144 |
sampler=sampler,
|
145 |
-
pin_memory=
|
146 |
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn,
|
147 |
worker_init_fn=seed_worker,
|
148 |
generator=generator), dataset
|
@@ -184,7 +187,7 @@ class _RepeatSampler:
|
|
184 |
|
185 |
class LoadImages:
|
186 |
# YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
|
187 |
-
def __init__(self, path, img_size=640, stride=32, auto=True):
|
188 |
files = []
|
189 |
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
|
190 |
p = str(Path(p).resolve())
|
@@ -208,8 +211,10 @@ class LoadImages:
|
|
208 |
self.video_flag = [False] * ni + [True] * nv
|
209 |
self.mode = 'image'
|
210 |
self.auto = auto
|
|
|
|
|
211 |
if any(videos):
|
212 |
-
self.
|
213 |
else:
|
214 |
self.cap = None
|
215 |
assert self.nf > 0, f'No images or videos found in {p}. ' \
|
@@ -227,103 +232,71 @@ class LoadImages:
|
|
227 |
if self.video_flag[self.count]:
|
228 |
# Read video
|
229 |
self.mode = 'video'
|
230 |
-
ret_val,
|
|
|
231 |
while not ret_val:
|
232 |
self.count += 1
|
233 |
self.cap.release()
|
234 |
if self.count == self.nf: # last video
|
235 |
raise StopIteration
|
236 |
path = self.files[self.count]
|
237 |
-
self.
|
238 |
-
ret_val,
|
239 |
|
240 |
self.frame += 1
|
|
|
241 |
s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
|
242 |
|
243 |
else:
|
244 |
# Read image
|
245 |
self.count += 1
|
246 |
-
|
247 |
-
assert
|
248 |
s = f'image {self.count}/{self.nf} {path}: '
|
249 |
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
|
257 |
-
return path,
|
258 |
|
259 |
-
def
|
|
|
260 |
self.frame = 0
|
261 |
self.cap = cv2.VideoCapture(path)
|
262 |
-
self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
|
264 |
def __len__(self):
|
265 |
return self.nf # number of files
|
266 |
|
267 |
|
268 |
-
class LoadWebcam: # for inference
|
269 |
-
# YOLOv5 local webcam dataloader, i.e. `python detect.py --source 0`
|
270 |
-
def __init__(self, pipe='0', img_size=640, stride=32):
|
271 |
-
self.img_size = img_size
|
272 |
-
self.stride = stride
|
273 |
-
self.pipe = eval(pipe) if pipe.isnumeric() else pipe
|
274 |
-
self.cap = cv2.VideoCapture(self.pipe) # video capture object
|
275 |
-
self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
|
276 |
-
|
277 |
-
def __iter__(self):
|
278 |
-
self.count = -1
|
279 |
-
return self
|
280 |
-
|
281 |
-
def __next__(self):
|
282 |
-
self.count += 1
|
283 |
-
if cv2.waitKey(1) == ord('q'): # q to quit
|
284 |
-
self.cap.release()
|
285 |
-
cv2.destroyAllWindows()
|
286 |
-
raise StopIteration
|
287 |
-
|
288 |
-
# Read frame
|
289 |
-
ret_val, img0 = self.cap.read()
|
290 |
-
img0 = cv2.flip(img0, 1) # flip left-right
|
291 |
-
|
292 |
-
# Print
|
293 |
-
assert ret_val, f'Camera Error {self.pipe}'
|
294 |
-
img_path = 'webcam.jpg'
|
295 |
-
s = f'webcam {self.count}: '
|
296 |
-
|
297 |
-
# Padded resize
|
298 |
-
img = letterbox(img0, self.img_size, stride=self.stride)[0]
|
299 |
-
|
300 |
-
# Convert
|
301 |
-
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
302 |
-
img = np.ascontiguousarray(img)
|
303 |
-
|
304 |
-
return img_path, img, img0, None, s
|
305 |
-
|
306 |
-
def __len__(self):
|
307 |
-
return 0
|
308 |
-
|
309 |
-
|
310 |
class LoadStreams:
|
311 |
# YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
|
312 |
-
def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True):
|
|
|
313 |
self.mode = 'stream'
|
314 |
self.img_size = img_size
|
315 |
self.stride = stride
|
316 |
-
|
317 |
-
|
318 |
-
with open(sources) as f:
|
319 |
-
sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
|
320 |
-
else:
|
321 |
-
sources = [sources]
|
322 |
-
|
323 |
n = len(sources)
|
324 |
-
self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
|
325 |
self.sources = [clean_str(x) for x in sources] # clean source names for later
|
326 |
-
self.
|
327 |
for i, s in enumerate(sources): # index, source
|
328 |
# Start thread to read frames from video stream
|
329 |
st = f'{i + 1}/{n}: {s}... '
|
@@ -350,19 +323,20 @@ class LoadStreams:
|
|
350 |
LOGGER.info('') # newline
|
351 |
|
352 |
# check for common shapes
|
353 |
-
s = np.stack([letterbox(x,
|
354 |
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
|
|
|
|
|
355 |
if not self.rect:
|
356 |
LOGGER.warning('WARNING: Stream shapes differ. For optimal performance supply similarly-shaped streams.')
|
357 |
|
358 |
def update(self, i, cap, stream):
|
359 |
# Read stream `i` frames in daemon thread
|
360 |
-
n, f
|
361 |
while cap.isOpened() and n < f:
|
362 |
n += 1
|
363 |
-
#
|
364 |
-
|
365 |
-
if n % read == 0:
|
366 |
success, im = cap.retrieve()
|
367 |
if success:
|
368 |
self.imgs[i] = im
|
@@ -382,18 +356,15 @@ class LoadStreams:
|
|
382 |
cv2.destroyAllWindows()
|
383 |
raise StopIteration
|
384 |
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
# Convert
|
393 |
-
img = img[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
|
394 |
-
img = np.ascontiguousarray(img)
|
395 |
|
396 |
-
return self.sources,
|
397 |
|
398 |
def __len__(self):
|
399 |
return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
|
@@ -453,7 +424,7 @@ class LoadImagesAndLabels(Dataset):
|
|
453 |
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
|
454 |
assert self.im_files, f'{prefix}No images found'
|
455 |
except Exception as e:
|
456 |
-
raise Exception(f'{prefix}Error loading data from {path}: {e}\
|
457 |
|
458 |
# Check cache
|
459 |
self.label_files = img2label_paths(self.im_files) # labels
|
@@ -472,11 +443,13 @@ class LoadImagesAndLabels(Dataset):
|
|
472 |
tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=BAR_FORMAT) # display cache results
|
473 |
if cache['msgs']:
|
474 |
LOGGER.info('\n'.join(cache['msgs'])) # display warnings
|
475 |
-
assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}
|
476 |
|
477 |
# Read cache
|
478 |
[cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
|
479 |
labels, shapes, self.segments = zip(*cache.values())
|
|
|
|
|
480 |
self.labels = list(labels)
|
481 |
self.shapes = np.array(shapes)
|
482 |
self.im_files = list(cache.keys()) # update
|
@@ -569,7 +542,7 @@ class LoadImagesAndLabels(Dataset):
|
|
569 |
if msgs:
|
570 |
LOGGER.info('\n'.join(msgs))
|
571 |
if nf == 0:
|
572 |
-
LOGGER.warning(f'{prefix}WARNING: No labels found in {path}.
|
573 |
x['hash'] = get_hash(self.label_files + self.im_files)
|
574 |
x['results'] = nf, nm, ne, nc, len(self.im_files)
|
575 |
x['msgs'] = msgs # warnings
|
@@ -831,7 +804,7 @@ class LoadImagesAndLabels(Dataset):
|
|
831 |
|
832 |
@staticmethod
|
833 |
def collate_fn4(batch):
|
834 |
-
|
835 |
n = len(shapes) // 4
|
836 |
im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
|
837 |
|
@@ -841,13 +814,13 @@ class LoadImagesAndLabels(Dataset):
|
|
841 |
for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
|
842 |
i *= 4
|
843 |
if random.random() < 0.5:
|
844 |
-
|
845 |
-
|
846 |
lb = label[i]
|
847 |
else:
|
848 |
-
|
849 |
lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
|
850 |
-
im4.append(
|
851 |
label4.append(lb)
|
852 |
|
853 |
for i, lb in enumerate(label4):
|
@@ -870,7 +843,7 @@ def flatten_recursive(path=DATASETS_DIR / 'coco128'):
|
|
870 |
def extract_boxes(path=DATASETS_DIR / 'coco128'): # from utils.dataloaders import *; extract_boxes()
|
871 |
# Convert detection dataset into classification dataset, with one directory per class
|
872 |
path = Path(path) # images dir
|
873 |
-
shutil.rmtree(path / '
|
874 |
files = list(path.rglob('*.*'))
|
875 |
n = len(files) # number of files
|
876 |
for im_file in tqdm(files, total=n):
|
@@ -916,7 +889,9 @@ def autosplit(path=DATASETS_DIR / 'coco128/images', weights=(0.9, 0.1, 0.0), ann
|
|
916 |
indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
|
917 |
|
918 |
txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
|
919 |
-
|
|
|
|
|
920 |
|
921 |
print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
|
922 |
for i, img in tqdm(zip(indices, files), total=n):
|
@@ -962,7 +937,7 @@ def verify_image_label(args):
|
|
962 |
if len(i) < nl: # duplicate row check
|
963 |
lb = lb[i] # remove duplicates
|
964 |
if segments:
|
965 |
-
segments = segments[i]
|
966 |
msg = f'{prefix}WARNING: {im_file}: {nl - len(i)} duplicate labels removed'
|
967 |
else:
|
968 |
ne = 1 # label empty
|
@@ -1002,7 +977,7 @@ class HUBDatasetStats():
|
|
1002 |
self.hub_dir = Path(data['path'] + '-hub')
|
1003 |
self.im_dir = self.hub_dir / 'images'
|
1004 |
self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
|
1005 |
-
self.stats = {'nc': data['nc'], 'names': data['names']} # statistics dictionary
|
1006 |
self.data = data
|
1007 |
|
1008 |
@staticmethod
|
@@ -1090,3 +1065,65 @@ class HUBDatasetStats():
|
|
1090 |
pass
|
1091 |
print(f'Done. All images saved to {self.im_dir}')
|
1092 |
return self.im_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
import numpy as np
|
23 |
import torch
|
24 |
import torch.nn.functional as F
|
25 |
+
import torchvision
|
26 |
import yaml
|
27 |
from PIL import ExifTags, Image, ImageOps
|
28 |
from torch.utils.data import DataLoader, Dataset, dataloader, distributed
|
29 |
from tqdm import tqdm
|
30 |
|
31 |
+
from utils.augmentations import (Albumentations, augment_hsv, classify_albumentations, classify_transforms, copy_paste,
|
32 |
+
letterbox, mixup, random_perspective)
|
33 |
from utils.general import (DATASETS_DIR, LOGGER, NUM_THREADS, check_dataset, check_requirements, check_yaml, clean_str,
|
34 |
cv2, is_colab, is_kaggle, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn)
|
35 |
from utils.torch_utils import torch_distributed_zero_first
|
36 |
|
37 |
# Parameters
|
38 |
+
HELP_URL = 'See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
|
39 |
+
IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # include image suffixes
|
40 |
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
|
41 |
BAR_FORMAT = '{l_bar}{bar:10}{r_bar}{bar:-10b}' # tqdm bar format
|
42 |
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
43 |
+
PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders
|
44 |
|
45 |
# Get orientation exif tag
|
46 |
for orientation in ExifTags.TAGS.keys():
|
|
|
84 |
5: Image.TRANSPOSE,
|
85 |
6: Image.ROTATE_270,
|
86 |
7: Image.TRANSVERSE,
|
87 |
+
8: Image.ROTATE_90}.get(orientation)
|
88 |
if method is not None:
|
89 |
image = image.transpose(method)
|
90 |
del exif[0x0112]
|
|
|
145 |
shuffle=shuffle and sampler is None,
|
146 |
num_workers=nw,
|
147 |
sampler=sampler,
|
148 |
+
pin_memory=PIN_MEMORY,
|
149 |
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn,
|
150 |
worker_init_fn=seed_worker,
|
151 |
generator=generator), dataset
|
|
|
187 |
|
188 |
class LoadImages:
|
189 |
# YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
|
190 |
+
def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
|
191 |
files = []
|
192 |
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
|
193 |
p = str(Path(p).resolve())
|
|
|
211 |
self.video_flag = [False] * ni + [True] * nv
|
212 |
self.mode = 'image'
|
213 |
self.auto = auto
|
214 |
+
self.transforms = transforms # optional
|
215 |
+
self.vid_stride = vid_stride # video frame-rate stride
|
216 |
if any(videos):
|
217 |
+
self._new_video(videos[0]) # new video
|
218 |
else:
|
219 |
self.cap = None
|
220 |
assert self.nf > 0, f'No images or videos found in {p}. ' \
|
|
|
232 |
if self.video_flag[self.count]:
|
233 |
# Read video
|
234 |
self.mode = 'video'
|
235 |
+
ret_val, im0 = self.cap.read()
|
236 |
+
self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.vid_stride * (self.frame + 1)) # read at vid_stride
|
237 |
while not ret_val:
|
238 |
self.count += 1
|
239 |
self.cap.release()
|
240 |
if self.count == self.nf: # last video
|
241 |
raise StopIteration
|
242 |
path = self.files[self.count]
|
243 |
+
self._new_video(path)
|
244 |
+
ret_val, im0 = self.cap.read()
|
245 |
|
246 |
self.frame += 1
|
247 |
+
# im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
|
248 |
s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
|
249 |
|
250 |
else:
|
251 |
# Read image
|
252 |
self.count += 1
|
253 |
+
im0 = cv2.imread(path) # BGR
|
254 |
+
assert im0 is not None, f'Image Not Found {path}'
|
255 |
s = f'image {self.count}/{self.nf} {path}: '
|
256 |
|
257 |
+
if self.transforms:
|
258 |
+
im = self.transforms(im0) # transforms
|
259 |
+
else:
|
260 |
+
im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
|
261 |
+
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
262 |
+
im = np.ascontiguousarray(im) # contiguous
|
263 |
|
264 |
+
return path, im, im0, self.cap, s
|
265 |
|
266 |
+
def _new_video(self, path):
|
267 |
+
# Create a new video capture object
|
268 |
self.frame = 0
|
269 |
self.cap = cv2.VideoCapture(path)
|
270 |
+
self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
|
271 |
+
self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META)) # rotation degrees
|
272 |
+
# self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0) # disable https://github.com/ultralytics/yolov5/issues/8493
|
273 |
+
|
274 |
+
def _cv2_rotate(self, im):
|
275 |
+
# Rotate a cv2 video manually
|
276 |
+
if self.orientation == 0:
|
277 |
+
return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)
|
278 |
+
elif self.orientation == 180:
|
279 |
+
return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
280 |
+
elif self.orientation == 90:
|
281 |
+
return cv2.rotate(im, cv2.ROTATE_180)
|
282 |
+
return im
|
283 |
|
284 |
def __len__(self):
|
285 |
return self.nf # number of files
|
286 |
|
287 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
class LoadStreams:
|
289 |
# YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
|
290 |
+
def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
|
291 |
+
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
|
292 |
self.mode = 'stream'
|
293 |
self.img_size = img_size
|
294 |
self.stride = stride
|
295 |
+
self.vid_stride = vid_stride # video frame-rate stride
|
296 |
+
sources = Path(sources).read_text().rsplit() if Path(sources).is_file() else [sources]
|
|
|
|
|
|
|
|
|
|
|
297 |
n = len(sources)
|
|
|
298 |
self.sources = [clean_str(x) for x in sources] # clean source names for later
|
299 |
+
self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
|
300 |
for i, s in enumerate(sources): # index, source
|
301 |
# Start thread to read frames from video stream
|
302 |
st = f'{i + 1}/{n}: {s}... '
|
|
|
323 |
LOGGER.info('') # newline
|
324 |
|
325 |
# check for common shapes
|
326 |
+
s = np.stack([letterbox(x, img_size, stride=stride, auto=auto)[0].shape for x in self.imgs])
|
327 |
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
|
328 |
+
self.auto = auto and self.rect
|
329 |
+
self.transforms = transforms # optional
|
330 |
if not self.rect:
|
331 |
LOGGER.warning('WARNING: Stream shapes differ. For optimal performance supply similarly-shaped streams.')
|
332 |
|
333 |
def update(self, i, cap, stream):
|
334 |
# Read stream `i` frames in daemon thread
|
335 |
+
n, f = 0, self.frames[i] # frame number, frame array
|
336 |
while cap.isOpened() and n < f:
|
337 |
n += 1
|
338 |
+
cap.grab() # .read() = .grab() followed by .retrieve()
|
339 |
+
if n % self.vid_stride == 0:
|
|
|
340 |
success, im = cap.retrieve()
|
341 |
if success:
|
342 |
self.imgs[i] = im
|
|
|
356 |
cv2.destroyAllWindows()
|
357 |
raise StopIteration
|
358 |
|
359 |
+
im0 = self.imgs.copy()
|
360 |
+
if self.transforms:
|
361 |
+
im = np.stack([self.transforms(x) for x in im0]) # transforms
|
362 |
+
else:
|
363 |
+
im = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0] for x in im0]) # resize
|
364 |
+
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
|
365 |
+
im = np.ascontiguousarray(im) # contiguous
|
|
|
|
|
|
|
366 |
|
367 |
+
return self.sources, im, im0, None, ''
|
368 |
|
369 |
def __len__(self):
|
370 |
return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
|
|
|
424 |
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
|
425 |
assert self.im_files, f'{prefix}No images found'
|
426 |
except Exception as e:
|
427 |
+
raise Exception(f'{prefix}Error loading data from {path}: {e}\n{HELP_URL}')
|
428 |
|
429 |
# Check cache
|
430 |
self.label_files = img2label_paths(self.im_files) # labels
|
|
|
443 |
tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=BAR_FORMAT) # display cache results
|
444 |
if cache['msgs']:
|
445 |
LOGGER.info('\n'.join(cache['msgs'])) # display warnings
|
446 |
+
assert nf > 0 or not augment, f'{prefix}No labels found in {cache_path}, can not start training. {HELP_URL}'
|
447 |
|
448 |
# Read cache
|
449 |
[cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
|
450 |
labels, shapes, self.segments = zip(*cache.values())
|
451 |
+
nl = len(np.concatenate(labels, 0)) # number of labels
|
452 |
+
assert nl > 0 or not augment, f'{prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}'
|
453 |
self.labels = list(labels)
|
454 |
self.shapes = np.array(shapes)
|
455 |
self.im_files = list(cache.keys()) # update
|
|
|
542 |
if msgs:
|
543 |
LOGGER.info('\n'.join(msgs))
|
544 |
if nf == 0:
|
545 |
+
LOGGER.warning(f'{prefix}WARNING: No labels found in {path}. {HELP_URL}')
|
546 |
x['hash'] = get_hash(self.label_files + self.im_files)
|
547 |
x['results'] = nf, nm, ne, nc, len(self.im_files)
|
548 |
x['msgs'] = msgs # warnings
|
|
|
804 |
|
805 |
@staticmethod
|
806 |
def collate_fn4(batch):
|
807 |
+
im, label, path, shapes = zip(*batch) # transposed
|
808 |
n = len(shapes) // 4
|
809 |
im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
|
810 |
|
|
|
814 |
for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
|
815 |
i *= 4
|
816 |
if random.random() < 0.5:
|
817 |
+
im1 = F.interpolate(im[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear',
|
818 |
+
align_corners=False)[0].type(im[i].type())
|
819 |
lb = label[i]
|
820 |
else:
|
821 |
+
im1 = torch.cat((torch.cat((im[i], im[i + 1]), 1), torch.cat((im[i + 2], im[i + 3]), 1)), 2)
|
822 |
lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
|
823 |
+
im4.append(im1)
|
824 |
label4.append(lb)
|
825 |
|
826 |
for i, lb in enumerate(label4):
|
|
|
843 |
def extract_boxes(path=DATASETS_DIR / 'coco128'): # from utils.dataloaders import *; extract_boxes()
|
844 |
# Convert detection dataset into classification dataset, with one directory per class
|
845 |
path = Path(path) # images dir
|
846 |
+
shutil.rmtree(path / 'classification') if (path / 'classification').is_dir() else None # remove existing
|
847 |
files = list(path.rglob('*.*'))
|
848 |
n = len(files) # number of files
|
849 |
for im_file in tqdm(files, total=n):
|
|
|
889 |
indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
|
890 |
|
891 |
txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
|
892 |
+
for x in txt:
|
893 |
+
if (path.parent / x).exists():
|
894 |
+
(path.parent / x).unlink() # remove existing
|
895 |
|
896 |
print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
|
897 |
for i, img in tqdm(zip(indices, files), total=n):
|
|
|
937 |
if len(i) < nl: # duplicate row check
|
938 |
lb = lb[i] # remove duplicates
|
939 |
if segments:
|
940 |
+
segments = [segments[x] for x in i]
|
941 |
msg = f'{prefix}WARNING: {im_file}: {nl - len(i)} duplicate labels removed'
|
942 |
else:
|
943 |
ne = 1 # label empty
|
|
|
977 |
self.hub_dir = Path(data['path'] + '-hub')
|
978 |
self.im_dir = self.hub_dir / 'images'
|
979 |
self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
|
980 |
+
self.stats = {'nc': data['nc'], 'names': list(data['names'].values())} # statistics dictionary
|
981 |
self.data = data
|
982 |
|
983 |
@staticmethod
|
|
|
1065 |
pass
|
1066 |
print(f'Done. All images saved to {self.im_dir}')
|
1067 |
return self.im_dir
|
1068 |
+
|
1069 |
+
|
1070 |
+
# Classification dataloaders -------------------------------------------------------------------------------------------
|
1071 |
+
class ClassificationDataset(torchvision.datasets.ImageFolder):
|
1072 |
+
"""
|
1073 |
+
YOLOv5 Classification Dataset.
|
1074 |
+
Arguments
|
1075 |
+
root: Dataset path
|
1076 |
+
transform: torchvision transforms, used by default
|
1077 |
+
album_transform: Albumentations transforms, used if installed
|
1078 |
+
"""
|
1079 |
+
|
1080 |
+
def __init__(self, root, augment, imgsz, cache=False):
|
1081 |
+
super().__init__(root=root)
|
1082 |
+
self.torch_transforms = classify_transforms(imgsz)
|
1083 |
+
self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
|
1084 |
+
self.cache_ram = cache is True or cache == 'ram'
|
1085 |
+
self.cache_disk = cache == 'disk'
|
1086 |
+
self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
|
1087 |
+
|
1088 |
+
def __getitem__(self, i):
|
1089 |
+
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
|
1090 |
+
if self.cache_ram and im is None:
|
1091 |
+
im = self.samples[i][3] = cv2.imread(f)
|
1092 |
+
elif self.cache_disk:
|
1093 |
+
if not fn.exists(): # load npy
|
1094 |
+
np.save(fn.as_posix(), cv2.imread(f))
|
1095 |
+
im = np.load(fn)
|
1096 |
+
else: # read image
|
1097 |
+
im = cv2.imread(f) # BGR
|
1098 |
+
if self.album_transforms:
|
1099 |
+
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]
|
1100 |
+
else:
|
1101 |
+
sample = self.torch_transforms(im)
|
1102 |
+
return sample, j
|
1103 |
+
|
1104 |
+
|
1105 |
+
def create_classification_dataloader(path,
|
1106 |
+
imgsz=224,
|
1107 |
+
batch_size=16,
|
1108 |
+
augment=True,
|
1109 |
+
cache=False,
|
1110 |
+
rank=-1,
|
1111 |
+
workers=8,
|
1112 |
+
shuffle=True):
|
1113 |
+
# Returns Dataloader object to be used with YOLOv5 Classifier
|
1114 |
+
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
1115 |
+
dataset = ClassificationDataset(root=path, imgsz=imgsz, augment=augment, cache=cache)
|
1116 |
+
batch_size = min(batch_size, len(dataset))
|
1117 |
+
nd = torch.cuda.device_count()
|
1118 |
+
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])
|
1119 |
+
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
1120 |
+
generator = torch.Generator()
|
1121 |
+
generator.manual_seed(0)
|
1122 |
+
return InfiniteDataLoader(dataset,
|
1123 |
+
batch_size=batch_size,
|
1124 |
+
shuffle=shuffle and sampler is None,
|
1125 |
+
num_workers=nw,
|
1126 |
+
sampler=sampler,
|
1127 |
+
pin_memory=PIN_MEMORY,
|
1128 |
+
worker_init_fn=seed_worker,
|
1129 |
+
generator=generator) # or DataLoader(persistent_workers=True)
|
utils/downloads.py
CHANGED
@@ -33,6 +33,12 @@ def gsutil_getsize(url=''):
|
|
33 |
return eval(s.split(' ')[0]) if len(s) else 0 # bytes
|
34 |
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
|
37 |
# Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
|
38 |
from utils.general import LOGGER
|
@@ -44,24 +50,26 @@ def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
|
|
44 |
torch.hub.download_url_to_file(url, str(file), progress=LOGGER.level <= logging.INFO)
|
45 |
assert file.exists() and file.stat().st_size > min_bytes, assert_msg # check
|
46 |
except Exception as e: # url2
|
47 |
-
file.
|
|
|
48 |
LOGGER.info(f'ERROR: {e}\nRe-attempting {url2 or url} to {file}...')
|
49 |
-
os.system(f"curl -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
|
50 |
finally:
|
51 |
if not file.exists() or file.stat().st_size < min_bytes: # check
|
52 |
-
file.
|
|
|
53 |
LOGGER.info(f"ERROR: {assert_msg}\n{error_msg}")
|
54 |
LOGGER.info('')
|
55 |
|
56 |
|
57 |
-
def attempt_download(file, repo='ultralytics/yolov5', release='v6.
|
58 |
-
# Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.
|
59 |
from utils.general import LOGGER
|
60 |
|
61 |
def github_assets(repository, version='latest'):
|
62 |
-
# Return GitHub repo tag (i.e. 'v6.
|
63 |
if version != 'latest':
|
64 |
-
version = f'tags/{version}' # i.e. tags/v6.
|
65 |
response = requests.get(f'https://api.github.com/repos/{repository}/releases/{version}').json() # github api
|
66 |
return response['tag_name'], [x['name'] for x in response['assets']] # tag, assets
|
67 |
|
@@ -112,8 +120,10 @@ def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', file='tmp.zip'):
|
|
112 |
file = Path(file)
|
113 |
cookie = Path('cookie') # gdrive cookie
|
114 |
print(f'Downloading https://drive.google.com/uc?export=download&id={id} as {file}... ', end='')
|
115 |
-
file.
|
116 |
-
|
|
|
|
|
117 |
|
118 |
# Attempt file download
|
119 |
out = "NUL" if platform.system() == "Windows" else "/dev/null"
|
@@ -123,11 +133,13 @@ def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', file='tmp.zip'):
|
|
123 |
else: # small file
|
124 |
s = f'curl -s -L -o {file} "drive.google.com/uc?export=download&id={id}"'
|
125 |
r = os.system(s) # execute, capture return
|
126 |
-
cookie.
|
|
|
127 |
|
128 |
# Error check
|
129 |
if r != 0:
|
130 |
-
file.
|
|
|
131 |
print('Download error ') # raise Exception('Download error')
|
132 |
return r
|
133 |
|
|
|
33 |
return eval(s.split(' ')[0]) if len(s) else 0 # bytes
|
34 |
|
35 |
|
36 |
+
def url_getsize(url='https://ultralytics.com/images/bus.jpg'):
|
37 |
+
# Return downloadable file size in bytes
|
38 |
+
response = requests.head(url, allow_redirects=True)
|
39 |
+
return int(response.headers.get('content-length', -1))
|
40 |
+
|
41 |
+
|
42 |
def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
|
43 |
# Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
|
44 |
from utils.general import LOGGER
|
|
|
50 |
torch.hub.download_url_to_file(url, str(file), progress=LOGGER.level <= logging.INFO)
|
51 |
assert file.exists() and file.stat().st_size > min_bytes, assert_msg # check
|
52 |
except Exception as e: # url2
|
53 |
+
if file.exists():
|
54 |
+
file.unlink() # remove partial downloads
|
55 |
LOGGER.info(f'ERROR: {e}\nRe-attempting {url2 or url} to {file}...')
|
56 |
+
os.system(f"curl -# -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
|
57 |
finally:
|
58 |
if not file.exists() or file.stat().st_size < min_bytes: # check
|
59 |
+
if file.exists():
|
60 |
+
file.unlink() # remove partial downloads
|
61 |
LOGGER.info(f"ERROR: {assert_msg}\n{error_msg}")
|
62 |
LOGGER.info('')
|
63 |
|
64 |
|
65 |
+
def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'):
|
66 |
+
# Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
|
67 |
from utils.general import LOGGER
|
68 |
|
69 |
def github_assets(repository, version='latest'):
|
70 |
+
# Return GitHub repo tag (i.e. 'v6.2') and assets (i.e. ['yolov5s.pt', 'yolov5m.pt', ...])
|
71 |
if version != 'latest':
|
72 |
+
version = f'tags/{version}' # i.e. tags/v6.2
|
73 |
response = requests.get(f'https://api.github.com/repos/{repository}/releases/{version}').json() # github api
|
74 |
return response['tag_name'], [x['name'] for x in response['assets']] # tag, assets
|
75 |
|
|
|
120 |
file = Path(file)
|
121 |
cookie = Path('cookie') # gdrive cookie
|
122 |
print(f'Downloading https://drive.google.com/uc?export=download&id={id} as {file}... ', end='')
|
123 |
+
if file.exists():
|
124 |
+
file.unlink() # remove existing file
|
125 |
+
if cookie.exists():
|
126 |
+
cookie.unlink() # remove existing cookie
|
127 |
|
128 |
# Attempt file download
|
129 |
out = "NUL" if platform.system() == "Windows" else "/dev/null"
|
|
|
133 |
else: # small file
|
134 |
s = f'curl -s -L -o {file} "drive.google.com/uc?export=download&id={id}"'
|
135 |
r = os.system(s) # execute, capture return
|
136 |
+
if cookie.exists():
|
137 |
+
cookie.unlink() # remove existing cookie
|
138 |
|
139 |
# Error check
|
140 |
if r != 0:
|
141 |
+
if file.exists():
|
142 |
+
file.unlink() # remove partial
|
143 |
print('Download error ') # raise Exception('Download error')
|
144 |
return r
|
145 |
|
utils/general.py
CHANGED
@@ -15,7 +15,6 @@ import re
|
|
15 |
import shutil
|
16 |
import signal
|
17 |
import sys
|
18 |
-
import threading
|
19 |
import time
|
20 |
import urllib
|
21 |
from datetime import datetime
|
@@ -34,6 +33,7 @@ import torch
|
|
34 |
import torchvision
|
35 |
import yaml
|
36 |
|
|
|
37 |
from utils.downloads import gsutil_getsize
|
38 |
from utils.metrics import box_iou, fitness
|
39 |
|
@@ -56,13 +56,35 @@ os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
|
|
56 |
os.environ['OMP_NUM_THREADS'] = '1' if platform.system() == 'darwin' else str(NUM_THREADS) # OpenMP (PyTorch and SciPy)
|
57 |
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
def is_kaggle():
|
60 |
# Is environment a Kaggle Notebook?
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
64 |
return True
|
65 |
-
|
|
|
|
|
|
|
66 |
return False
|
67 |
|
68 |
|
@@ -82,7 +104,7 @@ def is_writeable(dir, test=False):
|
|
82 |
|
83 |
def set_logging(name=None, verbose=VERBOSE):
|
84 |
# Sets level and returns logger
|
85 |
-
if is_kaggle():
|
86 |
for h in logging.root.handlers:
|
87 |
logging.root.removeHandler(h) # remove all handlers associated with the root logger object
|
88 |
rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
|
@@ -119,16 +141,27 @@ CONFIG_DIR = user_config_dir() # Ultralytics settings dir
|
|
119 |
|
120 |
|
121 |
class Profile(contextlib.ContextDecorator):
|
122 |
-
# Usage: @Profile() decorator or 'with Profile():' context manager
|
|
|
|
|
|
|
|
|
123 |
def __enter__(self):
|
124 |
-
self.start =
|
|
|
125 |
|
126 |
def __exit__(self, type, value, traceback):
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
|
130 |
class Timeout(contextlib.ContextDecorator):
|
131 |
-
# Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
|
132 |
def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
|
133 |
self.seconds = int(seconds)
|
134 |
self.timeout_message = timeout_msg
|
@@ -162,64 +195,50 @@ class WorkingDirectory(contextlib.ContextDecorator):
|
|
162 |
os.chdir(self.cwd)
|
163 |
|
164 |
|
165 |
-
def try_except(func):
|
166 |
-
# try-except function. Usage: @try_except decorator
|
167 |
-
def handler(*args, **kwargs):
|
168 |
-
try:
|
169 |
-
func(*args, **kwargs)
|
170 |
-
except Exception as e:
|
171 |
-
print(e)
|
172 |
-
|
173 |
-
return handler
|
174 |
-
|
175 |
-
|
176 |
-
def threaded(func):
|
177 |
-
# Multi-threads a target function and returns thread. Usage: @threaded decorator
|
178 |
-
def wrapper(*args, **kwargs):
|
179 |
-
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
|
180 |
-
thread.start()
|
181 |
-
return thread
|
182 |
-
|
183 |
-
return wrapper
|
184 |
-
|
185 |
-
|
186 |
def methods(instance):
|
187 |
# Get class/instance methods
|
188 |
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
|
189 |
|
190 |
|
191 |
-
def print_args(args: Optional[dict] = None, show_file=True,
|
192 |
# Print function arguments (optional args dict)
|
193 |
x = inspect.currentframe().f_back # previous frame
|
194 |
-
file, _,
|
195 |
if args is None: # get args automatically
|
196 |
args, _, _, frm = inspect.getargvalues(x)
|
197 |
args = {k: v for k, v in frm.items() if k in args}
|
198 |
-
|
|
|
|
|
|
|
|
|
199 |
LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
|
200 |
|
201 |
|
202 |
def init_seeds(seed=0, deterministic=False):
|
203 |
# Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
|
204 |
-
# cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
|
205 |
-
import torch.backends.cudnn as cudnn
|
206 |
-
|
207 |
-
if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213
|
208 |
-
torch.use_deterministic_algorithms(True)
|
209 |
-
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
210 |
-
os.environ['PYTHONHASHSEED'] = str(seed)
|
211 |
-
|
212 |
random.seed(seed)
|
213 |
np.random.seed(seed)
|
214 |
torch.manual_seed(seed)
|
215 |
-
cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
|
216 |
torch.cuda.manual_seed(seed)
|
217 |
torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
|
219 |
|
220 |
def intersect_dicts(da, db, exclude=()):
|
221 |
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
|
222 |
-
return {k: v for k, v in da.items() if k in db and
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
|
224 |
|
225 |
def get_latest_run(search_dir='.'):
|
@@ -228,42 +247,6 @@ def get_latest_run(search_dir='.'):
|
|
228 |
return max(last_list, key=os.path.getctime) if last_list else ''
|
229 |
|
230 |
|
231 |
-
def is_docker() -> bool:
|
232 |
-
"""Check if the process runs inside a docker container."""
|
233 |
-
if Path("/.dockerenv").exists():
|
234 |
-
return True
|
235 |
-
try: # check if docker is in control groups
|
236 |
-
with open("/proc/self/cgroup") as file:
|
237 |
-
return any("docker" in line for line in file)
|
238 |
-
except OSError:
|
239 |
-
return False
|
240 |
-
|
241 |
-
|
242 |
-
def is_colab():
|
243 |
-
# Is environment a Google Colab instance?
|
244 |
-
try:
|
245 |
-
import google.colab
|
246 |
-
return True
|
247 |
-
except ImportError:
|
248 |
-
return False
|
249 |
-
|
250 |
-
|
251 |
-
def is_pip():
|
252 |
-
# Is file in a pip package?
|
253 |
-
return 'site-packages' in Path(__file__).resolve().parts
|
254 |
-
|
255 |
-
|
256 |
-
def is_ascii(s=''):
|
257 |
-
# Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
|
258 |
-
s = str(s) # convert list, tuple, None, etc. to str
|
259 |
-
return len(s.encode().decode('ascii', 'ignore')) == len(s)
|
260 |
-
|
261 |
-
|
262 |
-
def is_chinese(s='人工智能'):
|
263 |
-
# Is string composed of any Chinese characters?
|
264 |
-
return bool(re.search('[\u4e00-\u9fff]', str(s)))
|
265 |
-
|
266 |
-
|
267 |
def emojis(str=''):
|
268 |
# Return platform-dependent emoji-safe version of string
|
269 |
return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
|
@@ -312,9 +295,9 @@ def git_describe(path=ROOT): # path must be a directory
|
|
312 |
return ''
|
313 |
|
314 |
|
315 |
-
@
|
316 |
@WorkingDirectory(ROOT)
|
317 |
-
def check_git_status(repo='ultralytics/yolov5'):
|
318 |
# YOLOv5 status check, recommend 'git pull' if code is out of date
|
319 |
url = f'https://github.com/{repo}'
|
320 |
msg = f', for updates see {url}'
|
@@ -330,10 +313,10 @@ def check_git_status(repo='ultralytics/yolov5'):
|
|
330 |
remote = 'ultralytics'
|
331 |
check_output(f'git remote add {remote} {url}', shell=True)
|
332 |
check_output(f'git fetch {remote}', shell=True, timeout=5) # git fetch
|
333 |
-
|
334 |
-
n = int(check_output(f'git rev-list {
|
335 |
if n > 0:
|
336 |
-
pull = 'git pull' if remote == 'origin' else f'git pull {remote}
|
337 |
s += f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `{pull}` or `git clone {url}` to update."
|
338 |
else:
|
339 |
s += f'up to date with {url} ✅'
|
@@ -349,17 +332,17 @@ def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=Fals
|
|
349 |
# Check version vs. required version
|
350 |
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
|
351 |
result = (current == minimum) if pinned else (current >= minimum) # bool
|
352 |
-
s = f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed' # string
|
353 |
if hard:
|
354 |
-
assert result, s # assert min requirements met
|
355 |
if verbose and not result:
|
356 |
LOGGER.warning(s)
|
357 |
return result
|
358 |
|
359 |
|
360 |
-
@
|
361 |
def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()):
|
362 |
-
# Check installed dependencies meet requirements (pass *.txt file or list of packages)
|
363 |
prefix = colorstr('red', 'bold', 'requirements:')
|
364 |
check_python() # check python version
|
365 |
if isinstance(requirements, (str, Path)): # requirements.txt file
|
@@ -470,7 +453,7 @@ def check_font(font=FONT, progress=False):
|
|
470 |
font = Path(font)
|
471 |
file = CONFIG_DIR / font.name
|
472 |
if not font.exists() and not file.exists():
|
473 |
-
url =
|
474 |
LOGGER.info(f'Downloading {url} to {file}...')
|
475 |
torch.hub.download_url_to_file(url, str(file), progress=progress)
|
476 |
|
@@ -491,11 +474,11 @@ def check_dataset(data, autodownload=True):
|
|
491 |
data = yaml.safe_load(f) # dictionary
|
492 |
|
493 |
# Checks
|
494 |
-
for k in 'train', 'val', '
|
495 |
assert k in data, f"data.yaml '{k}:' field missing ❌"
|
496 |
-
if 'names'
|
497 |
-
|
498 |
-
|
499 |
|
500 |
# Resolve paths
|
501 |
path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.'
|
@@ -549,8 +532,8 @@ def check_amp(model):
|
|
549 |
|
550 |
prefix = colorstr('AMP: ')
|
551 |
device = next(model.parameters()).device # get model device
|
552 |
-
if device.type
|
553 |
-
return False # AMP
|
554 |
f = ROOT / 'data' / 'images' / 'bus.jpg' # image to check
|
555 |
im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if check_online() else np.ones((640, 640, 3))
|
556 |
try:
|
@@ -563,6 +546,18 @@ def check_amp(model):
|
|
563 |
return False
|
564 |
|
565 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
566 |
def url2file(url):
|
567 |
# Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
|
568 |
url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
|
@@ -570,7 +565,7 @@ def url2file(url):
|
|
570 |
|
571 |
|
572 |
def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
|
573 |
-
#
|
574 |
def download_one(url, dir):
|
575 |
# Download 1 file
|
576 |
success = True
|
@@ -582,7 +577,8 @@ def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry
|
|
582 |
for i in range(retry + 1):
|
583 |
if curl:
|
584 |
s = 'sS' if threads > 1 else '' # silent
|
585 |
-
r = os.system(
|
|
|
586 |
success = r == 0
|
587 |
else:
|
588 |
torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
|
@@ -594,10 +590,12 @@ def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry
|
|
594 |
else:
|
595 |
LOGGER.warning(f'Failed to download {url}...')
|
596 |
|
597 |
-
if unzip and success and f.suffix in ('.zip', '.gz'):
|
598 |
LOGGER.info(f'Unzipping {f}...')
|
599 |
if f.suffix == '.zip':
|
600 |
ZipFile(f).extractall(path=dir) # unzip
|
|
|
|
|
601 |
elif f.suffix == '.gz':
|
602 |
os.system(f'tar xfz {f} --directory {f.parent}') # unzip
|
603 |
if delete:
|
@@ -607,7 +605,7 @@ def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry
|
|
607 |
dir.mkdir(parents=True, exist_ok=True) # make directory
|
608 |
if threads > 1:
|
609 |
pool = ThreadPool(threads)
|
610 |
-
pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) #
|
611 |
pool.close()
|
612 |
pool.join()
|
613 |
else:
|
@@ -815,6 +813,9 @@ def non_max_suppression(prediction,
|
|
815 |
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
|
816 |
"""
|
817 |
|
|
|
|
|
|
|
818 |
bs = prediction.shape[0] # batch size
|
819 |
nc = prediction.shape[2] - 5 # number of classes
|
820 |
xc = prediction[..., 4] > conf_thres # candidates
|
|
|
15 |
import shutil
|
16 |
import signal
|
17 |
import sys
|
|
|
18 |
import time
|
19 |
import urllib
|
20 |
from datetime import datetime
|
|
|
33 |
import torchvision
|
34 |
import yaml
|
35 |
|
36 |
+
from utils import TryExcept
|
37 |
from utils.downloads import gsutil_getsize
|
38 |
from utils.metrics import box_iou, fitness
|
39 |
|
|
|
56 |
os.environ['OMP_NUM_THREADS'] = '1' if platform.system() == 'darwin' else str(NUM_THREADS) # OpenMP (PyTorch and SciPy)
|
57 |
|
58 |
|
59 |
+
def is_ascii(s=''):
|
60 |
+
# Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
|
61 |
+
s = str(s) # convert list, tuple, None, etc. to str
|
62 |
+
return len(s.encode().decode('ascii', 'ignore')) == len(s)
|
63 |
+
|
64 |
+
|
65 |
+
def is_chinese(s='人工智能'):
|
66 |
+
# Is string composed of any Chinese characters?
|
67 |
+
return bool(re.search('[\u4e00-\u9fff]', str(s)))
|
68 |
+
|
69 |
+
|
70 |
+
def is_colab():
|
71 |
+
# Is environment a Google Colab instance?
|
72 |
+
return 'COLAB_GPU' in os.environ
|
73 |
+
|
74 |
+
|
75 |
def is_kaggle():
|
76 |
# Is environment a Kaggle Notebook?
|
77 |
+
return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
|
78 |
+
|
79 |
+
|
80 |
+
def is_docker() -> bool:
|
81 |
+
"""Check if the process runs inside a docker container."""
|
82 |
+
if Path("/.dockerenv").exists():
|
83 |
return True
|
84 |
+
try: # check if docker is in control groups
|
85 |
+
with open("/proc/self/cgroup") as file:
|
86 |
+
return any("docker" in line for line in file)
|
87 |
+
except OSError:
|
88 |
return False
|
89 |
|
90 |
|
|
|
104 |
|
105 |
def set_logging(name=None, verbose=VERBOSE):
|
106 |
# Sets level and returns logger
|
107 |
+
if is_kaggle() or is_colab():
|
108 |
for h in logging.root.handlers:
|
109 |
logging.root.removeHandler(h) # remove all handlers associated with the root logger object
|
110 |
rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
|
|
|
141 |
|
142 |
|
143 |
class Profile(contextlib.ContextDecorator):
|
144 |
+
# YOLOv5 Profile class. Usage: @Profile() decorator or 'with Profile():' context manager
|
145 |
+
def __init__(self, t=0.0):
|
146 |
+
self.t = t
|
147 |
+
self.cuda = torch.cuda.is_available()
|
148 |
+
|
149 |
def __enter__(self):
|
150 |
+
self.start = self.time()
|
151 |
+
return self
|
152 |
|
153 |
def __exit__(self, type, value, traceback):
|
154 |
+
self.dt = self.time() - self.start # delta-time
|
155 |
+
self.t += self.dt # accumulate dt
|
156 |
+
|
157 |
+
def time(self):
|
158 |
+
if self.cuda:
|
159 |
+
torch.cuda.synchronize()
|
160 |
+
return time.time()
|
161 |
|
162 |
|
163 |
class Timeout(contextlib.ContextDecorator):
|
164 |
+
# YOLOv5 Timeout class. Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
|
165 |
def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
|
166 |
self.seconds = int(seconds)
|
167 |
self.timeout_message = timeout_msg
|
|
|
195 |
os.chdir(self.cwd)
|
196 |
|
197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
def methods(instance):
|
199 |
# Get class/instance methods
|
200 |
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
|
201 |
|
202 |
|
203 |
+
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
|
204 |
# Print function arguments (optional args dict)
|
205 |
x = inspect.currentframe().f_back # previous frame
|
206 |
+
file, _, func, _, _ = inspect.getframeinfo(x)
|
207 |
if args is None: # get args automatically
|
208 |
args, _, _, frm = inspect.getargvalues(x)
|
209 |
args = {k: v for k, v in frm.items() if k in args}
|
210 |
+
try:
|
211 |
+
file = Path(file).resolve().relative_to(ROOT).with_suffix('')
|
212 |
+
except ValueError:
|
213 |
+
file = Path(file).stem
|
214 |
+
s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '')
|
215 |
LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
|
216 |
|
217 |
|
218 |
def init_seeds(seed=0, deterministic=False):
|
219 |
# Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
random.seed(seed)
|
221 |
np.random.seed(seed)
|
222 |
torch.manual_seed(seed)
|
|
|
223 |
torch.cuda.manual_seed(seed)
|
224 |
torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
|
225 |
+
torch.backends.cudnn.benchmark = True # for faster training
|
226 |
+
if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213
|
227 |
+
torch.use_deterministic_algorithms(True)
|
228 |
+
torch.backends.cudnn.deterministic = True
|
229 |
+
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
230 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
231 |
|
232 |
|
233 |
def intersect_dicts(da, db, exclude=()):
|
234 |
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
|
235 |
+
return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
|
236 |
+
|
237 |
+
|
238 |
+
def get_default_args(func):
|
239 |
+
# Get func() default arguments
|
240 |
+
signature = inspect.signature(func)
|
241 |
+
return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
|
242 |
|
243 |
|
244 |
def get_latest_run(search_dir='.'):
|
|
|
247 |
return max(last_list, key=os.path.getctime) if last_list else ''
|
248 |
|
249 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
def emojis(str=''):
|
251 |
# Return platform-dependent emoji-safe version of string
|
252 |
return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
|
|
|
295 |
return ''
|
296 |
|
297 |
|
298 |
+
@TryExcept()
|
299 |
@WorkingDirectory(ROOT)
|
300 |
+
def check_git_status(repo='ultralytics/yolov5', branch='master'):
|
301 |
# YOLOv5 status check, recommend 'git pull' if code is out of date
|
302 |
url = f'https://github.com/{repo}'
|
303 |
msg = f', for updates see {url}'
|
|
|
313 |
remote = 'ultralytics'
|
314 |
check_output(f'git remote add {remote} {url}', shell=True)
|
315 |
check_output(f'git fetch {remote}', shell=True, timeout=5) # git fetch
|
316 |
+
local_branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
|
317 |
+
n = int(check_output(f'git rev-list {local_branch}..{remote}/{branch} --count', shell=True)) # commits behind
|
318 |
if n > 0:
|
319 |
+
pull = 'git pull' if remote == 'origin' else f'git pull {remote} {branch}'
|
320 |
s += f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `{pull}` or `git clone {url}` to update."
|
321 |
else:
|
322 |
s += f'up to date with {url} ✅'
|
|
|
332 |
# Check version vs. required version
|
333 |
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
|
334 |
result = (current == minimum) if pinned else (current >= minimum) # bool
|
335 |
+
s = f'WARNING: ⚠️ {name}{minimum} is required by YOLOv5, but {name}{current} is currently installed' # string
|
336 |
if hard:
|
337 |
+
assert result, emojis(s) # assert min requirements met
|
338 |
if verbose and not result:
|
339 |
LOGGER.warning(s)
|
340 |
return result
|
341 |
|
342 |
|
343 |
+
@TryExcept()
|
344 |
def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()):
|
345 |
+
# Check installed dependencies meet YOLOv5 requirements (pass *.txt file or list of packages)
|
346 |
prefix = colorstr('red', 'bold', 'requirements:')
|
347 |
check_python() # check python version
|
348 |
if isinstance(requirements, (str, Path)): # requirements.txt file
|
|
|
453 |
font = Path(font)
|
454 |
file = CONFIG_DIR / font.name
|
455 |
if not font.exists() and not file.exists():
|
456 |
+
url = f'https://ultralytics.com/assets/{font.name}'
|
457 |
LOGGER.info(f'Downloading {url} to {file}...')
|
458 |
torch.hub.download_url_to_file(url, str(file), progress=progress)
|
459 |
|
|
|
474 |
data = yaml.safe_load(f) # dictionary
|
475 |
|
476 |
# Checks
|
477 |
+
for k in 'train', 'val', 'names':
|
478 |
assert k in data, f"data.yaml '{k}:' field missing ❌"
|
479 |
+
if isinstance(data['names'], (list, tuple)): # old array format
|
480 |
+
data['names'] = dict(enumerate(data['names'])) # convert to dict
|
481 |
+
data['nc'] = len(data['names'])
|
482 |
|
483 |
# Resolve paths
|
484 |
path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.'
|
|
|
532 |
|
533 |
prefix = colorstr('AMP: ')
|
534 |
device = next(model.parameters()).device # get model device
|
535 |
+
if device.type in ('cpu', 'mps'):
|
536 |
+
return False # AMP only used on CUDA devices
|
537 |
f = ROOT / 'data' / 'images' / 'bus.jpg' # image to check
|
538 |
im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if check_online() else np.ones((640, 640, 3))
|
539 |
try:
|
|
|
546 |
return False
|
547 |
|
548 |
|
549 |
+
def yaml_load(file='data.yaml'):
|
550 |
+
# Single-line safe yaml loading
|
551 |
+
with open(file, errors='ignore') as f:
|
552 |
+
return yaml.safe_load(f)
|
553 |
+
|
554 |
+
|
555 |
+
def yaml_save(file='data.yaml', data={}):
|
556 |
+
# Single-line safe yaml saving
|
557 |
+
with open(file, 'w') as f:
|
558 |
+
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
|
559 |
+
|
560 |
+
|
561 |
def url2file(url):
|
562 |
# Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
|
563 |
url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
|
|
|
565 |
|
566 |
|
567 |
def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
|
568 |
+
# Multithreaded file download and unzip function, used in data.yaml for autodownload
|
569 |
def download_one(url, dir):
|
570 |
# Download 1 file
|
571 |
success = True
|
|
|
577 |
for i in range(retry + 1):
|
578 |
if curl:
|
579 |
s = 'sS' if threads > 1 else '' # silent
|
580 |
+
r = os.system(
|
581 |
+
f'curl -# -{s}L "{url}" -o "{f}" --retry 9 -C -') # curl download with retry, continue
|
582 |
success = r == 0
|
583 |
else:
|
584 |
torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
|
|
|
590 |
else:
|
591 |
LOGGER.warning(f'Failed to download {url}...')
|
592 |
|
593 |
+
if unzip and success and f.suffix in ('.zip', '.tar', '.gz'):
|
594 |
LOGGER.info(f'Unzipping {f}...')
|
595 |
if f.suffix == '.zip':
|
596 |
ZipFile(f).extractall(path=dir) # unzip
|
597 |
+
elif f.suffix == '.tar':
|
598 |
+
os.system(f'tar xf {f} --directory {f.parent}') # unzip
|
599 |
elif f.suffix == '.gz':
|
600 |
os.system(f'tar xfz {f} --directory {f.parent}') # unzip
|
601 |
if delete:
|
|
|
605 |
dir.mkdir(parents=True, exist_ok=True) # make directory
|
606 |
if threads > 1:
|
607 |
pool = ThreadPool(threads)
|
608 |
+
pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
|
609 |
pool.close()
|
610 |
pool.join()
|
611 |
else:
|
|
|
813 |
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
|
814 |
"""
|
815 |
|
816 |
+
if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
|
817 |
+
prediction = prediction[0] # select only inference output
|
818 |
+
|
819 |
bs = prediction.shape[0] # batch size
|
820 |
nc = prediction.shape[2] - 5 # number of classes
|
821 |
xc = prediction[..., 4] > conf_thres # candidates
|
utils/metrics.py
CHANGED
@@ -11,6 +11,8 @@ import matplotlib.pyplot as plt
|
|
11 |
import numpy as np
|
12 |
import torch
|
13 |
|
|
|
|
|
14 |
|
15 |
def fitness(x):
|
16 |
# Model fitness as a weighted combination of metrics
|
@@ -141,7 +143,7 @@ class ConfusionMatrix:
|
|
141 |
"""
|
142 |
if detections is None:
|
143 |
gt_classes = labels.int()
|
144 |
-
for
|
145 |
self.matrix[self.nc, gc] += 1 # background FN
|
146 |
return
|
147 |
|
@@ -184,36 +186,35 @@ class ConfusionMatrix:
|
|
184 |
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
|
185 |
return tp[:-1], fp[:-1] # remove background class
|
186 |
|
|
|
187 |
def plot(self, normalize=True, save_dir='', names=()):
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
except Exception as e:
|
216 |
-
print(f'WARNING: ConfusionMatrix plot failure: {e}')
|
217 |
|
218 |
def print(self):
|
219 |
for i in range(self.nc + 1):
|
@@ -320,6 +321,7 @@ def wh_iou(wh1, wh2, eps=1e-7):
|
|
320 |
# Plots ----------------------------------------------------------------------------------------------------------------
|
321 |
|
322 |
|
|
|
323 |
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
|
324 |
# Precision-recall curve
|
325 |
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
@@ -336,12 +338,13 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
|
|
336 |
ax.set_ylabel('Precision')
|
337 |
ax.set_xlim(0, 1)
|
338 |
ax.set_ylim(0, 1)
|
339 |
-
|
340 |
-
|
341 |
fig.savefig(save_dir, dpi=250)
|
342 |
-
plt.close()
|
343 |
|
344 |
|
|
|
345 |
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
|
346 |
# Metric-confidence curve
|
347 |
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
@@ -358,7 +361,7 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi
|
|
358 |
ax.set_ylabel(ylabel)
|
359 |
ax.set_xlim(0, 1)
|
360 |
ax.set_ylim(0, 1)
|
361 |
-
|
362 |
-
|
363 |
fig.savefig(save_dir, dpi=250)
|
364 |
-
plt.close()
|
|
|
11 |
import numpy as np
|
12 |
import torch
|
13 |
|
14 |
+
from utils import TryExcept, threaded
|
15 |
+
|
16 |
|
17 |
def fitness(x):
|
18 |
# Model fitness as a weighted combination of metrics
|
|
|
143 |
"""
|
144 |
if detections is None:
|
145 |
gt_classes = labels.int()
|
146 |
+
for gc in gt_classes:
|
147 |
self.matrix[self.nc, gc] += 1 # background FN
|
148 |
return
|
149 |
|
|
|
186 |
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
|
187 |
return tp[:-1], fp[:-1] # remove background class
|
188 |
|
189 |
+
@TryExcept('WARNING: ConfusionMatrix plot failure: ')
|
190 |
def plot(self, normalize=True, save_dir='', names=()):
|
191 |
+
import seaborn as sn
|
192 |
+
|
193 |
+
array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
|
194 |
+
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
|
195 |
+
|
196 |
+
fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
|
197 |
+
nc, nn = self.nc, len(names) # number of classes, names
|
198 |
+
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
|
199 |
+
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
|
200 |
+
with warnings.catch_warnings():
|
201 |
+
warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
|
202 |
+
sn.heatmap(array,
|
203 |
+
ax=ax,
|
204 |
+
annot=nc < 30,
|
205 |
+
annot_kws={
|
206 |
+
"size": 8},
|
207 |
+
cmap='Blues',
|
208 |
+
fmt='.2f',
|
209 |
+
square=True,
|
210 |
+
vmin=0.0,
|
211 |
+
xticklabels=names + ['background FP'] if labels else "auto",
|
212 |
+
yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
|
213 |
+
ax.set_ylabel('True')
|
214 |
+
ax.set_ylabel('Predicted')
|
215 |
+
ax.set_title('Confusion Matrix')
|
216 |
+
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
|
217 |
+
plt.close(fig)
|
|
|
|
|
218 |
|
219 |
def print(self):
|
220 |
for i in range(self.nc + 1):
|
|
|
321 |
# Plots ----------------------------------------------------------------------------------------------------------------
|
322 |
|
323 |
|
324 |
+
@threaded
|
325 |
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
|
326 |
# Precision-recall curve
|
327 |
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
|
|
338 |
ax.set_ylabel('Precision')
|
339 |
ax.set_xlim(0, 1)
|
340 |
ax.set_ylim(0, 1)
|
341 |
+
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
342 |
+
ax.set_title('Precision-Recall Curve')
|
343 |
fig.savefig(save_dir, dpi=250)
|
344 |
+
plt.close(fig)
|
345 |
|
346 |
|
347 |
+
@threaded
|
348 |
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
|
349 |
# Metric-confidence curve
|
350 |
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
|
|
361 |
ax.set_ylabel(ylabel)
|
362 |
ax.set_xlim(0, 1)
|
363 |
ax.set_ylim(0, 1)
|
364 |
+
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
365 |
+
ax.set_title(f'{ylabel}-Confidence Curve')
|
366 |
fig.savefig(save_dir, dpi=250)
|
367 |
+
plt.close(fig)
|
utils/plots.py
CHANGED
@@ -3,6 +3,7 @@
|
|
3 |
Plotting utils
|
4 |
"""
|
5 |
|
|
|
6 |
import math
|
7 |
import os
|
8 |
from copy import copy
|
@@ -18,8 +19,9 @@ import seaborn as sn
|
|
18 |
import torch
|
19 |
from PIL import Image, ImageDraw, ImageFont
|
20 |
|
21 |
-
from utils
|
22 |
-
|
|
|
23 |
from utils.metrics import fitness
|
24 |
|
25 |
# Settings
|
@@ -115,10 +117,12 @@ class Annotator:
|
|
115 |
# Add rectangle to image (PIL-only)
|
116 |
self.draw.rectangle(xy, fill, outline, width)
|
117 |
|
118 |
-
def text(self, xy, text, txt_color=(255, 255, 255)):
|
119 |
# Add text to image (PIL-only)
|
120 |
-
|
121 |
-
|
|
|
|
|
122 |
|
123 |
def result(self):
|
124 |
# Return annotated image as array
|
@@ -180,8 +184,7 @@ def output_to_target(output):
|
|
180 |
# Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
|
181 |
targets = []
|
182 |
for i, o in enumerate(output):
|
183 |
-
for *box, conf, cls in o.cpu().numpy()
|
184 |
-
targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf])
|
185 |
return np.array(targets)
|
186 |
|
187 |
|
@@ -221,7 +224,7 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max
|
|
221 |
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
222 |
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
|
223 |
if paths:
|
224 |
-
annotator.text((x + 5, y + 5
|
225 |
if len(targets) > 0:
|
226 |
ti = targets[targets[:, 0] == i] # image targets
|
227 |
boxes = xywh2xyxy(ti[:, 2:6]).T
|
@@ -339,8 +342,7 @@ def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_
|
|
339 |
plt.savefig(f, dpi=300)
|
340 |
|
341 |
|
342 |
-
@
|
343 |
-
@Timeout(30) # known issue https://github.com/ultralytics/yolov5/issues/5611
|
344 |
def plot_labels(labels, names=(), save_dir=Path('')):
|
345 |
# plot dataset labels
|
346 |
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
|
@@ -357,10 +359,8 @@ def plot_labels(labels, names=(), save_dir=Path('')):
|
|
357 |
matplotlib.use('svg') # faster
|
358 |
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
|
359 |
y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
|
360 |
-
|
361 |
[y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # known issue #3195
|
362 |
-
except Exception:
|
363 |
-
pass
|
364 |
ax[0].set_ylabel('instances')
|
365 |
if 0 < len(names) < 30:
|
366 |
ax[0].set_xticks(range(len(names)))
|
@@ -388,6 +388,35 @@ def plot_labels(labels, names=(), save_dir=Path('')):
|
|
388 |
plt.close()
|
389 |
|
390 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
391 |
def plot_evolve(evolve_csv='path/to/evolve.csv'): # from utils.plots import *; plot_evolve()
|
392 |
# Plot evolve.csv hyp evolution results
|
393 |
evolve_csv = Path(evolve_csv)
|
|
|
3 |
Plotting utils
|
4 |
"""
|
5 |
|
6 |
+
import contextlib
|
7 |
import math
|
8 |
import os
|
9 |
from copy import copy
|
|
|
19 |
import torch
|
20 |
from PIL import Image, ImageDraw, ImageFont
|
21 |
|
22 |
+
from utils import TryExcept, threaded
|
23 |
+
from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_coords, increment_path,
|
24 |
+
is_ascii, xywh2xyxy, xyxy2xywh)
|
25 |
from utils.metrics import fitness
|
26 |
|
27 |
# Settings
|
|
|
117 |
# Add rectangle to image (PIL-only)
|
118 |
self.draw.rectangle(xy, fill, outline, width)
|
119 |
|
120 |
+
def text(self, xy, text, txt_color=(255, 255, 255), anchor='top'):
|
121 |
# Add text to image (PIL-only)
|
122 |
+
if anchor == 'bottom': # start y from font bottom
|
123 |
+
w, h = self.font.getsize(text) # text width, height
|
124 |
+
xy[1] += 1 - h
|
125 |
+
self.draw.text(xy, text, fill=txt_color, font=self.font)
|
126 |
|
127 |
def result(self):
|
128 |
# Return annotated image as array
|
|
|
184 |
# Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
|
185 |
targets = []
|
186 |
for i, o in enumerate(output):
|
187 |
+
targets.extend([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf] for *box, conf, cls in o.cpu().numpy())
|
|
|
188 |
return np.array(targets)
|
189 |
|
190 |
|
|
|
224 |
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
225 |
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
|
226 |
if paths:
|
227 |
+
annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
|
228 |
if len(targets) > 0:
|
229 |
ti = targets[targets[:, 0] == i] # image targets
|
230 |
boxes = xywh2xyxy(ti[:, 2:6]).T
|
|
|
342 |
plt.savefig(f, dpi=300)
|
343 |
|
344 |
|
345 |
+
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
|
|
|
346 |
def plot_labels(labels, names=(), save_dir=Path('')):
|
347 |
# plot dataset labels
|
348 |
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
|
|
|
359 |
matplotlib.use('svg') # faster
|
360 |
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
|
361 |
y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
|
362 |
+
with contextlib.suppress(Exception): # color histogram bars by class
|
363 |
[y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # known issue #3195
|
|
|
|
|
364 |
ax[0].set_ylabel('instances')
|
365 |
if 0 < len(names) < 30:
|
366 |
ax[0].set_xticks(range(len(names)))
|
|
|
388 |
plt.close()
|
389 |
|
390 |
|
391 |
+
def imshow_cls(im, labels=None, pred=None, names=None, nmax=25, verbose=False, f=Path('images.jpg')):
|
392 |
+
# Show classification image grid with labels (optional) and predictions (optional)
|
393 |
+
from utils.augmentations import denormalize
|
394 |
+
|
395 |
+
names = names or [f'class{i}' for i in range(1000)]
|
396 |
+
blocks = torch.chunk(denormalize(im.clone()).cpu().float(), len(im),
|
397 |
+
dim=0) # select batch index 0, block by channels
|
398 |
+
n = min(len(blocks), nmax) # number of plots
|
399 |
+
m = min(8, round(n ** 0.5)) # 8 x 8 default
|
400 |
+
fig, ax = plt.subplots(math.ceil(n / m), m) # 8 rows x n/8 cols
|
401 |
+
ax = ax.ravel() if m > 1 else [ax]
|
402 |
+
# plt.subplots_adjust(wspace=0.05, hspace=0.05)
|
403 |
+
for i in range(n):
|
404 |
+
ax[i].imshow(blocks[i].squeeze().permute((1, 2, 0)).numpy().clip(0.0, 1.0))
|
405 |
+
ax[i].axis('off')
|
406 |
+
if labels is not None:
|
407 |
+
s = names[labels[i]] + (f'—{names[pred[i]]}' if pred is not None else '')
|
408 |
+
ax[i].set_title(s, fontsize=8, verticalalignment='top')
|
409 |
+
plt.savefig(f, dpi=300, bbox_inches='tight')
|
410 |
+
plt.close()
|
411 |
+
if verbose:
|
412 |
+
LOGGER.info(f"Saving {f}")
|
413 |
+
if labels is not None:
|
414 |
+
LOGGER.info('True: ' + ' '.join(f'{names[i]:3s}' for i in labels[:nmax]))
|
415 |
+
if pred is not None:
|
416 |
+
LOGGER.info('Predicted:' + ' '.join(f'{names[i]:3s}' for i in pred[:nmax]))
|
417 |
+
return f
|
418 |
+
|
419 |
+
|
420 |
def plot_evolve(evolve_csv='path/to/evolve.csv'): # from utils.plots import *; plot_evolve()
|
421 |
# Plot evolve.csv hyp evolution results
|
422 |
evolve_csv = Path(evolve_csv)
|
utils/torch_utils.py
CHANGED
@@ -42,6 +42,15 @@ def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
|
|
42 |
return decorate
|
43 |
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
def smart_DDP(model):
|
46 |
# Model DDP creation with checks
|
47 |
assert not check_version(torch.__version__, '1.12.0', pinned=True), \
|
@@ -53,6 +62,28 @@ def smart_DDP(model):
|
|
53 |
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
|
54 |
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
@contextmanager
|
57 |
def torch_distributed_zero_first(local_rank: int):
|
58 |
# Decorator to make all processes in distributed training wait for each local_master to do something
|
@@ -86,7 +117,7 @@ def select_device(device='', batch_size=0, newline=True):
|
|
86 |
assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
|
87 |
f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
|
88 |
|
89 |
-
if not
|
90 |
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
91 |
n = len(devices) # device count
|
92 |
if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
|
@@ -117,14 +148,13 @@ def time_sync():
|
|
117 |
|
118 |
|
119 |
def profile(input, ops, n=10, device=None):
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
results = []
|
129 |
if not isinstance(device, torch.device):
|
130 |
device = select_device(device)
|
@@ -251,7 +281,7 @@ def model_info(model, verbose=False, imgsz=640):
|
|
251 |
try: # FLOPs
|
252 |
p = next(model.parameters())
|
253 |
stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride
|
254 |
-
im = torch.
|
255 |
flops = thop.profile(deepcopy(model), inputs=(im,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
|
256 |
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
|
257 |
fs = f', {flops * imgsz[0] / stride * imgsz[1] / stride:.1f} GFLOPs' # 640x640 GFLOPs
|
@@ -313,6 +343,18 @@ def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
|
|
313 |
return optimizer
|
314 |
|
315 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
def smart_resume(ckpt, optimizer, ema=None, weights='yolov5s.pt', epochs=300, resume=True):
|
317 |
# Resume training from a partially trained checkpoint
|
318 |
best_fitness = 0.0
|
@@ -365,14 +407,11 @@ class ModelEMA:
|
|
365 |
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
366 |
# Create EMA
|
367 |
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
|
368 |
-
# if next(model.parameters()).device.type != 'cpu':
|
369 |
-
# self.ema.half() # FP16 EMA
|
370 |
self.updates = updates # number of EMA updates
|
371 |
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
|
372 |
for p in self.ema.parameters():
|
373 |
p.requires_grad_(False)
|
374 |
|
375 |
-
@smart_inference_mode()
|
376 |
def update(self, model):
|
377 |
# Update EMA parameters
|
378 |
self.updates += 1
|
@@ -380,9 +419,10 @@ class ModelEMA:
|
|
380 |
|
381 |
msd = de_parallel(model).state_dict() # model state_dict
|
382 |
for k, v in self.ema.state_dict().items():
|
383 |
-
if v.dtype.is_floating_point:
|
384 |
v *= d
|
385 |
v += (1 - d) * msd[k].detach()
|
|
|
386 |
|
387 |
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
|
388 |
# Update EMA attributes
|
|
|
42 |
return decorate
|
43 |
|
44 |
|
45 |
+
def smartCrossEntropyLoss(label_smoothing=0.0):
|
46 |
+
# Returns nn.CrossEntropyLoss with label smoothing enabled for torch>=1.10.0
|
47 |
+
if check_version(torch.__version__, '1.10.0'):
|
48 |
+
return nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
49 |
+
if label_smoothing > 0:
|
50 |
+
LOGGER.warning(f'WARNING: label smoothing {label_smoothing} requires torch>=1.10.0')
|
51 |
+
return nn.CrossEntropyLoss()
|
52 |
+
|
53 |
+
|
54 |
def smart_DDP(model):
|
55 |
# Model DDP creation with checks
|
56 |
assert not check_version(torch.__version__, '1.12.0', pinned=True), \
|
|
|
62 |
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
|
63 |
|
64 |
|
65 |
+
def reshape_classifier_output(model, n=1000):
|
66 |
+
# Update a TorchVision classification model to class count 'n' if required
|
67 |
+
from models.common import Classify
|
68 |
+
name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1] # last module
|
69 |
+
if isinstance(m, Classify): # YOLOv5 Classify() head
|
70 |
+
if m.linear.out_features != n:
|
71 |
+
m.linear = nn.Linear(m.linear.in_features, n)
|
72 |
+
elif isinstance(m, nn.Linear): # ResNet, EfficientNet
|
73 |
+
if m.out_features != n:
|
74 |
+
setattr(model, name, nn.Linear(m.in_features, n))
|
75 |
+
elif isinstance(m, nn.Sequential):
|
76 |
+
types = [type(x) for x in m]
|
77 |
+
if nn.Linear in types:
|
78 |
+
i = types.index(nn.Linear) # nn.Linear index
|
79 |
+
if m[i].out_features != n:
|
80 |
+
m[i] = nn.Linear(m[i].in_features, n)
|
81 |
+
elif nn.Conv2d in types:
|
82 |
+
i = types.index(nn.Conv2d) # nn.Conv2d index
|
83 |
+
if m[i].out_channels != n:
|
84 |
+
m[i] = nn.Conv2d(m[i].in_channels, n, m[i].kernel_size, m[i].stride, bias=m[i].bias)
|
85 |
+
|
86 |
+
|
87 |
@contextmanager
|
88 |
def torch_distributed_zero_first(local_rank: int):
|
89 |
# Decorator to make all processes in distributed training wait for each local_master to do something
|
|
|
117 |
assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
|
118 |
f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
|
119 |
|
120 |
+
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
|
121 |
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
122 |
n = len(devices) # device count
|
123 |
if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
|
|
|
148 |
|
149 |
|
150 |
def profile(input, ops, n=10, device=None):
|
151 |
+
""" YOLOv5 speed/memory/FLOPs profiler
|
152 |
+
Usage:
|
153 |
+
input = torch.randn(16, 3, 640, 640)
|
154 |
+
m1 = lambda x: x * torch.sigmoid(x)
|
155 |
+
m2 = nn.SiLU()
|
156 |
+
profile(input, [m1, m2], n=100) # profile over 100 iterations
|
157 |
+
"""
|
|
|
158 |
results = []
|
159 |
if not isinstance(device, torch.device):
|
160 |
device = select_device(device)
|
|
|
281 |
try: # FLOPs
|
282 |
p = next(model.parameters())
|
283 |
stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride
|
284 |
+
im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
|
285 |
flops = thop.profile(deepcopy(model), inputs=(im,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
|
286 |
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
|
287 |
fs = f', {flops * imgsz[0] / stride * imgsz[1] / stride:.1f} GFLOPs' # 640x640 GFLOPs
|
|
|
343 |
return optimizer
|
344 |
|
345 |
|
346 |
+
def smart_hub_load(repo='ultralytics/yolov5', model='yolov5s', **kwargs):
|
347 |
+
# YOLOv5 torch.hub.load() wrapper with smart error/issue handling
|
348 |
+
if check_version(torch.__version__, '1.9.1'):
|
349 |
+
kwargs['skip_validation'] = True # validation causes GitHub API rate limit errors
|
350 |
+
if check_version(torch.__version__, '1.12.0'):
|
351 |
+
kwargs['trust_repo'] = True # argument required starting in torch 0.12
|
352 |
+
try:
|
353 |
+
return torch.hub.load(repo, model, **kwargs)
|
354 |
+
except Exception:
|
355 |
+
return torch.hub.load(repo, model, force_reload=True, **kwargs)
|
356 |
+
|
357 |
+
|
358 |
def smart_resume(ckpt, optimizer, ema=None, weights='yolov5s.pt', epochs=300, resume=True):
|
359 |
# Resume training from a partially trained checkpoint
|
360 |
best_fitness = 0.0
|
|
|
407 |
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
408 |
# Create EMA
|
409 |
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
|
|
|
|
|
410 |
self.updates = updates # number of EMA updates
|
411 |
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
|
412 |
for p in self.ema.parameters():
|
413 |
p.requires_grad_(False)
|
414 |
|
|
|
415 |
def update(self, model):
|
416 |
# Update EMA parameters
|
417 |
self.updates += 1
|
|
|
419 |
|
420 |
msd = de_parallel(model).state_dict() # model state_dict
|
421 |
for k, v in self.ema.state_dict().items():
|
422 |
+
if v.dtype.is_floating_point: # true for FP16 and FP32
|
423 |
v *= d
|
424 |
v += (1 - d) * msd[k].detach()
|
425 |
+
# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
|
426 |
|
427 |
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
|
428 |
# Update EMA attributes
|
val.py
CHANGED
@@ -1,21 +1,21 @@
|
|
1 |
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
2 |
"""
|
3 |
-
Validate a trained YOLOv5 model
|
4 |
|
5 |
Usage:
|
6 |
-
$ python
|
7 |
|
8 |
Usage - formats:
|
9 |
-
$ python
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
"""
|
20 |
|
21 |
import argparse
|
@@ -37,12 +37,12 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
|
37 |
from models.common import DetectMultiBackend
|
38 |
from utils.callbacks import Callbacks
|
39 |
from utils.dataloaders import create_dataloader
|
40 |
-
from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_yaml,
|
41 |
coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, print_args,
|
42 |
scale_coords, xywh2xyxy, xyxy2xywh)
|
43 |
from utils.metrics import ConfusionMatrix, ap_per_class, box_iou
|
44 |
from utils.plots import output_to_target, plot_images, plot_val_study
|
45 |
-
from utils.torch_utils import select_device, smart_inference_mode
|
46 |
|
47 |
|
48 |
def save_one_txt(predn, save_conf, shape, file):
|
@@ -182,40 +182,39 @@ def run(
|
|
182 |
|
183 |
seen = 0
|
184 |
confusion_matrix = ConfusionMatrix(nc=nc)
|
185 |
-
names =
|
|
|
|
|
186 |
class_map = coco80_to_coco91_class() if is_coco else list(range(1000))
|
187 |
-
s = ('%
|
188 |
-
dt, p, r, f1, mp, mr, map50, map =
|
189 |
loss = torch.zeros(3, device=device)
|
190 |
jdict, stats, ap, ap_class = [], [], [], []
|
191 |
callbacks.run('on_val_start')
|
192 |
pbar = tqdm(dataloader, desc=s, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
|
193 |
for batch_i, (im, targets, paths, shapes) in enumerate(pbar):
|
194 |
callbacks.run('on_val_batch_start')
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
t2 = time_sync()
|
203 |
-
dt[0] += t2 - t1
|
204 |
|
205 |
# Inference
|
206 |
-
|
207 |
-
|
208 |
|
209 |
# Loss
|
210 |
if compute_loss:
|
211 |
-
loss += compute_loss(
|
212 |
|
213 |
# NMS
|
214 |
targets[:, 2:] *= torch.tensor((width, height, width, height), device=device) # to pixels
|
215 |
lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
|
216 |
-
|
217 |
-
|
218 |
-
dt[2] += time_sync() - t3
|
219 |
|
220 |
# Metrics
|
221 |
for si, pred in enumerate(out):
|
@@ -271,7 +270,7 @@ def run(
|
|
271 |
nt = np.bincount(stats[3].astype(int), minlength=nc) # number of targets per class
|
272 |
|
273 |
# Print results
|
274 |
-
pf = '%
|
275 |
LOGGER.info(pf % ('all', seen, nt.sum(), mp, mr, map50, map))
|
276 |
if nt.sum() == 0:
|
277 |
LOGGER.warning(f'WARNING: no labels found in {task} set, can not compute metrics without labels ⚠️')
|
@@ -282,7 +281,7 @@ def run(
|
|
282 |
LOGGER.info(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))
|
283 |
|
284 |
# Print speeds
|
285 |
-
t = tuple(x / seen * 1E3 for x in dt) # speeds per image
|
286 |
if not training:
|
287 |
shape = (batch_size, 3, imgsz, imgsz)
|
288 |
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {shape}' % t)
|
@@ -366,6 +365,8 @@ def main(opt):
|
|
366 |
if opt.task in ('train', 'val', 'test'): # run normally
|
367 |
if opt.conf_thres > 0.001: # https://github.com/ultralytics/yolov5/issues/1466
|
368 |
LOGGER.info(f'WARNING: confidence threshold {opt.conf_thres} > 0.001 produces invalid results ⚠️')
|
|
|
|
|
369 |
run(**vars(opt))
|
370 |
|
371 |
else:
|
|
|
1 |
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
2 |
"""
|
3 |
+
Validate a trained YOLOv5 detection model on a detection dataset
|
4 |
|
5 |
Usage:
|
6 |
+
$ python val.py --weights yolov5s.pt --data coco128.yaml --img 640
|
7 |
|
8 |
Usage - formats:
|
9 |
+
$ python val.py --weights yolov5s.pt # PyTorch
|
10 |
+
yolov5s.torchscript # TorchScript
|
11 |
+
yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn
|
12 |
+
yolov5s.xml # OpenVINO
|
13 |
+
yolov5s.engine # TensorRT
|
14 |
+
yolov5s.mlmodel # CoreML (macOS-only)
|
15 |
+
yolov5s_saved_model # TensorFlow SavedModel
|
16 |
+
yolov5s.pb # TensorFlow GraphDef
|
17 |
+
yolov5s.tflite # TensorFlow Lite
|
18 |
+
yolov5s_edgetpu.tflite # TensorFlow Edge TPU
|
19 |
"""
|
20 |
|
21 |
import argparse
|
|
|
37 |
from models.common import DetectMultiBackend
|
38 |
from utils.callbacks import Callbacks
|
39 |
from utils.dataloaders import create_dataloader
|
40 |
+
from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_yaml,
|
41 |
coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, print_args,
|
42 |
scale_coords, xywh2xyxy, xyxy2xywh)
|
43 |
from utils.metrics import ConfusionMatrix, ap_per_class, box_iou
|
44 |
from utils.plots import output_to_target, plot_images, plot_val_study
|
45 |
+
from utils.torch_utils import select_device, smart_inference_mode
|
46 |
|
47 |
|
48 |
def save_one_txt(predn, save_conf, shape, file):
|
|
|
182 |
|
183 |
seen = 0
|
184 |
confusion_matrix = ConfusionMatrix(nc=nc)
|
185 |
+
names = model.names if hasattr(model, 'names') else model.module.names # get class names
|
186 |
+
if isinstance(names, (list, tuple)): # old format
|
187 |
+
names = dict(enumerate(names))
|
188 |
class_map = coco80_to_coco91_class() if is_coco else list(range(1000))
|
189 |
+
s = ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'P', 'R', '[email protected]', '[email protected]:.95')
|
190 |
+
dt, p, r, f1, mp, mr, map50, map = (Profile(), Profile(), Profile()), 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
|
191 |
loss = torch.zeros(3, device=device)
|
192 |
jdict, stats, ap, ap_class = [], [], [], []
|
193 |
callbacks.run('on_val_start')
|
194 |
pbar = tqdm(dataloader, desc=s, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
|
195 |
for batch_i, (im, targets, paths, shapes) in enumerate(pbar):
|
196 |
callbacks.run('on_val_batch_start')
|
197 |
+
with dt[0]:
|
198 |
+
if cuda:
|
199 |
+
im = im.to(device, non_blocking=True)
|
200 |
+
targets = targets.to(device)
|
201 |
+
im = im.half() if half else im.float() # uint8 to fp16/32
|
202 |
+
im /= 255 # 0 - 255 to 0.0 - 1.0
|
203 |
+
nb, _, height, width = im.shape # batch size, channels, height, width
|
|
|
|
|
204 |
|
205 |
# Inference
|
206 |
+
with dt[1]:
|
207 |
+
out, train_out = model(im) if compute_loss else (model(im, augment=augment), None)
|
208 |
|
209 |
# Loss
|
210 |
if compute_loss:
|
211 |
+
loss += compute_loss(train_out, targets)[1] # box, obj, cls
|
212 |
|
213 |
# NMS
|
214 |
targets[:, 2:] *= torch.tensor((width, height, width, height), device=device) # to pixels
|
215 |
lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
|
216 |
+
with dt[2]:
|
217 |
+
out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls)
|
|
|
218 |
|
219 |
# Metrics
|
220 |
for si, pred in enumerate(out):
|
|
|
270 |
nt = np.bincount(stats[3].astype(int), minlength=nc) # number of targets per class
|
271 |
|
272 |
# Print results
|
273 |
+
pf = '%22s' + '%11i' * 2 + '%11.3g' * 4 # print format
|
274 |
LOGGER.info(pf % ('all', seen, nt.sum(), mp, mr, map50, map))
|
275 |
if nt.sum() == 0:
|
276 |
LOGGER.warning(f'WARNING: no labels found in {task} set, can not compute metrics without labels ⚠️')
|
|
|
281 |
LOGGER.info(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))
|
282 |
|
283 |
# Print speeds
|
284 |
+
t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
|
285 |
if not training:
|
286 |
shape = (batch_size, 3, imgsz, imgsz)
|
287 |
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {shape}' % t)
|
|
|
365 |
if opt.task in ('train', 'val', 'test'): # run normally
|
366 |
if opt.conf_thres > 0.001: # https://github.com/ultralytics/yolov5/issues/1466
|
367 |
LOGGER.info(f'WARNING: confidence threshold {opt.conf_thres} > 0.001 produces invalid results ⚠️')
|
368 |
+
if opt.save_hybrid:
|
369 |
+
LOGGER.info('WARNING: --save-hybrid will return high mAP from hybrid labels, not from predictions alone ⚠️')
|
370 |
run(**vars(opt))
|
371 |
|
372 |
else:
|