fix compatibility for hyper config (#1146)
Browse files* fix/hyper
* Hyp giou check to train.py
* restore general.py
* train.py overwrite fix
* restore general.py and pep8 update
Co-authored-by: Glenn Jocher <[email protected]>
- train.py +8 -3
- utils/general.py +2 -2
train.py
CHANGED
@@ -5,6 +5,7 @@ import random
|
|
5 |
import shutil
|
6 |
import time
|
7 |
from pathlib import Path
|
|
|
8 |
|
9 |
import math
|
10 |
import numpy as np
|
@@ -430,9 +431,8 @@ if __name__ == '__main__':
|
|
430 |
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
|
431 |
log_dir = increment_dir(Path(opt.logdir) / 'exp', opt.name) # runs/exp1
|
432 |
|
433 |
-
device = select_device(opt.device, batch_size=opt.batch_size)
|
434 |
-
|
435 |
# DDP mode
|
|
|
436 |
if opt.local_rank != -1:
|
437 |
assert torch.cuda.device_count() > opt.local_rank
|
438 |
torch.cuda.set_device(opt.local_rank)
|
@@ -441,11 +441,16 @@ if __name__ == '__main__':
|
|
441 |
assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
|
442 |
opt.batch_size = opt.total_batch_size // opt.world_size
|
443 |
|
444 |
-
|
445 |
with open(opt.hyp) as f:
|
446 |
hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps
|
|
|
|
|
|
|
|
|
447 |
|
448 |
# Train
|
|
|
449 |
if not opt.evolve:
|
450 |
tb_writer = None
|
451 |
if opt.global_rank in [-1, 0]:
|
|
|
5 |
import shutil
|
6 |
import time
|
7 |
from pathlib import Path
|
8 |
+
from warnings import warn
|
9 |
|
10 |
import math
|
11 |
import numpy as np
|
|
|
431 |
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
|
432 |
log_dir = increment_dir(Path(opt.logdir) / 'exp', opt.name) # runs/exp1
|
433 |
|
|
|
|
|
434 |
# DDP mode
|
435 |
+
device = select_device(opt.device, batch_size=opt.batch_size)
|
436 |
if opt.local_rank != -1:
|
437 |
assert torch.cuda.device_count() > opt.local_rank
|
438 |
torch.cuda.set_device(opt.local_rank)
|
|
|
441 |
assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
|
442 |
opt.batch_size = opt.total_batch_size // opt.world_size
|
443 |
|
444 |
+
# Hyperparameters
|
445 |
with open(opt.hyp) as f:
|
446 |
hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps
|
447 |
+
if 'box' not in hyp:
|
448 |
+
warn('Compatibility: %s missing "box" which was renamed from "giou" in %s' %
|
449 |
+
(opt.hyp, 'https://github.com/ultralytics/yolov5/pull/1120'))
|
450 |
+
hyp['box'] = hyp.pop('giou')
|
451 |
|
452 |
# Train
|
453 |
+
logger.info(opt)
|
454 |
if not opt.evolve:
|
455 |
tb_writer = None
|
456 |
if opt.global_rank in [-1, 0]:
|
utils/general.py
CHANGED
@@ -1,18 +1,18 @@
|
|
1 |
import glob
|
2 |
import logging
|
3 |
-
import math
|
4 |
import os
|
5 |
import platform
|
6 |
import random
|
|
|
7 |
import shutil
|
8 |
import subprocess
|
9 |
import time
|
10 |
-
import re
|
11 |
from contextlib import contextmanager
|
12 |
from copy import copy
|
13 |
from pathlib import Path
|
14 |
|
15 |
import cv2
|
|
|
16 |
import matplotlib
|
17 |
import matplotlib.pyplot as plt
|
18 |
import numpy as np
|
|
|
1 |
import glob
|
2 |
import logging
|
|
|
3 |
import os
|
4 |
import platform
|
5 |
import random
|
6 |
+
import re
|
7 |
import shutil
|
8 |
import subprocess
|
9 |
import time
|
|
|
10 |
from contextlib import contextmanager
|
11 |
from copy import copy
|
12 |
from pathlib import Path
|
13 |
|
14 |
import cv2
|
15 |
+
import math
|
16 |
import matplotlib
|
17 |
import matplotlib.pyplot as plt
|
18 |
import numpy as np
|