glenn-jocher commited on
Commit
01a67a9
·
2 Parent(s): d3e786e 9006b85

Merge remote-tracking branch 'origin/master'

Browse files
Files changed (4) hide show
  1. models/export.py +2 -2
  2. train.py +6 -7
  3. utils/datasets.py +52 -54
  4. utils/utils.py +4 -6
models/export.py CHANGED
@@ -31,7 +31,7 @@ if __name__ == '__main__':
31
  # TorchScript export
32
  try:
33
  print('\nStarting TorchScript export with torch %s...' % torch.__version__)
34
- f = opt.weights.replace('.pt', '.torchscript') # filename
35
  ts = torch.jit.trace(model, img)
36
  ts.save(f)
37
  print('TorchScript export success, saved as %s' % f)
@@ -62,7 +62,7 @@ if __name__ == '__main__':
62
 
63
  print('\nStarting CoreML export with coremltools %s...' % ct.__version__)
64
  # convert model from torchscript and apply pixel scaling as per detect.py
65
- model = ct.convert(ts, inputs=[ct.ImageType(name='images', shape=img.shape, scale=1/255.0, bias=[0, 0, 0])])
66
  f = opt.weights.replace('.pt', '.mlmodel') # filename
67
  model.save(f)
68
  print('CoreML export success, saved as %s' % f)
 
31
  # TorchScript export
32
  try:
33
  print('\nStarting TorchScript export with torch %s...' % torch.__version__)
34
+ f = opt.weights.replace('.pt', '.torchscript.pt') # filename
35
  ts = torch.jit.trace(model, img)
36
  ts.save(f)
37
  print('TorchScript export success, saved as %s' % f)
 
62
 
63
  print('\nStarting CoreML export with coremltools %s...' % ct.__version__)
64
  # convert model from torchscript and apply pixel scaling as per detect.py
65
+ model = ct.convert(ts, inputs=[ct.ImageType(name='images', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])])
66
  f = opt.weights.replace('.pt', '.mlmodel') # filename
67
  model.save(f)
68
  print('CoreML export success, saved as %s' % f)
train.py CHANGED
@@ -44,7 +44,7 @@ hyp = {'optimizer': 'SGD', # ['adam', 'SGD', None] if none, default is SGD
44
 
45
  def train(hyp):
46
  print(f'Hyperparameters {hyp}')
47
- log_dir = tb_writer.log_dir # run directory
48
  wdir = str(Path(log_dir) / 'weights') + os.sep # weights directory
49
 
50
  os.makedirs(wdir, exist_ok=True)
@@ -387,7 +387,10 @@ if __name__ == '__main__':
387
  opt.weights = last if opt.resume and not opt.weights else opt.weights
388
  opt.cfg = check_file(opt.cfg) # check file
389
  opt.data = check_file(opt.data) # check file
390
- opt.hyp = check_file(opt.hyp) if opt.hyp else '' # check file
 
 
 
391
  print(opt)
392
  opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
393
  device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size)
@@ -396,12 +399,8 @@ if __name__ == '__main__':
396
 
397
  # Train
398
  if not opt.evolve:
399
- print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
400
  tb_writer = SummaryWriter(log_dir=increment_dir('runs/exp', opt.name))
401
- if opt.hyp: # update hyps
402
- with open(opt.hyp) as f:
403
- hyp.update(yaml.load(f, Loader=yaml.FullLoader))
404
-
405
  train(hyp)
406
 
407
  # Evolve hyperparameters (optional)
 
44
 
45
  def train(hyp):
46
  print(f'Hyperparameters {hyp}')
47
+ log_dir = tb_writer.log_dir if tb_writer else 'runs/evolution' # run directory
48
  wdir = str(Path(log_dir) / 'weights') + os.sep # weights directory
49
 
50
  os.makedirs(wdir, exist_ok=True)
 
387
  opt.weights = last if opt.resume and not opt.weights else opt.weights
388
  opt.cfg = check_file(opt.cfg) # check file
389
  opt.data = check_file(opt.data) # check file
390
+ if opt.hyp: # update hyps
391
+ opt.hyp = check_file(opt.hyp) # check file
392
+ with open(opt.hyp) as f:
393
+ hyp.update(yaml.load(f, Loader=yaml.FullLoader)) # update hyps
394
  print(opt)
395
  opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
396
  device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size)
 
399
 
400
  # Train
401
  if not opt.evolve:
 
402
  tb_writer = SummaryWriter(log_dir=increment_dir('runs/exp', opt.name))
403
+ print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
 
 
 
404
  train(hyp)
405
 
