Charles Frye commited on
Commit
19100ba
·
unverified ·
1 Parent(s): dd7f0b7

Improves docs and handling of entities and resuming by WandbLogger (#3264)

Browse files

* adds latest tag to match wandb defaults

* adds entity handling, 'last' tag

* fixes bug causing finished runs to resume

* removes redundant "last" tag for wandb artifact

Files changed (2) hide show
  1. train.py +1 -1
  2. utils/wandb_logging/wandb_utils.py +24 -9
train.py CHANGED
@@ -443,7 +443,7 @@ def train(hyp, opt, device, tb_writer=None):
443
  if wandb_logger.wandb and not opt.evolve: # Log the stripped model
444
  wandb_logger.wandb.log_artifact(str(final), type='model',
445
  name='run_' + wandb_logger.wandb_run.id + '_model',
446
- aliases=['last', 'best', 'stripped'])
447
  wandb_logger.finish_run()
448
  else:
449
  dist.destroy_process_group()
 
443
  if wandb_logger.wandb and not opt.evolve: # Log the stripped model
444
  wandb_logger.wandb.log_artifact(str(final), type='model',
445
  name='run_' + wandb_logger.wandb_run.id + '_model',
446
+ aliases=['latest', 'best', 'stripped'])
447
  wandb_logger.finish_run()
448
  else:
449
  dist.destroy_process_group()
utils/wandb_logging/wandb_utils.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import json
2
  import sys
3
  from pathlib import Path
@@ -35,8 +36,9 @@ def get_run_info(run_path):
35
  run_path = Path(remove_prefix(run_path, WANDB_ARTIFACT_PREFIX))
36
  run_id = run_path.stem
37
  project = run_path.parent.stem
 
38
  model_artifact_name = 'run_' + run_id + '_model'
39
- return run_id, project, model_artifact_name
40
 
41
 
42
  def check_wandb_resume(opt):
@@ -44,9 +46,9 @@ def check_wandb_resume(opt):
44
  if isinstance(opt.resume, str):
45
  if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
46
  if opt.global_rank not in [-1, 0]: # For resuming DDP runs
47
- run_id, project, model_artifact_name = get_run_info(opt.resume)
48
  api = wandb.Api()
49
- artifact = api.artifact(project + '/' + model_artifact_name + ':latest')
50
  modeldir = artifact.download()
51
  opt.weights = str(Path(modeldir) / "last.pt")
52
  return True
@@ -78,6 +80,18 @@ def process_wandb_config_ddp_mode(opt):
78
 
79
 
80
  class WandbLogger():
 
 
 
 
 
 
 
 
 
 
 
 
81
  def __init__(self, opt, name, run_id, data_dict, job_type='Training'):
82
  # Pre-training routine --
83
  self.job_type = job_type
@@ -85,16 +99,17 @@ class WandbLogger():
85
  # It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call
86
  if isinstance(opt.resume, str): # checks resume from artifact
87
  if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
88
- run_id, project, model_artifact_name = get_run_info(opt.resume)
89
  model_artifact_name = WANDB_ARTIFACT_PREFIX + model_artifact_name
90
  assert wandb, 'install wandb to resume wandb runs'
91
  # Resume wandb-artifact:// runs here| workaround for not overwriting wandb.config
92
- self.wandb_run = wandb.init(id=run_id, project=project, resume='allow')
93
  opt.resume = model_artifact_name
94
  elif self.wandb:
95
  self.wandb_run = wandb.init(config=opt,
96
  resume="allow",
97
  project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
 
98
  name=name,
99
  job_type=job_type,
100
  id=run_id) if not wandb.run else wandb.run
@@ -172,8 +187,8 @@ class WandbLogger():
172
  modeldir = model_artifact.download()
173
  epochs_trained = model_artifact.metadata.get('epochs_trained')
174
  total_epochs = model_artifact.metadata.get('total_epochs')
175
- assert epochs_trained < total_epochs, 'training to %g epochs is finished, nothing to resume.' % (
176
- total_epochs)
177
  return modeldir, model_artifact
178
  return None, None
179
 
@@ -188,7 +203,7 @@ class WandbLogger():
188
  })
189
  model_artifact.add_file(str(path / 'last.pt'), name='last.pt')
190
  wandb.log_artifact(model_artifact,
191
- aliases=['latest', 'epoch ' + str(self.current_epoch), 'best' if best_model else ''])
192
  print("Saving model artifact on epoch ", epoch + 1)
