|
from datetime import datetime |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import wandb |
|
|
|
|
|
class WanDBWriter: |
|
def __init__(self, config, logger): |
|
self.writer = None |
|
self.selected_module = "" |
|
|
|
try: |
|
import wandb |
|
wandb.login() |
|
|
|
if config['trainer'].get('wandb_project') is None: |
|
raise ValueError("please specify project name for wandb") |
|
|
|
wandb.init( |
|
project=config['trainer'].get('wandb_project'), |
|
config=config.config |
|
) |
|
self.wandb = wandb |
|
|
|
except ImportError: |
|
logger.warning("For use wandb install it via \n\t pip install wandb") |
|
|
|
self.step = 0 |
|
self.mode = "" |
|
self.timer = datetime.now() |
|
|
|
def set_step(self, step, mode="train"): |
|
self.mode = mode |
|
self.step = step |
|
if step == 0: |
|
self.timer = datetime.now() |
|
else: |
|
duration = datetime.now() - self.timer |
|
self.add_scalar("steps_per_sec", 1 / duration.total_seconds()) |
|
self.timer = datetime.now() |
|
|
|
def _scalar_name(self, scalar_name): |
|
return f"{scalar_name}_{self.mode}" |
|
|
|
def add_scalar(self, scalar_name, scalar): |
|
self.wandb.log({ |
|
self._scalar_name(scalar_name): scalar, |
|
}, step=self.step) |
|
|
|
def add_scalars(self, tag, scalars): |
|
self.wandb.log({ |
|
**{f"{scalar_name}_{tag}_{self.mode}": scalar for scalar_name, scalar in |
|
scalars.items()} |
|
}, step=self.step) |
|
|
|
def add_image(self, scalar_name, image): |
|
self.wandb.log({ |
|
self._scalar_name(scalar_name): self.wandb.Image(image) |
|
}, step=self.step) |
|
|
|
def add_audio(self, scalar_name, audio, sample_rate=None): |
|
audio = audio.detach().cpu().numpy().T |
|
self.wandb.log({ |
|
self._scalar_name(scalar_name): self.wandb.Audio(audio, sample_rate=sample_rate) |
|
}, step=self.step) |
|
|
|
def add_text(self, scalar_name, text): |
|
self.wandb.log({ |
|
self._scalar_name(scalar_name): self.wandb.Html(text) |
|
}, step=self.step) |
|
|
|
def add_histogram(self, scalar_name, hist, bins=None): |
|
hist = hist.detach().cpu().numpy() |
|
np_hist = np.histogram(hist, bins=bins) |
|
if np_hist[0].shape[0] > 512: |
|
np_hist = np.histogram(hist, bins=512) |
|
|
|
hist = self.wandb.Histogram( |
|
np_histogram=np_hist |
|
) |
|
|
|
self.wandb.log({ |
|
self._scalar_name(scalar_name): hist |
|
}, step=self.step) |
|
|
|
def add_table(self, table_name, table: pd.DataFrame): |
|
self.wandb.log({self._scalar_name(table_name): wandb.Table(dataframe=table)}, |
|
step=self.step) |
|
|
|
def add_images(self, scalar_name, images): |
|
raise NotImplementedError() |
|
|
|
def add_pr_curve(self, scalar_name, scalar): |
|
raise NotImplementedError() |
|
|
|
def add_embedding(self, scalar_name, scalar): |
|
raise NotImplementedError() |
|
|