406
  # Evolve hyperparameters (optional)
utils/datasets.py CHANGED
@@ -26,6 +26,11 @@ for orientation in ExifTags.TAGS.keys():
26
  break
27
 
28
 
 
 
 
 
 
29
  def exif_size(img):
30
  # Returns exif-corrected PIL size
31
  s = img.size # (width, height)
@@ -280,7 +285,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
280
  def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
281
  cache_images=False, single_cls=False, stride=32, pad=0.0):
282
  try:
283
- f = []
284
  for p in path if isinstance(path, list) else [path]:
285
  p = str(Path(p)) # os-agnostic
286
  parent = str(Path(p).parent) + os.sep
@@ -292,7 +297,6 @@ class LoadImagesAndLabels(Dataset): # for training/testing
292
  f += glob.iglob(p + os.sep + '*.*')
293
  else:
294
  raise Exception('%s does not exist' % p)
295
- path = p # *.npy dir
296
  self.img_files = [x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats]
297
  except Exception as e:
298
  raise Exception('Error loading data from %s: %s\nSee %s' % (path, e, help_url))
@@ -314,20 +318,22 @@ class LoadImagesAndLabels(Dataset): # for training/testing
314
  self.stride = stride
315
 
316
  # Define labels
317
- self.label_files = [x.replace('images', 'labels').replace(os.path.splitext(x)[-1], '.txt')
318
- for x in self.img_files]
319
-
320
- # Read image shapes (wh)
321
- sp = path.replace('.txt', '') + '.shapes' # shapefile path
322
- try:
323
- with open(sp, 'r') as f: # read existing shapefile
324
- s = [x.split() for x in f.read().splitlines()]
325
- assert len(s) == n, 'Shapefile out of sync'
326
- except:
327
- s = [exif_size(Image.open(f)) for f in tqdm(self.img_files, desc='Reading image shapes')]
328
- np.savetxt(sp, s, fmt='%g') # overwrites existing (if any)
329
 
330
- self.shapes = np.array(s, dtype=np.float64)
 
 
 
331
 
332
  # Rectangular Training https://github.com/ultralytics/yolov3/issues/232
333
  if self.rect:
@@ -337,6 +343,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
337
  irect = ar.argsort()
338
  self.img_files = [self.img_files[i] for i in irect]
339
  self.label_files = [self.label_files[i] for i in irect]
 
340
  self.shapes = s[irect] # wh
341
  ar = ar[irect]
342
 
@@ -353,33 +360,11 @@ class LoadImagesAndLabels(Dataset): # for training/testing
353
  self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
354
 
355
  # Cache labels
356
- self.imgs = [None] * n
357
- self.labels = [np.zeros((0, 5), dtype=np.float32)] * n
358
  create_datasubset, extract_bounding_boxes, labels_loaded = False, False, False
359
  nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number missing, found, empty, datasubset, duplicate
360
- np_labels_path = str(Path(self.label_files[0]).parent) + '.npy' # saved labels in *.npy file
361
- if os.path.isfile(np_labels_path):
362
- s = np_labels_path # print string
363
- x = np.load(np_labels_path, allow_pickle=True)
364
- if len(x) == n:
365
- self.labels = x
366
- labels_loaded = True
367
- else:
368
- s = path.replace('images', 'labels')
369
-
370
  pbar = tqdm(self.label_files)
371
  for i, file in enumerate(pbar):
372
- if labels_loaded:
373
- l = self.labels[i]
374
- # np.savetxt(file, l, '%g') # save *.txt from *.npy file
375
- else:
376
- try:
377
- with open(file, 'r') as f:
378
- l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32)
379
- except:
380
- nm += 1 # print('missing labels for image %s' % self.img_files[i]) # file missing
381
- continue
382
-
383
  if l.shape[0]:
384
  assert l.shape[1] == 5, '> 5 label columns: %s' % file
385
  assert (l >= 0).all(), 'negative labels: %s' % file
@@ -425,15 +410,13 @@ class LoadImagesAndLabels(Dataset): # for training/testing
425
  ne += 1 # print('empty labels for image %s' % self.img_files[i]) # file empty
426
  # os.system("rm '%s' '%s'" % (self.img_files[i], self.label_files[i])) # remove
427
 
428
- pbar.desc = 'Caching labels %s (%g found, %g missing, %g empty, %g duplicate, for %g images)' % (
429
- s, nf, nm, ne, nd, n)
430
- assert nf > 0 or n == 20288, 'No labels found in %s. See %s' % (os.path.dirname(file) + os.sep, help_url)
431
- if not labels_loaded and n > 1000:
432
- print('Saving labels to %s for faster future loading' % np_labels_path)
433
- np.save(np_labels_path, self.labels) # save for next time
434
 
