add util function to get most recent last.pt file
Browse filesadded logic in train.py __main__ to handle resuming from a run
- train.py +10 -3
- utils/utils.py +6 -0
train.py
CHANGED
@@ -198,10 +198,10 @@ def train(hyp):
|
|
198 |
model.names = data_dict['names']
|
199 |
|
200 |
#save hyperparamter and training options in run folder
|
201 |
-
with open(os.path.join(log_dir, 'hyp.yaml', 'w')
|
202 |
yaml.dump(hyp, f)
|
203 |
|
204 |
-
with open(os.path.join(log_dir, 'opt.yaml', 'w')
|
205 |
yaml.dump(opt, f)
|
206 |
|
207 |
# Class frequency
|
@@ -294,7 +294,7 @@ def train(hyp):
|
|
294 |
|
295 |
# Plot
|
296 |
if ni < 3:
|
297 |
-
f = 'train_batch%g.jpg' % i # filename
|
298 |
res = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
|
299 |
if tb_writer:
|
300 |
tb_writer.add_image(f, res, dataformats='HWC', global_step=epoch)
|
@@ -385,6 +385,7 @@ if __name__ == '__main__':
|
|
385 |
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes')
|
386 |
parser.add_argument('--rect', action='store_true', help='rectangular training')
|
387 |
parser.add_argument('--resume', action='store_true', help='resume training from last.pt')
|
|
|
388 |
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
|
389 |
parser.add_argument('--notest', action='store_true', help='only test final epoch')
|
390 |
parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
|
@@ -398,6 +399,12 @@ if __name__ == '__main__':
|
|
398 |
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
|
399 |
parser.add_argument('--hyp', type=str, default='', help ='path to hyp yaml file')
|
400 |
opt = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
401 |
opt.weights = last if opt.resume else opt.weights
|
402 |
opt.cfg = check_file(opt.cfg) # check file
|
403 |
opt.data = check_file(opt.data) # check file
|
|
|
198 |
model.names = data_dict['names']
|
199 |
|
200 |
#save hyperparamter and training options in run folder
|
201 |
+
with open(os.path.join(log_dir, 'hyp.yaml'), 'w') as f:
|
202 |
yaml.dump(hyp, f)
|
203 |
|
204 |
+
with open(os.path.join(log_dir, 'opt.yaml'), 'w') as f:
|
205 |
yaml.dump(opt, f)
|
206 |
|
207 |
# Class frequency
|
|
|
294 |
|
295 |
# Plot
|
296 |
if ni < 3:
|
297 |
+
f = os.path.join(log_dir, 'train_batch%g.jpg' % i) # filename
|
298 |
res = plot_images(images=imgs, targets=targets, paths=paths, fname=f)
|
299 |
if tb_writer:
|
300 |
tb_writer.add_image(f, res, dataformats='HWC', global_step=epoch)
|
|
|
385 |
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes')
|
386 |
parser.add_argument('--rect', action='store_true', help='rectangular training')
|
387 |
parser.add_argument('--resume', action='store_true', help='resume training from last.pt')
|
388 |
+
parser.add_argument('--resume_from_run', type=str, default='', 'resume training from last.pt in this dir')
|
389 |
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
|
390 |
parser.add_argument('--notest', action='store_true', help='only test final epoch')
|
391 |
parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
|
|
|
399 |
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
|
400 |
parser.add_argument('--hyp', type=str, default='', help ='path to hyp yaml file')
|
401 |
opt = parser.parse_args()
|
402 |
+
|
403 |
+
if opt.resume and not opt.resume_from_run:
|
404 |
+
last = get_latest_run()
|
405 |
+
print(f'WARNING: No run provided to resume from. Resuming from most recent run found at {last}')
|
406 |
+
else:
|
407 |
+
last = opt.resume_from_run
|
408 |
opt.weights = last if opt.resume else opt.weights
|
409 |
opt.cfg = check_file(opt.cfg) # check file
|
410 |
opt.data = check_file(opt.data) # check file
|
utils/utils.py
CHANGED
@@ -36,6 +36,12 @@ def init_seeds(seed=0):
|
|
36 |
np.random.seed(seed)
|
37 |
torch_utils.init_seeds(seed=seed)
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
def check_git_status():
|
41 |
# Suggest 'git pull' if repo is out of date
|
|
|
36 |
np.random.seed(seed)
|
37 |
torch_utils.init_seeds(seed=seed)
|
38 |
|
39 |
+
def get_latest_run(search_dir = './runs/'):
|
40 |
+
# get path to most recent 'last.pt' in run dirs
|
41 |
+
# assumes most recently saved 'last.pt' is the desired weights to --resume from
|
42 |
+
last_list = glob.glob('runs/*/last.pt')
|
43 |
+
latest = max(last_list, key = os.path.getctime)
|
44 |
+
return latest
|
45 |
|
46 |
def check_git_status():
|
47 |
# Suggest 'git pull' if repo is out of date
|