File size: 5,067 Bytes
7b918f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import librosa.display
import matplotlib.pyplot as plt
import json
import torch
import torchaudio
import hifigan
def manual_logging(logger, item, idx, tag, global_step, data_type, config):
if data_type == "audio":
audio = item[idx, ...].detach().cpu().numpy()
logger.add_audio(
tag,
audio,
global_step,
sample_rate=config["preprocess"]["sampling_rate"],
)
elif data_type == "image":
image = item[idx, ...].detach().cpu().numpy()
fig, ax = plt.subplots()
_ = librosa.display.specshow(
image,
x_axis="time",
y_axis="linear",
sr=config["preprocess"]["sampling_rate"],
hop_length=config["preprocess"]["frame_shift"],
fmax=config["preprocess"]["sampling_rate"] // 2,
ax=ax,
)
logger.add_figure(tag, fig, global_step)
else:
raise NotImplementedError(
"Data type given to logger should be [audio] or [image]"
)
def load_vocoder(config):
with open(
"hifigan/config_{}.json".format(config["general"]["feature_type"]), "r"
) as f:
config_hifigan = hifigan.AttrDict(json.load(f))
vocoder = hifigan.Generator(config_hifigan)
vocoder.load_state_dict(torch.load(config["general"]["hifigan_path"])["generator"])
vocoder.remove_weight_norm()
for param in vocoder.parameters():
param.requires_grad = False
return vocoder
def get_conv_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def plot_and_save_mels(wav, save_path, config):
spec_module = torchaudio.transforms.MelSpectrogram(
sample_rate=config["preprocess"]["sampling_rate"],
n_fft=config["preprocess"]["fft_length"],
win_length=config["preprocess"]["frame_length"],
hop_length=config["preprocess"]["frame_shift"],
f_min=config["preprocess"]["fmin"],
f_max=config["preprocess"]["fmax"],
n_mels=config["preprocess"]["n_mels"],
power=1,
center=True,
norm="slaney",
mel_scale="slaney",
)
spec = spec_module(wav.unsqueeze(0))
log_spec = torch.log(
torch.clamp_min(spec, config["preprocess"]["min_magnitude"])
* config["preprocess"]["comp_factor"]
)
fig, ax = plt.subplots()
_ = librosa.display.specshow(
log_spec.squeeze(0).numpy(),
x_axis="time",
y_axis="linear",
sr=config["preprocess"]["sampling_rate"],
hop_length=config["preprocess"]["frame_shift"],
fmax=config["preprocess"]["sampling_rate"] // 2,
ax=ax,
cmap="viridis",
)
fig.savefig(save_path, bbox_inches="tight", pad_inches=0)
def plot_and_save_mels_all(wavs, keys, save_path, config):
spec_module = torchaudio.transforms.MelSpectrogram(
sample_rate=config["preprocess"]["sampling_rate"],
n_fft=config["preprocess"]["fft_length"],
win_length=config["preprocess"]["frame_length"],
hop_length=config["preprocess"]["frame_shift"],
f_min=config["preprocess"]["fmin"],
f_max=config["preprocess"]["fmax"],
n_mels=config["preprocess"]["n_mels"],
power=1,
center=True,
norm="slaney",
mel_scale="slaney",
)
fig, ax = plt.subplots(nrows=3, ncols=3, figsize=(18, 18))
for i, key in enumerate(keys):
wav = wavs[key][0, ...].cpu()
spec = spec_module(wav.unsqueeze(0))
log_spec = torch.log(
torch.clamp_min(spec, config["preprocess"]["min_magnitude"])
* config["preprocess"]["comp_factor"]
)
ax[i // 3, i % 3].set(title=key)
_ = librosa.display.specshow(
log_spec.squeeze(0).numpy(),
x_axis="time",
y_axis="linear",
sr=config["preprocess"]["sampling_rate"],
hop_length=config["preprocess"]["frame_shift"],
fmax=config["preprocess"]["sampling_rate"] // 2,
ax=ax[i // 3, i % 3],
cmap="viridis",
)
fig.savefig(save_path, bbox_inches="tight", pad_inches=0)
def configure_args(config, args):
for key in ["stage", "corpus_type", "source_path", "aux_path", "preprocessed_path"]:
if getattr(args, key) != None:
config["general"][key] = str(getattr(args, key))
for key in ["n_train", "n_val", "n_test"]:
if getattr(args, key) != None:
config["preprocess"][key] = getattr(args, key)
for key in ["alpha", "beta", "learning_rate", "epoch"]:
if getattr(args, key) != None:
config["train"][key] = getattr(args, key)
for key in ["load_pretrained", "early_stopping"]:
config["train"][key] = getattr(args, key)
if args.feature_loss_type != None:
config["train"]["feature_loss"]["type"] = args.feature_loss_type
for key in ["pretrained_path"]:
if getattr(args, key) != None:
config["train"][key] = str(getattr(args, key))
return config, args
|