435
  # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
436
- if cache_images: # if training
 
437
  gb = 0 # Gigabytes of cached images
438
  pbar = tqdm(range(len(self.img_files)), desc='Caching images')
439
  self.img_hw0, self.img_hw = [None] * n, [None] * n
@@ -442,15 +425,30 @@ class LoadImagesAndLabels(Dataset): # for training/testing
442
  gb += self.imgs[i].nbytes
443
  pbar.desc = 'Caching images (%.1fGB)' % (gb / 1E9)
444
 
445
- # Detect corrupted images https://medium.com/joelthchao/programmatically-detect-corrupted-image-8c1b2006c3d3
446
- detect_corrupted_images = False
447
- if detect_corrupted_images:
448
- from skimage import io # conda install -c conda-forge scikit-image
449
- for file in tqdm(self.img_files, desc='Detecting corrupted images'):
450
- try:
451
- _ = io.imread(file)
452
- except:
453
- print('Corrupted image detected: %s' % file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
 
455
  def __len__(self):
456
  return len(self.img_files)
 
26
  break
27
 
28
 
29
+ def get_hash(files):
30
+ # Returns a single hash value of a list of files
31
+ return sum(os.path.getsize(f) for f in files if os.path.isfile(f))
32
+
33
+
34
  def exif_size(img):
35
  # Returns exif-corrected PIL size
36
  s = img.size # (width, height)
 
285
  def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
286
  cache_images=False, single_cls=False, stride=32, pad=0.0):
287
  try:
288
+ f = [] # image files
289
  for p in path if isinstance(path, list) else [path]:
290
  p = str(Path(p)) # os-agnostic
291
  parent = str(Path(p).parent) + os.sep
 
297
  f += glob.iglob(p + os.sep + '*.*')
298
  else:
299
  raise Exception('%s does not exist' % p)
 
300
  self.img_files = [x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats]
301
  except Exception as e:
302
  raise Exception('Error loading data from %s: %s\nSee %s' % (path, e, help_url))
 
318
  self.stride = stride
319
 
320
  # Define labels
321
+ self.label_files = [x.replace('images', 'labels').replace(os.path.splitext(x)[-1], '.txt') for x in
322
+ self.img_files]
323
+
324
+ # Check cache
325
+ cache_path = str(Path(self.label_files[0]).parent) + '.cache' # cached labels
326
+ if os.path.isfile(cache_path):
327
+ cache = torch.load(cache_path) # load
328
+ if cache['hash'] != get_hash(self.label_files + self.img_files): # dataset changed
329
+ cache = self.cache_labels(cache_path) # re-cache
330
+ else:
331
+ cache = self.cache_labels(cache_path) # cache
 
332
 
333
+ # Get labels
334
+ labels, shapes = zip(*[cache[x] for x in self.img_files])
335
+ self.shapes = np.array(shapes, dtype=np.float64)
336
+ self.labels = list(labels)
337
 
338
  # Rectangular Training https://github.com/ultralytics/yolov3/issues/232
339
  if self.rect:
 
343
  irect = ar.argsort()
344
  self.img_files = [self.img_files[i] for i in irect]
345
  self.label_files = [self.label_files[i] for i in irect]
346
+ self.labels = [self.labels[i] for i in irect]
347
  self.shapes = s[irect] # wh
348
  ar = ar[irect]
349
 
 
360
  self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(np.int) * stride
361
 
362
  # Cache labels
 
 
363
  create_datasubset, extract_bounding_boxes, labels_loaded = False, False, False
364
  nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number missing, found, empty, datasubset, duplicate
 
 
 
 
 
 
 
 
 
 
365
  pbar = tqdm(self.label_files)
366
  for i, file in enumerate(pbar):
367
+ l = self.labels[i] # label
 
 
 
 
 
 
 
 
 
 
368
  if l.shape[0]:
369
  assert l.shape[1] == 5, '> 5 label columns: %s' % file
370
  assert (l >= 0).all(), 'negative labels: %s' % file
 
410
  ne += 1 # print('empty labels for image %s' % self.img_files[i]) # file empty
411
  # os.system("rm '%s' '%s'" % (self.img_files[i], self.label_files[i])) # remove
412
 
413
+ pbar.desc = 'Scanning labels %s (%g found, %g missing, %g empty, %g duplicate, for %g images)' % (
414
+ cache_path, nf, nm, ne, nd, n)
415
+ assert nf > 0, 'No labels found in %s. See %s' % (os.path.dirname(file) + os.sep, help_url)
 
 
 
