franzi2505 commited on
Commit
c385f38
1 Parent(s): 4b4713e

add optional parameters to the wandb logging functionality

Browse files
Files changed (1) hide show
  1. det-metrics.py +18 -5
det-metrics.py CHANGED
@@ -341,12 +341,15 @@ class DetectionMetric(evaluate.Metric):
341
  )
342
  return fig
343
 
344
- def wandb(self, results , wandb_project='detection_metrics'):
345
  """
346
  Logs metrics to Weights and Biases (wandb) for tracking and visualization.
347
 
348
  This function logs the provided metrics to Weights and Biases (wandb), a platform for tracking machine learning experiments.
349
  Each key in the `results` dictionary represents a separate run and the corresponding value contains the metrics for that run.
 
 
 
350
  The function logs in to wandb using an API key obtained from the secret 'WANDB_API_KEY', initializes a run for
351
  each key in `results` and logs the metrics.
352
 
@@ -357,6 +360,9 @@ class DetectionMetric(evaluate.Metric):
357
  "run1": {"metrics": {"accuracy": 0.9, "loss": 0.1}},
358
  "run2": {"metrics": {"accuracy": 0.85, "loss": 0.15}}
359
  }
 
 
 
360
  wandb_project (str, optional): The name of the wandb project to which the runs will be logged. Defaults to 'detection_metrics'.
361
 
362
  Environment Variables:
@@ -375,10 +381,17 @@ class DetectionMetric(evaluate.Metric):
375
  formatted_datetime = current_datetime.strftime("%Y-%m-%d_%H-%M-%S")
376
  wandb.login(key=os.getenv('WANDB_API_KEY'))
377
 
378
- for k in results.keys():
379
- run = wandb.init(project=wandb_project, name=f"{k}-{formatted_datetime}")
380
- run.log(results[k]['metrics'])
381
- run.finish()
 
 
 
 
 
 
 
382
 
383
  def _generate_sample_data(self):
384
  """
 
341
  )
342
  return fig
343
 
344
+ def wandb(self, results , wandb_runs: list = None, wandb_section: str = None, wandb_project='detection_metrics'):
345
  """
346
  Logs metrics to Weights and Biases (wandb) for tracking and visualization.
347
 
348
  This function logs the provided metrics to Weights and Biases (wandb), a platform for tracking machine learning experiments.
349
  Each key in the `results` dictionary represents a separate run and the corresponding value contains the metrics for that run.
350
+ If a W&B run list is provided, the results of the runs will be added to the passed W&B runs. Otherwise new W&B runs will be created.
351
+ If a W&B section ist provided, the metrics will be logged in this section drop-down. Otherwise no extra W&B section is created
352
+ and the metrics are logged directly.
353
  The function logs in to wandb using an API key obtained from the secret 'WANDB_API_KEY', initializes a run for
354
  each key in `results` and logs the metrics.
355
 
 
360
  "run1": {"metrics": {"accuracy": 0.9, "loss": 0.1}},
361
  "run2": {"metrics": {"accuracy": 0.85, "loss": 0.15}}
362
  }
363
+ wandb_runs (list, optional): A list containing W&B runs where the results should be added
364
+ (e.g. the first item in results will be added to the first run in wandb_runs, etc.)
365
+ wandb_section (str, optional): A string to specify the W&B
366
  wandb_project (str, optional): The name of the wandb project to which the runs will be logged. Defaults to 'detection_metrics'.
367
 
368
  Environment Variables:
 
381
  formatted_datetime = current_datetime.strftime("%Y-%m-%d_%H-%M-%S")
382
  wandb.login(key=os.getenv('WANDB_API_KEY'))
383
 
384
+ if not wandb_runs is None:
385
+ assert len(wandb_runs) == len(results), "runs and results must have the same length"
386
+
387
+ for i, k in enumerate(results.keys()):
388
+ if wandb_runs is None:
389
+ run = wandb.init(project=wandb_project, name=f"{k}-{formatted_datetime}")
390
+ else:
391
+ run = wandb_runs[i]
392
+ run.log({f"{wandb_section}/{m}" : v for m, v in results[k]['metrics'].items()} if wandb_section is not None else results[k]['metrics'])
393
+ if wandb_runs is None:
394
+ run.finish()
395
 
396
  def _generate_sample_data(self):
397
  """