Last commit not found
from os import path | |
import librosa as rosa | |
import matplotlib | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from pytorch_lightning.loggers import TensorBoardLogger | |
from pytorch_lightning.utilities import rank_zero_only | |
from utils.stft import STFTMag | |
matplotlib.use('Agg') | |
class TensorBoardLoggerExpanded(TensorBoardLogger): | |
def __init__(self, sr=16000): | |
super().__init__(save_dir='lightning_logs', default_hp_metric=False, name='') | |
self.sr = sr | |
self.stftmag = STFTMag() | |
def fig2np(self, fig): | |
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') | |
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
return data | |
def plot_spectrogram_to_numpy(self, y, y_low, y_recon, step): | |
name_list = ['y', 'y_low', 'y_recon'] | |
fig = plt.figure(figsize=(9, 15)) | |
fig.suptitle(f'Epoch_{step}') | |
for i, yy in enumerate([y, y_low, y_recon]): | |
if yy.dim() == 1: | |
yy = self.stftmag(yy) | |
ax = plt.subplot(3, 1, i + 1) | |
ax.set_title(name_list[i]) | |
plt.imshow(rosa.amplitude_to_db(yy.numpy(), | |
ref=np.max, top_db=80.), | |
# vmin = -20, | |
vmax=0., | |
aspect='auto', | |
origin='lower', | |
interpolation='none') | |
plt.colorbar() | |
plt.xlabel('Frames') | |
plt.ylabel('Channels') | |
plt.tight_layout() | |
fig.canvas.draw() | |
data = self.fig2np(fig) | |
plt.close() | |
return data | |
def log_spectrogram(self, y, y_low, y_recon, epoch): | |
y, y_low, y_recon = y.detach().cpu(), y_low.detach().cpu(), y_recon.detach().cpu() | |
spec_img = self.plot_spectrogram_to_numpy(y, y_low, y_recon, epoch) | |
self.experiment.add_image(path.join(self.save_dir, 'result'), | |
spec_img, | |
epoch, | |
dataformats='HWC') | |
self.experiment.flush() | |
return | |
def log_audio(self, y, y_low, y_recon, epoch): | |
y, y_low, y_recon = y.detach().cpu(), y_low.detach().cpu(), y_recon.detach().cpu(), | |
name_list = ['y', 'y_low', 'y_recon'] | |
for n, yy in zip(name_list, [y, y_low, y_recon]): | |
self.experiment.add_audio(n, yy, epoch, self.sr) | |
self.experiment.flush() | |
return | |