|
import logging |
|
import os |
|
import random |
|
import sys |
|
from collections import defaultdict |
|
|
|
import hydra |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
from einops.layers.torch import Rearrange |
|
from scipy.io.wavfile import read |
|
from scipy.ndimage import gaussian_filter1d |
|
from torch.utils.data import DataLoader, Dataset |
|
from tqdm import tqdm |
|
|
|
dir_path = os.path.dirname(__file__) |
|
resynth_path = os.path.dirname(dir_path) + "/speech-resynthesis" |
|
sys.path.append(resynth_path) |
|
from dataset import parse_speaker, parse_style |
|
from .utils import F0Stat |
|
|
|
MAX_WAV_VALUE = 32768.0 |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def quantize_f0(speaker_to_f0, nbins, normalize, log): |
|
f0_all = [] |
|
for speaker, f0 in speaker_to_f0.items(): |
|
f0 = f0.raw_data |
|
if log: |
|
f0 = f0.log() |
|
mean = speaker_to_f0[speaker].mean_log if log else speaker_to_f0[speaker].mean |
|
std = speaker_to_f0[speaker].std_log if log else speaker_to_f0[speaker].std |
|
if normalize == "mean": |
|
f0 = f0 - mean |
|
elif normalize == "meanstd": |
|
f0 = (f0 - mean) / std |
|
f0_all.extend(f0.tolist()) |
|
|
|
hist, bin_x = np.histogram(f0_all, 100000) |
|
cum_hist = np.cumsum(hist) / len(f0_all) * 100 |
|
|
|
bin_offset = [] |
|
bin_size = 100 / nbins |
|
threshold = bin_size |
|
for i in range(nbins - 1): |
|
index = (np.abs(cum_hist - threshold)).argmin() |
|
bin_offset.append(bin_x[index]) |
|
threshold += bin_size |
|
bins = np.array(bin_offset) |
|
bins = torch.FloatTensor(bins) |
|
|
|
return bins |
|
|
|
|
|
def save_ckpt(model, path, model_class, f0_min, f0_max, f0_bins, speaker_stats): |
|
ckpt = { |
|
"state_dict": model.state_dict(), |
|
"padding_token": model.padding_token, |
|
"model_class": model_class, |
|
"speaker_stats": speaker_stats, |
|
"f0_min": f0_min, |
|
"f0_max": f0_max, |
|
"f0_bins": f0_bins, |
|
} |
|
torch.save(ckpt, path) |
|
|
|
|
|
def load_ckpt(path): |
|
ckpt = torch.load(path) |
|
ckpt["model_class"]["_target_"] = "emotion_models.pitch_predictor.CnnPredictor" |
|
model = hydra.utils.instantiate(ckpt["model_class"]) |
|
model.load_state_dict(ckpt["state_dict"]) |
|
model.setup_f0_stats( |
|
ckpt["f0_min"], |
|
ckpt["f0_max"], |
|
ckpt["f0_bins"], |
|
ckpt["speaker_stats"], |
|
) |
|
return model |
|
|
|
|
|
def freq2bin(f0, f0_min, f0_max, bins): |
|
f0 = f0.clone() |
|
f0[f0 < f0_min] = f0_min |
|
f0[f0 > f0_max] = f0_max |
|
f0 = torch.bucketize(f0, bins) |
|
return f0 |
|
|
|
|
|
def bin2freq(x, f0_min, f0_max, bins, mode): |
|
n_bins = len(bins) + 1 |
|
assert x.shape[-1] == n_bins |
|
bins = torch.cat([torch.tensor([f0_min]), bins]).to(x.device) |
|
if mode == "mean": |
|
f0 = (x * bins).sum(-1, keepdims=True) / x.sum(-1, keepdims=True) |
|
elif mode == "argmax": |
|
idx = F.one_hot(x.argmax(-1), num_classes=n_bins) |
|
f0 = (idx * bins).sum(-1, keepdims=True) |
|
else: |
|
raise NotImplementedError() |
|
return f0[..., 0] |
|
|
|
|
|
def load_wav(full_path): |
|
sampling_rate, data = read(full_path) |
|
return data, sampling_rate |
|
|
|
|
|
def l1_loss(input, target): |
|
return F.l1_loss(input=input.float(), target=target.float(), reduce=False) |
|
|
|
|
|
def l2_loss(input, target): |
|
return F.mse_loss(input=input.float(), target=target.float(), reduce=False) |
|
|
|
|
|
class Collator: |
|
def __init__(self, padding_idx): |
|
self.padding_idx = padding_idx |
|
|
|
def __call__(self, batch): |
|
tokens = [item[0] for item in batch] |
|
lengths = [len(item) for item in tokens] |
|
tokens = torch.nn.utils.rnn.pad_sequence( |
|
tokens, batch_first=True, padding_value=self.padding_idx |
|
) |
|
f0 = [item[1] for item in batch] |
|
f0 = torch.nn.utils.rnn.pad_sequence( |
|
f0, batch_first=True, padding_value=self.padding_idx |
|
) |
|
f0_raw = [item[2] for item in batch] |
|
f0_raw = torch.nn.utils.rnn.pad_sequence( |
|
f0_raw, batch_first=True, padding_value=self.padding_idx |
|
) |
|
spk = [item[3] for item in batch] |
|
spk = torch.LongTensor(spk) |
|
gst = [item[4] for item in batch] |
|
gst = torch.LongTensor(gst) |
|
mask = tokens != self.padding_idx |
|
return tokens, f0, f0_raw, spk, gst, mask, lengths |
|
|
|
|
|
class CnnPredictor(nn.Module): |
|
def __init__( |
|
self, |
|
n_tokens, |
|
emb_dim, |
|
channels, |
|
kernel, |
|
dropout, |
|
n_layers, |
|
spk_emb, |
|
gst_emb, |
|
n_bins, |
|
f0_pred, |
|
f0_log, |
|
f0_norm, |
|
): |
|
super(CnnPredictor, self).__init__() |
|
self.n_tokens = n_tokens |
|
self.emb_dim = emb_dim |
|
self.f0_log = f0_log |
|
self.f0_pred = f0_pred |
|
self.padding_token = n_tokens |
|
self.f0_norm = f0_norm |
|
|
|
|
|
self.token_emb = nn.Embedding( |
|
n_tokens + 1, emb_dim, padding_idx=self.padding_token |
|
) |
|
|
|
self.spk_emb = spk_emb |
|
self.gst_emb = nn.Embedding(20, gst_emb) |
|
self.setup = False |
|
|
|
feats = emb_dim + gst_emb |
|
|
|
layers = [ |
|
nn.Sequential( |
|
Rearrange("b t c -> b c t"), |
|
nn.Conv1d( |
|
feats, channels, kernel_size=kernel, padding=(kernel - 1) // 2 |
|
), |
|
Rearrange("b c t -> b t c"), |
|
nn.ReLU(), |
|
nn.LayerNorm(channels), |
|
nn.Dropout(dropout), |
|
) |
|
] |
|
for _ in range(n_layers - 1): |
|
layers += [ |
|
nn.Sequential( |
|
Rearrange("b t c -> b c t"), |
|
nn.Conv1d( |
|
channels, |
|
channels, |
|
kernel_size=kernel, |
|
padding=(kernel - 1) // 2, |
|
), |
|
Rearrange("b c t -> b t c"), |
|
nn.ReLU(), |
|
nn.LayerNorm(channels), |
|
nn.Dropout(dropout), |
|
) |
|
] |
|
self.conv_layer = nn.ModuleList(layers) |
|
self.proj = nn.Linear(channels, n_bins) |
|
|
|
def forward(self, x, gst=None): |
|
x = self.token_emb(x) |
|
feats = [x] |
|
|
|
if gst is not None: |
|
gst = self.gst_emb(gst) |
|
gst = rearrange(gst, "b c -> b c 1") |
|
gst = F.interpolate(gst, x.shape[1]) |
|
gst = rearrange(gst, "b c t -> b t c") |
|
feats.append(gst) |
|
|
|
x = torch.cat(feats, dim=-1) |
|
|
|
for i, conv in enumerate(self.conv_layer): |
|
if i != 0: |
|
x = conv(x) + x |
|
else: |
|
x = conv(x) |
|
|
|
x = self.proj(x) |
|
x = x.squeeze(-1) |
|
|
|
if self.f0_pred == "mean": |
|
x = torch.sigmoid(x) |
|
elif self.f0_pred == "argmax": |
|
x = torch.softmax(x, dim=-1) |
|
else: |
|
raise NotImplementedError |
|
return x |
|
|
|
def setup_f0_stats(self, f0_min, f0_max, f0_bins, speaker_stats): |
|
self.f0_min = f0_min |
|
self.f0_max = f0_max |
|
self.f0_bins = f0_bins |
|
self.speaker_stats = speaker_stats |
|
self.setup = True |
|
|
|
def inference(self, x, spk_id=None, gst=None): |
|
assert ( |
|
self.setup == True |
|
), "make sure that `setup_f0_stats` was called before inference!" |
|
probs = self(x, gst) |
|
f0 = bin2freq(probs, self.f0_min, self.f0_max, self.f0_bins, self.f0_pred) |
|
for i in range(f0.shape[0]): |
|
mean = ( |
|
self.speaker_stats[spk_id[i].item()].mean_log |
|
if self.f0_log |
|
else self.speaker_stats[spk_id[i].item()].mean |
|
) |
|
std = ( |
|
self.speaker_stats[spk_id[i].item()].std_log |
|
if self.f0_log |
|
else self.speaker_stats[spk_id[i].item()].std |
|
) |
|
if self.f0_norm == "mean": |
|
f0[i] = f0[i] + mean |
|
if self.f0_norm == "meanstd": |
|
f0[i] = (f0[i] * std) + mean |
|
if self.f0_log: |
|
f0 = f0.exp() |
|
return f0 |
|
|
|
|
|
class PitchDataset(Dataset): |
|
def __init__( |
|
self, |
|
tsv_path, |
|
km_path, |
|
substring, |
|
spk, |
|
spk2id, |
|
gst, |
|
gst2id, |
|
f0_bins, |
|
f0_bin_type, |
|
f0_smoothing, |
|
f0_norm, |
|
f0_log, |
|
): |
|
lines = open(tsv_path, "r").readlines() |
|
self.root, self.tsv = lines[0], lines[1:] |
|
self.root = self.root.strip() |
|
self.km = open(km_path, "r").readlines() |
|
print(f"loaded {len(self.km)} files") |
|
|
|
self.spk = spk |
|
self.spk2id = spk2id |
|
self.gst = gst |
|
self.gst2id = gst2id |
|
|
|
self.f0_bins = f0_bins |
|
self.f0_smoothing = f0_smoothing |
|
self.f0_norm = f0_norm |
|
self.f0_log = f0_log |
|
|
|
if substring != "": |
|
tsv, km = [], [] |
|
for tsv_line, km_line in zip(self.tsv, self.km): |
|
if substring.lower() in tsv_line.lower(): |
|
tsv.append(tsv_line) |
|
km.append(km_line) |
|
self.tsv, self.km = tsv, km |
|
print(f"after filtering: {len(self.km)} files") |
|
|
|
self.speaker_stats = self._compute_f0_stats() |
|
self.f0_min, self.f0_max = self._compute_f0_minmax() |
|
if f0_bin_type == "adaptive": |
|
self.f0_bins = quantize_f0( |
|
self.speaker_stats, self.f0_bins, self.f0_norm, self.f0_log |
|
) |
|
elif f0_bin_type == "uniform": |
|
self.f0_bins = torch.linspace(self.f0_min, self.f0_max, self.f0_bins + 1)[ |
|
1:-1 |
|
] |
|
else: |
|
raise NotImplementedError |
|
print(f"f0 min: {self.f0_min}, f0 max: {self.f0_max}") |
|
print(f"bins: {self.f0_bins} (shape: {self.f0_bins.shape})") |
|
|
|
def __len__(self): |
|
return len(self.km) |
|
|
|
def _load_f0(self, tsv_line): |
|
tsv_line = tsv_line.split("\t")[0] |
|
f0 = self.root + "/" + tsv_line.replace(".wav", ".yaapt.f0.npy") |
|
f0 = np.load(f0) |
|
f0 = torch.FloatTensor(f0) |
|
return f0 |
|
|
|
def _preprocess_f0(self, f0, spk): |
|
mask = f0 != -999999 |
|
|
|
mean = ( |
|
self.speaker_stats[spk].mean_log |
|
if self.f0_log |
|
else self.speaker_stats[spk].mean |
|
) |
|
std = ( |
|
self.speaker_stats[spk].std_log |
|
if self.f0_log |
|
else self.speaker_stats[spk].std |
|
) |
|
if self.f0_log: |
|
f0[f0 == 0] = 1e-5 |
|
f0[mask] = f0[mask].log() |
|
if self.f0_norm == "mean": |
|
f0[mask] = f0[mask] - mean |
|
if self.f0_norm == "meanstd": |
|
f0[mask] = (f0[mask] - mean) / std |
|
return f0 |
|
|
|
def _compute_f0_minmax(self): |
|
f0_min, f0_max = float("inf"), -float("inf") |
|
for tsv_line in tqdm(self.tsv, desc="computing f0 minmax"): |
|
spk = self.spk2id[parse_speaker(tsv_line, self.spk)] |
|
f0 = self._load_f0(tsv_line) |
|
f0 = self._preprocess_f0(f0, spk) |
|
f0_min = min(f0_min, f0.min().item()) |
|
f0_max = max(f0_max, f0.max().item()) |
|
return f0_min, f0_max |
|
|
|
def _compute_f0_stats(self): |
|
from functools import partial |
|
|
|
speaker_stats = defaultdict(partial(F0Stat, True)) |
|
for tsv_line in tqdm(self.tsv, desc="computing speaker stats"): |
|
spk = self.spk2id[parse_speaker(tsv_line, self.spk)] |
|
f0 = self._load_f0(tsv_line) |
|
mask = f0 != 0 |
|
f0 = f0[mask] |
|
speaker_stats[spk].update(f0) |
|
return speaker_stats |
|
|
|
def __getitem__(self, i): |
|
x = self.km[i] |
|
x = x.split(" ") |
|
x = list(map(int, x)) |
|
x = torch.LongTensor(x) |
|
|
|
gst = parse_style(self.tsv[i], self.gst) |
|
gst = self.gst2id[gst] |
|
spk = parse_speaker(self.tsv[i], self.spk) |
|
spk = self.spk2id[spk] |
|
|
|
f0_raw = self._load_f0(self.tsv[i]) |
|
f0 = self._preprocess_f0(f0_raw.clone(), spk) |
|
|
|
f0 = F.interpolate(f0.unsqueeze(0).unsqueeze(0), x.shape[0])[0, 0] |
|
f0_raw = F.interpolate(f0_raw.unsqueeze(0).unsqueeze(0), x.shape[0])[0, 0] |
|
|
|
f0 = freq2bin(f0, f0_min=self.f0_min, f0_max=self.f0_max, bins=self.f0_bins) |
|
f0 = F.one_hot(f0.long(), num_classes=len(self.f0_bins) + 1).float() |
|
if self.f0_smoothing > 0: |
|
f0 = torch.tensor( |
|
gaussian_filter1d(f0.float().numpy(), sigma=self.f0_smoothing) |
|
) |
|
return x, f0, f0_raw, spk, gst |
|
|
|
|
|
def train(cfg): |
|
device = "cuda:0" |
|
|
|
|
|
padding_token = cfg.n_tokens |
|
collate_fn = Collator(padding_idx=padding_token) |
|
train_ds = PitchDataset( |
|
cfg.train_tsv, |
|
cfg.train_km, |
|
substring=cfg.substring, |
|
spk=cfg.spk, |
|
spk2id=cfg.spk2id, |
|
gst=cfg.gst, |
|
gst2id=cfg.gst2id, |
|
f0_bins=cfg.f0_bins, |
|
f0_bin_type=cfg.f0_bin_type, |
|
f0_smoothing=cfg.f0_smoothing, |
|
f0_norm=cfg.f0_norm, |
|
f0_log=cfg.f0_log, |
|
) |
|
valid_ds = PitchDataset( |
|
cfg.valid_tsv, |
|
cfg.valid_km, |
|
substring=cfg.substring, |
|
spk=cfg.spk, |
|
spk2id=cfg.spk2id, |
|
gst=cfg.gst, |
|
gst2id=cfg.gst2id, |
|
f0_bins=cfg.f0_bins, |
|
f0_bin_type=cfg.f0_bin_type, |
|
f0_smoothing=cfg.f0_smoothing, |
|
f0_norm=cfg.f0_norm, |
|
f0_log=cfg.f0_log, |
|
) |
|
train_dl = DataLoader( |
|
train_ds, |
|
num_workers=0, |
|
batch_size=cfg.batch_size, |
|
shuffle=True, |
|
collate_fn=collate_fn, |
|
) |
|
valid_dl = DataLoader( |
|
valid_ds, num_workers=0, batch_size=16, shuffle=False, collate_fn=collate_fn |
|
) |
|
|
|
f0_min = train_ds.f0_min |
|
f0_max = train_ds.f0_max |
|
f0_bins = train_ds.f0_bins |
|
speaker_stats = train_ds.speaker_stats |
|
|
|
model = hydra.utils.instantiate(cfg["model"]).to(device) |
|
model.setup_f0_stats(f0_min, f0_max, f0_bins, speaker_stats) |
|
|
|
optimizer = hydra.utils.instantiate(cfg.optimizer, model.parameters()) |
|
|
|
best_loss = float("inf") |
|
for epoch in range(cfg.epochs): |
|
train_loss, train_l2_loss, train_l2_voiced_loss = run_epoch( |
|
model, train_dl, optimizer, device, cfg, mode="train" |
|
) |
|
valid_loss, valid_l2_loss, valid_l2_voiced_loss = run_epoch( |
|
model, valid_dl, None, device, cfg, mode="valid" |
|
) |
|
print( |
|
f"[epoch {epoch}] train loss: {train_loss:.3f}, l2 loss: {train_l2_loss:.3f}, l2 voiced loss: {train_l2_voiced_loss:.3f}" |
|
) |
|
print( |
|
f"[epoch {epoch}] valid loss: {valid_loss:.3f}, l2 loss: {valid_l2_loss:.3f}, l2 voiced loss: {valid_l2_voiced_loss:.3f}" |
|
) |
|
if valid_l2_voiced_loss < best_loss: |
|
path = f"{os.getcwd()}/pitch_predictor.ckpt" |
|
save_ckpt(model, path, cfg["model"], f0_min, f0_max, f0_bins, speaker_stats) |
|
best_loss = valid_l2_voiced_loss |
|
print(f"saved checkpoint: {path}") |
|
print(f"[epoch {epoch}] best loss: {best_loss:.3f}") |
|
|
|
|
|
def run_epoch(model, loader, optimizer, device, cfg, mode): |
|
if mode == "train": |
|
model.train() |
|
else: |
|
model.eval() |
|
|
|
epoch_loss = 0 |
|
l1 = 0 |
|
l1_voiced = 0 |
|
for x, f0_bin, f0_raw, spk_id, gst, mask, _ in tqdm(loader): |
|
x, f0_bin, f0_raw, spk_id, gst, mask = ( |
|
x.to(device), |
|
f0_bin.to(device), |
|
f0_raw.to(device), |
|
spk_id.to(device), |
|
gst.to(device), |
|
mask.to(device), |
|
) |
|
b, t, n_bins = f0_bin.shape |
|
yhat = model(x, gst) |
|
nonzero_mask = (f0_raw != 0).logical_and(mask) |
|
yhat_raw = model.inference(x, spk_id, gst) |
|
expanded_mask = mask.unsqueeze(-1).expand(-1, -1, n_bins) |
|
if cfg.f0_pred == "mean": |
|
loss = F.binary_cross_entropy( |
|
yhat[expanded_mask], f0_bin[expanded_mask] |
|
).mean() |
|
elif cfg.f0_pred == "argmax": |
|
loss = F.cross_entropy( |
|
rearrange(yhat, "b t d -> (b t) d"), |
|
rearrange(f0_bin.argmax(-1), "b t -> (b t)"), |
|
reduce=False, |
|
) |
|
loss = rearrange(loss, "(b t) -> b t", b=b, t=t) |
|
loss = (loss * mask).sum() / mask.float().sum() |
|
else: |
|
raise NotImplementedError |
|
l1 += F.l1_loss(yhat_raw[mask], f0_raw[mask]).item() |
|
l1_voiced += F.l1_loss(yhat_raw[nonzero_mask], f0_raw[nonzero_mask]).item() |
|
epoch_loss += loss.item() |
|
|
|
if mode == "train": |
|
loss.backward() |
|
nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
optimizer.step() |
|
|
|
print(f"{mode} example y: {f0_bin.argmax(-1)[0, 50:60].tolist()}") |
|
print(f"{mode} example yhat: {yhat.argmax(-1)[0, 50:60].tolist()}") |
|
print(f"{mode} example y: {f0_raw[0, 50:60].round().tolist()}") |
|
print(f"{mode} example yhat: {yhat_raw[0, 50:60].round().tolist()}") |
|
return epoch_loss / len(loader), l1 / len(loader), l1_voiced / len(loader) |
|
|
|
|
|
@hydra.main(config_path=dir_path, config_name="pitch_predictor.yaml") |
|
def main(cfg): |
|
np.random.seed(1) |
|
random.seed(1) |
|
torch.manual_seed(1) |
|
from hydra.core.hydra_config import HydraConfig |
|
|
|
overrides = { |
|
x.split("=")[0]: x.split("=")[1] |
|
for x in HydraConfig.get().overrides.task |
|
if "/" not in x |
|
} |
|
print(f"{cfg}") |
|
train(cfg) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|