Commit
·
ce36905
1
Parent(s):
1e84a23
updates
Browse files
test.py
CHANGED
@@ -256,7 +256,7 @@ if __name__ == '__main__':
|
|
256 |
opt.augment)
|
257 |
|
258 |
elif opt.task == 'study': # run over a range of settings and save/plot
|
259 |
-
for weights in ['yolov5s.pt', 'yolov5m.pt', 'yolovl.
|
260 |
f = 'study_%s_%s.txt' % (Path(opt.data).stem, Path(weights).stem) # filename to save to
|
261 |
x = list(range(256, 1024, 32)) # x axis
|
262 |
y = [] # y axis
|
|
|
256 |
opt.augment)
|
257 |
|
258 |
elif opt.task == 'study': # run over a range of settings and save/plot
|
259 |
+
for weights in ['yolov5s.pt', 'yolov5m.pt', 'yolovl.pt', 'yolov5x.pt', 'yolov3-spp.pt']:
|
260 |
f = 'study_%s_%s.txt' % (Path(opt.data).stem, Path(weights).stem) # filename to save to
|
261 |
x = list(range(256, 1024, 32)) # x axis
|
262 |
y = [] # y axis
|
train.py
CHANGED
@@ -108,30 +108,30 @@ def train(hyp):
|
|
108 |
google_utils.attempt_download(weights)
|
109 |
start_epoch, best_fitness = 0, 0.0
|
110 |
if weights.endswith('.pt'): # pytorch format
|
111 |
-
|
112 |
|
113 |
# load model
|
114 |
try:
|
115 |
-
|
116 |
-
{k: v for k, v in
|
117 |
-
model.load_state_dict(
|
118 |
except KeyError as e:
|
119 |
s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s." \
|
120 |
% (opt.weights, opt.cfg, opt.weights)
|
121 |
raise KeyError(s) from e
|
122 |
|
123 |
# load optimizer
|
124 |
-
if
|
125 |
-
optimizer.load_state_dict(
|
126 |
-
best_fitness =
|
127 |
|
128 |
# load results
|
129 |
-
if
|
130 |
with open(results_file, 'w') as file:
|
131 |
-
file.write(
|
132 |
|
133 |
-
start_epoch =
|
134 |
-
del
|
135 |
|
136 |
# Mixed precision training https://github.com/NVIDIA/apex
|
137 |
if mixed_precision:
|
@@ -324,17 +324,17 @@ def train(hyp):
|
|
324 |
save = (not opt.nosave) or (final_epoch and not opt.evolve)
|
325 |
if save:
|
326 |
with open(results_file, 'r') as f: # create checkpoint
|
327 |
-
|
328 |
'best_fitness': best_fitness,
|
329 |
'training_results': f.read(),
|
330 |
'model': ema.ema.module if hasattr(model, 'module') else ema.ema,
|
331 |
'optimizer': None if final_epoch else optimizer.state_dict()}
|
332 |
|
333 |
# Save last, best and delete
|
334 |
-
torch.save(
|
335 |
if (best_fitness == fi) and not final_epoch:
|
336 |
-
torch.save(
|
337 |
-
del
|
338 |
|
339 |
# end epoch ----------------------------------------------------------------------------------------------------
|
340 |
# end training
|
|
|
108 |
google_utils.attempt_download(weights)
|
109 |
start_epoch, best_fitness = 0, 0.0
|
110 |
if weights.endswith('.pt'): # pytorch format
|
111 |
+
ckpt = torch.load(weights, map_location=device) # load checkpoint
|
112 |
|
113 |
# load model
|
114 |
try:
|
115 |
+
ckpt['model'] = \
|
116 |
+
{k: v for k, v in ckpt['model'].state_dict().items() if model.state_dict()[k].numel() == v.numel()}
|
117 |
+
model.load_state_dict(ckpt['model'], strict=False)
|
118 |
except KeyError as e:
|
119 |
s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s." \
|
120 |
% (opt.weights, opt.cfg, opt.weights)
|
121 |
raise KeyError(s) from e
|
122 |
|
123 |
# load optimizer
|
124 |
+
if ckpt['optimizer'] is not None:
|
125 |
+
optimizer.load_state_dict(ckpt['optimizer'])
|
126 |
+
best_fitness = ckpt['best_fitness']
|
127 |
|
128 |
# load results
|
129 |
+
if ckpt.get('training_results') is not None:
|
130 |
with open(results_file, 'w') as file:
|
131 |
+
file.write(ckpt['training_results']) # write results.txt
|
132 |
|
133 |
+
start_epoch = ckpt['epoch'] + 1
|
134 |
+
del ckpt
|
135 |
|
136 |
# Mixed precision training https://github.com/NVIDIA/apex
|
137 |
if mixed_precision:
|
|
|
324 |
save = (not opt.nosave) or (final_epoch and not opt.evolve)
|
325 |
if save:
|
326 |
with open(results_file, 'r') as f: # create checkpoint
|
327 |
+
ckpt = {'epoch': epoch,
|
328 |
'best_fitness': best_fitness,
|
329 |
'training_results': f.read(),
|
330 |
'model': ema.ema.module if hasattr(model, 'module') else ema.ema,
|
331 |
'optimizer': None if final_epoch else optimizer.state_dict()}
|
332 |
|
333 |
# Save last, best and delete
|
334 |
+
torch.save(ckpt, last)
|
335 |
if (best_fitness == fi) and not final_epoch:
|
336 |
+
torch.save(ckpt, best)
|
337 |
+
del ckpt
|
338 |
|
339 |
# end epoch ----------------------------------------------------------------------------------------------------
|
340 |
# end training
|