193
 
194
  def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False):
@@ -291,7 +306,7 @@ class WandbLogger():
291
  if self.result_artifact:
292
  train_results = wandb.JoinedTable(self.val_table, self.result_table, "id")
293
  self.result_artifact.add(train_results, 'result')
294
- wandb.log_artifact(self.result_artifact, aliases=['latest', 'epoch ' + str(self.current_epoch),
295
  ('best' if best_result else '')])
296
  self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"])
297
  self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
 
1
+ """Utilities and tools for tracking runs with Weights & Biases."""
2
  import json
3
  import sys
4
  from pathlib import Path
 
36
  run_path = Path(remove_prefix(run_path, WANDB_ARTIFACT_PREFIX))
37
  run_id = run_path.stem
38
  project = run_path.parent.stem
39
+ entity = run_path.parent.parent.stem
40
  model_artifact_name = 'run_' + run_id + '_model'
41
+ return entity, project, run_id, model_artifact_name
42
 
43
 
44
  def check_wandb_resume(opt):
 
46
  if isinstance(opt.resume, str):
47
  if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
48
  if opt.global_rank not in [-1, 0]: # For resuming DDP runs
49
+ entity, project, run_id, model_artifact_name = get_run_info(opt.resume)
50
  api = wandb.Api()
51
+ artifact = api.artifact(entity + '/' + project + '/' + model_artifact_name + ':latest')
52
  modeldir = artifact.download()
53
  opt.weights = str(Path(modeldir) / "last.pt")
54
  return True
 
80
 
81
 
82
  class WandbLogger():
83
+ """Log training runs, datasets, models, and predictions to Weights & Biases.
84
+
85
+ This logger sends information to W&B at wandb.ai. By default, this information
86
+ includes hyperparameters, system configuration and metrics, model metrics,
87
+ and basic data metrics and analyses.
88
+
89
+ By providing additional command line arguments to train.py, datasets,
90
+ models and predictions can also be logged.
91
+
92
+ For more on how this logger is used, see the Weights & Biases documentation:
93
+ https://docs.wandb.com/guides/integrations/yolov5
94
+ """
95
  def __init__(self, opt, name, run_id, data_dict, job_type='Training'):
96
  # Pre-training routine --
97
  self.job_type = job_type
 
99
  # It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call
100
  if isinstance(opt.resume, str): # checks resume from artifact
101
  if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
102
+ entity, project, run_id, model_artifact_name = get_run_info(opt.resume)
103
  model_artifact_name = WANDB_ARTIFACT_PREFIX + model_artifact_name
104
  assert wandb, 'install wandb to resume wandb runs'
105
  # Resume wandb-artifact:// runs here| workaround for not overwriting wandb.config
106
+ self.wandb_run = wandb.init(id=run_id, project=project, entity=entity, resume='allow')
107
  opt.resume = model_artifact_name
108
  elif self.wandb:
109
  self.wandb_run = wandb.init(config=opt,
110
  resume="allow",
111
  project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
112
+ entity=opt.entity,
113
  name=name,
114
  job_type=job_type,
115
  id=run_id) if not wandb.run else wandb.run
 
187
  modeldir = model_artifact.download()
188
  epochs_trained = model_artifact.metadata.get('epochs_trained')
189
  total_epochs = model_artifact.metadata.get('total_epochs')
190
+ is_finished = total_epochs is None
191
+ assert not is_finished, 'training is finished, can only resume incomplete runs.'
192
  return modeldir, model_artifact
193
  return None, None
194
 
 
203
  })
204
  model_artifact.add_file(str(path / 'last.pt'), name='last.pt')
205
  wandb.log_artifact(model_artifact,
206
+ aliases=['latest', 'last', 'epoch ' + str(self.current_epoch), 'best' if best_model else ''])
207
  print("Saving model artifact on epoch ", epoch + 1)
208
 
209
  def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=False):
 
306
  if self.result_artifact:
307
  train_results = wandb.JoinedTable(self.val_table, self.result_table, "id")
308
  self.result_artifact.add(train_results, 'result')
309
+ wandb.log_artifact(self.result_artifact, aliases=['latest', 'last', 'epoch ' + str(self.current_epoch),
310
  ('best' if best_result else '')])
311
  self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"])
312
  self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")