Commit
·
7765557
1
Parent(s):
4ffd977
update train.py gsutil bucket fix (#463)
Browse files
train.py
CHANGED
@@ -47,11 +47,13 @@ def train(hyp, tb_writer, opt, device):
|
|
47 |
print(f'Hyperparameters {hyp}')
|
48 |
log_dir = tb_writer.log_dir if tb_writer else 'runs/evolution' # run directory
|
49 |
wdir = str(Path(log_dir) / 'weights') + os.sep # weights directory
|
50 |
-
|
51 |
os.makedirs(wdir, exist_ok=True)
|
52 |
last = wdir + 'last.pt'
|
53 |
best = wdir + 'best.pt'
|
54 |
results_file = log_dir + os.sep + 'results.txt'
|
|
|
|
|
|
|
55 |
|
56 |
# Save run settings
|
57 |
with open(Path(log_dir) / 'hyp.yaml', 'w') as f:
|
@@ -59,17 +61,8 @@ def train(hyp, tb_writer, opt, device):
|
|
59 |
with open(Path(log_dir) / 'opt.yaml', 'w') as f:
|
60 |
yaml.dump(vars(opt), f, sort_keys=False)
|
61 |
|
62 |
-
epochs = opt.epochs # 300
|
63 |
-
batch_size = opt.batch_size # batch size per process.
|
64 |
-
total_batch_size = opt.total_batch_size
|
65 |
-
weights = opt.weights # initial training weights
|
66 |
-
local_rank = opt.local_rank
|
67 |
-
|
68 |
-
# TODO: Init DDP logging. Only the first process is allowed to log.
|
69 |
-
# Since I see lots of print here, the logging configuration is skipped here. We may see repeated outputs.
|
70 |
-
|
71 |
# Configure
|
72 |
-
init_seeds(2 +
|
73 |
with open(opt.data) as f:
|
74 |
data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
75 |
train_path = data_dict['train']
|
@@ -78,7 +71,7 @@ def train(hyp, tb_writer, opt, device):
|
|
78 |
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
|
79 |
|
80 |
# Remove previous results
|
81 |
-
if
|
82 |
for f in glob.glob('*_batch*.jpg') + glob.glob(results_file):
|
83 |
os.remove(f)
|
84 |
|
@@ -91,7 +84,7 @@ def train(hyp, tb_writer, opt, device):
|
|
91 |
|
92 |
# Optimizer
|
93 |
nbs = 64 # nominal batch size
|
94 |
-
#
|
95 |
# all-reduce operation is carried out during loss.backward().
|
96 |
# Thus, there would be redundant all-reduce communications in a accumulation procedure,
|
97 |
# which means, the result is still right but the training speed gets slower.
|
@@ -121,8 +114,7 @@ def train(hyp, tb_writer, opt, device):
|
|
121 |
del pg0, pg1, pg2
|
122 |
|
123 |
# Load Model
|
124 |
-
|
125 |
-
with torch_distributed_zero_first(local_rank):
|
126 |
google_utils.attempt_download(weights)
|
127 |
start_epoch, best_fitness = 0, 0.0
|
128 |
if weights.endswith('.pt'): # pytorch format
|
@@ -169,32 +161,31 @@ def train(hyp, tb_writer, opt, device):
|
|
169 |
# plot_lr_scheduler(optimizer, scheduler, epochs)
|
170 |
|
171 |
# DP mode
|
172 |
-
if device.type != 'cpu' and
|
173 |
model = torch.nn.DataParallel(model)
|
174 |
|
175 |
-
#
|
176 |
-
|
177 |
-
# "Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper"
|
178 |
-
# chenyzsjtu: ema should be placed before after SyncBN. As SyncBN introduces new modules.
|
179 |
-
if opt.sync_bn and device.type != 'cpu' and local_rank != -1:
|
180 |
-
print("SyncBN activated!")
|
181 |
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
|
182 |
-
|
|
|
|
|
|
|
183 |
|
184 |
# DDP mode
|
185 |
-
if device.type != 'cpu' and
|
186 |
-
model = DDP(model, device_ids=[
|
187 |
|
188 |
# Trainloader
|
189 |
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True,
|
190 |
-
cache=opt.cache_images, rect=opt.rect, local_rank=
|
191 |
world_size=opt.world_size)
|
192 |
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
|
193 |
nb = len(dataloader) # number of batches
|
194 |
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
|
195 |
|
196 |
# Testloader
|
197 |
-
if
|
198 |
# local_rank is set to -1. Because only the first process is expected to do evaluation.
|
199 |
testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, hyp=hyp, augment=False,
|
200 |
cache=opt.cache_images, rect=True, local_rank=-1, world_size=opt.world_size)[0]
|
@@ -208,8 +199,7 @@ def train(hyp, tb_writer, opt, device):
|
|
208 |
model.names = names
|
209 |
|
210 |
# Class frequency
|
211 |
-
|
212 |
-
if local_rank in [-1, 0]:
|
213 |
labels = np.concatenate(dataset.labels, 0)
|
214 |
c = torch.tensor(labels[:, 0]) # classes
|
215 |
# cf = torch.bincount(c.long(), minlength=nc) + 1.
|
@@ -222,13 +212,14 @@ def train(hyp, tb_writer, opt, device):
|
|
222 |
# Check anchors
|
223 |
if not opt.noautoanchor:
|
224 |
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
|
|
|
225 |
# Start training
|
226 |
t0 = time.time()
|
227 |
nw = max(3 * nb, 1e3) # number of warmup iterations, max(3 epochs, 1k iterations)
|
228 |
maps = np.zeros(nc) # mAP per class
|
229 |
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
|
230 |
scheduler.last_epoch = start_epoch - 1 # do not move
|
231 |
-
if
|
232 |
print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
|
233 |
print('Using %g dataloader workers' % dataloader.num_workers)
|
234 |
print('Starting training for %g epochs...' % epochs)
|
@@ -240,18 +231,18 @@ def train(hyp, tb_writer, opt, device):
|
|
240 |
# When in DDP mode, the generated indices will be broadcasted to synchronize dataset.
|
241 |
if dataset.image_weights:
|
242 |
# Generate indices.
|
243 |
-
if
|
244 |
w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
|
245 |
image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
|
246 |
dataset.indices = random.choices(range(dataset.n), weights=image_weights,
|
247 |
k=dataset.n) # rand weighted idx
|
248 |
# Broadcast.
|
249 |
-
if
|
250 |
indices = torch.zeros([dataset.n], dtype=torch.int)
|
251 |
-
if
|
252 |
indices[:] = torch.from_tensor(dataset.indices, dtype=torch.int)
|
253 |
dist.broadcast(indices, 0)
|
254 |
-
if
|
255 |
dataset.indices = indices.cpu().numpy()
|
256 |
|
257 |
# Update mosaic border
|
@@ -259,10 +250,10 @@ def train(hyp, tb_writer, opt, device):
|
|
259 |
# dataset.mosaic_border = [b - imgsz, -b] # height, width borders
|
260 |
|
261 |
mloss = torch.zeros(4, device=device) # mean losses
|
262 |
-
if
|
263 |
dataloader.sampler.set_epoch(epoch)
|
264 |
pbar = enumerate(dataloader)
|
265 |
-
if
|
266 |
print(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size'))
|
267 |
pbar = tqdm(pbar, total=nb) # progress bar
|
268 |
optimizer.zero_grad()
|
@@ -293,10 +284,9 @@ def train(hyp, tb_writer, opt, device):
|
|
293 |
pred = model(imgs)
|
294 |
|
295 |
# Loss
|
296 |
-
loss, loss_items = compute_loss(pred, targets.to(device), model)
|
297 |
-
|
298 |
-
|
299 |
-
loss *= opt.world_size
|
300 |
if not torch.isfinite(loss):
|
301 |
print('WARNING: non-finite loss, ending training ', loss_items)
|
302 |
return results
|
@@ -316,7 +306,7 @@ def train(hyp, tb_writer, opt, device):
|
|
316 |
ema.update(model)
|
317 |
|
318 |
# Print
|
319 |
-
if
|
320 |
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
|
321 |
mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0) # (GB)
|
322 |
s = ('%10s' * 2 + '%10.4g' * 6) % (
|
@@ -337,7 +327,7 @@ def train(hyp, tb_writer, opt, device):
|
|
337 |
scheduler.step()
|
338 |
|
339 |
# Only the first process in DDP mode is allowed to log or save checkpoints.
|
340 |
-
if
|
341 |
# mAP
|
342 |
if ema is not None:
|
343 |
ema.update_attr(model, include=['md', 'nc', 'hyp', 'gr', 'names', 'stride'])
|
@@ -351,17 +341,17 @@ def train(hyp, tb_writer, opt, device):
|
|
351 |
single_cls=opt.single_cls,
|
352 |
dataloader=testloader,
|
353 |
save_dir=log_dir)
|
354 |
-
|
355 |
# Write
|
356 |
with open(results_file, 'a') as f:
|
357 |
f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP, F1, test_losses=(GIoU, obj, cls)
|
358 |
if len(opt.name) and opt.bucket:
|
359 |
-
os.system('gsutil cp
|
360 |
|
361 |
# Tensorboard
|
362 |
if tb_writer:
|
363 |
tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss',
|
364 |
-
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/
|
365 |
'val/giou_loss', 'val/obj_loss', 'val/cls_loss']
|
366 |
for x, tag in zip(list(mloss[:-1]) + list(results), tags):
|
367 |
tb_writer.add_scalar(tag, x, epoch)
|
@@ -389,7 +379,7 @@ def train(hyp, tb_writer, opt, device):
|
|
389 |
# end epoch ----------------------------------------------------------------------------------------------------
|
390 |
# end training
|
391 |
|
392 |
-
if
|
393 |
# Strip optimizers
|
394 |
n = ('_' if len(opt.name) and not opt.name.isnumeric() else '') + opt.name
|
395 |
fresults, flast, fbest = 'results%s.txt' % n, wdir + 'last%s.pt' % n, wdir + 'best%s.pt' % n
|
@@ -401,10 +391,10 @@ def train(hyp, tb_writer, opt, device):
|
|
401 |
os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket and ispt else None # upload
|
402 |
# Finish
|
403 |
if not opt.evolve:
|
404 |
-
plot_results() # save as results.png
|
405 |
print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
|
406 |
|
407 |
-
dist.destroy_process_group() if
|
408 |
torch.cuda.empty_cache()
|
409 |
return results
|
410 |
|
@@ -431,10 +421,8 @@ if __name__ == '__main__':
|
|
431 |
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
432 |
parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
|
433 |
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
|
434 |
-
parser.add_argument(
|
435 |
-
|
436 |
-
parser.add_argument('--local_rank', type=int, default=-1,
|
437 |
-
help="Extra parameter for DDP implementation. Don't use it manually.")
|
438 |
opt = parser.parse_args()
|
439 |
|
440 |
last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run
|
|
|
47 |
print(f'Hyperparameters {hyp}')
|
48 |
log_dir = tb_writer.log_dir if tb_writer else 'runs/evolution' # run directory
|
49 |
wdir = str(Path(log_dir) / 'weights') + os.sep # weights directory
|
|
|
50 |
os.makedirs(wdir, exist_ok=True)
|
51 |
last = wdir + 'last.pt'
|
52 |
best = wdir + 'best.pt'
|
53 |
results_file = log_dir + os.sep + 'results.txt'
|
54 |
+
epochs, batch_size, total_batch_size, weights, rank = opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.local_rank
|
55 |
+
# TODO: Init DDP logging. Only the first process is allowed to log.
|
56 |
+
# Since I see lots of print here, the logging configuration is skipped here. We may see repeated outputs.
|
57 |
|
58 |
# Save run settings
|
59 |
with open(Path(log_dir) / 'hyp.yaml', 'w') as f:
|
|
|
61 |
with open(Path(log_dir) / 'opt.yaml', 'w') as f:
|
62 |
yaml.dump(vars(opt), f, sort_keys=False)
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
# Configure
|
65 |
+
init_seeds(2 + rank)
|
66 |
with open(opt.data) as f:
|
67 |
data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
68 |
train_path = data_dict['train']
|
|
|
71 |
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
|
72 |
|
73 |
# Remove previous results
|
74 |
+
if rank in [-1, 0]:
|
75 |
for f in glob.glob('*_batch*.jpg') + glob.glob(results_file):
|
76 |
os.remove(f)
|
77 |
|
|
|
84 |
|
85 |
# Optimizer
|
86 |
nbs = 64 # nominal batch size
|
87 |
+
# default DDP implementation is slow for accumulation according to: https://pytorch.org/docs/stable/notes/ddp.html
|
88 |
# all-reduce operation is carried out during loss.backward().
|
89 |
# Thus, there would be redundant all-reduce communications in a accumulation procedure,
|
90 |
# which means, the result is still right but the training speed gets slower.
|
|
|
114 |
del pg0, pg1, pg2
|
115 |
|
116 |
# Load Model
|
117 |
+
with torch_distributed_zero_first(rank):
|
|
|
118 |
google_utils.attempt_download(weights)
|
119 |
start_epoch, best_fitness = 0, 0.0
|
120 |
if weights.endswith('.pt'): # pytorch format
|
|
|
161 |
# plot_lr_scheduler(optimizer, scheduler, epochs)
|
162 |
|
163 |
# DP mode
|
164 |
+
if device.type != 'cpu' and rank == -1 and torch.cuda.device_count() > 1:
|
165 |
model = torch.nn.DataParallel(model)
|
166 |
|
167 |
+
# SyncBatchNorm
|
168 |
+
if opt.sync_bn and device.type != 'cpu' and rank != -1:
|
|
|
|
|
|
|
|
|
169 |
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
|
170 |
+
print('Using SyncBatchNorm()')
|
171 |
+
|
172 |
+
# Exponential moving average
|
173 |
+
ema = torch_utils.ModelEMA(model) if rank in [-1, 0] else None
|
174 |
|
175 |
# DDP mode
|
176 |
+
if device.type != 'cpu' and rank != -1:
|
177 |
+
model = DDP(model, device_ids=[rank], output_device=rank)
|
178 |
|
179 |
# Trainloader
|
180 |
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True,
|
181 |
+
cache=opt.cache_images, rect=opt.rect, local_rank=rank,
|
182 |
world_size=opt.world_size)
|
183 |
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
|
184 |
nb = len(dataloader) # number of batches
|
185 |
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
|
186 |
|
187 |
# Testloader
|
188 |
+
if rank in [-1, 0]:
|
189 |
# local_rank is set to -1. Because only the first process is expected to do evaluation.
|
190 |
testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, hyp=hyp, augment=False,
|
191 |
cache=opt.cache_images, rect=True, local_rank=-1, world_size=opt.world_size)[0]
|
|
|
199 |
model.names = names
|
200 |
|
201 |
# Class frequency
|
202 |
+
if rank in [-1, 0]:
|
|
|
203 |
labels = np.concatenate(dataset.labels, 0)
|
204 |
c = torch.tensor(labels[:, 0]) # classes
|
205 |
# cf = torch.bincount(c.long(), minlength=nc) + 1.
|
|
|
212 |
# Check anchors
|
213 |
if not opt.noautoanchor:
|
214 |
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
|
215 |
+
|
216 |
# Start training
|
217 |
t0 = time.time()
|
218 |
nw = max(3 * nb, 1e3) # number of warmup iterations, max(3 epochs, 1k iterations)
|
219 |
maps = np.zeros(nc) # mAP per class
|
220 |
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
|
221 |
scheduler.last_epoch = start_epoch - 1 # do not move
|
222 |
+
if rank in [0, -1]:
|
223 |
print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
|
224 |
print('Using %g dataloader workers' % dataloader.num_workers)
|
225 |
print('Starting training for %g epochs...' % epochs)
|
|
|
231 |
# When in DDP mode, the generated indices will be broadcasted to synchronize dataset.
|
232 |
if dataset.image_weights:
|
233 |
# Generate indices.
|
234 |
+
if rank in [-1, 0]:
|
235 |
w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
|
236 |
image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
|
237 |
dataset.indices = random.choices(range(dataset.n), weights=image_weights,
|
238 |
k=dataset.n) # rand weighted idx
|
239 |
# Broadcast.
|
240 |
+
if rank != -1:
|
241 |
indices = torch.zeros([dataset.n], dtype=torch.int)
|
242 |
+
if rank == 0:
|
243 |
indices[:] = torch.from_tensor(dataset.indices, dtype=torch.int)
|
244 |
dist.broadcast(indices, 0)
|
245 |
+
if rank != 0:
|
246 |
dataset.indices = indices.cpu().numpy()
|
247 |
|
248 |
# Update mosaic border
|
|
|
250 |
# dataset.mosaic_border = [b - imgsz, -b] # height, width borders
|
251 |
|
252 |
mloss = torch.zeros(4, device=device) # mean losses
|
253 |
+
if rank != -1:
|
254 |
dataloader.sampler.set_epoch(epoch)
|
255 |
pbar = enumerate(dataloader)
|
256 |
+
if rank in [-1, 0]:
|
257 |
print(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls', 'total', 'targets', 'img_size'))
|
258 |
pbar = tqdm(pbar, total=nb) # progress bar
|
259 |
optimizer.zero_grad()
|
|
|
284 |
pred = model(imgs)
|
285 |
|
286 |
# Loss
|
287 |
+
loss, loss_items = compute_loss(pred, targets.to(device), model) # scaled by batch_size
|
288 |
+
if rank != -1:
|
289 |
+
loss *= opt.world_size # gradient averaged between devices in DDP mode
|
|
|
290 |
if not torch.isfinite(loss):
|
291 |
print('WARNING: non-finite loss, ending training ', loss_items)
|
292 |
return results
|
|
|
306 |
ema.update(model)
|
307 |
|
308 |
# Print
|
309 |
+
if rank in [-1, 0]:
|
310 |
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
|
311 |
mem = '%.3gG' % (torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0) # (GB)
|
312 |
s = ('%10s' * 2 + '%10.4g' * 6) % (
|
|
|
327 |
scheduler.step()
|
328 |
|
329 |
# Only the first process in DDP mode is allowed to log or save checkpoints.
|
330 |
+
if rank in [-1, 0]:
|
331 |
# mAP
|
332 |
if ema is not None:
|
333 |
ema.update_attr(model, include=['md', 'nc', 'hyp', 'gr', 'names', 'stride'])
|
|
|
341 |
single_cls=opt.single_cls,
|
342 |
dataloader=testloader,
|
343 |
save_dir=log_dir)
|
344 |
+
|
345 |
# Write
|
346 |
with open(results_file, 'a') as f:
|
347 |
f.write(s + '%10.4g' * 7 % results + '\n') # P, R, mAP, F1, test_losses=(GIoU, obj, cls)
|
348 |
if len(opt.name) and opt.bucket:
|
349 |
+
os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name))
|
350 |
|
351 |
# Tensorboard
|
352 |
if tb_writer:
|
353 |
tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss',
|
354 |
+
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
|
355 |
'val/giou_loss', 'val/obj_loss', 'val/cls_loss']
|
356 |
for x, tag in zip(list(mloss[:-1]) + list(results), tags):
|
357 |
tb_writer.add_scalar(tag, x, epoch)
|
|
|
379 |
# end epoch ----------------------------------------------------------------------------------------------------
|
380 |
# end training
|
381 |
|
382 |
+
if rank in [-1, 0]:
|
383 |
# Strip optimizers
|
384 |
n = ('_' if len(opt.name) and not opt.name.isnumeric() else '') + opt.name
|
385 |
fresults, flast, fbest = 'results%s.txt' % n, wdir + 'last%s.pt' % n, wdir + 'best%s.pt' % n
|
|
|
391 |
os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket and ispt else None # upload
|
392 |
# Finish
|
393 |
if not opt.evolve:
|
394 |
+
plot_results(save_dir=log_dir) # save as results.png
|
395 |
print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
|
396 |
|
397 |
+
dist.destroy_process_group() if rank not in [-1, 0] else None
|
398 |
torch.cuda.empty_cache()
|
399 |
return results
|
400 |
|
|
|
421 |
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
|
422 |
parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
|
423 |
parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset')
|
424 |
+
parser.add_argument('--sync-bn', action="store_true", help='use SyncBatchNorm, only available in DDP mode')
|
425 |
+
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
|
|
|
|
|
426 |
opt = parser.parse_args()
|
427 |
|
428 |
last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run
|