Add `@threaded` decorator (#7813)
Browse files* Add `@threaded` decorator
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- train.py +2 -2
- utils/general.py +11 -0
- utils/loggers/__init__.py +3 -4
- utils/plots.py +7 -6
- val.py +2 -5
train.py
CHANGED
@@ -48,8 +48,8 @@ from utils.dataloaders import create_dataloader
|
|
48 |
from utils.downloads import attempt_download
|
49 |
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
|
50 |
check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path,
|
51 |
-
init_seeds, intersect_dicts,
|
52 |
-
|
53 |
from utils.loggers import Loggers
|
54 |
from utils.loggers.wandb.wandb_utils import check_wandb_resume
|
55 |
from utils.loss import ComputeLoss
|
|
|
48 |
from utils.downloads import attempt_download
|
49 |
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
|
50 |
check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path,
|
51 |
+
init_seeds, intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods,
|
52 |
+
one_cycle, print_args, print_mutation, strip_optimizer)
|
53 |
from utils.loggers import Loggers
|
54 |
from utils.loggers.wandb.wandb_utils import check_wandb_resume
|
55 |
from utils.loss import ComputeLoss
|
utils/general.py
CHANGED
@@ -14,6 +14,7 @@ import random
|
|
14 |
import re
|
15 |
import shutil
|
16 |
import signal
|
|
|
17 |
import time
|
18 |
import urllib
|
19 |
from datetime import datetime
|
@@ -167,6 +168,16 @@ def try_except(func):
|
|
167 |
return handler
|
168 |
|
169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
def methods(instance):
|
171 |
# Get class/instance methods
|
172 |
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
|
|
|
14 |
import re
|
15 |
import shutil
|
16 |
import signal
|
17 |
+
import threading
|
18 |
import time
|
19 |
import urllib
|
20 |
from datetime import datetime
|
|
|
168 |
return handler
|
169 |
|
170 |
|
171 |
+
def threaded(func):
|
172 |
+
# Multi-threads a target function and returns thread. Usage: @threaded decorator
|
173 |
+
def wrapper(*args, **kwargs):
|
174 |
+
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
|
175 |
+
thread.start()
|
176 |
+
return thread
|
177 |
+
|
178 |
+
return wrapper
|
179 |
+
|
180 |
+
|
181 |
def methods(instance):
|
182 |
# Get class/instance methods
|
183 |
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
|
utils/loggers/__init__.py
CHANGED
@@ -5,7 +5,6 @@ Logging utils
|
|
5 |
|
6 |
import os
|
7 |
import warnings
|
8 |
-
from threading import Thread
|
9 |
|
10 |
import pkg_resources as pkg
|
11 |
import torch
|
@@ -109,7 +108,7 @@ class Loggers():
|
|
109 |
self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
|
110 |
if ni < 3:
|
111 |
f = self.save_dir / f'train_batch{ni}.jpg' # filename
|
112 |
-
|
113 |
if self.wandb and ni == 10:
|
114 |
files = sorted(self.save_dir.glob('train*.jpg'))
|
115 |
self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})
|
@@ -132,7 +131,7 @@ class Loggers():
|
|
132 |
|
133 |
def on_fit_epoch_end(self, vals, epoch, best_fitness, fi):
|
134 |
# Callback runs at the end of each fit (train+val) epoch
|
135 |
-
x =
|
136 |
if self.csv:
|
137 |
file = self.save_dir / 'results.csv'
|
138 |
n = len(x) + 1 # number of cols
|
@@ -171,7 +170,7 @@ class Loggers():
|
|
171 |
self.tb.add_image(f.stem, cv2.imread(str(f))[..., ::-1], epoch, dataformats='HWC')
|
172 |
|
173 |
if self.wandb:
|
174 |
-
self.wandb.log(
|
175 |
self.wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]})
|
176 |
# Calling wandb.log. TODO: Refactor this into WandbLogger.log_model
|
177 |
if not self.opt.evolve:
|
|
|
5 |
|
6 |
import os
|
7 |
import warnings
|
|
|
8 |
|
9 |
import pkg_resources as pkg
|
10 |
import torch
|
|
|
108 |
self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
|
109 |
if ni < 3:
|
110 |
f = self.save_dir / f'train_batch{ni}.jpg' # filename
|
111 |
+
plot_images(imgs, targets, paths, f)
|
112 |
if self.wandb and ni == 10:
|
113 |
files = sorted(self.save_dir.glob('train*.jpg'))
|
114 |
self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})
|
|
|
131 |
|
132 |
def on_fit_epoch_end(self, vals, epoch, best_fitness, fi):
|
133 |
# Callback runs at the end of each fit (train+val) epoch
|
134 |
+
x = dict(zip(self.keys, vals))
|
135 |
if self.csv:
|
136 |
file = self.save_dir / 'results.csv'
|
137 |
n = len(x) + 1 # number of cols
|
|
|
170 |
self.tb.add_image(f.stem, cv2.imread(str(f))[..., ::-1], epoch, dataformats='HWC')
|
171 |
|
172 |
if self.wandb:
|
173 |
+
self.wandb.log(dict(zip(self.keys[3:10], results)))
|
174 |
self.wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]})
|
175 |
# Calling wandb.log. TODO: Refactor this into WandbLogger.log_model
|
176 |
if not self.opt.evolve:
|
utils/plots.py
CHANGED
@@ -19,7 +19,7 @@ import torch
|
|
19 |
from PIL import Image, ImageDraw, ImageFont
|
20 |
|
21 |
from utils.general import (CONFIG_DIR, FONT, LOGGER, Timeout, check_font, check_requirements, clip_coords,
|
22 |
-
increment_path, is_ascii, try_except, xywh2xyxy, xyxy2xywh)
|
23 |
from utils.metrics import fitness
|
24 |
|
25 |
# Settings
|
@@ -32,9 +32,9 @@ class Colors:
|
|
32 |
# Ultralytics color palette https://ultralytics.com/
|
33 |
def __init__(self):
|
34 |
# hex = matplotlib.colors.TABLEAU_COLORS.values()
|
35 |
-
|
36 |
-
|
37 |
-
self.palette = [self.hex2rgb('#'
|
38 |
self.n = len(self.palette)
|
39 |
|
40 |
def __call__(self, i, bgr=False):
|
@@ -100,7 +100,7 @@ class Annotator:
|
|
100 |
if label:
|
101 |
tf = max(self.lw - 1, 1) # font thickness
|
102 |
w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height
|
103 |
-
outside = p1[1] - h
|
104 |
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
|
105 |
cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
|
106 |
cv2.putText(self.im,
|
@@ -184,6 +184,7 @@ def output_to_target(output):
|
|
184 |
return np.array(targets)
|
185 |
|
186 |
|
|
|
187 |
def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=1920, max_subplots=16):
|
188 |
# Plot image grid with labels
|
189 |
if isinstance(images, torch.Tensor):
|
@@ -420,7 +421,7 @@ def plot_results(file='path/to/results.csv', dir=''):
|
|
420 |
ax = ax.ravel()
|
421 |
files = list(save_dir.glob('results*.csv'))
|
422 |
assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
|
423 |
-
for
|
424 |
try:
|
425 |
data = pd.read_csv(f)
|
426 |
s = [x.strip() for x in data.columns]
|
|
|
19 |
from PIL import Image, ImageDraw, ImageFont
|
20 |
|
21 |
from utils.general import (CONFIG_DIR, FONT, LOGGER, Timeout, check_font, check_requirements, clip_coords,
|
22 |
+
increment_path, is_ascii, threaded, try_except, xywh2xyxy, xyxy2xywh)
|
23 |
from utils.metrics import fitness
|
24 |
|
25 |
# Settings
|
|
|
32 |
# Ultralytics color palette https://ultralytics.com/
|
33 |
def __init__(self):
|
34 |
# hex = matplotlib.colors.TABLEAU_COLORS.values()
|
35 |
+
hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
|
36 |
+
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
|
37 |
+
self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
|
38 |
self.n = len(self.palette)
|
39 |
|
40 |
def __call__(self, i, bgr=False):
|
|
|
100 |
if label:
|
101 |
tf = max(self.lw - 1, 1) # font thickness
|
102 |
w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height
|
103 |
+
outside = p1[1] - h >= 3
|
104 |
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
|
105 |
cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
|
106 |
cv2.putText(self.im,
|
|
|
184 |
return np.array(targets)
|
185 |
|
186 |
|
187 |
+
@threaded
|
188 |
def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=1920, max_subplots=16):
|
189 |
# Plot image grid with labels
|
190 |
if isinstance(images, torch.Tensor):
|
|
|
421 |
ax = ax.ravel()
|
422 |
files = list(save_dir.glob('results*.csv'))
|
423 |
assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
|
424 |
+
for f in files:
|
425 |
try:
|
426 |
data = pd.read_csv(f)
|
427 |
s = [x.strip() for x in data.columns]
|
val.py
CHANGED
@@ -23,7 +23,6 @@ import json
|
|
23 |
import os
|
24 |
import sys
|
25 |
from pathlib import Path
|
26 |
-
from threading import Thread
|
27 |
|
28 |
import numpy as np
|
29 |
import torch
|
@@ -255,10 +254,8 @@ def run(
|
|
255 |
|
256 |
# Plot images
|
257 |
if plots and batch_i < 3:
|
258 |
-
|
259 |
-
|
260 |
-
f = save_dir / f'val_batch{batch_i}_pred.jpg' # predictions
|
261 |
-
Thread(target=plot_images, args=(im, output_to_target(out), paths, f, names), daemon=True).start()
|
262 |
|
263 |
callbacks.run('on_val_batch_end')
|
264 |
|
|
|
23 |
import os
|
24 |
import sys
|
25 |
from pathlib import Path
|
|
|
26 |
|
27 |
import numpy as np
|
28 |
import torch
|
|
|
254 |
|
255 |
# Plot images
|
256 |
if plots and batch_i < 3:
|
257 |
+
plot_images(im, targets, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names) # labels
|
258 |
+
plot_images(im, output_to_target(out), paths, save_dir / f'val_batch{batch_i}_pred.jpg', names) # pred
|
|
|
|
|
259 |
|
260 |
callbacks.run('on_val_batch_end')
|
261 |
|