416
 
417
  # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
418
+ self.imgs = [None] * n
419
+ if cache_images:
420
  gb = 0 # Gigabytes of cached images
421
  pbar = tqdm(range(len(self.img_files)), desc='Caching images')
422
  self.img_hw0, self.img_hw = [None] * n, [None] * n
 
425
  gb += self.imgs[i].nbytes
426
  pbar.desc = 'Caching images (%.1fGB)' % (gb / 1E9)
427
 
428
+ def cache_labels(self, path='labels.cache'):
429
+ # Cache dataset labels, check images and read shapes
430
+ x = {} # dict
431
+ pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files))
432
+ for (img, label) in pbar:
433
+ try:
434
+ l = []
435
+ image = Image.open(img)
436
+ image.verify() # PIL verify
437
+ # _ = io.imread(img) # skimage verify (from skimage import io)
438
+ shape = exif_size(image) # image size
439
+ if os.path.isfile(label):
440
+ with open(label, 'r') as f:
441
+ l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32) # labels
442
+ if len(l) == 0:
443
+ l = np.zeros((0, 5), dtype=np.float32)
444
+ x[img] = [l, shape]
445
+ except Exception as e:
446
+ x[img] = None
447
+ print('WARNING: %s: %s' % (img, e))
448
+
449
+ x['hash'] = get_hash(self.label_files + self.img_files)
450
+ torch.save(x, path) # save for next time
451
+ return x
452
 
453
  def __len__(self):
454
  return len(self.img_files)
utils/utils.py CHANGED
@@ -45,7 +45,7 @@ def get_latest_run(search_dir='./runs'):
45
 
46
  def check_git_status():
47
  # Suggest 'git pull' if repo is out of date
48
- if platform in ['linux', 'darwin']:
49
  s = subprocess.check_output('if [ -d .git ]; then git fetch && git status -uno; fi', shell=True).decode('utf-8')
50
  if 'Your branch is behind' in s:
51
  print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n')
@@ -636,14 +636,12 @@ def strip_optimizer(f='weights/best.pt'): # from utils.utils import *; strip_op
636
  x['optimizer'] = None
637
  x['model'].half() # to FP16
638
  torch.save(x, f)
639
- print('Optimizer stripped from %s' % f)
640
 
641
 
642
  def create_pretrained(f='weights/best.pt', s='weights/pretrained.pt'): # from utils.utils import *; create_pretrained()
643
  # create pretrained checkpoint 's' from 'f' (create_pretrained(x, x) for x in glob.glob('./*.pt'))
644
- device = torch.device('cpu')
645
- x = torch.load(s, map_location=device)
646
-
647
  x['optimizer'] = None
648
  x['training_results'] = None
649
  x['epoch'] = -1
@@ -651,7 +649,7 @@ def create_pretrained(f='weights/best.pt', s='weights/pretrained.pt'): # from u
651
  for p in x['model'].parameters():
652
  p.requires_grad = True
653
  torch.save(x, s)
654
- print('%s saved as pretrained checkpoint %s' % (f, s))
655
 
656
 
657
  def coco_class_count(path='../coco/labels/train2014/'):
 
45
 
46
  def check_git_status():
47
  # Suggest 'git pull' if repo is out of date
48
+ if platform in ['linux', 'darwin'] and not os.path.isfile('/.dockerenv'):
49
  s = subprocess.check_output('if [ -d .git ]; then git fetch && git status -uno; fi', shell=True).decode('utf-8')
50
  if 'Your branch is behind' in s:
51
  print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n')
 
636
  x['optimizer'] = None
637
  x['model'].half() # to FP16
638
  torch.save(x, f)
639
+ print('Optimizer stripped from %s, %.1fMB' % (f, os.path.getsize(f) / 1E6))
640
 
641
 
642
  def create_pretrained(f='weights/best.pt', s='weights/pretrained.pt'): # from utils.utils import *; create_pretrained()
643
  # create pretrained checkpoint 's' from 'f' (create_pretrained(x, x) for x in glob.glob('./*.pt'))
644
+ x = torch.load(f, map_location=torch.device('cpu'))
 
 
645
  x['optimizer'] = None
646
  x['training_results'] = None
647
  x['epoch'] = -1
 
649
  for p in x['model'].parameters():
650
  p.requires_grad = True
651
  torch.save(x, s)
652
+ print('%s saved as pretrained checkpoint %s, %.1fMB' % (f, s, os.path.getsize(s) / 1E6))
653
 
654
 
655
  def coco_class_count(path='../coco/labels/train2014/'):