aet_demo / utils.py
saeki
fix
7b918f7
raw
history blame
5.07 kB
import librosa.display
import matplotlib.pyplot as plt
import json
import torch
import torchaudio
import hifigan
def manual_logging(logger, item, idx, tag, global_step, data_type, config):
if data_type == "audio":
audio = item[idx, ...].detach().cpu().numpy()
logger.add_audio(
tag,
audio,
global_step,
sample_rate=config["preprocess"]["sampling_rate"],
)
elif data_type == "image":
image = item[idx, ...].detach().cpu().numpy()
fig, ax = plt.subplots()
_ = librosa.display.specshow(
image,
x_axis="time",
y_axis="linear",
sr=config["preprocess"]["sampling_rate"],
hop_length=config["preprocess"]["frame_shift"],
fmax=config["preprocess"]["sampling_rate"] // 2,
ax=ax,
)
logger.add_figure(tag, fig, global_step)
else:
raise NotImplementedError(
"Data type given to logger should be [audio] or [image]"
)
def load_vocoder(config):
with open(
"hifigan/config_{}.json".format(config["general"]["feature_type"]), "r"
) as f:
config_hifigan = hifigan.AttrDict(json.load(f))
vocoder = hifigan.Generator(config_hifigan)
vocoder.load_state_dict(torch.load(config["general"]["hifigan_path"])["generator"])
vocoder.remove_weight_norm()
for param in vocoder.parameters():
param.requires_grad = False
return vocoder
def get_conv_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def plot_and_save_mels(wav, save_path, config):
spec_module = torchaudio.transforms.MelSpectrogram(
sample_rate=config["preprocess"]["sampling_rate"],
n_fft=config["preprocess"]["fft_length"],
win_length=config["preprocess"]["frame_length"],
hop_length=config["preprocess"]["frame_shift"],
f_min=config["preprocess"]["fmin"],
f_max=config["preprocess"]["fmax"],
n_mels=config["preprocess"]["n_mels"],
power=1,
center=True,
norm="slaney",
mel_scale="slaney",
)
spec = spec_module(wav.unsqueeze(0))
log_spec = torch.log(
torch.clamp_min(spec, config["preprocess"]["min_magnitude"])
* config["preprocess"]["comp_factor"]
)
fig, ax = plt.subplots()
_ = librosa.display.specshow(
log_spec.squeeze(0).numpy(),
x_axis="time",
y_axis="linear",
sr=config["preprocess"]["sampling_rate"],
hop_length=config["preprocess"]["frame_shift"],
fmax=config["preprocess"]["sampling_rate"] // 2,
ax=ax,
cmap="viridis",
)
fig.savefig(save_path, bbox_inches="tight", pad_inches=0)
def plot_and_save_mels_all(wavs, keys, save_path, config):
spec_module = torchaudio.transforms.MelSpectrogram(
sample_rate=config["preprocess"]["sampling_rate"],
n_fft=config["preprocess"]["fft_length"],
win_length=config["preprocess"]["frame_length"],
hop_length=config["preprocess"]["frame_shift"],
f_min=config["preprocess"]["fmin"],
f_max=config["preprocess"]["fmax"],
n_mels=config["preprocess"]["n_mels"],
power=1,
center=True,
norm="slaney",
mel_scale="slaney",
)
fig, ax = plt.subplots(nrows=3, ncols=3, figsize=(18, 18))
for i, key in enumerate(keys):
wav = wavs[key][0, ...].cpu()
spec = spec_module(wav.unsqueeze(0))
log_spec = torch.log(
torch.clamp_min(spec, config["preprocess"]["min_magnitude"])
* config["preprocess"]["comp_factor"]
)
ax[i // 3, i % 3].set(title=key)
_ = librosa.display.specshow(
log_spec.squeeze(0).numpy(),
x_axis="time",
y_axis="linear",
sr=config["preprocess"]["sampling_rate"],
hop_length=config["preprocess"]["frame_shift"],
fmax=config["preprocess"]["sampling_rate"] // 2,
ax=ax[i // 3, i % 3],
cmap="viridis",
)
fig.savefig(save_path, bbox_inches="tight", pad_inches=0)
def configure_args(config, args):
for key in ["stage", "corpus_type", "source_path", "aux_path", "preprocessed_path"]:
if getattr(args, key) != None:
config["general"][key] = str(getattr(args, key))
for key in ["n_train", "n_val", "n_test"]:
if getattr(args, key) != None:
config["preprocess"][key] = getattr(args, key)
for key in ["alpha", "beta", "learning_rate", "epoch"]:
if getattr(args, key) != None:
config["train"][key] = getattr(args, key)
for key in ["load_pretrained", "early_stopping"]:
config["train"][key] = getattr(args, key)
if args.feature_loss_type != None:
config["train"]["feature_loss"]["type"] = args.feature_loss_type
for key in ["pretrained_path"]:
if getattr(args, key) != None:
config["train"][key] = str(getattr(args, key))
return config, args